1
This commit is contained in:
59
apps/backend/app/routers/xcx_ai_cache.py
Normal file
59
apps/backend/app/routers/xcx_ai_cache.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
小程序 AI 缓存查询路由 —— 查询各 AI 应用的最新缓存结果。
|
||||
|
||||
端点清单:
|
||||
- GET /api/ai/cache/{cache_type}?target_id=xxx — 查询最新缓存
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.schemas import CacheTypeEnum
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 缓存"])
|
||||
|
||||
|
||||
@router.get("/cache/{cache_type}")
|
||||
async def get_ai_cache(
|
||||
cache_type: str,
|
||||
target_id: str = Query(..., description="目标 ID(member_id / assistant_id_member_id / 时间维度编码)"),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""查询指定类型的最新 AI 缓存结果。
|
||||
|
||||
site_id 从 JWT 提取,强制过滤,确保门店隔离。
|
||||
"""
|
||||
# 校验 cache_type 合法性
|
||||
valid_types = {e.value for e in CacheTypeEnum}
|
||||
if cache_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"无效的 cache_type: {cache_type},合法值: {sorted(valid_types)}",
|
||||
)
|
||||
|
||||
cache_svc = AICacheService()
|
||||
result = cache_svc.get_latest(
|
||||
cache_type=cache_type,
|
||||
site_id=user.site_id,
|
||||
target_id=target_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": result.get("id"),
|
||||
"cache_type": result.get("cache_type"),
|
||||
"target_id": result.get("target_id"),
|
||||
"result_json": result.get("result_json"),
|
||||
"score": result.get("score"),
|
||||
"created_at": result.get("created_at"),
|
||||
}
|
||||
223
apps/backend/app/routers/xcx_ai_chat.py
Normal file
223
apps/backend/app/routers/xcx_ai_chat.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
小程序 AI 对话路由 —— SSE 流式对话、历史对话列表、消息查询。
|
||||
|
||||
端点清单:
|
||||
- POST /api/ai/chat/stream — SSE 流式对话
|
||||
- GET /api/ai/conversations — 历史对话列表(分页)
|
||||
- GET /api/ai/conversations/{conversation_id}/messages — 对话消息列表
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.apps.app1_chat import chat_stream
|
||||
from app.ai.schemas import ChatStreamRequest, SSEEvent
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.database import get_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 对话"])
|
||||
|
||||
|
||||
# ── 辅助:获取用户 nickname ──────────────────────────────────
|
||||
|
||||
|
||||
def _get_user_nickname(user_id: int) -> str:
|
||||
"""从 auth.users 查询用户 nickname,查不到返回空字符串。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT nickname FROM auth.users WHERE id = %s",
|
||||
(user_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return row[0] if row and row[0] else ""
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ── 辅助:获取用户主要角色 ───────────────────────────────────
|
||||
|
||||
|
||||
def _get_user_role_label(roles: list[str]) -> str:
|
||||
"""从角色列表提取主要角色标签,用于 AI 上下文。"""
|
||||
if "store_manager" in roles or "owner" in roles:
|
||||
return "管理者"
|
||||
if "assistant" in roles or "coach" in roles:
|
||||
return "助教"
|
||||
return "用户"
|
||||
|
||||
|
||||
# ── 辅助:构建 BailianClient 实例 ────────────────────────────
|
||||
|
||||
|
||||
def _get_bailian_client() -> BailianClient:
|
||||
"""从环境变量构建 BailianClient,缺失时报错。"""
|
||||
api_key = os.environ.get("BAILIAN_API_KEY")
|
||||
base_url = os.environ.get("BAILIAN_BASE_URL")
|
||||
model = os.environ.get("BAILIAN_MODEL")
|
||||
if not api_key or not base_url or not model:
|
||||
raise RuntimeError(
|
||||
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
|
||||
)
|
||||
return BailianClient(api_key=api_key, base_url=base_url, model=model)
|
||||
|
||||
|
||||
# ── SSE 流式对话 ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/chat/stream")
|
||||
async def ai_chat_stream(
|
||||
body: ChatStreamRequest,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""SSE 流式对话端点。
|
||||
|
||||
接收用户消息,通过百炼 API 流式返回 AI 回复。
|
||||
每个 SSE 事件格式:data: {json}\n\n
|
||||
事件类型:chunk(文本片段)/ done(完成)/ error(错误)
|
||||
"""
|
||||
if not body.message or not body.message.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="消息内容不能为空",
|
||||
)
|
||||
|
||||
nickname = _get_user_nickname(user.user_id)
|
||||
role_label = _get_user_role_label(user.roles)
|
||||
bailian = _get_bailian_client()
|
||||
conv_svc = ConversationService()
|
||||
|
||||
async def event_generator():
|
||||
"""SSE 事件生成器,逐事件 yield data: {json}\n\n 格式。"""
|
||||
try:
|
||||
async for event in chat_stream(
|
||||
message=body.message.strip(),
|
||||
user_id=user.user_id,
|
||||
nickname=nickname,
|
||||
role=role_label,
|
||||
site_id=user.site_id,
|
||||
source_page=body.source_page,
|
||||
page_context=body.page_context,
|
||||
screen_content=body.screen_content,
|
||||
bailian=bailian,
|
||||
conv_svc=conv_svc,
|
||||
):
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
except Exception as e:
|
||||
# 兜底:生成器内部异常也以 SSE error 事件返回
|
||||
logger.error("SSE 生成器异常: %s", e, exc_info=True)
|
||||
error_event = SSEEvent(type="error", message=str(e))
|
||||
yield f"data: {error_event.model_dump_json()}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # nginx 禁用缓冲
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── 历史对话列表 ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""查询当前用户的历史对话列表,按时间倒序,分页。"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1 or page_size > 100:
|
||||
page_size = 20
|
||||
|
||||
conv_svc = ConversationService()
|
||||
conversations = conv_svc.get_conversations(
|
||||
user_id=user.user_id,
|
||||
site_id=user.site_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
# 为每条对话附加首条消息预览
|
||||
result = []
|
||||
for conv in conversations:
|
||||
item = {
|
||||
"id": conv["id"],
|
||||
"app_id": conv["app_id"],
|
||||
"source_page": conv.get("source_page"),
|
||||
"created_at": conv["created_at"],
|
||||
"first_message_preview": None,
|
||||
}
|
||||
# 查询首条 user 消息作为预览
|
||||
messages = conv_svc.get_messages(conv["id"])
|
||||
for msg in messages:
|
||||
if msg["role"] == "user":
|
||||
content = msg["content"] or ""
|
||||
item["first_message_preview"] = content[:50] if len(content) > 50 else content
|
||||
break
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── 对话消息列表 ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/conversations/{conversation_id}/messages")
|
||||
async def get_conversation_messages(
|
||||
conversation_id: int,
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
):
|
||||
"""查询指定对话的所有消息,按时间升序。
|
||||
|
||||
验证对话归属当前用户和 site_id,防止越权访问。
|
||||
"""
|
||||
# 先验证对话归属
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id FROM biz.ai_conversations
|
||||
WHERE id = %s AND user_id = %s AND site_id = %s
|
||||
""",
|
||||
(conversation_id, str(user.user_id), user.site_id),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="对话不存在或无权访问",
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
conv_svc = ConversationService()
|
||||
messages = conv_svc.get_messages(conversation_id)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": msg["id"],
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
"tokens_used": msg.get("tokens_used"),
|
||||
"created_at": msg["created_at"],
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
Reference in New Issue
Block a user