# -*- coding: utf-8 -*- """ 小程序 AI 对话路由 —— SSE 流式对话、历史对话列表、消息查询。 端点清单: - POST /api/ai/chat/stream — SSE 流式对话 - GET /api/ai/conversations — 历史对话列表(分页) - GET /api/ai/conversations/{conversation_id}/messages — 对话消息列表 """ from __future__ import annotations import json import logging import os from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from app.ai.bailian_client import BailianClient from app.ai.conversation_service import ConversationService from app.ai.apps.app1_chat import chat_stream from app.ai.schemas import ChatStreamRequest, SSEEvent from app.auth.dependencies import CurrentUser, get_current_user from app.database import get_connection logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/ai", tags=["小程序 AI 对话"]) # ── 辅助:获取用户 nickname ────────────────────────────────── def _get_user_nickname(user_id: int) -> str: """从 auth.users 查询用户 nickname,查不到返回空字符串。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( "SELECT nickname FROM auth.users WHERE id = %s", (user_id,), ) row = cur.fetchone() return row[0] if row and row[0] else "" finally: conn.close() # ── 辅助:获取用户主要角色 ─────────────────────────────────── def _get_user_role_label(roles: list[str]) -> str: """从角色列表提取主要角色标签,用于 AI 上下文。""" if "store_manager" in roles or "owner" in roles: return "管理者" if "assistant" in roles or "coach" in roles: return "助教" return "用户" # ── 辅助:构建 BailianClient 实例 ──────────────────────────── def _get_bailian_client() -> BailianClient: """从环境变量构建 BailianClient,缺失时报错。""" api_key = os.environ.get("BAILIAN_API_KEY") base_url = os.environ.get("BAILIAN_BASE_URL") model = os.environ.get("BAILIAN_MODEL") if not api_key or not base_url or not model: raise RuntimeError( "百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL" ) return BailianClient(api_key=api_key, base_url=base_url, model=model) # ── SSE 流式对话 ───────────────────────────────────────────── @router.post("/chat/stream") async def ai_chat_stream( body: ChatStreamRequest, user: CurrentUser = Depends(get_current_user), ): """SSE 流式对话端点。 接收用户消息,通过百炼 API 流式返回 AI 回复。 每个 SSE 事件格式:data: {json}\n\n 事件类型:chunk(文本片段)/ done(完成)/ error(错误) """ if not body.message or not body.message.strip(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="消息内容不能为空", ) nickname = _get_user_nickname(user.user_id) role_label = _get_user_role_label(user.roles) bailian = _get_bailian_client() conv_svc = ConversationService() async def event_generator(): """SSE 事件生成器,逐事件 yield data: {json}\n\n 格式。""" try: async for event in chat_stream( message=body.message.strip(), user_id=user.user_id, nickname=nickname, role=role_label, site_id=user.site_id, source_page=body.source_page, page_context=body.page_context, screen_content=body.screen_content, bailian=bailian, conv_svc=conv_svc, ): yield f"data: {event.model_dump_json()}\n\n" except Exception as e: # 兜底:生成器内部异常也以 SSE error 事件返回 logger.error("SSE 生成器异常: %s", e, exc_info=True) error_event = SSEEvent(type="error", message=str(e)) yield f"data: {error_event.model_dump_json()}\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # nginx 禁用缓冲 }, ) # ── 历史对话列表 ───────────────────────────────────────────── @router.get("/conversations") async def list_conversations( page: int = 1, page_size: int = 20, user: CurrentUser = Depends(get_current_user), ): """查询当前用户的历史对话列表,按时间倒序,分页。""" if page < 1: page = 1 if page_size < 1 or page_size > 100: page_size = 20 conv_svc = ConversationService() conversations = conv_svc.get_conversations( user_id=user.user_id, site_id=user.site_id, page=page, page_size=page_size, ) # 为每条对话附加首条消息预览 result = [] for conv in conversations: item = { "id": conv["id"], "app_id": conv["app_id"], "source_page": conv.get("source_page"), "created_at": conv["created_at"], "first_message_preview": None, } # 查询首条 user 消息作为预览 messages = conv_svc.get_messages(conv["id"]) for msg in messages: if msg["role"] == "user": content = msg["content"] or "" item["first_message_preview"] = content[:50] if len(content) > 50 else content break result.append(item) return result # ── 对话消息列表 ───────────────────────────────────────────── @router.get("/conversations/{conversation_id}/messages") async def get_conversation_messages( conversation_id: int, user: CurrentUser = Depends(get_current_user), ): """查询指定对话的所有消息,按时间升序。 验证对话归属当前用户和 site_id,防止越权访问。 """ # 先验证对话归属 conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id FROM biz.ai_conversations WHERE id = %s AND user_id = %s AND site_id = %s """, (conversation_id, str(user.user_id), user.site_id), ) if not cur.fetchone(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="对话不存在或无权访问", ) finally: conn.close() conv_svc = ConversationService() messages = conv_svc.get_messages(conversation_id) return [ { "id": msg["id"], "role": msg["role"], "content": msg["content"], "tokens_used": msg.get("tokens_used"), "created_at": msg["created_at"], } for msg in messages ]