161 lines
4.8 KiB
Python
161 lines
4.8 KiB
Python
"""
|
||
对话记录持久化服务。
|
||
|
||
负责 biz.ai_conversations 和 biz.ai_messages 两张表的 CRUD。
|
||
所有 8 个 AI 应用的每次调用都通过本服务记录对话和消息。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
from app.database import get_connection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ConversationService:
|
||
"""AI 对话记录持久化服务。"""
|
||
|
||
def create_conversation(
|
||
self,
|
||
user_id: int | str,
|
||
nickname: str,
|
||
app_id: str,
|
||
site_id: int,
|
||
source_page: str | None = None,
|
||
source_context: dict | None = None,
|
||
) -> int:
|
||
"""创建对话记录,返回 conversation_id。
|
||
|
||
系统自动调用时 user_id 为 'system'。
|
||
"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO biz.ai_conversations
|
||
(user_id, nickname, app_id, site_id, source_page, source_context)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
RETURNING id
|
||
""",
|
||
(
|
||
str(user_id),
|
||
nickname,
|
||
app_id,
|
||
site_id,
|
||
source_page,
|
||
json.dumps(source_context, ensure_ascii=False) if source_context else None,
|
||
),
|
||
)
|
||
row = cur.fetchone()
|
||
conn.commit()
|
||
return row[0]
|
||
except Exception:
|
||
conn.rollback()
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
def add_message(
|
||
self,
|
||
conversation_id: int,
|
||
role: str,
|
||
content: str,
|
||
tokens_used: int | None = None,
|
||
) -> int:
|
||
"""添加消息记录,返回 message_id。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO biz.ai_messages
|
||
(conversation_id, role, content, tokens_used)
|
||
VALUES (%s, %s, %s, %s)
|
||
RETURNING id
|
||
""",
|
||
(conversation_id, role, content, tokens_used),
|
||
)
|
||
row = cur.fetchone()
|
||
conn.commit()
|
||
return row[0]
|
||
except Exception:
|
||
conn.rollback()
|
||
raise
|
||
finally:
|
||
conn.close()
|
||
|
||
def get_conversations(
|
||
self,
|
||
user_id: int | str,
|
||
site_id: int,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
) -> list[dict]:
|
||
"""查询用户历史对话列表,按 created_at 降序,分页。"""
|
||
offset = (page - 1) * page_size
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT id, user_id, nickname, app_id, site_id,
|
||
source_page, source_context, created_at
|
||
FROM biz.ai_conversations
|
||
WHERE user_id = %s AND site_id = %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s OFFSET %s
|
||
""",
|
||
(str(user_id), site_id, page_size, offset),
|
||
)
|
||
columns = [desc[0] for desc in cur.description]
|
||
rows = cur.fetchall()
|
||
return [
|
||
_row_to_dict(columns, row)
|
||
for row in rows
|
||
]
|
||
finally:
|
||
conn.close()
|
||
|
||
def get_messages(
|
||
self,
|
||
conversation_id: int,
|
||
) -> list[dict]:
|
||
"""查询对话的所有消息,按 created_at 升序。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT id, conversation_id, role, content,
|
||
tokens_used, created_at
|
||
FROM biz.ai_messages
|
||
WHERE conversation_id = %s
|
||
ORDER BY created_at ASC
|
||
""",
|
||
(conversation_id,),
|
||
)
|
||
columns = [desc[0] for desc in cur.description]
|
||
rows = cur.fetchall()
|
||
return [
|
||
_row_to_dict(columns, row)
|
||
for row in rows
|
||
]
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def _row_to_dict(columns: list[str], row: tuple) -> dict:
|
||
"""将数据库行转换为 dict,处理特殊类型序列化。"""
|
||
result = {}
|
||
for col, val in zip(columns, row):
|
||
if isinstance(val, datetime):
|
||
result[col] = val.isoformat()
|
||
else:
|
||
result[col] = val
|
||
return result
|