# -*- coding: utf-8 -*- """ 小程序 CHAT 路由 —— CHAT-1/2/3/4 端点。 替代原 xcx_ai_chat.py(/api/ai/*),统一迁移到 /api/xcx/chat/* 路径。 端点清单: - GET /api/xcx/chat/history — CHAT-1 对话历史列表 - GET /api/xcx/chat/{chat_id}/messages — CHAT-2a 通过 chatId 查询消息 - GET /api/xcx/chat/messages?contextType=&contextId= — CHAT-2b 通过上下文查询消息 - POST /api/xcx/chat/{chat_id}/messages — CHAT-3 发送消息(同步回复) - POST /api/xcx/chat/stream — CHAT-4 SSE 流式端点 所有端点使用 require_approved() 权限检查。 """ from __future__ import annotations import json import logging import os from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.responses import StreamingResponse from app.ai.bailian_client import BailianClient from app.auth.dependencies import CurrentUser from app.database import get_connection from app.middleware.permission import require_approved from app.schemas.xcx_chat import ( ChatHistoryItem, ChatHistoryResponse, ChatMessageItem, ChatMessagesResponse, ChatStreamRequest, MessageBrief, ReferenceCard, SendMessageRequest, SendMessageResponse, ) from app.services.chat_service import ChatService logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/xcx/chat", tags=["小程序 CHAT"]) # ── CHAT-1: 对话历史列表 ───────────────────────────────────── @router.get("/history", response_model=ChatHistoryResponse) async def list_chat_history( page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), user: CurrentUser = Depends(require_approved()), ) -> ChatHistoryResponse: """CHAT-1: 查询当前用户的对话历史列表,按最后消息时间倒序。""" svc = ChatService() items, total = svc.get_chat_history( user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) return ChatHistoryResponse( items=[ChatHistoryItem(**item) for item in items], total=total, page=page, page_size=page_size, ) # ── CHAT-2b: 通过上下文查询消息 ───────────────────────────── # ⚠️ 必须在 /{chat_id}/messages 之前注册,否则 "messages" 会被当作 chat_id 路径参数 @router.get("/messages", response_model=ChatMessagesResponse) async def get_chat_messages_by_context( context_type: str = Query(..., alias="contextType"), context_id: str = Query(..., alias="contextId"), page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), user: CurrentUser = Depends(require_approved()), ) -> ChatMessagesResponse: """CHAT-2b: 通过上下文类型和 ID 查询消息(自动查找/创建对话)。""" svc = ChatService() # 按复用规则查找或创建对话 chat_id = svc.get_or_create_session( user_id=user.user_id, site_id=user.site_id, context_type=context_type, context_id=context_id if context_id else None, ) messages, total, resolved_chat_id = svc.get_messages( chat_id=chat_id, user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) return ChatMessagesResponse( chat_id=resolved_chat_id, items=[_to_message_item(m) for m in messages], total=total, page=page, page_size=page_size, ) # ── CHAT-2a: 通过 chatId 查询消息 ─────────────────────────── @router.get("/{chat_id}/messages", response_model=ChatMessagesResponse) async def get_chat_messages( chat_id: int, page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), user: CurrentUser = Depends(require_approved()), ) -> ChatMessagesResponse: """CHAT-2a: 通过 chatId 查询对话消息列表,按 createdAt 正序。""" svc = ChatService() messages, total, resolved_chat_id = svc.get_messages( chat_id=chat_id, user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) return ChatMessagesResponse( chat_id=resolved_chat_id, items=[_to_message_item(m) for m in messages], total=total, page=page, page_size=page_size, ) # ── CHAT-4: SSE 流式端点 ──────────────────────────────────── # ⚠️ 必须在 /{chat_id}/messages 之前注册,否则 "stream" 会被当作 chat_id 路径参数 @router.post("/stream") async def chat_stream( body: ChatStreamRequest, user: CurrentUser = Depends(require_approved()), ) -> StreamingResponse: """CHAT-4: SSE 流式对话端点。 接收用户消息,通过百炼 API 流式返回 AI 回复。 SSE 事件类型:message(逐 token)/ done(完成)/ error(错误)。 chatId 归属验证:不属于当前用户返回 HTTP 403(普通 JSON 错误,非 SSE)。 """ if not body.content or not body.content.strip(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="消息内容不能为空", ) svc = ChatService() content = body.content.strip() # 归属验证(在 SSE 流开始前完成,失败时返回普通 HTTP 错误) svc._verify_ownership(body.chat_id, user.user_id, user.site_id) # 存入用户消息(P5 PRD 合规:发送时即写入) user_msg_id, user_created_at = svc._save_message(body.chat_id, "user", content) async def event_generator(): """SSE 事件生成器。 事件格式: - event: message\\ndata: {"token": "..."}\\n\\n - event: done\\ndata: {"messageId": ..., "createdAt": "..."}\\n\\n - event: error\\ndata: {"message": "..."}\\n\\n """ full_reply_parts: list[str] = [] try: bailian = _get_bailian_client() # 获取历史消息作为上下文 messages = _build_ai_messages(body.chat_id) # 流式调用百炼 API async for chunk in bailian.chat_stream(messages): full_reply_parts.append(chunk) yield f"event: message\ndata: {json.dumps({'token': chunk}, ensure_ascii=False)}\n\n" # 流结束:拼接完整回复并持久化 full_reply = "".join(full_reply_parts) estimated_tokens = len(full_reply) ai_msg_id, ai_created_at = svc._save_message( body.chat_id, "assistant", full_reply, tokens_used=estimated_tokens, ) svc._update_session_metadata(body.chat_id, full_reply) # 发送 done 事件 done_data = json.dumps( {"messageId": ai_msg_id, "createdAt": ai_created_at}, ensure_ascii=False, ) yield f"event: done\ndata: {done_data}\n\n" except Exception as e: logger.error("SSE 流式对话异常: %s", e, exc_info=True) # 如果已有部分回复,仍然持久化 if full_reply_parts: partial = "".join(full_reply_parts) try: svc._save_message(body.chat_id, "assistant", partial) svc._update_session_metadata(body.chat_id, partial) except Exception: logger.error("持久化部分回复失败", exc_info=True) error_data = json.dumps( {"message": "AI 服务暂时不可用"}, ensure_ascii=False, ) yield f"event: error\ndata: {error_data}\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ── CHAT-3: 发送消息(同步回复) ───────────────────────────── @router.post("/{chat_id}/messages", response_model=SendMessageResponse) async def send_message( chat_id: int, body: SendMessageRequest, user: CurrentUser = Depends(require_approved()), ) -> SendMessageResponse: """CHAT-3: 发送用户消息并获取同步 AI 回复。 chatId 归属验证:不属于当前用户返回 HTTP 403。 AI 失败时返回错误提示消息(HTTP 200)。 """ if not body.content or not body.content.strip(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="消息内容不能为空", ) svc = ChatService() result = await svc.send_message_sync( chat_id=chat_id, content=body.content.strip(), user_id=user.user_id, site_id=user.site_id, ) return SendMessageResponse( user_message=MessageBrief(**result["user_message"]), ai_reply=MessageBrief(**result["ai_reply"]), ) # ── 辅助函数 ───────────────────────────────────────────────── def _to_message_item(msg: dict) -> ChatMessageItem: """将 chat_service 返回的消息 dict 转换为 ChatMessageItem。""" ref_card = msg.get("reference_card") reference_card = ReferenceCard(**ref_card) if ref_card and isinstance(ref_card, dict) else None return ChatMessageItem( id=msg["id"], role=msg["role"], content=msg["content"], created_at=msg["created_at"], reference_card=reference_card, ) def _get_bailian_client() -> BailianClient: """从环境变量构建 BailianClient,缺失时报错。""" api_key = os.environ.get("BAILIAN_API_KEY") base_url = os.environ.get("BAILIAN_BASE_URL") model = os.environ.get("BAILIAN_MODEL") if not api_key or not base_url or not model: raise RuntimeError( "百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL" ) return BailianClient(api_key=api_key, base_url=base_url, model=model) def _build_ai_messages(chat_id: int) -> list[dict]: """构建发送给 AI 的消息列表(含历史上下文)。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT role, content FROM biz.ai_messages WHERE conversation_id = %s ORDER BY created_at ASC """, (chat_id,), ) history = cur.fetchall() finally: conn.close() messages: list[dict] = [] # 取最近 20 条 recent = history[-20:] if len(history) > 20 else history for role, msg_content in recent: messages.append({"role": role, "content": msg_content}) # 如果没有 system 消息,添加默认 system prompt if not messages or messages[0]["role"] != "system": system_prompt = { "role": "system", "content": json.dumps( {"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"}, ensure_ascii=False, ), } messages.insert(0, system_prompt) return messages