This commit is contained in:
Neo
2026-03-15 10:15:02 +08:00
parent 2dd217522c
commit 72bb11b34f
916 changed files with 65306 additions and 16102803 deletions

View File

@@ -0,0 +1,223 @@
# -*- coding: utf-8 -*-
"""
小程序 AI 对话路由 —— SSE 流式对话、历史对话列表、消息查询。
端点清单:
- POST /api/ai/chat/stream — SSE 流式对话
- GET /api/ai/conversations — 历史对话列表(分页)
- GET /api/ai/conversations/{conversation_id}/messages — 对话消息列表
"""
from __future__ import annotations
import json
import logging
import os
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from app.ai.bailian_client import BailianClient
from app.ai.conversation_service import ConversationService
from app.ai.apps.app1_chat import chat_stream
from app.ai.schemas import ChatStreamRequest, SSEEvent
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 对话"])
# ── 辅助:获取用户 nickname ──────────────────────────────────
def _get_user_nickname(user_id: int) -> str:
"""从 auth.users 查询用户 nickname查不到返回空字符串。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT nickname FROM auth.users WHERE id = %s",
(user_id,),
)
row = cur.fetchone()
return row[0] if row and row[0] else ""
finally:
conn.close()
# ── 辅助:获取用户主要角色 ───────────────────────────────────
def _get_user_role_label(roles: list[str]) -> str:
"""从角色列表提取主要角色标签,用于 AI 上下文。"""
if "store_manager" in roles or "owner" in roles:
return "管理者"
if "assistant" in roles or "coach" in roles:
return "助教"
return "用户"
# ── 辅助:构建 BailianClient 实例 ────────────────────────────
def _get_bailian_client() -> BailianClient:
"""从环境变量构建 BailianClient缺失时报错。"""
api_key = os.environ.get("BAILIAN_API_KEY")
base_url = os.environ.get("BAILIAN_BASE_URL")
model = os.environ.get("BAILIAN_MODEL")
if not api_key or not base_url or not model:
raise RuntimeError(
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
)
return BailianClient(api_key=api_key, base_url=base_url, model=model)
# ── SSE 流式对话 ─────────────────────────────────────────────
@router.post("/chat/stream")
async def ai_chat_stream(
body: ChatStreamRequest,
user: CurrentUser = Depends(get_current_user),
):
"""SSE 流式对话端点。
接收用户消息,通过百炼 API 流式返回 AI 回复。
每个 SSE 事件格式data: {json}\n\n
事件类型chunk文本片段/ done完成/ error错误
"""
if not body.message or not body.message.strip():
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="消息内容不能为空",
)
nickname = _get_user_nickname(user.user_id)
role_label = _get_user_role_label(user.roles)
bailian = _get_bailian_client()
conv_svc = ConversationService()
async def event_generator():
"""SSE 事件生成器,逐事件 yield data: {json}\n\n 格式。"""
try:
async for event in chat_stream(
message=body.message.strip(),
user_id=user.user_id,
nickname=nickname,
role=role_label,
site_id=user.site_id,
source_page=body.source_page,
page_context=body.page_context,
screen_content=body.screen_content,
bailian=bailian,
conv_svc=conv_svc,
):
yield f"data: {event.model_dump_json()}\n\n"
except Exception as e:
# 兜底:生成器内部异常也以 SSE error 事件返回
logger.error("SSE 生成器异常: %s", e, exc_info=True)
error_event = SSEEvent(type="error", message=str(e))
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # nginx 禁用缓冲
},
)
# ── 历史对话列表 ─────────────────────────────────────────────
@router.get("/conversations")
async def list_conversations(
page: int = 1,
page_size: int = 20,
user: CurrentUser = Depends(get_current_user),
):
"""查询当前用户的历史对话列表,按时间倒序,分页。"""
if page < 1:
page = 1
if page_size < 1 or page_size > 100:
page_size = 20
conv_svc = ConversationService()
conversations = conv_svc.get_conversations(
user_id=user.user_id,
site_id=user.site_id,
page=page,
page_size=page_size,
)
# 为每条对话附加首条消息预览
result = []
for conv in conversations:
item = {
"id": conv["id"],
"app_id": conv["app_id"],
"source_page": conv.get("source_page"),
"created_at": conv["created_at"],
"first_message_preview": None,
}
# 查询首条 user 消息作为预览
messages = conv_svc.get_messages(conv["id"])
for msg in messages:
if msg["role"] == "user":
content = msg["content"] or ""
item["first_message_preview"] = content[:50] if len(content) > 50 else content
break
result.append(item)
return result
# ── 对话消息列表 ─────────────────────────────────────────────
@router.get("/conversations/{conversation_id}/messages")
async def get_conversation_messages(
conversation_id: int,
user: CurrentUser = Depends(get_current_user),
):
"""查询指定对话的所有消息,按时间升序。
验证对话归属当前用户和 site_id防止越权访问。
"""
# 先验证对话归属
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id FROM biz.ai_conversations
WHERE id = %s AND user_id = %s AND site_id = %s
""",
(conversation_id, str(user.user_id), user.site_id),
)
if not cur.fetchone():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在或无权访问",
)
finally:
conn.close()
conv_svc = ConversationService()
messages = conv_svc.get_messages(conversation_id)
return [
{
"id": msg["id"],
"role": msg["role"],
"content": msg["content"],
"tokens_used": msg.get("tokens_used"),
"created_at": msg["created_at"],
}
for msg in messages
]