224 lines
7.5 KiB
Python
224 lines
7.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
小程序 AI 对话路由 —— SSE 流式对话、历史对话列表、消息查询。
|
||
|
||
端点清单:
|
||
- POST /api/ai/chat/stream — SSE 流式对话
|
||
- GET /api/ai/conversations — 历史对话列表(分页)
|
||
- GET /api/ai/conversations/{conversation_id}/messages — 对话消息列表
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from app.ai.bailian_client import BailianClient
|
||
from app.ai.conversation_service import ConversationService
|
||
from app.ai.apps.app1_chat import chat_stream
|
||
from app.ai.schemas import ChatStreamRequest, SSEEvent
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.database import get_connection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 对话"])
|
||
|
||
|
||
# ── 辅助:获取用户 nickname ──────────────────────────────────
|
||
|
||
|
||
def _get_user_nickname(user_id: int) -> str:
|
||
"""从 auth.users 查询用户 nickname,查不到返回空字符串。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT nickname FROM auth.users WHERE id = %s",
|
||
(user_id,),
|
||
)
|
||
row = cur.fetchone()
|
||
return row[0] if row and row[0] else ""
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
# ── 辅助:获取用户主要角色 ───────────────────────────────────
|
||
|
||
|
||
def _get_user_role_label(roles: list[str]) -> str:
|
||
"""从角色列表提取主要角色标签,用于 AI 上下文。"""
|
||
if "store_manager" in roles or "owner" in roles:
|
||
return "管理者"
|
||
if "assistant" in roles or "coach" in roles:
|
||
return "助教"
|
||
return "用户"
|
||
|
||
|
||
# ── 辅助:构建 BailianClient 实例 ────────────────────────────
|
||
|
||
|
||
def _get_bailian_client() -> BailianClient:
|
||
"""从环境变量构建 BailianClient,缺失时报错。"""
|
||
api_key = os.environ.get("BAILIAN_API_KEY")
|
||
base_url = os.environ.get("BAILIAN_BASE_URL")
|
||
model = os.environ.get("BAILIAN_MODEL")
|
||
if not api_key or not base_url or not model:
|
||
raise RuntimeError(
|
||
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
|
||
)
|
||
return BailianClient(api_key=api_key, base_url=base_url, model=model)
|
||
|
||
|
||
# ── SSE 流式对话 ─────────────────────────────────────────────
|
||
|
||
|
||
@router.post("/chat/stream")
|
||
async def ai_chat_stream(
|
||
body: ChatStreamRequest,
|
||
user: CurrentUser = Depends(get_current_user),
|
||
):
|
||
"""SSE 流式对话端点。
|
||
|
||
接收用户消息,通过百炼 API 流式返回 AI 回复。
|
||
每个 SSE 事件格式:data: {json}\n\n
|
||
事件类型:chunk(文本片段)/ done(完成)/ error(错误)
|
||
"""
|
||
if not body.message or not body.message.strip():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||
detail="消息内容不能为空",
|
||
)
|
||
|
||
nickname = _get_user_nickname(user.user_id)
|
||
role_label = _get_user_role_label(user.roles)
|
||
bailian = _get_bailian_client()
|
||
conv_svc = ConversationService()
|
||
|
||
async def event_generator():
|
||
"""SSE 事件生成器,逐事件 yield data: {json}\n\n 格式。"""
|
||
try:
|
||
async for event in chat_stream(
|
||
message=body.message.strip(),
|
||
user_id=user.user_id,
|
||
nickname=nickname,
|
||
role=role_label,
|
||
site_id=user.site_id,
|
||
source_page=body.source_page,
|
||
page_context=body.page_context,
|
||
screen_content=body.screen_content,
|
||
bailian=bailian,
|
||
conv_svc=conv_svc,
|
||
):
|
||
yield f"data: {event.model_dump_json()}\n\n"
|
||
except Exception as e:
|
||
# 兜底:生成器内部异常也以 SSE error 事件返回
|
||
logger.error("SSE 生成器异常: %s", e, exc_info=True)
|
||
error_event = SSEEvent(type="error", message=str(e))
|
||
yield f"data: {error_event.model_dump_json()}\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no", # nginx 禁用缓冲
|
||
},
|
||
)
|
||
|
||
|
||
# ── 历史对话列表 ─────────────────────────────────────────────
|
||
|
||
|
||
@router.get("/conversations")
|
||
async def list_conversations(
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
user: CurrentUser = Depends(get_current_user),
|
||
):
|
||
"""查询当前用户的历史对话列表,按时间倒序,分页。"""
|
||
if page < 1:
|
||
page = 1
|
||
if page_size < 1 or page_size > 100:
|
||
page_size = 20
|
||
|
||
conv_svc = ConversationService()
|
||
conversations = conv_svc.get_conversations(
|
||
user_id=user.user_id,
|
||
site_id=user.site_id,
|
||
page=page,
|
||
page_size=page_size,
|
||
)
|
||
|
||
# 为每条对话附加首条消息预览
|
||
result = []
|
||
for conv in conversations:
|
||
item = {
|
||
"id": conv["id"],
|
||
"app_id": conv["app_id"],
|
||
"source_page": conv.get("source_page"),
|
||
"created_at": conv["created_at"],
|
||
"first_message_preview": None,
|
||
}
|
||
# 查询首条 user 消息作为预览
|
||
messages = conv_svc.get_messages(conv["id"])
|
||
for msg in messages:
|
||
if msg["role"] == "user":
|
||
content = msg["content"] or ""
|
||
item["first_message_preview"] = content[:50] if len(content) > 50 else content
|
||
break
|
||
result.append(item)
|
||
|
||
return result
|
||
|
||
|
||
# ── 对话消息列表 ─────────────────────────────────────────────
|
||
|
||
|
||
@router.get("/conversations/{conversation_id}/messages")
|
||
async def get_conversation_messages(
|
||
conversation_id: int,
|
||
user: CurrentUser = Depends(get_current_user),
|
||
):
|
||
"""查询指定对话的所有消息,按时间升序。
|
||
|
||
验证对话归属当前用户和 site_id,防止越权访问。
|
||
"""
|
||
# 先验证对话归属
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT id FROM biz.ai_conversations
|
||
WHERE id = %s AND user_id = %s AND site_id = %s
|
||
""",
|
||
(conversation_id, str(user.user_id), user.site_id),
|
||
)
|
||
if not cur.fetchone():
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="对话不存在或无权访问",
|
||
)
|
||
finally:
|
||
conn.close()
|
||
|
||
conv_svc = ConversationService()
|
||
messages = conv_svc.get_messages(conversation_id)
|
||
|
||
return [
|
||
{
|
||
"id": msg["id"],
|
||
"role": msg["role"],
|
||
"content": msg["content"],
|
||
"tokens_used": msg.get("tokens_used"),
|
||
"created_at": msg["created_at"],
|
||
}
|
||
for msg in messages
|
||
]
|