290 lines
10 KiB
Python
290 lines
10 KiB
Python
import json
|
||
import os
|
||
import logging
|
||
import sqlite3
|
||
import os
|
||
import time
|
||
import toml
|
||
from pathlib import Path
|
||
from http import HTTPStatus
|
||
from datetime import datetime
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler("chat_app.log"),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def configure_wal(cursor, cache_size=-50000, busy_timeout=5000):
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
# 1. 开启WAL模式
|
||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||
mode = cursor.fetchone()[0].lower()
|
||
|
||
# 2. 核心配置项
|
||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||
cursor.execute(f"PRAGMA busy_timeout={busy_timeout};")
|
||
cursor.execute(f"PRAGMA cache_size={cache_size};")
|
||
cursor.execute(f"PRAGMA wal_autocheckpoint=1000;")
|
||
|
||
# 3. 其他优化配置
|
||
cursor.execute("PRAGMA journal_size_limit=32768;") # 32MB WAL文件限制
|
||
cursor.execute("PRAGMA mmap_size=268435456;") # 256MB内存映射
|
||
|
||
if mode == "wal":
|
||
logger.debug(f"WAL配置成功: cache_size={cache_size} busy_timeout={busy_timeout}ms")
|
||
else:
|
||
logger.warning(f"WAL配置部分成功,当前模式: {mode}")
|
||
except sqlite3.Error as e:
|
||
logger.error(f"WAL配置错误: {str(e)}")
|
||
|
||
class ConfigManager:
|
||
"""配置管理类,处理应用配置"""
|
||
def __init__(self, config_path="config"):
|
||
self.config = {}
|
||
self.config_path = config_path
|
||
self.build_config_dict()
|
||
|
||
|
||
def build_config_dict(self) -> dict[str, str]:
|
||
config_dict = {}
|
||
for config_file in Path(self.config_path).rglob("*.toml"):
|
||
if not config_file.is_file():
|
||
continue
|
||
|
||
# 获取相对路径的父目录名
|
||
rel_path = config_file.relative_to(self.config_path)
|
||
parent_name = rel_path.parent.name if rel_path.parent.name else None
|
||
|
||
if parent_name:
|
||
key = parent_name
|
||
else:
|
||
key = config_file.stem # 去掉扩展名
|
||
|
||
config_dict[key] = str(config_file.absolute())
|
||
self.config = config_dict
|
||
|
||
def load_config(self,name="config"):
|
||
"""加载配置文件"""
|
||
if not os.path.exists(self.config[name]):
|
||
return {}
|
||
with open(self.config[name], 'r', encoding='utf-8') as f:
|
||
try:
|
||
return toml.load(f)
|
||
except toml.TomlDecodeError:
|
||
return {}
|
||
|
||
def save_config(self, key=None, value=None):
|
||
"""保存配置项"""
|
||
if key is not None and value is not None:
|
||
# 如果提供了 key 和 value,则更新单个值
|
||
self.config[key] = value
|
||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||
toml.dump(self.config, f)
|
||
|
||
def update_config(self, config_dict):
|
||
"""更新配置字典"""
|
||
self.config.update(config_dict)
|
||
self.save_config()
|
||
|
||
class MainDatabase:
|
||
def __init__(self,db_path):
|
||
self.db_path = db_path
|
||
self._init_db()
|
||
def _init_db(self):
|
||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
configure_wal(cursor)
|
||
cursor.execute("""
|
||
""")
|
||
configm = ConfigManager()
|
||
|
||
basecontent = configm.load_config("config").get("app", {}).get("system_content", "你是一个qq助手,名叫”the real“")
|
||
|
||
class ChatDatabase:
|
||
"""数据库管理类,处理所有SQLite操作"""
|
||
def __init__(self, db_path):
|
||
self.db_path = db_path
|
||
self._init_db()
|
||
|
||
|
||
def _init_db(self):
|
||
"""初始化数据库表结构,并添加初始系统消息"""
|
||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
configure_wal(cursor)
|
||
|
||
# 创建消息表
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS messages (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
role TEXT NOT NULL, -- user/assistant/system
|
||
content TEXT NOT NULL, -- 消息内容
|
||
sender_id TEXT, -- 发送者ID
|
||
timestamp REAL NOT NULL -- 时间戳
|
||
)
|
||
""")
|
||
|
||
# 创建索引
|
||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sender ON messages(sender_id)")
|
||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON messages(timestamp)")
|
||
|
||
# 检查表是否为空(新创建的数据库)
|
||
cursor.execute("SELECT COUNT(*) FROM messages")
|
||
count = cursor.fetchone()[0]
|
||
if count == 0:
|
||
# 插入初始系统消息
|
||
timestamp = datetime.now().timestamp()
|
||
cursor.execute("""
|
||
INSERT INTO messages (role, content, sender_id, timestamp)
|
||
VALUES (?, ?, ?, ?)
|
||
""", ('system', basecontent, None, timestamp))
|
||
logger.info(f"初始化系统消息已添加到数据库: {self.db_path}")
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def save_message(self, role, content, sender_id=None):
|
||
"""保存消息到数据库"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
configure_wal(cursor)
|
||
timestamp = datetime.now().timestamp()
|
||
cursor.execute("""
|
||
INSERT INTO messages (role, content, sender_id, timestamp)
|
||
VALUES (?, ?, ?, ?)
|
||
""", (role, content, sender_id, timestamp))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def load_messages(self, limit=10, sender_id=None):
|
||
"""从数据库加载消息"""
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
configure_wal(cursor)
|
||
query = "SELECT role, content, sender_id, timestamp FROM messages"
|
||
params = []
|
||
|
||
if sender_id:
|
||
query += " WHERE sender_id = ?"
|
||
params.append(sender_id)
|
||
|
||
query += " ORDER BY timestamp LIMIT ?"
|
||
params.append(limit)
|
||
|
||
cursor.execute(query, params)
|
||
rows = cursor.fetchall()
|
||
conn.close()
|
||
|
||
# 转换为消息字典列表
|
||
messages = list()
|
||
for row in rows:
|
||
messages.append({
|
||
'role': row[0],
|
||
'content': row[1],
|
||
'sender_id': row[2],
|
||
'timestamp': row[3]
|
||
})
|
||
|
||
return messages
|
||
|
||
class ChatManager:
|
||
"""聊天管理器,处理所有数据库操作"""
|
||
def __init__(self):
|
||
self.base_dir = os.path.join("databases","chats")
|
||
self.user_dir = os.path.join(self.base_dir, "user")
|
||
self.group_dir = os.path.join(self.base_dir, "group")
|
||
# 确保目录存在
|
||
os.makedirs(self.user_dir, exist_ok=True)
|
||
os.makedirs(self.group_dir, exist_ok=True)
|
||
|
||
def get_user_db(self, user_id):
|
||
"""获取用户私聊数据库实例"""
|
||
db_path = os.path.join(self.user_dir, f"{user_id}.db")
|
||
return ChatDatabase(db_path)
|
||
|
||
def get_group_db(self, group_id):
|
||
"""获取群聊数据库实例"""
|
||
db_path = os.path.join(self.group_dir, f"{group_id}.db")
|
||
return ChatDatabase(db_path)
|
||
|
||
def save_private_message(self, user, role, content):
|
||
"""保存私聊消息"""
|
||
db = self.get_user_db(user.user_id)
|
||
db.save_message(role, content, sender_id=user.user_id)
|
||
|
||
def load_private_messages(self, user, limit=100):
|
||
"""加载私聊消息"""
|
||
db = self.get_user_db(user.user_id)
|
||
return db.load_messages(limit)
|
||
|
||
def save_group_message(self, group, role, content, sender_id=None):
|
||
"""保存群聊消息"""
|
||
db = self.get_group_db(group.group_id)
|
||
db.save_message(role, content, sender_id=sender_id)
|
||
|
||
def load_group_messages(self, group, limit=100):
|
||
"""加载群聊消息"""
|
||
db = self.get_group_db(group.group_id)
|
||
return db.load_messages(limit)
|
||
|
||
def load_user_group_messages(self, user, group, limit=10):
|
||
"""加载用户在群聊中的消息"""
|
||
db = self.get_group_db(group.group_id)
|
||
return db.load_messages(limit, sender_id=user.user_id)
|
||
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
from modules import user_modules as chater
|
||
# 创建聊天管理器
|
||
chat_manager = ChatManager()
|
||
|
||
# 创建用户和群组(仅包含基本信息)
|
||
user1 = chater.Qquser("12345")
|
||
user2 = chater.Qquser("67890")
|
||
group = chater.Qqgroup("1001")
|
||
|
||
# 保存私聊消息
|
||
chat_manager.save_private_message(user1, 'user', '你好,我想问个问题')
|
||
chat_manager.save_private_message(user1, 'assistant', '请说,我会尽力回答')
|
||
|
||
# 保存群聊消息
|
||
chat_manager.save_group_message(group, 'user', '大家好,我是张三!', sender_id=user1.user_id)
|
||
chat_manager.save_group_message(group, 'user', '大家好,我是李四!', sender_id=user2.user_id)
|
||
chat_manager.save_group_message(group, 'assistant', '欢迎加入群聊!')
|
||
|
||
# 获取私聊消息
|
||
private_messages = chat_manager.load_private_messages(user1)
|
||
print(f"{user1.nickname}的私聊记录:")
|
||
for msg in private_messages:
|
||
role = "用户" if msg['role'] == 'user' else "AI助手"
|
||
print(f"{role}: {msg['content']}")
|
||
|
||
# 获取群聊完整消息
|
||
group_messages = chat_manager.load_group_messages(group)
|
||
print(f"\n{group.nickname}的群聊记录:")
|
||
for msg in group_messages:
|
||
if msg['role'] == 'user':
|
||
print(f"{msg['sender_id']}: {msg['content']}")
|
||
else:
|
||
print(f"AI助手: {msg['content']}")
|
||
|
||
# 获取用户在群聊中的消息
|
||
user1_messages = chat_manager.load_user_group_messages(user1, group)
|
||
print(f"\n{user1.nickname}在{group.nickname}中的消息:")
|
||
for msg in user1_messages:
|
||
print(f"{msg['content']}")
|
||
config = ConfigManager()
|
||
print(config.config) |