""" CHAT 模块业务逻辑层。 封装对话管理、消息持久化、referenceCard 组装、标题生成等核心逻辑。 路由层(xcx_chat.py)调用本服务完成 CHAT-1/2/3/4 端点的业务处理。 表依赖: - biz.ai_conversations — 对话会话(含 context_type/context_id/title/last_message 扩展字段) - biz.ai_messages — 消息记录(含 reference_card 扩展字段) - fdw_etl.v_dim_member — 会员信息(通过 ETL 直连) - fdw_etl.v_dws_member_consumption_summary / v_dwd_assistant_service_log — 消费指标 ⚠️ P5 PRD 合规: - app_id 固定为 'app1_chat' - 用户消息发送时即写入 ai_messages(role=user) - 流式完成后完整 assistant 回复写入 ai_messages(role=assistant),含 tokens_used """ from __future__ import annotations import json import logging import os from datetime import datetime from decimal import Decimal from typing import Any from fastapi import HTTPException, status from app.ai.bailian_client import BailianClient from app.database import get_connection from app.services import fdw_queries logger = logging.getLogger(__name__) APP_ID = "app1_chat" # 对话复用时限(天) _REUSE_DAYS = 3 class ChatService: """CHAT 模块业务逻辑。""" # ------------------------------------------------------------------ # CHAT-1: 对话历史列表 # ------------------------------------------------------------------ def get_chat_history( self, user_id: int, site_id: int, page: int, page_size: int, ) -> tuple[list[dict], int]: """查询对话历史列表,返回 (items, total)。 按 last_message_at 倒序,JOIN v_dim_member 获取 customerName。 仅返回 app_id='app1_chat' 的对话。 """ offset = (page - 1) * page_size conn = get_connection() try: with conn.cursor() as cur: # 总数 cur.execute( """ SELECT COUNT(*) FROM biz.ai_conversations WHERE user_id = %s AND site_id = %s AND app_id = %s """, (str(user_id), site_id, APP_ID), ) total = cur.fetchone()[0] # 分页列表 cur.execute( """ SELECT id, title, context_type, context_id, last_message, last_message_at, created_at FROM biz.ai_conversations WHERE user_id = %s AND site_id = %s AND app_id = %s ORDER BY COALESCE(last_message_at, created_at) DESC LIMIT %s OFFSET %s """, (str(user_id), site_id, APP_ID, page_size, offset), ) columns = [desc[0] for desc in cur.description] rows = cur.fetchall() finally: conn.close() # 组装结果,尝试获取 customerName items: list[dict] = [] # 收集需要查询姓名的 customer context_id customer_ids: list[int] = [] raw_items: list[dict] = [] for row in rows: item = dict(zip(columns, row)) raw_items.append(item) if item.get("context_type") == "customer" and item.get("context_id"): try: customer_ids.append(int(item["context_id"])) except (ValueError, TypeError): pass # 批量查询客户姓名(FDW 降级:查询失败返回空映射) name_map: dict[int, str] = {} if customer_ids: try: biz_conn = get_connection() try: info_map = fdw_queries.get_member_info(biz_conn, site_id, customer_ids) for mid, info in info_map.items(): name_map[mid] = info.get("nickname") or "" finally: biz_conn.close() except Exception: logger.warning("查询客户姓名失败,降级为空", exc_info=True) for item in raw_items: customer_name: str | None = None if item.get("context_type") == "customer" and item.get("context_id"): try: customer_name = name_map.get(int(item["context_id"])) except (ValueError, TypeError): pass # 生成标题 title = self.generate_title( title=item.get("title"), customer_name=customer_name, conversation_id=item["id"], ) ts = item.get("last_message_at") or item.get("created_at") items.append({ "id": item["id"], "title": title, "customer_name": customer_name, "last_message": item.get("last_message"), "timestamp": ts.isoformat() if isinstance(ts, datetime) else str(ts) if ts else "", "unread_count": 0, }) return items, total # ------------------------------------------------------------------ # 对话复用 / 创建 # ------------------------------------------------------------------ def get_or_create_session( self, user_id: int, site_id: int, context_type: str, context_id: str | None, ) -> int: """按入口上下文查找或创建对话,返回 chat_id。 复用规则: - context_type='task': 同一 taskId 始终复用(无时限) - context_type='customer'/'coach': 最后消息 ≤ 3 天复用,> 3 天新建 - context_type='general': 始终新建 """ # general 入口始终新建 if context_type == "general": return self._create_session(user_id, site_id, context_type, context_id) conn = get_connection() try: with conn.cursor() as cur: if context_type == "task": # task 入口:始终复用(无时限) cur.execute( """ SELECT id FROM biz.ai_conversations WHERE user_id = %s AND site_id = %s AND context_type = 'task' AND context_id = %s ORDER BY created_at DESC LIMIT 1 """, (str(user_id), site_id, context_id), ) elif context_type in ("customer", "coach"): # customer/coach 入口:3 天时限复用 cur.execute( """ SELECT id FROM biz.ai_conversations WHERE user_id = %s AND site_id = %s AND context_type = %s AND context_id = %s AND last_message_at > NOW() - INTERVAL '3 days' ORDER BY last_message_at DESC LIMIT 1 """, (str(user_id), site_id, context_type, context_id), ) else: # 未知类型,新建 return self._create_session(user_id, site_id, context_type, context_id) row = cur.fetchone() if row: return row[0] finally: conn.close() # 未找到可复用对话,新建 return self._create_session(user_id, site_id, context_type, context_id) def _create_session( self, user_id: int, site_id: int, context_type: str, context_id: str | None, ) -> int: """创建新对话记录,返回 conversation_id。""" conn = get_connection() try: with conn.cursor() as cur: # 查询用户昵称 cur.execute( "SELECT nickname FROM auth.users WHERE id = %s", (user_id,), ) row = cur.fetchone() nickname = row[0] if row and row[0] else "" cur.execute( """ INSERT INTO biz.ai_conversations (user_id, nickname, app_id, site_id, context_type, context_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING id """, (str(user_id), nickname, APP_ID, site_id, context_type, context_id), ) new_id = cur.fetchone()[0] conn.commit() return new_id except Exception: conn.rollback() raise finally: conn.close() # ------------------------------------------------------------------ # CHAT-2: 消息列表 # ------------------------------------------------------------------ def get_messages( self, chat_id: int, user_id: int, site_id: int, page: int, page_size: int, ) -> tuple[list[dict], int, int]: """查询消息列表,返回 (messages, total, chat_id)。 验证 chat_id 归属当前用户,按 created_at 正序。 """ self._verify_ownership(chat_id, user_id, site_id) offset = (page - 1) * page_size conn = get_connection() try: with conn.cursor() as cur: cur.execute( "SELECT COUNT(*) FROM biz.ai_messages WHERE conversation_id = %s", (chat_id,), ) total = cur.fetchone()[0] cur.execute( """ SELECT id, role, content, created_at, reference_card FROM biz.ai_messages WHERE conversation_id = %s ORDER BY created_at ASC LIMIT %s OFFSET %s """, (chat_id, page_size, offset), ) columns = [desc[0] for desc in cur.description] rows = cur.fetchall() finally: conn.close() messages = [] for row in rows: item = dict(zip(columns, row)) ref_card = item.get("reference_card") # reference_card 可能是 dict(psycopg2 自动解析 jsonb)或 str if isinstance(ref_card, str): try: ref_card = json.loads(ref_card) except (json.JSONDecodeError, TypeError): ref_card = None created_at = item["created_at"] messages.append({ "id": item["id"], "role": item["role"], "content": item["content"], "created_at": created_at.isoformat() if isinstance(created_at, datetime) else str(created_at), "reference_card": ref_card, }) return messages, total, chat_id # ------------------------------------------------------------------ # CHAT-3: 发送消息(同步回复) # ------------------------------------------------------------------ async def send_message_sync( self, chat_id: int, content: str, user_id: int, site_id: int, ) -> dict: """发送消息并获取同步 AI 回复。 流程: 1. 验证 chatId 归属 2. 存入用户消息(立即写入) 3. 调用 AI 获取回复 4. 存入 AI 回复 5. 更新 session 的 last_message / last_message_at 6. AI 失败时返回错误提示消息(HTTP 200) """ self._verify_ownership(chat_id, user_id, site_id) # 1. 立即存入用户消息(P5 PRD 合规:发送时即写入) user_msg_id, user_created_at = self._save_message(chat_id, "user", content) # 2. 调用 AI ai_reply_text: str tokens_used: int | None = None try: ai_reply_text, tokens_used = await self._call_ai(chat_id, content, user_id, site_id) except Exception as e: logger.error("AI 服务调用失败: %s", e, exc_info=True) ai_reply_text = "抱歉,AI 助手暂时无法回复,请稍后重试" # 3. 存入 AI 回复 ai_msg_id, ai_created_at = self._save_message( chat_id, "assistant", ai_reply_text, tokens_used=tokens_used, ) # 4. 更新 session 元数据 self._update_session_metadata(chat_id, ai_reply_text) return { "user_message": { "id": user_msg_id, "content": content, "created_at": user_created_at, }, "ai_reply": { "id": ai_msg_id, "content": ai_reply_text, "created_at": ai_created_at, }, } # ------------------------------------------------------------------ # referenceCard 组装 # ------------------------------------------------------------------ def build_reference_card( self, customer_id: int, site_id: int, ) -> dict | None: """从 FDW 查询客户关键指标,组装 referenceCard。 ⚠️ DWD-DOC 规则:金额用 items_sum 口径(ledger_amount), 会员信息通过 member_id JOIN dim_member(scd2_is_current=1)。 FDW 查询失败时静默降级返回 None(不影响消息本身)。 """ try: biz_conn = get_connection() try: # 客户姓名 info_map = fdw_queries.get_member_info(biz_conn, site_id, [customer_id]) if customer_id not in info_map: return None member_name = info_map[customer_id].get("nickname") or "未知客户" # 余额 balance: Decimal | None = None try: balance_map = fdw_queries.get_member_balance(biz_conn, site_id, [customer_id]) balance = balance_map.get(customer_id) except Exception: logger.warning("referenceCard: 查询余额失败", exc_info=True) # 近 30 天消费(items_sum 口径) consume_30d: Decimal | None = None try: consume_30d = self._get_consumption_30d(biz_conn, site_id, customer_id) except Exception: logger.warning("referenceCard: 查询近30天消费失败", exc_info=True) # 近 30 天到店次数 visit_count: int | None = None try: visit_count = self._get_visit_count_30d(biz_conn, site_id, customer_id) except Exception: logger.warning("referenceCard: 查询到店次数失败", exc_info=True) finally: biz_conn.close() # 格式化 balance_str = f"¥{balance:,.2f}" if balance is not None else "—" consume_str = f"¥{consume_30d:,.2f}" if consume_30d is not None else "—" visit_str = f"{visit_count}次" if visit_count is not None else "—" return { "type": "customer", "title": f"{member_name} — 消费概览", "summary": f"余额 {balance_str},近30天消费 {consume_str}", "data": { "余额": balance_str, "近30天消费": consume_str, "到店次数": visit_str, }, } except Exception: logger.warning("referenceCard 组装失败,降级为 null", exc_info=True) return None # ------------------------------------------------------------------ # 标题生成 # ------------------------------------------------------------------ def generate_title( self, title: str | None = None, customer_name: str | None = None, conversation_id: int | None = None, first_message: str | None = None, ) -> str: """生成对话标题:自定义标题 > 客户姓名 > 首条消息前 20 字。 结果始终非空。 """ # 优先级 1:自定义标题 if title and title.strip(): return title.strip() # 优先级 2:客户姓名 if customer_name and customer_name.strip(): return customer_name.strip() # 优先级 3:首条消息前 20 字 if first_message is None and conversation_id is not None: first_message = self._get_first_message(conversation_id) if first_message and first_message.strip(): text = first_message.strip() return text[:20] if len(text) > 20 else text return "新对话" # ------------------------------------------------------------------ # 内部辅助方法 # ------------------------------------------------------------------ def _verify_ownership(self, chat_id: int, user_id: int, site_id: int) -> None: """验证对话归属当前用户,不属于时抛出 HTTP 403/404。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT user_id FROM biz.ai_conversations WHERE id = %s AND site_id = %s """, (chat_id, site_id), ) row = cur.fetchone() if not row: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="对话不存在", ) if str(row[0]) != str(user_id): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此对话", ) finally: conn.close() def _save_message( self, conversation_id: int, role: str, content: str, tokens_used: int | None = None, reference_card: dict | None = None, ) -> tuple[int, str]: """写入消息记录,返回 (message_id, created_at ISO 字符串)。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO biz.ai_messages (conversation_id, role, content, tokens_used, reference_card) VALUES (%s, %s, %s, %s, %s) RETURNING id, created_at """, ( conversation_id, role, content, tokens_used, json.dumps(reference_card, ensure_ascii=False) if reference_card else None, ), ) row = cur.fetchone() conn.commit() msg_id = row[0] created_at = row[1] return msg_id, created_at.isoformat() if isinstance(created_at, datetime) else str(created_at) except Exception: conn.rollback() raise finally: conn.close() def _update_session_metadata(self, chat_id: int, last_message: str) -> None: """更新对话的 last_message 和 last_message_at。""" # 截断至 100 字 truncated = last_message[:100] if len(last_message) > 100 else last_message conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE biz.ai_conversations SET last_message = %s, last_message_at = NOW() WHERE id = %s """, (truncated, chat_id), ) conn.commit() except Exception: conn.rollback() raise finally: conn.close() def _get_first_message(self, conversation_id: int) -> str | None: """查询对话的首条 user 消息内容。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT content FROM biz.ai_messages WHERE conversation_id = %s AND role = 'user' ORDER BY created_at ASC LIMIT 1 """, (conversation_id,), ) row = cur.fetchone() return row[0] if row else None finally: conn.close() async def _call_ai( self, chat_id: int, content: str, user_id: int, site_id: int, ) -> tuple[str, int | None]: """调用百炼 API 获取非流式回复,返回 (reply_text, tokens_used)。 构建历史消息上下文发送给 AI。 """ bailian = _get_bailian_client() # 获取历史消息作为上下文(最近 20 条) conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT role, content FROM biz.ai_messages WHERE conversation_id = %s ORDER BY created_at ASC """, (chat_id,), ) history = cur.fetchall() finally: conn.close() # 构建消息列表 messages: list[dict] = [] # 取最近 20 条(含刚写入的 user 消息) recent = history[-20:] if len(history) > 20 else history for role, msg_content in recent: messages.append({"role": role, "content": msg_content}) # 如果没有 system 消息,添加默认 system prompt if not messages or messages[0]["role"] != "system": system_prompt = { "role": "system", "content": json.dumps( {"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"}, ensure_ascii=False, ), } messages.insert(0, system_prompt) # 非流式调用(chat_stream 用于 SSE,这里用 chat_stream 收集完整回复) full_parts: list[str] = [] async for chunk in bailian.chat_stream(messages): full_parts.append(chunk) reply = "".join(full_parts) # 流式模式不返回 tokens_used,按字符数估算 estimated_tokens = len(reply) return reply, estimated_tokens @staticmethod def _get_consumption_30d(conn: Any, site_id: int, member_id: int) -> Decimal | None: """查询客户近 30 天消费金额(items_sum 口径)。 ⚠️ DWD-DOC 规则 1: 使用 ledger_amount(items_sum 口径),禁用 consume_money。 """ with fdw_queries._fdw_context(conn, site_id) as cur: cur.execute( """ SELECT COALESCE(SUM(ledger_amount), 0) FROM app.v_dwd_assistant_service_log WHERE tenant_member_id = %s AND is_delete = 0 AND create_time >= (CURRENT_DATE - INTERVAL '30 days')::timestamptz """, (member_id,), ) row = cur.fetchone() return Decimal(str(row[0])) if row and row[0] is not None else None @staticmethod def _get_visit_count_30d(conn: Any, site_id: int, member_id: int) -> int | None: """查询客户近 30 天到店次数。""" with fdw_queries._fdw_context(conn, site_id) as cur: cur.execute( """ SELECT COUNT(DISTINCT create_time::date) FROM app.v_dwd_assistant_service_log WHERE tenant_member_id = %s AND is_delete = 0 AND create_time >= (CURRENT_DATE - INTERVAL '30 days')::timestamptz """, (member_id,), ) row = cur.fetchone() return int(row[0]) if row and row[0] is not None else None # ── 模块级辅助函数 ────────────────────────────────────────────── def _get_bailian_client() -> BailianClient: """从环境变量构建 BailianClient,缺失时报错。""" api_key = os.environ.get("BAILIAN_API_KEY") base_url = os.environ.get("BAILIAN_BASE_URL") model = os.environ.get("BAILIAN_MODEL") if not api_key or not base_url or not model: raise RuntimeError( "百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL" ) return BailianClient(api_key=api_key, base_url=base_url, model=model)