# -*- coding: utf-8 -*- """ 小程序 CHAT 路由 —— CHAT-1/2/3/4 端点。 替代原 xcx_ai_chat.py(/api/ai/*),统一迁移到 /api/xcx/chat/* 路径。 端点清单: - GET /api/xcx/chat/history — CHAT-1 对话历史列表 - GET /api/xcx/chat/{chat_id}/messages — CHAT-2a 通过 chatId 查询消息 - GET /api/xcx/chat/messages?contextType=&contextId= — CHAT-2b 通过上下文查询消息 - POST /api/xcx/chat/{chat_id}/messages — CHAT-3 发送消息(同步回复) - POST /api/xcx/chat/stream — CHAT-4 SSE 流式端点 所有端点使用 require_approved() 权限检查。 """ from __future__ import annotations import json import logging from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.responses import StreamingResponse from app.ai.config import AIConfig from app.ai.dashscope_client import DashScopeClient from app.auth.dependencies import CurrentUser from app.database import get_connection from app.middleware.permission import require_approved from app.schemas.xcx_chat import ( ChatHistoryItem, ChatHistoryResponse, ChatMessageItem, ChatMessagesResponse, ChatStreamRequest, MessageBrief, ReferenceCard, SendMessageRequest, SendMessageResponse, ) from app.services.chat_service import ChatService from app.trace.decorators import trace_service from app.trace.sse_wrapper import ( record_ai_call, record_ai_error, record_sse_end, record_sse_start, record_sse_token, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/xcx/chat", tags=["小程序 CHAT"]) # ── CHAT-1: 对话历史列表 ───────────────────────────────────── @router.get("/history", response_model=ChatHistoryResponse) @trace_service("查询对话历史", "List chat history") async def list_chat_history( page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), user: CurrentUser = Depends(require_approved()), ) -> ChatHistoryResponse: """CHAT-1: 查询当前用户的对话历史列表,按最后消息时间倒序。""" svc = ChatService() items, total = svc.get_chat_history( user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) return ChatHistoryResponse( items=[ChatHistoryItem(**item) for item in items], total=total, page=page, page_size=page_size, ) # ── CHAT-2b: 通过上下文查询消息 ───────────────────────────── # ⚠️ 必须在 /{chat_id}/messages 之前注册,否则 "messages" 会被当作 chat_id 路径参数 @router.get("/messages", response_model=ChatMessagesResponse) @trace_service("通过上下文查询消息", "Get messages by context") async def get_chat_messages_by_context( context_type: str = Query(..., alias="contextType"), context_id: str = Query(..., alias="contextId"), page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), user: CurrentUser = Depends(require_approved()), ) -> ChatMessagesResponse: """CHAT-2b: 通过上下文类型和 ID 查询消息(自动查找/创建对话)。""" svc = ChatService() # 按复用规则查找或创建对话 chat_id = svc.get_or_create_session( user_id=user.user_id, site_id=user.site_id, context_type=context_type, context_id=context_id if context_id else None, ) messages, total, resolved_chat_id = svc.get_messages( chat_id=chat_id, user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) return ChatMessagesResponse( chat_id=resolved_chat_id, items=[_to_message_item(m) for m in messages], total=total, page=page, page_size=page_size, ) # ── CHAT-2a: 通过 chatId 查询消息 ─────────────────────────── @router.get("/{chat_id}/messages", response_model=ChatMessagesResponse) @trace_service("查询对话消息", "Get chat messages") async def get_chat_messages( chat_id: int, page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), user: CurrentUser = Depends(require_approved()), ) -> ChatMessagesResponse: """CHAT-2a: 通过 chatId 查询对话消息列表,按 createdAt 正序。""" svc = ChatService() messages, total, resolved_chat_id = svc.get_messages( chat_id=chat_id, user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) return ChatMessagesResponse( chat_id=resolved_chat_id, items=[_to_message_item(m) for m in messages], total=total, page=page, page_size=page_size, ) # ── CHAT-4: SSE 流式端点 ──────────────────────────────────── # ⚠️ 必须在 /{chat_id}/messages 之前注册,否则 "stream" 会被当作 chat_id 路径参数 @router.post("/stream") @trace_service("SSE 流式对话", "Chat stream SSE") async def chat_stream( body: ChatStreamRequest, user: CurrentUser = Depends(require_approved()), ) -> StreamingResponse: """CHAT-4: SSE 流式对话端点。 接收用户消息,通过百炼 API 流式返回 AI 回复。 SSE 事件类型:message(逐 token)/ done(完成)/ error(错误)。 chatId 归属验证:不属于当前用户返回 HTTP 403(普通 JSON 错误,非 SSE)。 """ if not body.content or not body.content.strip(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="消息内容不能为空", ) svc = ChatService() content = body.content.strip() # 归属验证(在 SSE 流开始前完成,失败时返回普通 HTTP 错误) svc._verify_ownership(body.chat_id, user.user_id, user.site_id) # 存入用户消息(P5 PRD 合规:发送时即写入) user_msg_id, user_created_at = svc._save_message(body.chat_id, "user", content) async def event_generator(): """SSE 事件生成器。 事件格式: - event: message\\ndata: {"token": "..."}\\n\\n - event: done\\ndata: {"messageId": ..., "createdAt": "..."}\\n\\n - event: error\\ndata: {"message": "..."}\\n\\n """ import time as _time full_reply_parts: list[str] = [] tokens_total = 0 _sse_start_ts = _time.perf_counter() # SSE trace: 流开始 record_sse_start( endpoint="/api/xcx/chat/stream", user_id=user.user_id, chat_id=str(body.chat_id), ) try: client = _get_dashscope_client() config = AIConfig.from_env() # 构建 prompt(最近 20 条历史 + 当前消息已在历史中) prompt = _build_prompt(body.chat_id) # 构建 biz_params(用户身份信息) biz_params = { "User_ID": str(user.user_id), "Role": getattr(user, "role", "coach"), "Nickname": getattr(user, "nickname", ""), } # 看板入口:注入页面上下文到 prompt if body.source_page: try: from app.ai.data_fetchers import build_page_text filters = {} if body.page_context: filters = body.page_context context_id = filters.pop("contextId", None) page_text = await build_page_text( source_page=body.source_page, context_id=context_id, site_id=user.site_id, filters=filters if filters else None, ) if page_text: prompt = f"[页面上下文: {body.source_page}]\n{page_text}\n\n{prompt}" except Exception: logger.warning("页面上下文注入失败: source_page=%s", body.source_page, exc_info=True) # 获取 session_id(对话复用) session_id = svc.get_session_id(body.chat_id) if hasattr(svc, "get_session_id") else None # SSE trace: AI 调用 record_ai_call( app_id=config.app_id_1_chat, prompt_length=len(prompt), session_id=session_id or "", ) # 流式调用 DashScope Application API async for chunk in client.call_app_stream( app_id=config.app_id_1_chat, prompt=prompt, session_id=session_id, biz_params=biz_params, ): full_reply_parts.append(chunk) tokens_total += 1 # SSE trace: 每 10 个 token 记录一次 record_sse_token(token_count=1, total_tokens=tokens_total) yield f"event: message\ndata: {json.dumps({'token': chunk}, ensure_ascii=False)}\n\n" # 流结束:拼接完整回复并持久化 full_reply = "".join(full_reply_parts) estimated_tokens = len(full_reply) ai_msg_id, ai_created_at = svc._save_message( body.chat_id, "assistant", full_reply, tokens_used=estimated_tokens, ) svc._update_session_metadata(body.chat_id, full_reply) # 发送 done 事件 done_data = json.dumps( {"messageId": ai_msg_id, "createdAt": ai_created_at}, ensure_ascii=False, ) yield f"event: done\ndata: {done_data}\n\n" # SSE trace: 流正常结束 _sse_elapsed = (_time.perf_counter() - _sse_start_ts) * 1000 record_sse_end( total_tokens=tokens_total, total_duration_ms=_sse_elapsed, completed=True, ) except Exception as e: logger.error("SSE 流式对话异常: %s", e, exc_info=True) # SSE trace: AI 错误 record_ai_error( error_type=type(e).__name__, message=str(e), ) # 如果已有部分回复,仍然持久化 if full_reply_parts: partial = "".join(full_reply_parts) try: svc._save_message(body.chat_id, "assistant", partial) svc._update_session_metadata(body.chat_id, partial) except Exception: logger.error("持久化部分回复失败", exc_info=True) error_data = json.dumps( {"message": "AI 服务暂时不可用"}, ensure_ascii=False, ) yield f"event: error\ndata: {error_data}\n\n" # SSE trace: 流异常结束 _sse_elapsed = (_time.perf_counter() - _sse_start_ts) * 1000 record_sse_end( total_tokens=tokens_total, total_duration_ms=_sse_elapsed, completed=False, ) return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ── CHAT-3: 发送消息(同步回复) ───────────────────────────── @router.post("/{chat_id}/messages", response_model=SendMessageResponse) @trace_service("发送消息", "Send message") async def send_message( chat_id: int, body: SendMessageRequest, user: CurrentUser = Depends(require_approved()), ) -> SendMessageResponse: """CHAT-3: 发送用户消息并获取同步 AI 回复。 chatId 归属验证:不属于当前用户返回 HTTP 403。 AI 失败时返回错误提示消息(HTTP 200)。 """ if not body.content or not body.content.strip(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="消息内容不能为空", ) svc = ChatService() result = await svc.send_message_sync( chat_id=chat_id, content=body.content.strip(), user_id=user.user_id, site_id=user.site_id, ) return SendMessageResponse( user_message=MessageBrief(**result["user_message"]), ai_reply=MessageBrief(**result["ai_reply"]), ) # ── 辅助函数 ───────────────────────────────────────────────── def _to_message_item(msg: dict) -> ChatMessageItem: """将 chat_service 返回的消息 dict 转换为 ChatMessageItem。""" ref_card = msg.get("reference_card") reference_card = ReferenceCard(**ref_card) if ref_card and isinstance(ref_card, dict) else None return ChatMessageItem( id=msg["id"], role=msg["role"], content=msg["content"], created_at=msg["created_at"], reference_card=reference_card, ) def _get_dashscope_client() -> DashScopeClient: """从环境变量构建 DashScopeClient,缺失时报错。""" config = AIConfig.from_env() return DashScopeClient(api_key=config.api_key, workspace_id=config.workspace_id) def _build_prompt(chat_id: int) -> str: """构建发送给 DashScope Application 的 prompt。 从 ai_messages 取最近 20 条历史,拼接为文本 prompt。 百炼 Application API 的 System Prompt 在控制台配置,此处只传用户对话内容。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT role, content FROM biz.ai_messages WHERE conversation_id = %s AND role != 'system' ORDER BY created_at ASC """, (chat_id,), ) history = cur.fetchall() finally: conn.close() # 取最近 20 条 recent = history[-20:] if len(history) > 20 else history # 如果只有一条(刚发送的用户消息),直接返回内容 if len(recent) == 1: return recent[0][1] # 多条历史:拼接为对话格式,最后一条为当前用户消息 parts: list[str] = [] for role, msg_content in recent[:-1]: label = "用户" if role == "user" else "AI" parts.append(f"{label}: {msg_content}") # 最后一条是当前用户消息,作为主 prompt current_msg = recent[-1][1] if recent else "" if parts: context = "\n".join(parts) return f"[历史对话]\n{context}\n\n[当前问题]\n{current_msg}" return current_msg