Files
Neo-ZQYY/apps/backend/app/routers/xcx_ai_chat.py
2026-03-15 10:15:02 +08:00

224 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
小程序 AI 对话路由 —— SSE 流式对话、历史对话列表、消息查询。
端点清单:
- POST /api/ai/chat/stream — SSE 流式对话
- GET /api/ai/conversations — 历史对话列表(分页)
- GET /api/ai/conversations/{conversation_id}/messages — 对话消息列表
"""
from __future__ import annotations
import json
import logging
import os
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from app.ai.bailian_client import BailianClient
from app.ai.conversation_service import ConversationService
from app.ai.apps.app1_chat import chat_stream
from app.ai.schemas import ChatStreamRequest, SSEEvent
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 对话"])
# ── 辅助:获取用户 nickname ──────────────────────────────────
def _get_user_nickname(user_id: int) -> str:
"""从 auth.users 查询用户 nickname查不到返回空字符串。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT nickname FROM auth.users WHERE id = %s",
(user_id,),
)
row = cur.fetchone()
return row[0] if row and row[0] else ""
finally:
conn.close()
# ── 辅助:获取用户主要角色 ───────────────────────────────────
def _get_user_role_label(roles: list[str]) -> str:
"""从角色列表提取主要角色标签,用于 AI 上下文。"""
if "store_manager" in roles or "owner" in roles:
return "管理者"
if "assistant" in roles or "coach" in roles:
return "助教"
return "用户"
# ── 辅助:构建 BailianClient 实例 ────────────────────────────
def _get_bailian_client() -> BailianClient:
"""从环境变量构建 BailianClient缺失时报错。"""
api_key = os.environ.get("BAILIAN_API_KEY")
base_url = os.environ.get("BAILIAN_BASE_URL")
model = os.environ.get("BAILIAN_MODEL")
if not api_key or not base_url or not model:
raise RuntimeError(
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
)
return BailianClient(api_key=api_key, base_url=base_url, model=model)
# ── SSE 流式对话 ─────────────────────────────────────────────
@router.post("/chat/stream")
async def ai_chat_stream(
body: ChatStreamRequest,
user: CurrentUser = Depends(get_current_user),
):
"""SSE 流式对话端点。
接收用户消息,通过百炼 API 流式返回 AI 回复。
每个 SSE 事件格式data: {json}\n\n
事件类型chunk文本片段/ done完成/ error错误
"""
if not body.message or not body.message.strip():
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="消息内容不能为空",
)
nickname = _get_user_nickname(user.user_id)
role_label = _get_user_role_label(user.roles)
bailian = _get_bailian_client()
conv_svc = ConversationService()
async def event_generator():
"""SSE 事件生成器,逐事件 yield data: {json}\n\n 格式。"""
try:
async for event in chat_stream(
message=body.message.strip(),
user_id=user.user_id,
nickname=nickname,
role=role_label,
site_id=user.site_id,
source_page=body.source_page,
page_context=body.page_context,
screen_content=body.screen_content,
bailian=bailian,
conv_svc=conv_svc,
):
yield f"data: {event.model_dump_json()}\n\n"
except Exception as e:
# 兜底:生成器内部异常也以 SSE error 事件返回
logger.error("SSE 生成器异常: %s", e, exc_info=True)
error_event = SSEEvent(type="error", message=str(e))
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # nginx 禁用缓冲
},
)
# ── 历史对话列表 ─────────────────────────────────────────────
@router.get("/conversations")
async def list_conversations(
page: int = 1,
page_size: int = 20,
user: CurrentUser = Depends(get_current_user),
):
"""查询当前用户的历史对话列表,按时间倒序,分页。"""
if page < 1:
page = 1
if page_size < 1 or page_size > 100:
page_size = 20
conv_svc = ConversationService()
conversations = conv_svc.get_conversations(
user_id=user.user_id,
site_id=user.site_id,
page=page,
page_size=page_size,
)
# 为每条对话附加首条消息预览
result = []
for conv in conversations:
item = {
"id": conv["id"],
"app_id": conv["app_id"],
"source_page": conv.get("source_page"),
"created_at": conv["created_at"],
"first_message_preview": None,
}
# 查询首条 user 消息作为预览
messages = conv_svc.get_messages(conv["id"])
for msg in messages:
if msg["role"] == "user":
content = msg["content"] or ""
item["first_message_preview"] = content[:50] if len(content) > 50 else content
break
result.append(item)
return result
# ── 对话消息列表 ─────────────────────────────────────────────
@router.get("/conversations/{conversation_id}/messages")
async def get_conversation_messages(
conversation_id: int,
user: CurrentUser = Depends(get_current_user),
):
"""查询指定对话的所有消息,按时间升序。
验证对话归属当前用户和 site_id防止越权访问。
"""
# 先验证对话归属
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id FROM biz.ai_conversations
WHERE id = %s AND user_id = %s AND site_id = %s
""",
(conversation_id, str(user.user_id), user.site_id),
)
if not cur.fetchone():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在或无权访问",
)
finally:
conn.close()
conv_svc = ConversationService()
messages = conv_svc.get_messages(conversation_id)
return [
{
"id": msg["id"],
"role": msg["role"],
"content": msg["content"],
"tokens_used": msg.get("tokens_used"),
"created_at": msg["created_at"],
}
for msg in messages
]