""" 对话记录持久化服务。 负责 biz.ai_conversations 和 biz.ai_messages 两张表的 CRUD。 所有 8 个 AI 应用的每次调用都通过本服务记录对话和消息。 """ from __future__ import annotations import json import logging from datetime import datetime from app.database import get_connection logger = logging.getLogger(__name__) class ConversationService: """AI 对话记录持久化服务。""" def create_conversation( self, user_id: int | str, nickname: str, app_id: str, site_id: int, source_page: str | None = None, source_context: dict | None = None, ) -> int: """创建对话记录,返回 conversation_id。 系统自动调用时 user_id 为 'system'。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_conversations (user_id, nickname, app_id, site_id, source_page, source_context) VALUES (%s, %s, %s, %s, %s, %s) RETURNING id """, ( str(user_id), nickname, app_id, site_id, source_page, json.dumps(source_context, ensure_ascii=False) if source_context else None, ), ) row = cur.fetchone() conn.commit() return row[0] except Exception: conn.rollback() raise finally: conn.close() def add_message( self, conversation_id: int, role: str, content: str, tokens_used: int | None = None, ) -> int: """添加消息记录,返回 message_id。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_messages (conversation_id, role, content, tokens_used) VALUES (%s, %s, %s, %s) RETURNING id """, (conversation_id, role, content, tokens_used), ) row = cur.fetchone() conn.commit() return row[0] except Exception: conn.rollback() raise finally: conn.close() def get_conversations( self, user_id: int | str, site_id: int, page: int = 1, page_size: int = 20, ) -> list[dict]: """查询用户历史对话列表,按 created_at 降序,分页。""" offset = (page - 1) * page_size conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, user_id, nickname, app_id, site_id, source_page, source_context, created_at FROM biz.ai_conversations WHERE user_id = %s AND site_id = %s ORDER BY created_at DESC LIMIT %s OFFSET %s """, (str(user_id), site_id, page_size, offset), ) columns = [desc[0] for desc in cur.description] rows = cur.fetchall() return [ _row_to_dict(columns, row) for row in rows ] finally: conn.close() def get_messages( self, conversation_id: int, ) -> list[dict]: """查询对话的所有消息,按 created_at 升序。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, conversation_id, role, content, tokens_used, created_at FROM biz.ai_messages WHERE conversation_id = %s ORDER BY created_at ASC """, (conversation_id,), ) columns = [desc[0] for desc in cur.description] rows = cur.fetchall() return [ _row_to_dict(columns, row) for row in rows ] finally: conn.close() def _row_to_dict(columns: list[str], row: tuple) -> dict: """将数据库行转换为 dict,处理特殊类型序列化。""" result = {} for col, val in zip(columns, row): if isinstance(val, datetime): result[col] = val.isoformat() else: result[col] = val return result