feat: 累积功能变更 — 聊天集成、租户管理、小程序更新、ETL 增强、迁移脚本
包含多个会话的累积代码变更: - backend: AI 聊天服务、触发器调度、认证增强、WebSocket、调度器最小间隔 - admin-web: ETL 状态页、任务管理、调度配置、登录优化 - miniprogram: 看板页面、聊天集成、UI 组件、导航更新 - etl: DWS 新任务(finance_area_daily/board_cache)、连接器增强 - tenant-admin: 项目初始化 - db: 19 个迁移脚本(etl_feiqiu 11 + zqyy_app 8) - packages/shared: 枚举和工具函数更新 - tools: 数据库工具、报表生成、健康检查 - docs: PRD/架构/部署/合约文档更新 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,15 +12,19 @@ import json
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.data_fetchers import build_page_text
|
||||
from app.ai.schemas import SSEEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ID = "app1_chat"
|
||||
|
||||
# system prompt 总字符数上限
|
||||
_MAX_SYSTEM_PROMPT_LEN = 4000
|
||||
|
||||
|
||||
async def chat_stream(
|
||||
*,
|
||||
@@ -32,7 +36,7 @@ async def chat_stream(
|
||||
source_page: str | None = None,
|
||||
page_context: dict | None = None,
|
||||
screen_content: str | None = None,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
conv_svc: ConversationService,
|
||||
) -> AsyncGenerator[SSEEvent, None]:
|
||||
"""流式对话入口,返回 SSEEvent 异步生成器。
|
||||
@@ -76,11 +80,12 @@ async def chat_stream(
|
||||
)
|
||||
|
||||
# 3. 构建消息列表(system prompt + user message)
|
||||
messages = _build_messages(
|
||||
messages = await _build_messages(
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
nickname=nickname,
|
||||
role=role,
|
||||
site_id=site_id,
|
||||
source_page=source_page,
|
||||
page_context=page_context,
|
||||
screen_content=screen_content,
|
||||
@@ -118,12 +123,13 @@ async def chat_stream(
|
||||
yield SSEEvent(type="error", message=str(e))
|
||||
|
||||
|
||||
def _build_messages(
|
||||
async def _build_messages(
|
||||
*,
|
||||
message: str,
|
||||
user_id: int | str,
|
||||
nickname: str,
|
||||
role: str,
|
||||
site_id: int,
|
||||
source_page: str | None,
|
||||
page_context: dict | None,
|
||||
screen_content: str | None,
|
||||
@@ -132,25 +138,38 @@ def _build_messages(
|
||||
|
||||
首条 system 消息注入页面上下文和用户信息。
|
||||
"""
|
||||
system_content = _build_system_prompt(
|
||||
system_content = await _build_system_prompt(
|
||||
user_id=user_id,
|
||||
nickname=nickname,
|
||||
role=role,
|
||||
site_id=site_id,
|
||||
source_page=source_page,
|
||||
page_context=page_context,
|
||||
screen_content=screen_content,
|
||||
)
|
||||
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
# system prompt 总字符数控制
|
||||
if len(content_str) > _MAX_SYSTEM_PROMPT_LEN:
|
||||
# 截断 page_context 中的 data_text
|
||||
pc = system_content.get("page_context", {})
|
||||
dt = pc.get("data_text", "")
|
||||
if dt and len(dt) > 500:
|
||||
pc["data_text"] = dt[:500] + "…(已截断)"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
|
||||
{"role": "system", "content": content_str},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
|
||||
|
||||
def _build_system_prompt(
|
||||
async def _build_system_prompt(
|
||||
*,
|
||||
user_id: int | str,
|
||||
nickname: str,
|
||||
role: str,
|
||||
site_id: int,
|
||||
source_page: str | None,
|
||||
page_context: dict | None,
|
||||
screen_content: str | None,
|
||||
@@ -161,7 +180,12 @@ def _build_system_prompt(
|
||||
注入页面上下文供 AI 理解当前场景。
|
||||
"""
|
||||
prompt: dict = {
|
||||
"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。",
|
||||
"task": (
|
||||
"你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"
|
||||
"当 page_context 中包含 memberNickname、contextId 或 data_text 时,"
|
||||
"你必须直接使用这些信息回答问题,不要再向用户索要已有的信息。"
|
||||
"例如用户在客户详情页提问时,直接基于该客户的数据回答,无需要求提供会员 ID。"
|
||||
),
|
||||
"biz_params": {
|
||||
"user_prompt_params": {
|
||||
"User_ID": str(user_id),
|
||||
@@ -172,10 +196,11 @@ def _build_system_prompt(
|
||||
}
|
||||
|
||||
# 注入页面上下文(首条消息)
|
||||
page_ctx = _build_page_context(
|
||||
page_ctx = await _build_page_context(
|
||||
source_page=source_page,
|
||||
page_context=page_context,
|
||||
screen_content=screen_content,
|
||||
site_id=site_id,
|
||||
)
|
||||
if page_ctx:
|
||||
prompt["page_context"] = page_ctx
|
||||
@@ -183,25 +208,52 @@ def _build_system_prompt(
|
||||
return prompt
|
||||
|
||||
|
||||
def _build_page_context(
|
||||
async def _build_page_context(
|
||||
*,
|
||||
source_page: str | None,
|
||||
page_context: dict | None,
|
||||
screen_content: str | None,
|
||||
site_id: int,
|
||||
) -> dict:
|
||||
"""构建页面上下文信息。
|
||||
|
||||
P5-A 阶段:直接透传前端传入的上下文字段。
|
||||
P5-B 阶段:各页面逐步实现文本化工具,丰富 screen_content。
|
||||
根据 source_page(contextType)调用 build_page_text 获取结构化文本,
|
||||
看板类页面从 page_context 提取筛选参数传入 filters。
|
||||
contextType 为空或未识别时返回空 dict(跳过注入)。
|
||||
"""
|
||||
# TODO: P5-B 各页面文本化工具细化
|
||||
ctx: dict = {}
|
||||
|
||||
if source_page:
|
||||
ctx["source_page"] = source_page
|
||||
|
||||
# 从 page_context 提取 contextId 和筛选参数
|
||||
context_id = None
|
||||
filters: dict = {}
|
||||
if page_context:
|
||||
context_id = page_context.get("contextId")
|
||||
# 看板类页面筛选参数透传
|
||||
for key in ("timeDimension", "areaFilter", "dimension", "typeFilter", "projectFilter"):
|
||||
if key in page_context:
|
||||
filters[key] = page_context[key]
|
||||
|
||||
# 调用 data_fetcher 获取页面数据文本
|
||||
try:
|
||||
data_text = await build_page_text(
|
||||
source_page=source_page,
|
||||
context_id=context_id,
|
||||
site_id=site_id,
|
||||
filters=filters if filters else None,
|
||||
)
|
||||
if data_text:
|
||||
ctx["data_text"] = data_text
|
||||
except Exception:
|
||||
logger.warning("页面上下文文本化失败: source_page=%s", source_page, exc_info=True)
|
||||
|
||||
if page_context:
|
||||
ctx["page_context"] = page_context
|
||||
if screen_content:
|
||||
ctx["screen_content"] = screen_content
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import logging
|
||||
import os
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.prompts.app2_finance_prompt import build_prompt
|
||||
@@ -124,7 +124,7 @@ def compute_time_range(dimension: str, business_date: date) -> tuple[date, date]
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
|
||||
@@ -15,25 +15,42 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.data_fetchers import fetch_member_consumption_data
|
||||
from app.ai.schemas import CacheTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ID = "app3_clue"
|
||||
|
||||
# system message content 上限
|
||||
_MAX_SYSTEM_CONTENT_LEN = 8000
|
||||
|
||||
def build_prompt(
|
||||
|
||||
def _default_member_data() -> dict:
|
||||
"""数据获取失败时的默认空值。"""
|
||||
return {
|
||||
"member_nickname": "",
|
||||
"consumption_records": [],
|
||||
"member_cards": [],
|
||||
"card_balance_total": 0,
|
||||
"stored_value_balance_total": 0,
|
||||
"expected_visit_date": None,
|
||||
"days_since_last_visit": None,
|
||||
}
|
||||
|
||||
|
||||
async def build_prompt(
|
||||
context: dict,
|
||||
cache_svc: AICacheService | None = None,
|
||||
) -> list[dict]:
|
||||
"""构建 Prompt 消息列表。
|
||||
|
||||
P5-A 阶段:返回占位 Prompt,标注待细化字段。
|
||||
P5-B 阶段(P9-T1):补充 consumption_records 等完整数据。
|
||||
从 data_fetchers 获取真实消费数据,失败时降级为空值。
|
||||
|
||||
Args:
|
||||
context: 包含 site_id, member_id, nickname 等
|
||||
@@ -45,9 +62,28 @@ def build_prompt(
|
||||
site_id = context["site_id"]
|
||||
member_id = context["member_id"]
|
||||
|
||||
# 获取消费数据(失败时降级)
|
||||
data_fetch_failed = False
|
||||
try:
|
||||
member_data = await fetch_member_consumption_data(site_id, member_id)
|
||||
except Exception:
|
||||
logger.warning("App3 消费数据获取失败,使用默认空值: site_id=%s member_id=%s", site_id, member_id, exc_info=True)
|
||||
member_data = _default_member_data()
|
||||
data_fetch_failed = True
|
||||
|
||||
# 构建 reference:App6 线索 + 最近 2 套 App8 历史(附 generated_at)
|
||||
reference = _build_reference(site_id, member_id, cache_svc)
|
||||
|
||||
member_nickname = member_data.get("member_nickname", "")
|
||||
consumption_records = member_data.get("consumption_records", [])
|
||||
|
||||
# 空数据标注
|
||||
if not consumption_records:
|
||||
if data_fetch_failed:
|
||||
consumption_records = "⚠ 消费数据获取失败,该客户暂无消费记录可供分析"
|
||||
else:
|
||||
consumption_records = "该客户暂无消费记录"
|
||||
|
||||
system_content = {
|
||||
"task": "分析客户消费数据,提取维客线索。",
|
||||
"app_id": APP_ID,
|
||||
@@ -67,14 +103,28 @@ def build_prompt(
|
||||
}
|
||||
]
|
||||
},
|
||||
# TODO: P9-T1 细化 - consumption_records 等客户消费数据
|
||||
"data": {
|
||||
"consumption_records": "待 P9-T1 补充",
|
||||
"member_info": "待 P9-T1 补充",
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
"member_nickname": member_nickname,
|
||||
"main_data": {
|
||||
"consumption_records": consumption_records,
|
||||
"member_cards": member_data.get("member_cards", []),
|
||||
"card_balance_total": member_data.get("card_balance_total", 0),
|
||||
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
|
||||
"expected_visit_date": member_data.get("expected_visit_date"),
|
||||
"days_since_last_visit": member_data.get("days_since_last_visit"),
|
||||
},
|
||||
"reference": reference,
|
||||
}
|
||||
|
||||
# Token 预算控制:截断 consumption_records
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
records = system_content["main_data"].get("consumption_records")
|
||||
if isinstance(records, list) and len(records) > 5:
|
||||
system_content["main_data"]["consumption_records"] = records[:5]
|
||||
system_content["main_data"]["_truncated"] = f"消费记录已截断,原始共 {len(records)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
|
||||
user_content = (
|
||||
f"请分析会员 {member_id} 的消费数据,提取维客线索。"
|
||||
"每条线索包含 category、summary、detail、emoji 四个字段。"
|
||||
@@ -82,7 +132,7 @@ def build_prompt(
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
|
||||
{"role": "system", "content": content_str},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
@@ -134,7 +184,7 @@ def _build_reference(
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
@@ -162,7 +212,7 @@ async def run(
|
||||
nickname = context.get("nickname", "")
|
||||
|
||||
# 1. 构建 Prompt
|
||||
messages = build_prompt(context, cache_svc)
|
||||
messages = await build_prompt(context, cache_svc)
|
||||
|
||||
# 2. 创建对话记录
|
||||
conversation_id = conv_svc.create_conversation(
|
||||
|
||||
@@ -11,27 +11,50 @@ app_id = "app4_analysis"
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.data_fetchers import (
|
||||
fetch_assistant_info,
|
||||
fetch_member_consumption_data,
|
||||
fetch_member_notes,
|
||||
fetch_service_history,
|
||||
)
|
||||
from app.ai.schemas import CacheTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ID = "app4_analysis"
|
||||
|
||||
# system message content 上限
|
||||
_MAX_SYSTEM_CONTENT_LEN = 8000
|
||||
|
||||
def build_prompt(
|
||||
|
||||
def _default_member_data() -> dict:
|
||||
"""数据获取失败时的默认空值。"""
|
||||
return {
|
||||
"member_nickname": "",
|
||||
"consumption_records": [],
|
||||
"member_cards": [],
|
||||
"card_balance_total": 0,
|
||||
"stored_value_balance_total": 0,
|
||||
"expected_visit_date": None,
|
||||
"days_since_last_visit": None,
|
||||
}
|
||||
|
||||
|
||||
async def build_prompt(
|
||||
context: dict,
|
||||
cache_svc: AICacheService | None = None,
|
||||
) -> list[dict]:
|
||||
"""构建 Prompt 消息列表。
|
||||
|
||||
P5-A 阶段:返回占位 Prompt,标注待细化字段。
|
||||
P5-B 阶段(P6-T4):补充 service_history、assistant_info 等完整数据。
|
||||
并发获取助教信息、服务历史、客户消费数据、备注,部分失败不阻断。
|
||||
|
||||
Args:
|
||||
context: 包含 site_id, assistant_id, member_id
|
||||
@@ -44,10 +67,50 @@ def build_prompt(
|
||||
assistant_id = context["assistant_id"]
|
||||
member_id = context["member_id"]
|
||||
|
||||
# 并发获取 4 类数据,部分失败不阻断
|
||||
results = await asyncio.gather(
|
||||
fetch_assistant_info(site_id, assistant_id),
|
||||
fetch_service_history(site_id, assistant_id, member_id),
|
||||
fetch_member_consumption_data(site_id, member_id),
|
||||
fetch_member_notes(site_id, member_id),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# 降级处理
|
||||
fetch_errors: list[str] = []
|
||||
|
||||
if isinstance(results[0], Exception):
|
||||
logger.warning("App4 助教信息获取失败: %s", results[0])
|
||||
assistant_info = {}
|
||||
fetch_errors.append("助教信息获取失败")
|
||||
else:
|
||||
assistant_info = results[0]
|
||||
|
||||
if isinstance(results[1], Exception):
|
||||
logger.warning("App4 服务历史获取失败: %s", results[1])
|
||||
service_history: list = []
|
||||
fetch_errors.append("服务历史获取失败")
|
||||
else:
|
||||
service_history = results[1]
|
||||
|
||||
if isinstance(results[2], Exception):
|
||||
logger.warning("App4 消费数据获取失败: %s", results[2])
|
||||
member_data = _default_member_data()
|
||||
fetch_errors.append("消费数据获取失败")
|
||||
else:
|
||||
member_data = results[2]
|
||||
|
||||
if isinstance(results[3], Exception):
|
||||
logger.warning("App4 备注获取失败: %s", results[3])
|
||||
notes: list = []
|
||||
fetch_errors.append("备注获取失败")
|
||||
else:
|
||||
notes = results[3]
|
||||
|
||||
# 构建 reference:App8 最新 + 最近 2 套历史
|
||||
reference = _build_reference(site_id, member_id, cache_svc)
|
||||
|
||||
system_content = {
|
||||
system_content: dict = {
|
||||
"task": "分析助教与客户的关系,生成任务建议。",
|
||||
"app_id": APP_ID,
|
||||
"output_format": {
|
||||
@@ -55,14 +118,51 @@ def build_prompt(
|
||||
"action_suggestions": ["建议1", "建议2"],
|
||||
"one_line_summary": "一句话总结",
|
||||
},
|
||||
# TODO: P6-T4 细化 - service_history、assistant_info
|
||||
"data": {
|
||||
"service_history": "待 P6-T4 补充",
|
||||
"assistant_info": "待 P6-T4 补充",
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
"assistant_info": assistant_info if assistant_info else "⚠ 助教信息获取失败",
|
||||
"service_history": service_history if service_history else "暂无服务记录",
|
||||
"task_assignment_basis": {
|
||||
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
|
||||
"member_cards": member_data.get("member_cards", []),
|
||||
"card_balance_total": member_data.get("card_balance_total", 0),
|
||||
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
|
||||
"expected_visit_date": member_data.get("expected_visit_date"),
|
||||
"days_since_last_visit": member_data.get("days_since_last_visit"),
|
||||
},
|
||||
"customer_data": {
|
||||
"system_data": {
|
||||
"member_nickname": member_data.get("member_nickname", ""),
|
||||
},
|
||||
"notes": notes if notes else "暂无备注",
|
||||
},
|
||||
"reference": reference,
|
||||
}
|
||||
|
||||
if fetch_errors:
|
||||
system_content["_data_warnings"] = fetch_errors
|
||||
|
||||
# Token 预算控制
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
# 优先截断 service_history
|
||||
sh = system_content.get("service_history")
|
||||
if isinstance(sh, list) and len(sh) > 5:
|
||||
system_content["service_history"] = sh[:5]
|
||||
system_content["_truncated_service_history"] = f"服务记录已截断,原始共 {len(sh)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
records = system_content["task_assignment_basis"].get("consumption_records")
|
||||
if isinstance(records, list) and len(records) > 5:
|
||||
system_content["task_assignment_basis"]["consumption_records"] = records[:5]
|
||||
system_content["task_assignment_basis"]["_truncated"] = f"消费记录已截断,原始共 {len(records)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
n = system_content["customer_data"].get("notes")
|
||||
if isinstance(n, list) and len(n) > 10:
|
||||
system_content["customer_data"]["notes"] = n[:10]
|
||||
system_content["customer_data"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
|
||||
# 缓存不存在时在 user prompt 中标注
|
||||
no_history_hint = ""
|
||||
if not reference:
|
||||
@@ -75,7 +175,7 @@ def build_prompt(
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
|
||||
{"role": "system", "content": content_str},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
@@ -127,7 +227,7 @@ def _build_reference(
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
@@ -149,7 +249,7 @@ async def run(
|
||||
nickname = context.get("nickname", "")
|
||||
|
||||
# 1. 构建 Prompt
|
||||
messages = build_prompt(context, cache_svc)
|
||||
messages = await build_prompt(context, cache_svc)
|
||||
|
||||
# 2. 创建对话记录
|
||||
conversation_id = conv_svc.create_conversation(
|
||||
|
||||
@@ -10,27 +10,51 @@ app_id = "app5_tactics"
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.data_fetchers import (
|
||||
fetch_assistant_info,
|
||||
fetch_member_consumption_data,
|
||||
fetch_member_notes,
|
||||
fetch_service_history,
|
||||
)
|
||||
from app.ai.schemas import CacheTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ID = "app5_tactics"
|
||||
|
||||
# system message content 上限
|
||||
_MAX_SYSTEM_CONTENT_LEN = 8000
|
||||
|
||||
def build_prompt(
|
||||
|
||||
def _default_member_data() -> dict:
|
||||
"""数据获取失败时的默认空值。"""
|
||||
return {
|
||||
"member_nickname": "",
|
||||
"consumption_records": [],
|
||||
"member_cards": [],
|
||||
"card_balance_total": 0,
|
||||
"stored_value_balance_total": 0,
|
||||
"expected_visit_date": None,
|
||||
"days_since_last_visit": None,
|
||||
}
|
||||
|
||||
|
||||
async def build_prompt(
|
||||
context: dict,
|
||||
cache_svc: AICacheService | None = None,
|
||||
) -> list[dict]:
|
||||
"""构建 Prompt 消息列表。
|
||||
|
||||
P5-A 阶段:返回占位 Prompt,标注待细化字段。
|
||||
P5-B 阶段(P6-T4):补充 service_history、assistant_info(随 App4 同步)。
|
||||
复用 App4 的数据获取逻辑(并发获取助教信息、服务历史、消费数据、备注),
|
||||
额外从 context["app4_result"] 获取 task_suggestion。
|
||||
|
||||
Args:
|
||||
context: 包含 site_id, assistant_id, member_id, app4_result(dict)
|
||||
@@ -42,35 +66,117 @@ def build_prompt(
|
||||
site_id = context["site_id"]
|
||||
assistant_id = context["assistant_id"]
|
||||
member_id = context["member_id"]
|
||||
app4_result = context.get("app4_result", {})
|
||||
# App4 结果作为 task_suggestion,缺失时设为空对象
|
||||
task_suggestion = context.get("app4_result") or {}
|
||||
|
||||
# 并发获取 4 类数据,部分失败不阻断
|
||||
results = await asyncio.gather(
|
||||
fetch_assistant_info(site_id, assistant_id),
|
||||
fetch_service_history(site_id, assistant_id, member_id),
|
||||
fetch_member_consumption_data(site_id, member_id),
|
||||
fetch_member_notes(site_id, member_id),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# 降级处理
|
||||
fetch_errors: list[str] = []
|
||||
|
||||
if isinstance(results[0], Exception):
|
||||
logger.warning("App5 助教信息获取失败: %s", results[0])
|
||||
assistant_info = {}
|
||||
fetch_errors.append("助教信息获取失败")
|
||||
else:
|
||||
assistant_info = results[0]
|
||||
|
||||
if isinstance(results[1], Exception):
|
||||
logger.warning("App5 服务历史获取失败: %s", results[1])
|
||||
service_history: list = []
|
||||
fetch_errors.append("服务历史获取失败")
|
||||
else:
|
||||
service_history = results[1]
|
||||
|
||||
if isinstance(results[2], Exception):
|
||||
logger.warning("App5 消费数据获取失败: %s", results[2])
|
||||
member_data = _default_member_data()
|
||||
fetch_errors.append("消费数据获取失败")
|
||||
else:
|
||||
member_data = results[2]
|
||||
|
||||
if isinstance(results[3], Exception):
|
||||
logger.warning("App5 备注获取失败: %s", results[3])
|
||||
notes: list = []
|
||||
fetch_errors.append("备注获取失败")
|
||||
else:
|
||||
notes = results[3]
|
||||
|
||||
# 构建 reference:最近 2 套 App8 历史
|
||||
reference = _build_reference(site_id, member_id, cache_svc)
|
||||
|
||||
system_content = {
|
||||
"task": "基于关系分析和任务建议,生成沟通话术参考。",
|
||||
system_content: dict = {
|
||||
"task": (
|
||||
"基于关系分析和任务建议,生成沟通话术参考。"
|
||||
"输出必须严格遵循 output_format 中定义的 JSON 结构,"
|
||||
"每条话术必须包含 scenario(场景描述)和 script(话术内容)两个字段,"
|
||||
"禁止使用 content 或其他字段名替代。"
|
||||
),
|
||||
"app_id": APP_ID,
|
||||
"task_suggestion": app4_result,
|
||||
"task_suggestion": task_suggestion,
|
||||
"output_format": {
|
||||
"tactics": [
|
||||
{"scenario": "场景描述", "script": "话术内容"}
|
||||
]
|
||||
},
|
||||
# TODO: P6-T4 细化 - service_history、assistant_info(随 App4 同步)
|
||||
"data": {
|
||||
"service_history": "待 P6-T4 补充",
|
||||
"assistant_info": "待 P6-T4 补充",
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
"assistant_info": assistant_info if assistant_info else "⚠ 助教信息获取失败",
|
||||
"service_history": service_history if service_history else "暂无服务记录",
|
||||
"task_assignment_basis": {
|
||||
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
|
||||
"member_cards": member_data.get("member_cards", []),
|
||||
"card_balance_total": member_data.get("card_balance_total", 0),
|
||||
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
|
||||
"expected_visit_date": member_data.get("expected_visit_date"),
|
||||
"days_since_last_visit": member_data.get("days_since_last_visit"),
|
||||
},
|
||||
"customer_data": {
|
||||
"system_data": {
|
||||
"member_nickname": member_data.get("member_nickname", ""),
|
||||
},
|
||||
"notes": notes if notes else "暂无备注",
|
||||
},
|
||||
"reference": reference,
|
||||
}
|
||||
|
||||
if fetch_errors:
|
||||
system_content["_data_warnings"] = fetch_errors
|
||||
|
||||
# Token 预算控制
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
sh = system_content.get("service_history")
|
||||
if isinstance(sh, list) and len(sh) > 5:
|
||||
system_content["service_history"] = sh[:5]
|
||||
system_content["_truncated_service_history"] = f"服务记录已截断,原始共 {len(sh)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
records = system_content["task_assignment_basis"].get("consumption_records")
|
||||
if isinstance(records, list) and len(records) > 5:
|
||||
system_content["task_assignment_basis"]["consumption_records"] = records[:5]
|
||||
system_content["task_assignment_basis"]["_truncated"] = f"消费记录已截断,原始共 {len(records)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
n = system_content["customer_data"].get("notes")
|
||||
if isinstance(n, list) and len(n) > 10:
|
||||
system_content["customer_data"]["notes"] = n[:10]
|
||||
system_content["customer_data"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
|
||||
user_content = (
|
||||
f"请为助教 {assistant_id} 生成与会员 {member_id} 沟通的话术参考。"
|
||||
"返回 tactics 数组,每条包含 scenario 和 script 字段。"
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
|
||||
{"role": "system", "content": content_str},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
@@ -109,7 +215,7 @@ def _build_reference(
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
@@ -131,7 +237,7 @@ async def run(
|
||||
nickname = context.get("nickname", "")
|
||||
|
||||
# 1. 构建 Prompt
|
||||
messages = build_prompt(context, cache_svc)
|
||||
messages = await build_prompt(context, cache_svc)
|
||||
|
||||
# 2. 创建对话记录
|
||||
conversation_id = conv_svc.create_conversation(
|
||||
|
||||
@@ -13,27 +13,45 @@ app_id = "app6_note"
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.data_fetchers import fetch_member_consumption_data, fetch_member_notes
|
||||
from app.ai.schemas import CacheTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ID = "app6_note"
|
||||
|
||||
# system message content 上限
|
||||
_MAX_SYSTEM_CONTENT_LEN = 8000
|
||||
|
||||
def build_prompt(
|
||||
|
||||
def _default_member_data() -> dict:
|
||||
"""数据获取失败时的默认空值。"""
|
||||
return {
|
||||
"member_nickname": "",
|
||||
"consumption_records": [],
|
||||
"member_cards": [],
|
||||
"card_balance_total": 0,
|
||||
"stored_value_balance_total": 0,
|
||||
"expected_visit_date": None,
|
||||
"days_since_last_visit": None,
|
||||
}
|
||||
|
||||
|
||||
async def build_prompt(
|
||||
context: dict,
|
||||
cache_svc: AICacheService | None = None,
|
||||
) -> list[dict]:
|
||||
"""构建 Prompt 消息列表。
|
||||
|
||||
P5-A 阶段:返回占位 Prompt,标注待细化字段。
|
||||
P5-B 阶段(P9-T1):补充 consumption_data 等完整数据。
|
||||
并发获取消费数据和备注,失败时降级为空值。
|
||||
|
||||
Args:
|
||||
context: 包含 site_id, member_id, note_content, noted_by_name
|
||||
@@ -46,11 +64,47 @@ def build_prompt(
|
||||
member_id = context["member_id"]
|
||||
note_content = context.get("note_content", "")
|
||||
noted_by_name = context.get("noted_by_name", "")
|
||||
noted_by_created_at = context.get("noted_by_created_at", "")
|
||||
|
||||
# 并发获取消费数据和备注
|
||||
results = await asyncio.gather(
|
||||
fetch_member_consumption_data(site_id, member_id),
|
||||
fetch_member_notes(site_id, member_id),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
fetch_errors: list[str] = []
|
||||
|
||||
if isinstance(results[0], Exception):
|
||||
logger.warning("App6 消费数据获取失败: %s", results[0])
|
||||
member_data = _default_member_data()
|
||||
fetch_errors.append("消费数据获取失败")
|
||||
else:
|
||||
member_data = results[0]
|
||||
|
||||
if isinstance(results[1], Exception):
|
||||
logger.warning("App6 备注获取失败: %s", results[1])
|
||||
all_notes: list = []
|
||||
fetch_errors.append("备注获取失败")
|
||||
else:
|
||||
all_notes = results[1]
|
||||
|
||||
# 构建 reference:App3 线索 + 最近 2 套 App8 历史
|
||||
reference = _build_reference(site_id, member_id, cache_svc)
|
||||
|
||||
system_content = {
|
||||
# 将消费数据和备注注入 reference
|
||||
reference["member_nickname"] = member_data.get("member_nickname", "")
|
||||
reference["consumption_data"] = {
|
||||
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
|
||||
"member_cards": member_data.get("member_cards", []),
|
||||
"card_balance_total": member_data.get("card_balance_total", 0),
|
||||
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
|
||||
"expected_visit_date": member_data.get("expected_visit_date"),
|
||||
"days_since_last_visit": member_data.get("days_since_last_visit"),
|
||||
}
|
||||
reference["all_notes"] = all_notes if all_notes else []
|
||||
|
||||
system_content: dict = {
|
||||
"task": "分析备注内容,提取维客线索并评分。",
|
||||
"app_id": APP_ID,
|
||||
"rules": {
|
||||
@@ -73,15 +127,33 @@ def build_prompt(
|
||||
}
|
||||
],
|
||||
},
|
||||
"note_content": note_content,
|
||||
"noted_by_name": noted_by_name,
|
||||
# TODO: P9-T1 细化 - consumption_data 等客户消费数据
|
||||
"data": {
|
||||
"consumption_data": "待 P9-T1 补充",
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
"current_note": {
|
||||
"content": note_content,
|
||||
"recorded_by": noted_by_name,
|
||||
"created_at": noted_by_created_at,
|
||||
},
|
||||
"reference": reference,
|
||||
}
|
||||
|
||||
if fetch_errors:
|
||||
system_content["_data_warnings"] = fetch_errors
|
||||
|
||||
# Token 预算控制
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
records = system_content["reference"].get("consumption_data", {}).get("consumption_records")
|
||||
if isinstance(records, list) and len(records) > 5:
|
||||
system_content["reference"]["consumption_data"]["consumption_records"] = records[:5]
|
||||
system_content["reference"]["consumption_data"]["_truncated"] = f"消费记录已截断,原始共 {len(records)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
n = system_content["reference"].get("all_notes")
|
||||
if isinstance(n, list) and len(n) > 10:
|
||||
system_content["reference"]["all_notes"] = n[:10]
|
||||
system_content["reference"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
|
||||
user_content = (
|
||||
f"请分析以下备注内容,提取维客线索并评分。\n"
|
||||
f"备注提供人:{noted_by_name}\n"
|
||||
@@ -91,7 +163,7 @@ def build_prompt(
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
|
||||
{"role": "system", "content": content_str},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
@@ -143,7 +215,7 @@ def _build_reference(
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
@@ -164,7 +236,7 @@ async def run(
|
||||
nickname = context.get("nickname", "")
|
||||
|
||||
# 1. 构建 Prompt
|
||||
messages = build_prompt(context, cache_svc)
|
||||
messages = await build_prompt(context, cache_svc)
|
||||
|
||||
# 2. 创建对话记录
|
||||
conversation_id = conv_svc.create_conversation(
|
||||
|
||||
@@ -13,27 +13,45 @@ app_id = "app7_customer"
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.data_fetchers import fetch_member_consumption_data, fetch_member_notes
|
||||
from app.ai.schemas import CacheTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APP_ID = "app7_customer"
|
||||
|
||||
# system message content 上限
|
||||
_MAX_SYSTEM_CONTENT_LEN = 8000
|
||||
|
||||
def build_prompt(
|
||||
|
||||
def _default_member_data() -> dict:
|
||||
"""数据获取失败时的默认空值。"""
|
||||
return {
|
||||
"member_nickname": "",
|
||||
"consumption_records": [],
|
||||
"member_cards": [],
|
||||
"card_balance_total": 0,
|
||||
"stored_value_balance_total": 0,
|
||||
"expected_visit_date": None,
|
||||
"days_since_last_visit": None,
|
||||
}
|
||||
|
||||
|
||||
async def build_prompt(
|
||||
context: dict,
|
||||
cache_svc: AICacheService | None = None,
|
||||
) -> list[dict]:
|
||||
"""构建 Prompt 消息列表。
|
||||
|
||||
P5-A 阶段:返回占位 Prompt,标注待细化字段。
|
||||
P5-B 阶段(P9-T1):补充 objective_data 等完整数据。
|
||||
并发获取消费数据和备注,备注标注来源信息。
|
||||
|
||||
Args:
|
||||
context: 包含 site_id, member_id
|
||||
@@ -45,10 +63,46 @@ def build_prompt(
|
||||
site_id = context["site_id"]
|
||||
member_id = context["member_id"]
|
||||
|
||||
# 并发获取消费数据和备注
|
||||
results = await asyncio.gather(
|
||||
fetch_member_consumption_data(site_id, member_id),
|
||||
fetch_member_notes(site_id, member_id),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
fetch_errors: list[str] = []
|
||||
|
||||
if isinstance(results[0], Exception):
|
||||
logger.warning("App7 消费数据获取失败: %s", results[0])
|
||||
member_data = _default_member_data()
|
||||
fetch_errors.append("消费数据获取失败")
|
||||
else:
|
||||
member_data = results[0]
|
||||
|
||||
if isinstance(results[1], Exception):
|
||||
logger.warning("App7 备注获取失败: %s", results[1])
|
||||
notes_raw: list = []
|
||||
fetch_errors.append("备注获取失败")
|
||||
else:
|
||||
notes_raw = results[1]
|
||||
|
||||
# 备注标注来源信息
|
||||
if notes_raw:
|
||||
subjective_notes = []
|
||||
for note in notes_raw:
|
||||
recorded_by = note.get("recorded_by", "未知")
|
||||
annotated = dict(note)
|
||||
annotated["content"] = f"{note.get('content', '')}【来源:{recorded_by},请甄别信息真实性】"
|
||||
subjective_notes.append(annotated)
|
||||
else:
|
||||
subjective_notes = "该客户暂无主观备注信息"
|
||||
|
||||
member_nickname = member_data.get("member_nickname", "")
|
||||
|
||||
# 构建 reference:最新 + 最近 2 套 App8 历史
|
||||
reference = _build_reference(site_id, member_id, cache_svc)
|
||||
|
||||
system_content = {
|
||||
system_content: dict = {
|
||||
"task": "综合分析客户数据,生成运营策略建议。",
|
||||
"app_id": APP_ID,
|
||||
"rules": {
|
||||
@@ -62,13 +116,41 @@ def build_prompt(
|
||||
],
|
||||
"summary": "一句话总结",
|
||||
},
|
||||
# TODO: P9-T1 细化 - objective_data 等客户消费数据
|
||||
"data": {
|
||||
"objective_data": "待 P9-T1 补充",
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
"member_id": member_id,
|
||||
"member_nickname": member_nickname,
|
||||
"objective_data": {
|
||||
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
|
||||
"member_cards": member_data.get("member_cards", []),
|
||||
"card_balance_total": member_data.get("card_balance_total", 0),
|
||||
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
|
||||
"expected_visit_date": member_data.get("expected_visit_date"),
|
||||
"days_since_last_visit": member_data.get("days_since_last_visit"),
|
||||
},
|
||||
"subjective_data": {
|
||||
"notes": subjective_notes,
|
||||
},
|
||||
"reference": reference,
|
||||
}
|
||||
|
||||
if fetch_errors:
|
||||
system_content["_data_warnings"] = fetch_errors
|
||||
|
||||
# Token 预算控制
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
records = system_content["objective_data"].get("consumption_records")
|
||||
if isinstance(records, list) and len(records) > 5:
|
||||
system_content["objective_data"]["consumption_records"] = records[:5]
|
||||
system_content["objective_data"]["_truncated"] = f"消费记录已截断,原始共 {len(records)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
|
||||
n = system_content["subjective_data"].get("notes")
|
||||
if isinstance(n, list) and len(n) > 10:
|
||||
system_content["subjective_data"]["notes"] = n[:10]
|
||||
system_content["subjective_data"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)} 条"
|
||||
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
|
||||
|
||||
user_content = (
|
||||
f"请综合分析会员 {member_id} 的客户数据,生成运营策略建议。"
|
||||
"返回 strategies 数组(每条含 title 和 content)和 summary 字段。"
|
||||
@@ -76,7 +158,7 @@ def build_prompt(
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
|
||||
{"role": "system", "content": content_str},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
@@ -128,7 +210,7 @@ def _build_reference(
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
@@ -149,7 +231,7 @@ async def run(
|
||||
nickname = context.get("nickname", "")
|
||||
|
||||
# 1. 构建 Prompt
|
||||
messages = build_prompt(context, cache_svc)
|
||||
messages = await build_prompt(context, cache_svc)
|
||||
|
||||
# 2. 创建对话记录
|
||||
conversation_id = conv_svc.create_conversation(
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
|
||||
from app.ai.bailian_client import BailianClient
|
||||
from app.ai.dashscope_client import DashScopeClient
|
||||
from app.ai.cache_service import AICacheService
|
||||
from app.ai.conversation_service import ConversationService
|
||||
from app.ai.prompts.app8_consolidation_prompt import build_prompt
|
||||
@@ -120,7 +120,7 @@ def _determine_source(providers: str) -> str:
|
||||
|
||||
async def run(
|
||||
context: dict,
|
||||
bailian: BailianClient,
|
||||
client: DashScopeClient,
|
||||
cache_svc: AICacheService,
|
||||
conv_svc: ConversationService,
|
||||
) -> dict:
|
||||
|
||||
@@ -1,273 +0,0 @@
|
||||
"""百炼 API 统一封装层。
|
||||
|
||||
使用 openai Python SDK(百炼兼容 OpenAI 协议),提供流式和非流式两种调用模式。
|
||||
所有 AI 应用通过此客户端统一调用阿里云通义千问。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import openai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── 异常类 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class BailianApiError(Exception):
|
||||
"""百炼 API 调用失败(重试耗尽后)。"""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class BailianJsonParseError(Exception):
|
||||
"""百炼 API 返回的 JSON 解析失败。"""
|
||||
|
||||
def __init__(self, message: str, raw_content: str = ""):
|
||||
super().__init__(message)
|
||||
self.raw_content = raw_content
|
||||
|
||||
|
||||
class BailianAuthError(BailianApiError):
|
||||
"""百炼 API Key 无效(HTTP 401)。"""
|
||||
|
||||
def __init__(self, message: str = "API Key 无效或已过期"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
# ── 客户端 ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class BailianClient:
|
||||
"""百炼 API 统一封装层。
|
||||
|
||||
使用 openai.AsyncOpenAI 客户端,base_url 指向百炼端点。
|
||||
提供流式(chat_stream)和非流式(chat_json)两种调用模式。
|
||||
"""
|
||||
|
||||
# 重试配置
|
||||
MAX_RETRIES = 3
|
||||
BASE_INTERVAL = 1 # 秒
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, model: str):
|
||||
"""初始化百炼客户端。
|
||||
|
||||
Args:
|
||||
api_key: 百炼 API Key(环境变量 BAILIAN_API_KEY)
|
||||
base_url: 百炼 API 端点(环境变量 BAILIAN_BASE_URL)
|
||||
model: 模型标识,如 qwen-plus(环境变量 BAILIAN_MODEL)
|
||||
"""
|
||||
self.model = model
|
||||
self._client = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式调用,逐 chunk yield 文本。用于应用 1 SSE。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数,默认 0.7
|
||||
max_tokens: 最大 token 数,默认 2000
|
||||
|
||||
Yields:
|
||||
文本 chunk
|
||||
"""
|
||||
messages = self._inject_current_time(messages)
|
||||
response = await self._call_with_retry(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
async def chat_json(
|
||||
self,
|
||||
messages: list[dict],
|
||||
*,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4000,
|
||||
) -> tuple[dict, int]:
|
||||
"""非流式调用,返回解析后的 JSON dict 和 tokens_used。
|
||||
|
||||
用于应用 2-8 的结构化输出。使用 response_format={"type": "json_object"}
|
||||
确保返回合法 JSON。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数,默认 0.3(结构化输出用低温度)
|
||||
max_tokens: 最大 token 数,默认 4000
|
||||
|
||||
Returns:
|
||||
(parsed_json_dict, tokens_used) 元组
|
||||
|
||||
Raises:
|
||||
BailianJsonParseError: 响应内容无法解析为 JSON
|
||||
BailianApiError: API 调用失败(重试耗尽后)
|
||||
"""
|
||||
messages = self._inject_current_time(messages)
|
||||
response = await self._call_with_retry(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
tokens_used = response.usage.total_tokens if response.usage else 0
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw_content)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("百炼 API 返回非法 JSON: %s", raw_content[:500])
|
||||
raise BailianJsonParseError(
|
||||
f"JSON 解析失败: {e}",
|
||||
raw_content=raw_content,
|
||||
) from e
|
||||
|
||||
return parsed, tokens_used
|
||||
|
||||
def _inject_current_time(self, messages: list[dict]) -> list[dict]:
|
||||
"""纯函数:在首条消息的 content(JSON 字符串)中注入 current_time 字段。
|
||||
|
||||
- 深拷贝输入,不修改原始 messages
|
||||
- 首条消息 content 尝试解析为 JSON,注入 current_time
|
||||
- 如果首条消息 content 不是 JSON,则包装为 JSON
|
||||
- 其余消息不变
|
||||
- current_time 格式:ISO 8601 精确到秒,如 2026-03-08T14:30:00
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
|
||||
Returns:
|
||||
注入 current_time 后的新消息列表
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
result = copy.deepcopy(messages)
|
||||
first = result[0]
|
||||
content = first.get("content", "")
|
||||
now_str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||||
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
parsed["current_time"] = now_str
|
||||
else:
|
||||
# content 是合法 JSON 但不是 dict(如数组、字符串),包装为 dict
|
||||
parsed = {"original_content": parsed, "current_time": now_str}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# content 不是 JSON,包装为 dict
|
||||
parsed = {"content": content, "current_time": now_str}
|
||||
|
||||
first["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
return result
|
||||
|
||||
async def _call_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""带指数退避的重试封装。
|
||||
|
||||
重试策略:
|
||||
- 最多重试 MAX_RETRIES 次(默认 3 次)
|
||||
- 间隔:BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
|
||||
- HTTP 4xx:不重试,直接抛出(401 → BailianAuthError)
|
||||
- HTTP 5xx / 超时:重试
|
||||
|
||||
Args:
|
||||
**kwargs: 传递给 openai client 的参数
|
||||
|
||||
Returns:
|
||||
API 响应对象
|
||||
|
||||
Raises:
|
||||
BailianAuthError: API Key 无效(HTTP 401)
|
||||
BailianApiError: API 调用失败(重试耗尽后)
|
||||
"""
|
||||
is_stream = kwargs.get("stream", False)
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
if is_stream:
|
||||
# 流式调用:返回 async iterator
|
||||
return await self._client.chat.completions.create(**kwargs)
|
||||
else:
|
||||
return await self._client.chat.completions.create(**kwargs)
|
||||
|
||||
except openai.AuthenticationError as e:
|
||||
# 401:API Key 无效,不重试
|
||||
logger.error("百炼 API 认证失败: %s", e)
|
||||
raise BailianAuthError(str(e)) from e
|
||||
|
||||
except openai.BadRequestError as e:
|
||||
# 400:请求参数错误,不重试
|
||||
logger.error("百炼 API 请求参数错误: %s", e)
|
||||
raise BailianApiError(str(e), status_code=400) from e
|
||||
|
||||
except openai.RateLimitError as e:
|
||||
# 429:限流,不重试(属于 4xx)
|
||||
logger.error("百炼 API 限流: %s", e)
|
||||
raise BailianApiError(str(e), status_code=429) from e
|
||||
|
||||
except openai.PermissionDeniedError as e:
|
||||
# 403:权限不足,不重试
|
||||
logger.error("百炼 API 权限不足: %s", e)
|
||||
raise BailianApiError(str(e), status_code=403) from e
|
||||
|
||||
except openai.NotFoundError as e:
|
||||
# 404:资源不存在,不重试
|
||||
logger.error("百炼 API 资源不存在: %s", e)
|
||||
raise BailianApiError(str(e), status_code=404) from e
|
||||
|
||||
except openai.UnprocessableEntityError as e:
|
||||
# 422:不可处理,不重试
|
||||
logger.error("百炼 API 不可处理的请求: %s", e)
|
||||
raise BailianApiError(str(e), status_code=422) from e
|
||||
|
||||
except (openai.InternalServerError, openai.APIConnectionError, openai.APITimeoutError) as e:
|
||||
# 5xx / 超时 / 连接错误:重试
|
||||
last_error = e
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
wait_time = self.BASE_INTERVAL * (2 ** attempt)
|
||||
logger.warning(
|
||||
"百炼 API 调用失败(第 %d/%d 次),%ds 后重试: %s",
|
||||
attempt + 1,
|
||||
self.MAX_RETRIES,
|
||||
wait_time,
|
||||
e,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
"百炼 API 调用失败,已达最大重试次数 %d: %s",
|
||||
self.MAX_RETRIES,
|
||||
e,
|
||||
)
|
||||
|
||||
# 重试耗尽
|
||||
status_code = getattr(last_error, "status_code", None)
|
||||
raise BailianApiError(
|
||||
f"百炼 API 调用失败(重试 {self.MAX_RETRIES} 次后): {last_error}",
|
||||
status_code=status_code,
|
||||
) from last_error
|
||||
101
apps/backend/app/ai/budget_tracker.py
Normal file
101
apps/backend/app/ai/budget_tracker.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Token 预算追踪器 — 从 ai_run_logs 聚合日/月 token 消耗。
|
||||
|
||||
每次 AI 调用前检查预算,超限时拒绝请求。
|
||||
日预算默认 100,000 tokens,月预算默认 2,000,000 tokens。
|
||||
|
||||
聚合数据通过构造函数注入的 callable 获取(解耦 AIRunLogService),
|
||||
callable 签名:() -> int,分别返回当日/当月已消耗 token 数。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Protocol
|
||||
|
||||
|
||||
class UsageProvider(Protocol):
|
||||
"""Token 用量数据提供者协议。"""
|
||||
|
||||
def get_daily_usage(self) -> int:
|
||||
"""返回当日已消耗 token 数。"""
|
||||
...
|
||||
|
||||
def get_monthly_usage(self) -> int:
|
||||
"""返回当月已消耗 token 数。"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetStatus:
|
||||
"""预算检查结果。"""
|
||||
|
||||
allowed: bool
|
||||
daily_used: int
|
||||
monthly_used: int
|
||||
reason: str | None = None # "daily_exceeded" / "monthly_exceeded" / None
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Token 预算追踪器,从 ai_run_logs 聚合。
|
||||
|
||||
支持两种注入方式:
|
||||
1. 传入 UsageProvider 实例(如 AIRunLogService)
|
||||
2. 传入两个 callable:get_daily_usage / get_monthly_usage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
daily_limit: int = 100_000,
|
||||
monthly_limit: int = 2_000_000,
|
||||
*,
|
||||
get_daily_usage: Callable[[], int] | None = None,
|
||||
get_monthly_usage: Callable[[], int] | None = None,
|
||||
usage_provider: UsageProvider | None = None,
|
||||
) -> None:
|
||||
self.daily_limit = daily_limit
|
||||
self.monthly_limit = monthly_limit
|
||||
|
||||
# 优先使用 usage_provider,其次使用独立 callable
|
||||
if usage_provider is not None:
|
||||
self._get_daily_usage = usage_provider.get_daily_usage
|
||||
self._get_monthly_usage = usage_provider.get_monthly_usage
|
||||
elif get_daily_usage is not None and get_monthly_usage is not None:
|
||||
self._get_daily_usage = get_daily_usage
|
||||
self._get_monthly_usage = get_monthly_usage
|
||||
else:
|
||||
raise ValueError(
|
||||
"必须提供 usage_provider 或同时提供 "
|
||||
"get_daily_usage 和 get_monthly_usage callable"
|
||||
)
|
||||
|
||||
def check_budget(self) -> BudgetStatus:
|
||||
"""检查当前预算状态。
|
||||
|
||||
先检查日预算,再检查月预算。
|
||||
任一超限即返回 allowed=False 并附带原因。
|
||||
"""
|
||||
daily_used = self._get_daily_usage()
|
||||
monthly_used = self._get_monthly_usage()
|
||||
|
||||
if daily_used >= self.daily_limit:
|
||||
return BudgetStatus(
|
||||
allowed=False,
|
||||
daily_used=daily_used,
|
||||
monthly_used=monthly_used,
|
||||
reason="daily_exceeded",
|
||||
)
|
||||
|
||||
if monthly_used >= self.monthly_limit:
|
||||
return BudgetStatus(
|
||||
allowed=False,
|
||||
daily_used=daily_used,
|
||||
monthly_used=monthly_used,
|
||||
reason="monthly_exceeded",
|
||||
)
|
||||
|
||||
return BudgetStatus(
|
||||
allowed=True,
|
||||
daily_used=daily_used,
|
||||
monthly_used=monthly_used,
|
||||
reason=None,
|
||||
)
|
||||
@@ -3,18 +3,38 @@ AI 缓存读写服务。
|
||||
|
||||
负责 biz.ai_cache 表的 CRUD 和保留策略管理。
|
||||
所有查询和写入操作强制 site_id 隔离。
|
||||
|
||||
P14 改造:
|
||||
- 新增 status 字段处理(valid/expired/invalidated/generating)
|
||||
- 查询仅返回 status='valid' 且未过期的记录
|
||||
- 按 App 类型设置过期时间
|
||||
- 每 App 保留最新 20,000 条
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.database import get_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期策略:cache_type → 过期天数(0 表示当日 23:59:59)
|
||||
CACHE_EXPIRY_DAYS: dict[str, int] = {
|
||||
"app2_finance": 0, # 当日 23:59:59
|
||||
"app3_clue": 7,
|
||||
"app4_analysis": 7,
|
||||
"app5_tactics": 7,
|
||||
"app6_note_analysis": 30,
|
||||
"app7_customer_analysis": 7,
|
||||
"app8_clue_consolidated": 7,
|
||||
}
|
||||
|
||||
# 每 App 保留上限
|
||||
CACHE_MAX_PER_APP = 20_000
|
||||
|
||||
|
||||
class AICacheService:
|
||||
"""AI 缓存读写服务。"""
|
||||
@@ -25,9 +45,9 @@ class AICacheService:
|
||||
site_id: int,
|
||||
target_id: str,
|
||||
) -> dict | None:
|
||||
"""查询最新缓存记录。
|
||||
"""查询最新有效缓存记录。
|
||||
|
||||
按 (cache_type, site_id, target_id) 查询 created_at 最新的一条。
|
||||
仅返回 status='valid' 且未过期的记录。
|
||||
无记录时返回 None。
|
||||
"""
|
||||
conn = get_connection()
|
||||
@@ -37,9 +57,11 @@ class AICacheService:
|
||||
"""
|
||||
SELECT id, cache_type, site_id, target_id,
|
||||
result_json, score, triggered_by,
|
||||
created_at, expires_at
|
||||
created_at, expires_at, status
|
||||
FROM biz.ai_cache
|
||||
WHERE cache_type = %s AND site_id = %s AND target_id = %s
|
||||
AND (status = 'valid' OR status IS NULL)
|
||||
AND (expires_at IS NULL OR expires_at > now())
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
@@ -95,7 +117,15 @@ class AICacheService:
|
||||
score: int | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> int:
|
||||
"""写入缓存记录,返回 id。写入后清理超限记录。"""
|
||||
"""写入缓存记录,返回 id。
|
||||
|
||||
自动设置 status='valid' 和按 App 类型计算 expires_at。
|
||||
写入后清理超限记录(每 App 保留 20,000 条)。
|
||||
"""
|
||||
# 自动计算过期时间(如果未显式指定)
|
||||
if expires_at is None:
|
||||
expires_at = self._calc_expires_at(cache_type)
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
@@ -103,8 +133,8 @@ class AICacheService:
|
||||
"""
|
||||
INSERT INTO biz.ai_cache
|
||||
(cache_type, site_id, target_id, result_json,
|
||||
triggered_by, score, expires_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
triggered_by, score, expires_at, status)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, 'valid')
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
@@ -126,7 +156,7 @@ class AICacheService:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# 写入成功后清理超限记录(失败仅记录警告,不影响写入结果)
|
||||
# 写入成功后清理超限记录
|
||||
try:
|
||||
deleted = self._cleanup_excess(cache_type, site_id, target_id)
|
||||
if deleted > 0:
|
||||
@@ -143,12 +173,89 @@ class AICacheService:
|
||||
|
||||
return cache_id
|
||||
|
||||
def set_generating(
|
||||
self,
|
||||
cache_type: str,
|
||||
site_id: int,
|
||||
target_id: str,
|
||||
triggered_by: str | None = None,
|
||||
) -> int:
|
||||
"""写入 generating 状态占位记录,返回 id。完成后调用 finalize_cache 更新。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO biz.ai_cache
|
||||
(cache_type, site_id, target_id, result_json, status, triggered_by)
|
||||
VALUES (%s, %s, %s, '{}', 'generating', %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(cache_type, site_id, target_id, triggered_by),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return row[0]
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def finalize_cache(
|
||||
self,
|
||||
cache_id: int,
|
||||
result_json: dict,
|
||||
score: int | None = None,
|
||||
cache_type: str | None = None,
|
||||
) -> None:
|
||||
"""将 generating 记录更新为 valid,填充结果和过期时间。"""
|
||||
expires_at = self._calc_expires_at(cache_type) if cache_type else None
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE biz.ai_cache
|
||||
SET result_json = %s, score = %s, status = 'valid', expires_at = %s
|
||||
WHERE id = %s AND status = 'generating'
|
||||
""",
|
||||
(
|
||||
json.dumps(result_json, ensure_ascii=False),
|
||||
score,
|
||||
expires_at,
|
||||
cache_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@staticmethod
|
||||
def _calc_expires_at(cache_type: str | None) -> datetime | None:
|
||||
"""根据 cache_type 计算过期时间。未知类型返回 None。"""
|
||||
if cache_type is None:
|
||||
return None
|
||||
days = CACHE_EXPIRY_DAYS.get(cache_type)
|
||||
if days is None:
|
||||
return None
|
||||
now = datetime.now(timezone.utc)
|
||||
if days == 0:
|
||||
# 当日 23:59:59(UTC+8)
|
||||
local_now = now + timedelta(hours=8)
|
||||
end_of_day = local_now.replace(hour=23, minute=59, second=59, microsecond=0)
|
||||
return end_of_day - timedelta(hours=8) # 转回 UTC
|
||||
return now + timedelta(days=days)
|
||||
|
||||
def _cleanup_excess(
|
||||
self,
|
||||
cache_type: str,
|
||||
site_id: int,
|
||||
target_id: str,
|
||||
max_count: int = 500,
|
||||
max_count: int = CACHE_MAX_PER_APP,
|
||||
) -> int:
|
||||
"""清理超限记录,保留最近 max_count 条,返回删除数量。"""
|
||||
conn = get_connection()
|
||||
|
||||
116
apps/backend/app/ai/circuit_breaker.py
Normal file
116
apps/backend/app/ai/circuit_breaker.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""熔断器 — 按 app_id 独立的断路保护。
|
||||
|
||||
状态机:CLOSED → OPEN(连续失败达阈值)→ HALF_OPEN(超时后探测)→ CLOSED/OPEN。
|
||||
内存实现,单实例部署,不依赖外部存储。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
class CircuitState(enum.Enum):
|
||||
"""熔断器状态。"""
|
||||
|
||||
CLOSED = "closed" # 正常放行
|
||||
OPEN = "open" # 熔断中,拒绝请求
|
||||
HALF_OPEN = "half_open" # 探测中,放行单个请求
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BreakerState:
|
||||
"""单个 app_id 的熔断内部状态。"""
|
||||
|
||||
state: CircuitState = CircuitState.CLOSED
|
||||
failure_count: int = 0
|
||||
last_failure_time: float = 0.0
|
||||
last_state_change: float = field(default_factory=time.monotonic)
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""按 app_id 独立的熔断器。
|
||||
|
||||
- check():检查当前状态,OPEN 且超时自动转 HALF_OPEN
|
||||
- record_success():HALF_OPEN→CLOSED;CLOSED 重置失败计数
|
||||
- record_failure():连续达阈值→OPEN;HALF_OPEN 失败→重新 OPEN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout: int = 60,
|
||||
) -> None:
|
||||
self._failure_threshold = failure_threshold
|
||||
self._recovery_timeout = recovery_timeout
|
||||
self._breakers: dict[str, _BreakerState] = {}
|
||||
|
||||
def _get_state(self, app_id: str) -> _BreakerState:
|
||||
"""获取或初始化指定 app_id 的状态。"""
|
||||
if app_id not in self._breakers:
|
||||
self._breakers[app_id] = _BreakerState()
|
||||
return self._breakers[app_id]
|
||||
|
||||
def check(self, app_id: str) -> CircuitState:
|
||||
"""检查当前熔断状态。
|
||||
|
||||
- CLOSED / HALF_OPEN:允许通过,返回对应状态
|
||||
- OPEN 且未超时:返回 OPEN(拒绝)
|
||||
- OPEN 且已超时:自动转 HALF_OPEN,返回 HALF_OPEN(允许探测)
|
||||
"""
|
||||
breaker = self._get_state(app_id)
|
||||
|
||||
if breaker.state == CircuitState.CLOSED:
|
||||
return CircuitState.CLOSED
|
||||
|
||||
if breaker.state == CircuitState.HALF_OPEN:
|
||||
return CircuitState.HALF_OPEN
|
||||
|
||||
# OPEN 状态:检查是否超过恢复超时
|
||||
elapsed = time.monotonic() - breaker.last_failure_time
|
||||
if elapsed >= self._recovery_timeout:
|
||||
# 超时,转为 HALF_OPEN 探测
|
||||
breaker.state = CircuitState.HALF_OPEN
|
||||
breaker.last_state_change = time.monotonic()
|
||||
return CircuitState.HALF_OPEN
|
||||
|
||||
return CircuitState.OPEN
|
||||
|
||||
def record_success(self, app_id: str) -> None:
|
||||
"""记录调用成功。
|
||||
|
||||
- HALF_OPEN→CLOSED(探测成功,恢复正常)
|
||||
- CLOSED 下重置失败计数
|
||||
"""
|
||||
breaker = self._get_state(app_id)
|
||||
|
||||
if breaker.state == CircuitState.HALF_OPEN:
|
||||
breaker.state = CircuitState.CLOSED
|
||||
breaker.failure_count = 0
|
||||
breaker.last_state_change = time.monotonic()
|
||||
elif breaker.state == CircuitState.CLOSED:
|
||||
# CLOSED 状态下成功重置失败计数
|
||||
breaker.failure_count = 0
|
||||
|
||||
def record_failure(self, app_id: str) -> None:
|
||||
"""记录调用失败。
|
||||
|
||||
- CLOSED:累加失败计数,达阈值→OPEN
|
||||
- HALF_OPEN:探测失败→重新 OPEN
|
||||
"""
|
||||
breaker = self._get_state(app_id)
|
||||
now = time.monotonic()
|
||||
|
||||
if breaker.state == CircuitState.HALF_OPEN:
|
||||
# 探测失败,重新熔断
|
||||
breaker.state = CircuitState.OPEN
|
||||
breaker.failure_count = self._failure_threshold
|
||||
breaker.last_failure_time = now
|
||||
breaker.last_state_change = now
|
||||
elif breaker.state == CircuitState.CLOSED:
|
||||
breaker.failure_count += 1
|
||||
breaker.last_failure_time = now
|
||||
if breaker.failure_count >= self._failure_threshold:
|
||||
breaker.state = CircuitState.OPEN
|
||||
breaker.last_state_change = now
|
||||
68
apps/backend/app/ai/config.py
Normal file
68
apps/backend/app/ai/config.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""AI 模块配置 — 从环境变量加载 DashScope 相关参数。
|
||||
|
||||
所有 DASHSCOPE_* 环境变量和 INTERNAL_API_TOKEN 统一在此管理,
|
||||
启动时通过 from_env() 校验必需变量,缺失立即报错。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AIConfig:
|
||||
"""AI 模块配置,从环境变量加载。不可变(frozen)。"""
|
||||
|
||||
api_key: str # DASHSCOPE_API_KEY
|
||||
workspace_id: str | None # DASHSCOPE_WORKSPACE_ID(可选)
|
||||
app_id_1_chat: str # DASHSCOPE_APP_ID_1_CHAT
|
||||
app_id_2_finance: str # DASHSCOPE_APP_ID_2_FINANCE
|
||||
app_id_3_clue: str # DASHSCOPE_APP_ID_3_CLUE
|
||||
app_id_4_analysis: str # DASHSCOPE_APP_ID_4_ANALYSIS
|
||||
app_id_5_tactics: str # DASHSCOPE_APP_ID_5_TACTICS
|
||||
app_id_6_note: str # DASHSCOPE_APP_ID_6_NOTE
|
||||
app_id_7_customer: str # DASHSCOPE_APP_ID_7_CUSTOMER
|
||||
app_id_8_consolidate: str # DASHSCOPE_APP_ID_8_CONSOLIDATE
|
||||
internal_api_token: str # INTERNAL_API_TOKEN
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> AIConfig:
|
||||
"""从环境变量加载配置。
|
||||
|
||||
必需变量缺失时立即抛出 ValueError,禁止静默回退空字符串。
|
||||
可选变量(DASHSCOPE_WORKSPACE_ID)缺失时为 None。
|
||||
"""
|
||||
required_mapping: dict[str, str] = {
|
||||
"DASHSCOPE_API_KEY": "api_key",
|
||||
"DASHSCOPE_APP_ID_1_CHAT": "app_id_1_chat",
|
||||
"DASHSCOPE_APP_ID_2_FINANCE": "app_id_2_finance",
|
||||
"DASHSCOPE_APP_ID_3_CLUE": "app_id_3_clue",
|
||||
"DASHSCOPE_APP_ID_4_ANALYSIS": "app_id_4_analysis",
|
||||
"DASHSCOPE_APP_ID_5_TACTICS": "app_id_5_tactics",
|
||||
"DASHSCOPE_APP_ID_6_NOTE": "app_id_6_note",
|
||||
"DASHSCOPE_APP_ID_7_CUSTOMER": "app_id_7_customer",
|
||||
"DASHSCOPE_APP_ID_8_CONSOLIDATE": "app_id_8_consolidate",
|
||||
"INTERNAL_API_TOKEN": "internal_api_token",
|
||||
}
|
||||
|
||||
# 收集所有缺失的必需变量,一次性报错
|
||||
missing: list[str] = []
|
||||
values: dict[str, str] = {}
|
||||
|
||||
for env_name, field_name in required_mapping.items():
|
||||
val = os.environ.get(env_name)
|
||||
if not val: # None 或空字符串均视为缺失
|
||||
missing.append(env_name)
|
||||
else:
|
||||
values[field_name] = val
|
||||
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"AI 配置缺失必需环境变量: {', '.join(missing)}"
|
||||
)
|
||||
|
||||
# 可选变量
|
||||
workspace_id = os.environ.get("DASHSCOPE_WORKSPACE_ID") or None
|
||||
|
||||
return cls(workspace_id=workspace_id, **values)
|
||||
@@ -27,6 +27,7 @@ class ConversationService:
|
||||
site_id: int,
|
||||
source_page: str | None = None,
|
||||
source_context: dict | None = None,
|
||||
title: str | None = None,
|
||||
) -> int:
|
||||
"""创建对话记录,返回 conversation_id。
|
||||
|
||||
@@ -38,8 +39,8 @@ class ConversationService:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO biz.ai_conversations
|
||||
(user_id, nickname, app_id, site_id, source_page, source_context)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
(user_id, nickname, app_id, site_id, source_page, source_context, title)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(
|
||||
@@ -49,6 +50,7 @@ class ConversationService:
|
||||
site_id,
|
||||
source_page,
|
||||
json.dumps(source_context, ensure_ascii=False) if source_context else None,
|
||||
title,
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
@@ -89,6 +91,22 @@ class ConversationService:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_title(self, conversation_id: int, title: str) -> None:
|
||||
"""更新对话标题。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"UPDATE biz.ai_conversations SET title = %s WHERE id = %s",
|
||||
(title, conversation_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_conversations(
|
||||
self,
|
||||
user_id: int | str,
|
||||
@@ -104,7 +122,7 @@ class ConversationService:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, user_id, nickname, app_id, site_id,
|
||||
source_page, source_context, created_at
|
||||
source_page, source_context, title, created_at
|
||||
FROM biz.ai_conversations
|
||||
WHERE user_id = %s AND site_id = %s
|
||||
ORDER BY created_at DESC
|
||||
|
||||
318
apps/backend/app/ai/dashscope_client.py
Normal file
318
apps/backend/app/ai/dashscope_client.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""DashScope Application API 统一封装层。
|
||||
|
||||
使用 dashscope.Application.call() 调用百炼智能体应用,
|
||||
替代原 openai SDK 的通用模型 API。
|
||||
|
||||
- call_app_stream(): App1 流式调用,asyncio.Queue 桥接 async generator
|
||||
- call_app(): App2~8 单轮调用,asyncio.to_thread() 包装
|
||||
- _call_with_retry(): 指数退避重试(1s→2s→4s)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Callable
|
||||
|
||||
import dashscope
|
||||
from dashscope import Application
|
||||
|
||||
from app.ai.exceptions import (
|
||||
DashScopeApiError,
|
||||
DashScopeAuthError,
|
||||
DashScopeJsonParseError,
|
||||
DashScopeTimeoutError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DashScopeClient:
|
||||
"""DashScope Application API 统一封装层。
|
||||
|
||||
通过 app_id 调用百炼控制台配置的智能体应用,
|
||||
充分利用云端 System Prompt 和 MCP 工具。
|
||||
"""
|
||||
|
||||
MAX_RETRIES = 3
|
||||
BASE_INTERVAL = 1 # 秒
|
||||
|
||||
def __init__(self, api_key: str, workspace_id: str | None = None):
|
||||
"""初始化。dashscope 通过全局变量设置密钥。
|
||||
|
||||
Args:
|
||||
api_key: DashScope API Key
|
||||
workspace_id: 百炼工作空间 ID(可选)
|
||||
"""
|
||||
dashscope.api_key = api_key
|
||||
self._workspace_id = workspace_id
|
||||
|
||||
async def call_app_stream(
|
||||
self,
|
||||
app_id: str,
|
||||
prompt: str,
|
||||
session_id: str | None = None,
|
||||
biz_params: dict | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""App1 流式调用。
|
||||
|
||||
在线程中消费同步迭代器,通过 asyncio.Queue 桥接到 async generator。
|
||||
错误通过 queue 传递给调用方。
|
||||
|
||||
Args:
|
||||
app_id: 百炼应用 ID
|
||||
prompt: 用户输入
|
||||
session_id: 百炼 session_id(多轮对话)
|
||||
biz_params: 业务参数(如 user_prompt_params)
|
||||
|
||||
Yields:
|
||||
文本 chunk
|
||||
"""
|
||||
queue: asyncio.Queue[str | BaseException | None] = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _consume_in_thread() -> None:
|
||||
"""在线程中消费同步迭代器,逐 chunk 放入 queue。"""
|
||||
try:
|
||||
call_kwargs: dict[str, Any] = {
|
||||
"app_id": app_id,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"incremental_output": True,
|
||||
}
|
||||
if session_id is not None:
|
||||
call_kwargs["session_id"] = session_id
|
||||
if biz_params is not None:
|
||||
call_kwargs["biz_params"] = biz_params
|
||||
if self._workspace_id is not None:
|
||||
call_kwargs["workspace"] = self._workspace_id
|
||||
|
||||
response = Application.call(**call_kwargs)
|
||||
for chunk in response:
|
||||
if chunk.status_code == 200:
|
||||
text = chunk.output.get("text", "")
|
||||
if text:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put(text), loop
|
||||
)
|
||||
else:
|
||||
# 非 200 状态码,构造异常传递给调用方
|
||||
status = chunk.status_code
|
||||
msg = getattr(chunk, "message", "") or f"状态码 {status}"
|
||||
if status == 401:
|
||||
err = DashScopeAuthError(msg)
|
||||
else:
|
||||
err = DashScopeApiError(msg, status_code=status)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put(err), loop
|
||||
)
|
||||
return
|
||||
# 正常结束信号
|
||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||
except Exception as exc:
|
||||
# 线程内未预期异常,传递给调用方
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put(exc), loop
|
||||
)
|
||||
|
||||
loop.run_in_executor(None, _consume_in_thread)
|
||||
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
async def call_app(
|
||||
self,
|
||||
app_id: str,
|
||||
prompt: str,
|
||||
session_id: str | None = None,
|
||||
biz_params: dict | None = None,
|
||||
) -> tuple[dict, int, str | None]:
|
||||
"""App2~8 单轮调用。
|
||||
|
||||
通过 asyncio.to_thread() 包装同步 Application.call(),
|
||||
解析 response.output.text 获取 JSON 内容。
|
||||
非合法 JSON 触发重试(最多 3 次),不做本地修复。
|
||||
|
||||
Args:
|
||||
app_id: 百炼应用 ID
|
||||
prompt: 后端拼好的完整数据 JSON 字符串
|
||||
session_id: 百炼 session_id(可选)
|
||||
biz_params: 业务参数(可选)
|
||||
|
||||
Returns:
|
||||
(parsed_json, tokens_used, new_session_id) 元组
|
||||
|
||||
Raises:
|
||||
DashScopeApiError: API 调用失败(重试耗尽)
|
||||
DashScopeJsonParseError: JSON 解析失败(重试耗尽)
|
||||
"""
|
||||
call_kwargs: dict[str, Any] = {
|
||||
"app_id": app_id,
|
||||
"prompt": prompt,
|
||||
}
|
||||
if session_id is not None:
|
||||
call_kwargs["session_id"] = session_id
|
||||
if biz_params is not None:
|
||||
call_kwargs["biz_params"] = biz_params
|
||||
if self._workspace_id is not None:
|
||||
call_kwargs["workspace"] = self._workspace_id
|
||||
|
||||
# 非合法 JSON 纯重试,最多 MAX_RETRIES 次
|
||||
last_json_error: DashScopeJsonParseError | None = None
|
||||
for json_attempt in range(self.MAX_RETRIES):
|
||||
response = await self._call_with_retry(
|
||||
Application.call, **call_kwargs
|
||||
)
|
||||
|
||||
# 提取 output.text
|
||||
raw_text: str = ""
|
||||
if hasattr(response, "output"):
|
||||
output = response.output
|
||||
if isinstance(output, dict):
|
||||
raw_text = output.get("text", "")
|
||||
elif hasattr(output, "text"):
|
||||
raw_text = output.text or ""
|
||||
|
||||
# 提取 tokens_used
|
||||
tokens_used = 0
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
if isinstance(usage, dict):
|
||||
# input_tokens + output_tokens
|
||||
tokens_used = usage.get("input_tokens", 0) + usage.get(
|
||||
"output_tokens", 0
|
||||
)
|
||||
elif hasattr(usage, "total_tokens"):
|
||||
tokens_used = usage.total_tokens or 0
|
||||
|
||||
# 提取 new_session_id
|
||||
new_session_id: str | None = None
|
||||
if hasattr(response, "output") and isinstance(response.output, dict):
|
||||
new_session_id = response.output.get("session_id")
|
||||
|
||||
# 解析 JSON
|
||||
try:
|
||||
parsed = json.loads(raw_text)
|
||||
if isinstance(parsed, list):
|
||||
# CHANGE 2026-03-23 | Prompt: App2 LLM 返回 list 而非 dict
|
||||
# 百炼 LLM 有时直接返回 insights 数组而非包裹 dict,
|
||||
# 自动包装为 {"insights": list} 避免无意义重试
|
||||
logger.info(
|
||||
"LLM 返回 list(长度 %d),自动包装为 {\"insights\": [...]}",
|
||||
len(parsed),
|
||||
)
|
||||
parsed = {"insights": parsed}
|
||||
if not isinstance(parsed, dict):
|
||||
raise TypeError(f"期望 dict,实际 {type(parsed).__name__}")
|
||||
return parsed, tokens_used, new_session_id
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
last_json_error = DashScopeJsonParseError(
|
||||
f"JSON 解析失败(第 {json_attempt + 1}/{self.MAX_RETRIES} 次): {e}",
|
||||
raw_content=raw_text,
|
||||
)
|
||||
logger.warning(
|
||||
"Application API 返回非法 JSON(第 %d/%d 次): %s",
|
||||
json_attempt + 1,
|
||||
self.MAX_RETRIES,
|
||||
raw_text[:500],
|
||||
)
|
||||
# 非合法 JSON 纯重试,不做本地修复
|
||||
continue
|
||||
|
||||
# JSON 重试耗尽
|
||||
raise last_json_error # type: ignore[misc]
|
||||
|
||||
async def _call_with_retry(self, func: Callable, **kwargs: Any) -> Any:
|
||||
"""指数退避重试封装。
|
||||
|
||||
重试策略:
|
||||
- 最多重试 MAX_RETRIES 次(默认 3 次)
|
||||
- 间隔:BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
|
||||
- HTTP 4xx → 不重试,立即抛出(401 → DashScopeAuthError)
|
||||
- HTTP 5xx / 超时 / 连接错误 → 重试
|
||||
|
||||
Args:
|
||||
func: 同步调用函数(如 Application.call)
|
||||
**kwargs: 传递给 func 的参数
|
||||
|
||||
Returns:
|
||||
API 响应对象(status_code == 200)
|
||||
|
||||
Raises:
|
||||
DashScopeAuthError: API Key 无效(HTTP 401)
|
||||
DashScopeTimeoutError: 调用超时(重试耗尽)
|
||||
DashScopeApiError: API 调用失败(重试耗尽)
|
||||
"""
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
response = await asyncio.to_thread(func, **kwargs)
|
||||
except Exception as exc:
|
||||
# 网络/连接/超时等底层异常 → 可重试
|
||||
last_error = exc
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
wait_time = self.BASE_INTERVAL * (2**attempt)
|
||||
logger.warning(
|
||||
"DashScope API 底层异常(第 %d/%d 次),%ds 后重试: %s",
|
||||
attempt + 1,
|
||||
self.MAX_RETRIES,
|
||||
wait_time,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
"DashScope API 底层异常,已达最大重试次数 %d: %s",
|
||||
self.MAX_RETRIES,
|
||||
exc,
|
||||
)
|
||||
raise DashScopeApiError(
|
||||
f"DashScope API 调用失败(重试 {self.MAX_RETRIES} 次后): {exc}",
|
||||
) from exc
|
||||
|
||||
# Application.call() 返回 response 对象,通过 status_code 判断成功/失败
|
||||
status_code = getattr(response, "status_code", None)
|
||||
|
||||
if status_code == 200:
|
||||
return response
|
||||
|
||||
# 非 200:根据状态码分类处理
|
||||
message = getattr(response, "message", "") or f"状态码 {status_code}"
|
||||
|
||||
if status_code is not None and 400 <= status_code < 500:
|
||||
# 4xx:不重试,立即抛出
|
||||
if status_code == 401:
|
||||
raise DashScopeAuthError(message)
|
||||
raise DashScopeApiError(message, status_code=status_code)
|
||||
|
||||
# 5xx 或其他未知状态码 → 可重试
|
||||
last_error = DashScopeApiError(message, status_code=status_code)
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
wait_time = self.BASE_INTERVAL * (2**attempt)
|
||||
logger.warning(
|
||||
"DashScope API 调用失败(第 %d/%d 次,状态码 %s),%ds 后重试: %s",
|
||||
attempt + 1,
|
||||
self.MAX_RETRIES,
|
||||
status_code,
|
||||
wait_time,
|
||||
message,
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
"DashScope API 调用失败,已达最大重试次数 %d(状态码 %s): %s",
|
||||
self.MAX_RETRIES,
|
||||
status_code,
|
||||
message,
|
||||
)
|
||||
|
||||
# 重试耗尽
|
||||
raise last_error # type: ignore[misc]
|
||||
29
apps/backend/app/ai/data_fetchers/__init__.py
Normal file
29
apps/backend/app/ai/data_fetchers/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""AI 数据获取层。
|
||||
|
||||
为 AI 应用提供共享的数据获取函数,封装 FDW 查询和业务库查询逻辑。
|
||||
所有 FDW 查询通过 get_etl_readonly_connection(site_id) 获取只读连接,
|
||||
自动设置 RLS 门店隔离。
|
||||
|
||||
模块:
|
||||
- member_data: 客户消费数据获取(应用 3/6/7 共用)
|
||||
- assistant_data: 助教数据获取(应用 4/5 共用)
|
||||
- page_context: 页面上下文文本化(应用 1 专用)
|
||||
"""
|
||||
|
||||
from app.ai.data_fetchers.member_data import (
|
||||
fetch_member_consumption_data,
|
||||
fetch_member_notes,
|
||||
)
|
||||
from app.ai.data_fetchers.assistant_data import (
|
||||
fetch_assistant_info,
|
||||
fetch_service_history,
|
||||
)
|
||||
from app.ai.data_fetchers.page_context import build_page_text
|
||||
|
||||
__all__ = [
|
||||
"fetch_member_consumption_data",
|
||||
"fetch_member_notes",
|
||||
"fetch_assistant_info",
|
||||
"fetch_service_history",
|
||||
"build_page_text",
|
||||
]
|
||||
253
apps/backend/app/ai/data_fetchers/assistant_data.py
Normal file
253
apps/backend/app/ai/data_fetchers/assistant_data.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""助教数据获取模块(应用 4/5 共用)。
|
||||
|
||||
从 ETL 库 app.v_* RLS 视图获取助教基本信息和助教-客户服务历史。
|
||||
使用 is_delete 字段排除废单(is_delete=0 为正常),禁止使用已废弃的 dwd_assistant_trash_event 表。
|
||||
"""
|
||||
# CHANGE 2026-03-23 | Prompt: FDW 迁移——fdw_etl.* → app.* 直连 ETL 库
|
||||
# intent: 将所有 fdw_etl.* 外部表引用改为 app.v_* RLS 视图(直连 ETL 库),列名同步修正
|
||||
# 连接方式不变(get_etl_readonly_connection),仅改 SQL 表名和列名
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from app.database import get_etl_readonly_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FDW_QUERY_TIMEOUT_SEC = 5
|
||||
|
||||
|
||||
async def fetch_assistant_info(
|
||||
site_id: int,
|
||||
assistant_id: int,
|
||||
) -> dict[str, Any]:
|
||||
"""获取助教基本信息。
|
||||
|
||||
返回:
|
||||
{
|
||||
"nickname": str,
|
||||
"level": str,
|
||||
"hire_date": str,
|
||||
"tenure_months": int,
|
||||
"monthly_customers": int,
|
||||
"performance_tier": str,
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 助教不存在
|
||||
TimeoutError: FDW 查询超时
|
||||
ConnectionError: FDW 连接失败
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(_fetch_assistant_info_sync, site_id, assistant_id),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_assistant_info_sync(site_id: int, assistant_id: int) -> dict[str, Any]:
|
||||
"""同步实现。"""
|
||||
conn = None
|
||||
try:
|
||||
conn = get_etl_readonly_connection(site_id)
|
||||
# RLS 隔离 + 语句超时(get_etl_readonly_connection 的 SET LOCAL 在 commit 后失效,
|
||||
# 需在查询事务中重新设置)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL app.current_site_id = %s", (str(site_id),)
|
||||
)
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
|
||||
# 基本信息
|
||||
# ⚠️ v_dim_assistant 列名: hire_date→entry_time
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT nickname, level, entry_time AS hire_date
|
||||
FROM app.v_dim_assistant
|
||||
WHERE assistant_id = %s AND scd2_is_current = 1
|
||||
LIMIT 1
|
||||
""",
|
||||
(assistant_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise ValueError(f"assistant not found: assistant_id={assistant_id}")
|
||||
|
||||
nickname = row[0] or ""
|
||||
level = row[1] or ""
|
||||
hire_date = row[2]
|
||||
|
||||
# 计算工龄
|
||||
tenure_months = 0
|
||||
if hire_date and isinstance(hire_date, date):
|
||||
today = date.today()
|
||||
tenure_months = (today.year - hire_date.year) * 12 + (today.month - hire_date.month)
|
||||
|
||||
# 绩效数据
|
||||
# ⚠️ 列名映射: monthly_customers 不存在(用 0 占位),performance_tier→tier_name
|
||||
# ⚠️ salary_month 是 date 类型(YYYY-MM-01),按月降序取最新
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
0 AS monthly_customers,
|
||||
COALESCE(tier_name, '') AS performance_tier
|
||||
FROM app.v_dws_assistant_salary_calc
|
||||
WHERE assistant_id = %s
|
||||
ORDER BY salary_month DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(assistant_id,),
|
||||
)
|
||||
perf_row = cur.fetchone()
|
||||
monthly_customers = perf_row[0] if perf_row else 0
|
||||
performance_tier = perf_row[1] if perf_row else ""
|
||||
|
||||
conn.commit()
|
||||
return {
|
||||
"nickname": nickname,
|
||||
"level": level,
|
||||
"hire_date": hire_date.isoformat() if isinstance(hire_date, date) else "",
|
||||
"tenure_months": tenure_months,
|
||||
"monthly_customers": monthly_customers,
|
||||
"performance_tier": performance_tier,
|
||||
}
|
||||
|
||||
except (ValueError, TimeoutError, ConnectionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "statement timeout" in err_msg or "timeout" in err_msg:
|
||||
raise TimeoutError(
|
||||
f"FDW 查询超时: assistant_id={assistant_id}"
|
||||
) from e
|
||||
if "connection" in err_msg or "connect" in err_msg:
|
||||
raise ConnectionError(
|
||||
f"FDW 连接失败: assistant_id={assistant_id}"
|
||||
) from e
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
async def fetch_service_history(
|
||||
site_id: int,
|
||||
assistant_id: int,
|
||||
member_id: int,
|
||||
months: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取助教服务该客户的历史记录。
|
||||
|
||||
使用 is_delete 排除废单(WHERE is_delete = 0)。
|
||||
|
||||
返回:
|
||||
[
|
||||
{
|
||||
"service_date": str,
|
||||
"duration_minutes": int,
|
||||
"items_sum": float,
|
||||
"room_name": str,
|
||||
"is_pd": bool,
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Raises:
|
||||
TimeoutError: FDW 查询超时
|
||||
ConnectionError: FDW 连接失败
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(_fetch_service_history_sync, site_id, assistant_id, member_id, months),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_service_history_sync(
|
||||
site_id: int,
|
||||
assistant_id: int,
|
||||
member_id: int,
|
||||
months: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""同步实现。"""
|
||||
conn = None
|
||||
try:
|
||||
conn = get_etl_readonly_connection(site_id)
|
||||
# RLS 隔离 + 语句超时(get_etl_readonly_connection 的 SET LOCAL 在 commit 后失效,
|
||||
# 需在查询事务中重新设置)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL app.current_site_id = %s", (str(site_id),)
|
||||
)
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
|
||||
# ⚠️ 列名映射: assistant_id→site_assistant_id, member_id→tenant_member_id,
|
||||
# is_trash=false→is_delete=0, service_date→create_time,
|
||||
# duration_minutes→real_use_seconds/60, items_sum→ledger_amount,
|
||||
# room_name→site_table_id, is_pd→(order_assistant_type=1)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
create_time AS service_date,
|
||||
COALESCE(real_use_seconds / 60, 0) AS duration_minutes,
|
||||
ledger_amount AS items_sum,
|
||||
site_table_id AS room_name,
|
||||
(order_assistant_type = 1) AS is_pd
|
||||
FROM app.v_dwd_assistant_service_log
|
||||
WHERE site_assistant_id = %s
|
||||
AND tenant_member_id = %s
|
||||
AND is_delete = 0
|
||||
AND create_time >= (CURRENT_DATE - INTERVAL '%s months')
|
||||
ORDER BY create_time DESC
|
||||
""",
|
||||
(assistant_id, member_id, months),
|
||||
)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
rows = cur.fetchall()
|
||||
|
||||
conn.commit()
|
||||
|
||||
records = []
|
||||
for row in rows:
|
||||
record = {}
|
||||
for col, val in zip(columns, row):
|
||||
if isinstance(val, (date, datetime)):
|
||||
record[col] = val.isoformat()
|
||||
elif isinstance(val, Decimal):
|
||||
record[col] = float(val)
|
||||
elif isinstance(val, bool):
|
||||
record[col] = val
|
||||
else:
|
||||
record[col] = val
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
except (TimeoutError, ConnectionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "statement timeout" in err_msg or "timeout" in err_msg:
|
||||
raise TimeoutError(
|
||||
f"FDW 查询超时: assistant_id={assistant_id}, member_id={member_id}"
|
||||
) from e
|
||||
if "connection" in err_msg or "connect" in err_msg:
|
||||
raise ConnectionError(
|
||||
f"FDW 连接失败: assistant_id={assistant_id}, member_id={member_id}"
|
||||
) from e
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
402
apps/backend/app/ai/data_fetchers/member_data.py
Normal file
402
apps/backend/app/ai/data_fetchers/member_data.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""客户消费数据获取模块(应用 3/6/7 共用)。
|
||||
|
||||
从 ETL 库 app.v_* RLS 视图获取客户近 N 个月消费数据,从业务库获取备注。
|
||||
金额口径统一使用拆分字段(table_charge_money + assistant_pd/cx_money + goods_money),禁止 consume_money。
|
||||
会员信息通过 member_id JOIN v_dim_member (scd2_is_current=1) 获取。
|
||||
"""
|
||||
# CHANGE 2026-03-23 | Prompt: FDW 迁移——fdw_etl.* → app.* 直连 ETL 库
|
||||
# intent: 将所有 fdw_etl.* 外部表引用改为 app.v_* RLS 视图(直连 ETL 库),列名同步修正
|
||||
# 连接方式不变(get_etl_readonly_connection),仅改 SQL 表名和列名
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from app.database import get_connection, get_etl_readonly_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 消费记录最大返回数
|
||||
MAX_CONSUMPTION_RECORDS = 100
|
||||
# 备注最大返回数
|
||||
MAX_NOTES = 50
|
||||
# 备注单条最大字符数
|
||||
MAX_NOTE_LENGTH = 500
|
||||
# FDW 查询超时(秒)
|
||||
FDW_QUERY_TIMEOUT_SEC = 5
|
||||
|
||||
|
||||
async def fetch_member_consumption_data(
|
||||
site_id: int,
|
||||
member_id: int,
|
||||
months: int = 3,
|
||||
) -> dict[str, Any]:
|
||||
"""获取客户近 N 个月消费数据。
|
||||
|
||||
返回结构对应 NS2 设计文档中 main_data:
|
||||
- consumption_records: 消费记录列表(最多 100 条,settle_date DESC)
|
||||
- member_cards: 会员卡明细列表
|
||||
- card_balance_total: 储值卡余额合计
|
||||
- stored_value_balance_total: 储值余额合计
|
||||
- expected_visit_date: 预计到店日期
|
||||
- days_since_last_visit: 距上次到店天数
|
||||
- member_nickname: 会员昵称
|
||||
|
||||
Raises:
|
||||
TimeoutError: FDW 查询超时(>5s)
|
||||
ConnectionError: FDW 连接失败
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(_fetch_member_consumption_data_sync, site_id, member_id, months),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_member_consumption_data_sync(
|
||||
site_id: int,
|
||||
member_id: int,
|
||||
months: int,
|
||||
) -> dict[str, Any]:
|
||||
"""同步实现:在单个 FDW 连接上串行执行多个查询。"""
|
||||
conn = None
|
||||
try:
|
||||
conn = get_etl_readonly_connection(site_id)
|
||||
# RLS 隔离 + 语句超时(get_etl_readonly_connection 的 SET LOCAL 在 commit 后失效,
|
||||
# 需在查询事务中重新设置)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL app.current_site_id = %s", (str(site_id),)
|
||||
)
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",), # 毫秒
|
||||
)
|
||||
|
||||
# 1. 会员昵称
|
||||
nickname = _query_member_nickname(conn, member_id)
|
||||
|
||||
# 2. 消费记录(台桌结账 + 商城订单)
|
||||
records, total_count = _query_consumption_records(conn, member_id, months)
|
||||
|
||||
# 3. 会员卡明细
|
||||
cards = _query_member_cards(conn, member_id)
|
||||
|
||||
# 4. 余额汇总
|
||||
balance_info = _query_balance_summary(conn, member_id)
|
||||
|
||||
# 5. 到店数据
|
||||
visit_info = _query_visit_info(conn, member_id)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"member_nickname": nickname,
|
||||
"consumption_records": records,
|
||||
"member_cards": cards,
|
||||
"card_balance_total": balance_info.get("card_balance_total", Decimal("0")),
|
||||
"stored_value_balance_total": balance_info.get(
|
||||
"stored_value_balance_total", Decimal("0")
|
||||
),
|
||||
"expected_visit_date": visit_info.get("expected_visit_date"),
|
||||
"days_since_last_visit": visit_info.get("days_since_last_visit"),
|
||||
}
|
||||
if total_count > MAX_CONSUMPTION_RECORDS:
|
||||
result["truncated"] = True
|
||||
result["total_count"] = total_count
|
||||
|
||||
conn.commit()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# psycopg2 超时异常包含 "statement timeout"
|
||||
err_msg = str(e).lower()
|
||||
if "statement timeout" in err_msg or "timeout" in err_msg:
|
||||
raise TimeoutError(
|
||||
f"FDW 查询超时(>{FDW_QUERY_TIMEOUT_SEC}s): member_id={member_id}"
|
||||
) from e
|
||||
if "connection" in err_msg or "connect" in err_msg:
|
||||
raise ConnectionError(
|
||||
f"FDW 连接失败: member_id={member_id}, error={e}"
|
||||
) from e
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _query_member_nickname(conn: Any, member_id: int) -> str:
|
||||
"""从 app.v_dim_member 获取会员昵称(scd2_is_current=1)。"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT nickname
|
||||
FROM app.v_dim_member
|
||||
WHERE member_id = %s AND scd2_is_current = 1
|
||||
LIMIT 1
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return row[0] if row and row[0] else ""
|
||||
|
||||
|
||||
def _query_consumption_records(
|
||||
conn: Any, member_id: int, months: int
|
||||
) -> tuple[list[dict], int]:
|
||||
"""从 app.v_dwd_settlement_head + app.v_dwd_table_fee_log 获取消费记录。
|
||||
|
||||
仅包含正向交易(settle_type IN (1, 3))。
|
||||
⚠️ 费用拆分字段(table_charge_money, assistant_pd/cx_money)在 settlement_head 上。
|
||||
⚠️ table_fee_log 提供台桌时长(real_table_use_seconds)和桌台ID(site_table_id)。
|
||||
⚠️ 列名映射: settle_date→create_time, settle_id→order_settle_id, sale_amount→ledger_amount。
|
||||
返回 (records, total_count)。
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
# 先查总数
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM app.v_dwd_settlement_head sh
|
||||
WHERE sh.member_id = %s
|
||||
AND sh.settle_type IN (1, 3)
|
||||
AND sh.create_time >= (CURRENT_DATE - INTERVAL '%s months')
|
||||
""",
|
||||
(member_id, months),
|
||||
)
|
||||
total_count = cur.fetchone()[0]
|
||||
|
||||
# 查询消费记录(限制 100 条)
|
||||
# table_charge_money/assistant_pd_money/assistant_cx_money 直接从 settlement_head 取
|
||||
# 台桌信息从 table_fee_log 取(site_table_id, real_table_use_seconds)
|
||||
# 商品金额从 store_goods_sale 聚合
|
||||
# 助教姓名从 service_log JOIN dim_assistant 获取
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
sh.create_time AS settle_date,
|
||||
sh.settle_type,
|
||||
sh.table_charge_money + sh.assistant_pd_money + sh.assistant_cx_money
|
||||
+ COALESCE(sg.goods_money, 0) AS items_sum,
|
||||
COALESCE(sh.table_charge_money, 0) AS table_charge_money,
|
||||
COALESCE(sh.assistant_pd_money, 0) AS assistant_pd_money,
|
||||
COALESCE(sh.assistant_cx_money, 0) AS assistant_cx_money,
|
||||
COALESCE(sg.goods_money, 0) AS goods_money,
|
||||
tfl.site_table_id AS room_name,
|
||||
COALESCE(tfl.real_table_use_seconds / 60, 0) AS duration_minutes,
|
||||
coaches.assistant_names
|
||||
FROM app.v_dwd_settlement_head sh
|
||||
LEFT JOIN app.v_dwd_table_fee_log tfl
|
||||
ON sh.order_settle_id = tfl.order_settle_id
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT SUM(sgs.ledger_amount) AS goods_money
|
||||
FROM app.v_dwd_store_goods_sale sgs
|
||||
WHERE sgs.order_settle_id = sh.order_settle_id
|
||||
) sg ON true
|
||||
LEFT JOIN LATERAL (
|
||||
SELECT string_agg(DISTINCT COALESCE(da.nickname, da.real_name, ''), ', ')
|
||||
AS assistant_names
|
||||
FROM app.v_dwd_assistant_service_log sl
|
||||
LEFT JOIN app.v_dim_assistant da
|
||||
ON sl.site_assistant_id = da.assistant_id
|
||||
AND da.scd2_is_current = 1
|
||||
WHERE sl.order_settle_id = sh.order_settle_id
|
||||
AND sl.is_delete = 0
|
||||
) coaches ON true
|
||||
WHERE sh.member_id = %s
|
||||
AND sh.settle_type IN (1, 3)
|
||||
AND sh.create_time >= (CURRENT_DATE - INTERVAL '%s months')
|
||||
ORDER BY sh.create_time DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(member_id, months, MAX_CONSUMPTION_RECORDS),
|
||||
)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
rows = cur.fetchall()
|
||||
|
||||
records = []
|
||||
for row in rows:
|
||||
record = {}
|
||||
for col, val in zip(columns, row):
|
||||
if isinstance(val, (date, datetime)):
|
||||
record[col] = val.isoformat()
|
||||
elif isinstance(val, Decimal):
|
||||
record[col] = float(val)
|
||||
else:
|
||||
record[col] = val
|
||||
# assistant_names: 确保是列表
|
||||
names = record.get("assistant_names")
|
||||
if names and isinstance(names, str):
|
||||
record["assistant_names"] = [n.strip() for n in names.split(",") if n.strip()]
|
||||
elif not names:
|
||||
record["assistant_names"] = []
|
||||
records.append(record)
|
||||
|
||||
return records, total_count
|
||||
|
||||
|
||||
def _query_member_cards(conn: Any, member_id: int) -> list[dict]:
|
||||
"""从 app.v_dim_member_card_account 获取会员卡明细。
|
||||
⚠️ 列名映射: member_id→tenant_member_id, gift_balance 不存在(用 balance - principal_balance 近似)。
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT member_card_type_name AS card_type,
|
||||
COALESCE(balance, 0) AS balance,
|
||||
COALESCE(balance, 0) - COALESCE(principal_balance, 0) AS gift_balance
|
||||
FROM app.v_dim_member_card_account
|
||||
WHERE tenant_member_id = %s AND scd2_is_current = 1
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"card_type": row[0] or "",
|
||||
"balance": float(row[1]) if row[1] else 0.0,
|
||||
"gift_balance": float(row[2]) if row[2] else 0.0,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def _query_balance_summary(conn: Any, member_id: int) -> dict:
|
||||
"""从 app.v_dws_member_consumption_summary 获取余额汇总。
|
||||
⚠️ 列名映射: recharge_card_amount→cash_card_balance, balance_amount→total_card_balance。
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
COALESCE(cash_card_balance, 0) AS card_balance_total,
|
||||
COALESCE(total_card_balance, 0) AS stored_value_balance_total
|
||||
FROM app.v_dws_member_consumption_summary
|
||||
WHERE member_id = %s
|
||||
LIMIT 1
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return {
|
||||
"card_balance_total": Decimal("0"),
|
||||
"stored_value_balance_total": Decimal("0"),
|
||||
}
|
||||
return {
|
||||
"card_balance_total": row[0],
|
||||
"stored_value_balance_total": row[1],
|
||||
}
|
||||
|
||||
|
||||
def _query_visit_info(conn: Any, member_id: int) -> dict:
|
||||
"""从 app.v_dws_member_visit_detail 获取到店数据,推算预计到店日期。
|
||||
⚠️ 列名映射: last_visit_date→MAX(visit_date), avg_visit_interval_days 需从明细计算。
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
# 获取最近到店日期和平均到店间隔
|
||||
cur.execute(
|
||||
"""
|
||||
WITH visits AS (
|
||||
SELECT visit_date,
|
||||
LAG(visit_date) OVER (ORDER BY visit_date) AS prev_visit
|
||||
FROM app.v_dws_member_visit_detail
|
||||
WHERE member_id = %s
|
||||
)
|
||||
SELECT
|
||||
MAX(visit_date) AS last_visit_date,
|
||||
AVG(visit_date - prev_visit) AS avg_visit_interval_days
|
||||
FROM visits
|
||||
WHERE prev_visit IS NOT NULL
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row or not row[0]:
|
||||
return {"expected_visit_date": None, "days_since_last_visit": None}
|
||||
|
||||
last_visit = row[0]
|
||||
avg_interval = row[1]
|
||||
today = date.today()
|
||||
days_since = (today - last_visit).days if isinstance(last_visit, date) else None
|
||||
|
||||
expected = None
|
||||
if avg_interval and last_visit:
|
||||
from datetime import timedelta
|
||||
expected_date = last_visit + timedelta(days=int(avg_interval))
|
||||
expected = expected_date.isoformat()
|
||||
|
||||
return {
|
||||
"expected_visit_date": expected,
|
||||
"days_since_last_visit": days_since,
|
||||
}
|
||||
|
||||
|
||||
async def fetch_member_notes(
|
||||
site_id: int,
|
||||
member_id: int,
|
||||
limit: int = MAX_NOTES,
|
||||
) -> list[dict]:
|
||||
"""获取客户的全部备注(按 created_at DESC,最多 limit 条)。
|
||||
|
||||
从业务库 biz.notes 查询。
|
||||
单条备注内容截断 500 字符,超出附加"…(已截断)"。
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(_fetch_member_notes_sync, site_id, member_id, limit),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_member_notes_sync(
|
||||
site_id: int,
|
||||
member_id: int,
|
||||
limit: int,
|
||||
) -> list[dict]:
|
||||
"""同步实现:从业务库查询备注。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
n.content,
|
||||
u.nickname AS recorded_by,
|
||||
n.created_at
|
||||
FROM biz.notes n
|
||||
LEFT JOIN biz.coach_tasks ct ON ct.id = n.task_id
|
||||
LEFT JOIN public.users u ON u.id = n.user_id
|
||||
WHERE n.target_id = %s AND n.site_id = %s
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(member_id, site_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
notes = []
|
||||
for row in rows:
|
||||
content = row[0] or ""
|
||||
recorded_by = row[1] or ""
|
||||
created_at = row[2]
|
||||
|
||||
# 截断处理
|
||||
if len(content) > MAX_NOTE_LENGTH:
|
||||
content = content[:MAX_NOTE_LENGTH] + "…(已截断)"
|
||||
|
||||
notes.append({
|
||||
"recorded_by": recorded_by,
|
||||
"content": content,
|
||||
"created_at": created_at.isoformat() if isinstance(created_at, (date, datetime)) else str(created_at) if created_at else "",
|
||||
})
|
||||
|
||||
return notes
|
||||
finally:
|
||||
conn.close()
|
||||
645
apps/backend/app/ai/data_fetchers/page_context.py
Normal file
645
apps/backend/app/ai/data_fetchers/page_context.py
Normal file
@@ -0,0 +1,645 @@
|
||||
"""页面上下文文本化模块(应用 1 专用)。
|
||||
|
||||
根据 contextType 从数据库获取对应页面数据,
|
||||
格式化为结构化中文文本(≤ 2000 字符),供 AI 理解当前场景。
|
||||
不传入 member_phone 等断档敏感字段。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from app.database import get_connection, get_etl_readonly_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_PAGE_CONTEXT_LENGTH = 2000
|
||||
FDW_QUERY_TIMEOUT_SEC = 5
|
||||
|
||||
# 支持的 10 种页面类型
|
||||
SUPPORTED_PAGE_TYPES = {
|
||||
"task-detail",
|
||||
"customer-detail",
|
||||
"coach-detail",
|
||||
"board-finance",
|
||||
"board-customer",
|
||||
"board-coach",
|
||||
"performance",
|
||||
"my-profile",
|
||||
"task-list",
|
||||
"customer-service-records",
|
||||
}
|
||||
|
||||
|
||||
async def build_page_text(
|
||||
source_page: str,
|
||||
context_id: int | str | None,
|
||||
site_id: int,
|
||||
filters: dict | None = None,
|
||||
) -> str:
|
||||
"""将页面数据转换为 AI 可读的结构化中文文本。
|
||||
|
||||
Args:
|
||||
source_page: 页面类型(contextType)
|
||||
context_id: 实体 ID(contextId)
|
||||
site_id: 门店 ID
|
||||
filters: 看板类页面的筛选参数
|
||||
|
||||
Returns:
|
||||
结构化中文文本(≤ 2000 字符),失败时返回降级文本
|
||||
"""
|
||||
if not source_page or source_page not in SUPPORTED_PAGE_TYPES:
|
||||
return ""
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
text = await loop.run_in_executor(
|
||||
None,
|
||||
partial(_build_page_text_sync, source_page, context_id, site_id, filters or {}),
|
||||
)
|
||||
# 截断保护
|
||||
if len(text) > MAX_PAGE_CONTEXT_LENGTH:
|
||||
text = text[:MAX_PAGE_CONTEXT_LENGTH - 20] + "\n…(上下文已截断)"
|
||||
return text
|
||||
except Exception:
|
||||
logger.exception("页面上下文获取失败: source_page=%s", source_page)
|
||||
return "页面上下文获取失败,请直接描述您的问题"
|
||||
|
||||
|
||||
def _build_page_text_sync(
|
||||
source_page: str,
|
||||
context_id: int | str | None,
|
||||
site_id: int,
|
||||
filters: dict,
|
||||
) -> str:
|
||||
"""同步路由到对应页面文本化函数。"""
|
||||
handlers = {
|
||||
"task-detail": _text_task_detail,
|
||||
"customer-detail": _text_customer_detail,
|
||||
"coach-detail": _text_coach_detail,
|
||||
"board-finance": _text_board_finance,
|
||||
"board-customer": _text_board_customer,
|
||||
"board-coach": _text_board_coach,
|
||||
"performance": _text_performance,
|
||||
"my-profile": _text_my_profile,
|
||||
"task-list": _text_task_list,
|
||||
"customer-service-records": _text_customer_service_records,
|
||||
}
|
||||
handler = handlers.get(source_page)
|
||||
if not handler:
|
||||
return ""
|
||||
return handler(context_id, site_id, filters)
|
||||
|
||||
|
||||
# ── 详情类页面 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _text_task_detail(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""任务详情页文本化。"""
|
||||
if not context_id:
|
||||
return ""
|
||||
task_id = int(context_id)
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 任务信息
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT ct.task_type, ct.status, ct.deadline,
|
||||
ct.member_id, ct.assistant_id,
|
||||
dm.nickname AS member_nickname,
|
||||
da.nickname AS assistant_nickname
|
||||
FROM biz.coach_tasks ct
|
||||
LEFT JOIN biz.coach_tasks_member_view dm
|
||||
ON dm.member_id = ct.member_id AND dm.site_id = ct.site_id
|
||||
LEFT JOIN biz.coach_tasks_assistant_view da
|
||||
ON da.assistant_id = ct.assistant_id AND da.site_id = ct.site_id
|
||||
WHERE ct.id = %s AND ct.site_id = %s
|
||||
""",
|
||||
(task_id, site_id),
|
||||
)
|
||||
task = cur.fetchone()
|
||||
if not task:
|
||||
return f"任务 {task_id} 不存在"
|
||||
|
||||
task_type, status, deadline, member_id, assistant_id, member_nick, asst_nick = task
|
||||
|
||||
# 最近备注(最多 3 条)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT content, created_at
|
||||
FROM biz.notes
|
||||
WHERE task_id = %s AND site_id = %s
|
||||
ORDER BY created_at DESC LIMIT 3
|
||||
""",
|
||||
(task_id, site_id),
|
||||
)
|
||||
notes = cur.fetchall()
|
||||
|
||||
# AI 缓存(最新分析)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT result_json, created_at
|
||||
FROM biz.ai_cache
|
||||
WHERE cache_type = 'app4_analysis'
|
||||
AND site_id = %s
|
||||
AND target_id = %s
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
(site_id, f"{assistant_id}_{member_id}"),
|
||||
)
|
||||
ai_row = cur.fetchone()
|
||||
|
||||
lines = [
|
||||
"【任务详情】",
|
||||
f" 任务类型:{task_type or '未知'}",
|
||||
f" 状态:{status or '未知'}",
|
||||
f" 截止日期:{_fmt_date(deadline)}",
|
||||
f" 客户:{member_nick or f'ID:{member_id}'}",
|
||||
f" 助教:{asst_nick or f'ID:{assistant_id}'}",
|
||||
]
|
||||
if notes:
|
||||
lines.append("【最近备注】")
|
||||
for content, created_at in notes:
|
||||
short = (content or "")[:100]
|
||||
lines.append(f" {_fmt_date(created_at)} {short}")
|
||||
if ai_row:
|
||||
lines.append(f"【AI 分析】最近更新于 {_fmt_date(ai_row[1])}")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _text_customer_detail(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""客户详情页文本化。"""
|
||||
if not context_id:
|
||||
return ""
|
||||
member_id = int(context_id)
|
||||
|
||||
# 复用 member_data 的同步查询(避免循环导入,直接查询)
|
||||
etl_conn = None
|
||||
biz_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
# CHANGE 2026-03-23 | Prompt: FDW 迁移——fdw_etl.* → app.* 直连 ETL 库
|
||||
# 会员信息
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT nickname
|
||||
FROM app.v_dim_member
|
||||
WHERE member_id = %s AND scd2_is_current = 1
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
m_row = cur.fetchone()
|
||||
nickname = m_row[0] if m_row else f"ID:{member_id}"
|
||||
|
||||
# 最近 5 条消费
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT settle_date, items_sum, room_name
|
||||
FROM app.v_dwd_settlement_head
|
||||
WHERE member_id = %s AND settle_type IN (1, 3)
|
||||
ORDER BY settle_date DESC LIMIT 5
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
recent = cur.fetchall()
|
||||
|
||||
# 余额
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT balance_amount
|
||||
FROM app.v_dws_member_consumption_summary
|
||||
WHERE member_id = %s LIMIT 1
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
bal_row = cur.fetchone()
|
||||
etl_conn.commit()
|
||||
|
||||
# 维客线索
|
||||
biz_conn = get_connection()
|
||||
with biz_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT summary FROM member_retention_clue
|
||||
WHERE member_id = %s AND site_id = %s
|
||||
ORDER BY created_at DESC LIMIT 5
|
||||
""",
|
||||
(member_id, site_id),
|
||||
)
|
||||
clues = cur.fetchall()
|
||||
|
||||
lines = [
|
||||
"【客户详情】",
|
||||
f" 昵称:{nickname}",
|
||||
f" 储值余额:{_fmt_decimal(bal_row[0]) if bal_row else '未知'}",
|
||||
]
|
||||
if recent:
|
||||
lines.append("【近期消费】")
|
||||
for sd, amt, room in recent:
|
||||
lines.append(f" {_fmt_date(sd)} ¥{_fmt_decimal(amt)} {room or ''}")
|
||||
if clues:
|
||||
lines.append("【维客线索】")
|
||||
for (summary,) in clues:
|
||||
lines.append(f" {summary}")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
if biz_conn:
|
||||
biz_conn.close()
|
||||
|
||||
|
||||
def _text_coach_detail(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""助教详情页文本化。"""
|
||||
if not context_id:
|
||||
return ""
|
||||
assistant_id = int(context_id)
|
||||
|
||||
etl_conn = None
|
||||
biz_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT nickname, level, hire_date
|
||||
FROM app.v_dim_assistant
|
||||
WHERE assistant_id = %s LIMIT 1
|
||||
""",
|
||||
(assistant_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
etl_conn.commit()
|
||||
|
||||
if not row:
|
||||
return f"助教 {assistant_id} 不存在"
|
||||
|
||||
nickname, level, hire_date = row
|
||||
|
||||
biz_conn = get_connection()
|
||||
with biz_conn.cursor() as cur:
|
||||
# 任务统计
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT status, COUNT(*)
|
||||
FROM biz.coach_tasks
|
||||
WHERE assistant_id = %s AND site_id = %s
|
||||
GROUP BY status
|
||||
""",
|
||||
(assistant_id, site_id),
|
||||
)
|
||||
task_stats = cur.fetchall()
|
||||
|
||||
lines = [
|
||||
"【助教详情】",
|
||||
f" 花名:{nickname or ''}",
|
||||
f" 级别:{level or ''}",
|
||||
f" 入职日期:{_fmt_date(hire_date)}",
|
||||
]
|
||||
if task_stats:
|
||||
lines.append("【任务统计】")
|
||||
for status, cnt in task_stats:
|
||||
lines.append(f" {status}: {cnt} 个")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
if biz_conn:
|
||||
biz_conn.close()
|
||||
|
||||
|
||||
# ── 看板类页面 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _text_board_finance(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""财务看板文本化。"""
|
||||
time_dim = filters.get("timeDimension", "this_month")
|
||||
area = filters.get("areaFilter", "")
|
||||
|
||||
etl_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
# 简化查询:获取汇总数据
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) AS settle_count,
|
||||
COALESCE(SUM(items_sum), 0) AS total_revenue,
|
||||
COALESCE(AVG(items_sum), 0) AS avg_revenue
|
||||
FROM app.v_dwd_settlement_head
|
||||
WHERE settle_type IN (1, 3)
|
||||
AND settle_date >= (CURRENT_DATE - INTERVAL '1 month')
|
||||
""",
|
||||
)
|
||||
row = cur.fetchone()
|
||||
etl_conn.commit()
|
||||
|
||||
lines = [
|
||||
"【财务看板】",
|
||||
f" 时间维度:{time_dim}",
|
||||
]
|
||||
if area:
|
||||
lines.append(f" 区域筛选:{area}")
|
||||
if row:
|
||||
lines.append(f" 结算笔数:{row[0]}")
|
||||
lines.append(f" 总营收:¥{_fmt_decimal(row[1])}")
|
||||
lines.append(f" 笔均:¥{_fmt_decimal(row[2])}")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
|
||||
|
||||
def _text_board_customer(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""客户看板文本化。"""
|
||||
dimension = filters.get("dimension", "consumption")
|
||||
type_filter = filters.get("typeFilter", "")
|
||||
|
||||
etl_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
# Top 10 客户
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
dm.nickname,
|
||||
COALESCE(SUM(sh.items_sum), 0) AS total_consumption
|
||||
FROM app.v_dwd_settlement_head sh
|
||||
JOIN app.v_dim_member dm
|
||||
ON dm.member_id = sh.member_id AND dm.scd2_is_current = 1
|
||||
WHERE sh.settle_type IN (1, 3)
|
||||
AND sh.member_id > 0
|
||||
AND sh.settle_date >= (CURRENT_DATE - INTERVAL '1 month')
|
||||
GROUP BY dm.nickname
|
||||
ORDER BY total_consumption DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
etl_conn.commit()
|
||||
|
||||
lines = [
|
||||
"【客户看板】",
|
||||
f" 排序维度:{dimension}",
|
||||
]
|
||||
if type_filter:
|
||||
lines.append(f" 类型筛选:{type_filter}")
|
||||
if rows:
|
||||
lines.append(" Top 10 客户:")
|
||||
for i, (nick, amt) in enumerate(rows, 1):
|
||||
lines.append(f" {i}. {nick or '未知'} ¥{_fmt_decimal(amt)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
|
||||
|
||||
def _text_board_coach(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""助教看板文本化。"""
|
||||
dimension = filters.get("dimension", "service_count")
|
||||
project = filters.get("projectFilter", "")
|
||||
time_dim = filters.get("timeDimension", "this_month")
|
||||
|
||||
etl_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
da.nickname,
|
||||
COUNT(*) AS service_count,
|
||||
COALESCE(SUM(sl.ledger_amount), 0) AS total_revenue
|
||||
FROM app.v_dwd_assistant_service_log sl
|
||||
JOIN app.v_dim_assistant da
|
||||
ON da.assistant_id = sl.site_assistant_id
|
||||
WHERE sl.is_delete = 0
|
||||
AND sl.create_time >= (CURRENT_DATE - INTERVAL '1 month')
|
||||
GROUP BY da.nickname
|
||||
ORDER BY service_count DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
etl_conn.commit()
|
||||
|
||||
lines = [
|
||||
"【助教看板】",
|
||||
f" 排序维度:{dimension}",
|
||||
f" 时间维度:{time_dim}",
|
||||
]
|
||||
if project:
|
||||
lines.append(f" 技能筛选:{project}")
|
||||
if rows:
|
||||
lines.append(" Top 10 助教:")
|
||||
for i, (nick, cnt, amt) in enumerate(rows, 1):
|
||||
lines.append(f" {i}. {nick or '未知'} 服务{cnt}次 ¥{_fmt_decimal(amt)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
|
||||
|
||||
# ── 其他页面 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _text_performance(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""绩效页面文本化。"""
|
||||
time_dim = filters.get("timeDimension", "this_month")
|
||||
|
||||
etl_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
da.nickname,
|
||||
sc.performance_tier,
|
||||
sc.monthly_customers
|
||||
FROM app.v_dws_assistant_salary_calc sc
|
||||
JOIN app.v_dim_assistant da
|
||||
ON da.assistant_id = sc.assistant_id
|
||||
ORDER BY sc.calc_month DESC, sc.monthly_customers DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
etl_conn.commit()
|
||||
|
||||
lines = [
|
||||
"【绩效数据】",
|
||||
f" 时间维度:{time_dim}",
|
||||
]
|
||||
if rows:
|
||||
for nick, tier, customers in rows:
|
||||
lines.append(f" {nick or '未知'} {tier or ''} 服务{customers or 0}人")
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
|
||||
|
||||
def _text_my_profile(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""个人信息页文本化。"""
|
||||
return "【个人信息】\n 当前为个人信息页面,可查询个人绩效和任务情况。"
|
||||
|
||||
|
||||
def _text_task_list(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""任务列表页文本化。"""
|
||||
if not context_id:
|
||||
# 无特定任务,返回概要
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT status, COUNT(*)
|
||||
FROM biz.coach_tasks
|
||||
WHERE site_id = %s
|
||||
GROUP BY status
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
stats = cur.fetchall()
|
||||
|
||||
lines = ["【任务列表】"]
|
||||
for status, cnt in stats:
|
||||
lines.append(f" {status}: {cnt} 个")
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# 有特定任务 ID,复用 task-detail
|
||||
return _text_task_detail(context_id, site_id, filters)
|
||||
|
||||
|
||||
def _text_customer_service_records(
|
||||
context_id: int | str | None, site_id: int, filters: dict
|
||||
) -> str:
|
||||
"""客户服务记录页文本化。"""
|
||||
if not context_id:
|
||||
return ""
|
||||
member_id = int(context_id)
|
||||
|
||||
etl_conn = None
|
||||
try:
|
||||
etl_conn = get_etl_readonly_connection(site_id)
|
||||
with etl_conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SET LOCAL statement_timeout = %s",
|
||||
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
create_time,
|
||||
real_use_seconds / 60 AS duration_minutes,
|
||||
ledger_amount,
|
||||
site_table_id
|
||||
FROM app.v_dwd_assistant_service_log
|
||||
WHERE tenant_member_id = %s AND is_delete = 0
|
||||
ORDER BY create_time DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
(member_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
etl_conn.commit()
|
||||
|
||||
lines = ["【服务记录】"]
|
||||
if not rows:
|
||||
lines.append(" 暂无服务记录")
|
||||
else:
|
||||
for sd, dur, amt, room in rows:
|
||||
lines.append(
|
||||
f" {_fmt_date(sd)} {dur or 0}分钟 ¥{_fmt_decimal(amt)} {room or ''}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
finally:
|
||||
if etl_conn:
|
||||
etl_conn.close()
|
||||
|
||||
|
||||
# ── 工具函数 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _fmt_date(val: Any) -> str:
|
||||
"""格式化日期值。"""
|
||||
if isinstance(val, datetime):
|
||||
return val.strftime("%Y-%m-%d %H:%M")
|
||||
if isinstance(val, date):
|
||||
return val.isoformat()
|
||||
return str(val) if val else "未知"
|
||||
|
||||
|
||||
def _fmt_decimal(val: Any) -> str:
|
||||
"""格式化金额值。"""
|
||||
if val is None:
|
||||
return "0.00"
|
||||
if isinstance(val, Decimal):
|
||||
return f"{val:.2f}"
|
||||
if isinstance(val, float):
|
||||
return f"{val:.2f}"
|
||||
return str(val)
|
||||
File diff suppressed because it is too large
Load Diff
53
apps/backend/app/ai/exceptions.py
Normal file
53
apps/backend/app/ai/exceptions.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""AI 模块异常层级。
|
||||
|
||||
所有 DashScope 相关异常继承自 DashScopeError 基类,
|
||||
便于上层统一捕获和分类处理。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class DashScopeError(Exception):
|
||||
"""DashScope 异常基类。"""
|
||||
|
||||
|
||||
class DashScopeApiError(DashScopeError):
|
||||
"""Application API 调用失败(重试耗尽后)。"""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class DashScopeAuthError(DashScopeApiError):
|
||||
"""API Key 无效(HTTP 401)。"""
|
||||
|
||||
def __init__(self, message: str = "API Key 无效或已过期"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
class DashScopeTimeoutError(DashScopeApiError):
|
||||
"""调用超时。"""
|
||||
|
||||
def __init__(self, message: str = "DashScope API 调用超时"):
|
||||
super().__init__(message, status_code=None)
|
||||
|
||||
|
||||
class DashScopeJsonParseError(DashScopeError):
|
||||
"""响应 JSON 解析失败(重试耗尽后)。"""
|
||||
|
||||
def __init__(self, message: str, raw_content: str = ""):
|
||||
super().__init__(message)
|
||||
self.raw_content = raw_content
|
||||
|
||||
|
||||
class CircuitOpenError(DashScopeError):
|
||||
"""熔断器处于 OPEN 状态,拒绝请求。"""
|
||||
|
||||
|
||||
class RateLimitExceededError(DashScopeError):
|
||||
"""限流阈值超限。"""
|
||||
|
||||
|
||||
class BudgetExceededError(DashScopeError):
|
||||
"""Token 预算超限。"""
|
||||
73
apps/backend/app/ai/rate_limiter.py
Normal file
73
apps/backend/app/ai/rate_limiter.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""限流器 — 滑动窗口内存计数器。
|
||||
|
||||
App1 按 user_id 限流(每用户每分钟 10 次),
|
||||
App2~8 按 site_id 限流(每门店每小时 100 次)。
|
||||
内存实现,单实例部署,不依赖外部存储。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""滑动窗口内存限流器。
|
||||
|
||||
- check_user_rate():App1 每用户每分钟限流
|
||||
- check_store_rate():App2~8 每门店每小时限流
|
||||
|
||||
每个 key 维护一个时间戳 deque,检查时先清除过期条目,
|
||||
再判断窗口内请求数是否低于阈值。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._user_windows: dict[str, deque[float]] = {} # App1: user_id → 时间戳队列
|
||||
self._store_windows: dict[str, deque[float]] = {} # App2~8: site_id → 时间戳队列
|
||||
|
||||
def _check(
|
||||
self,
|
||||
windows: dict[str, deque[float]],
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> bool:
|
||||
"""通用滑动窗口检查。返回 True 表示允许。"""
|
||||
now = time.monotonic()
|
||||
|
||||
if key not in windows:
|
||||
windows[key] = deque()
|
||||
|
||||
window = windows[key]
|
||||
|
||||
# 清除窗口外的过期时间戳
|
||||
cutoff = now - window_seconds
|
||||
while window and window[0] <= cutoff:
|
||||
window.popleft()
|
||||
|
||||
# 判断是否超限
|
||||
if len(window) >= limit:
|
||||
return False
|
||||
|
||||
# 未超限,记录本次请求时间戳
|
||||
window.append(now)
|
||||
return True
|
||||
|
||||
def check_user_rate(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
window_seconds: int = 60,
|
||||
) -> bool:
|
||||
"""App1 每用户每分钟限流。返回 True 表示允许。"""
|
||||
return self._check(self._user_windows, user_id, limit, window_seconds)
|
||||
|
||||
def check_store_rate(
|
||||
self,
|
||||
site_id: int,
|
||||
limit: int = 100,
|
||||
window_seconds: int = 3600,
|
||||
) -> bool:
|
||||
"""App2~8 每门店每小时限流。返回 True 表示允许。"""
|
||||
# site_id 为 int,转为 str 作为 dict key
|
||||
return self._check(self._store_windows, str(site_id), limit, window_seconds)
|
||||
207
apps/backend/app/ai/run_log_service.py
Normal file
207
apps/backend/app/ai/run_log_service.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""AI 运行日志服务 — biz.ai_run_logs 表的 CRUD 操作。
|
||||
|
||||
每次 Application API 调用前创建 pending 记录,调用过程中更新状态,
|
||||
调用结束后记录结果。同时提供日/月 token 聚合查询,实现 UsageProvider 协议
|
||||
以便注入 BudgetTracker。
|
||||
|
||||
request_prompt 写入前截断为前 2000 字符,避免大 prompt 占用过多存储。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable
|
||||
|
||||
import psycopg2.extensions
|
||||
|
||||
# prompt 最大存储长度
|
||||
_MAX_PROMPT_LENGTH = 2000
|
||||
|
||||
|
||||
def _truncate_prompt(prompt: str | None) -> str | None:
|
||||
"""截断 prompt 为前 2000 字符。None 原样返回。"""
|
||||
if prompt is None:
|
||||
return None
|
||||
return prompt[:_MAX_PROMPT_LENGTH]
|
||||
|
||||
|
||||
class AIRunLogService:
|
||||
"""AI 运行日志 CRUD,实现 UsageProvider 协议。
|
||||
|
||||
构造函数接受 get_conn callable,每次操作时获取数据库连接,
|
||||
避免长期持有连接导致超时或连接池耗尽。
|
||||
"""
|
||||
|
||||
def __init__(self, get_conn: Callable[[], psycopg2.extensions.connection]) -> None:
|
||||
self._get_conn = get_conn
|
||||
|
||||
# ── 创建 ──────────────────────────────────────────────
|
||||
|
||||
def create_log(
|
||||
self,
|
||||
site_id: int,
|
||||
app_type: str,
|
||||
trigger_type: str,
|
||||
*,
|
||||
member_id: int | None = None,
|
||||
request_prompt: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> int:
|
||||
"""创建日志记录(status: pending),返回 log_id。
|
||||
|
||||
request_prompt 自动截断为前 2000 字符。
|
||||
"""
|
||||
truncated = _truncate_prompt(request_prompt)
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO biz.ai_run_logs
|
||||
(site_id, app_type, trigger_type, member_id,
|
||||
request_prompt, session_id, status)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, 'pending')
|
||||
RETURNING id
|
||||
""",
|
||||
(site_id, app_type, trigger_type, member_id,
|
||||
truncated, session_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
assert row is not None, "INSERT RETURNING 应返回 id"
|
||||
log_id: int = row[0]
|
||||
conn.commit()
|
||||
return log_id
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ── 状态转换 ──────────────────────────────────────────
|
||||
|
||||
def update_running(self, log_id: int) -> None:
|
||||
"""更新为 running。"""
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE biz.ai_run_logs
|
||||
SET status = 'running'
|
||||
WHERE id = %s
|
||||
""",
|
||||
(log_id,),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
def update_success(
|
||||
self,
|
||||
log_id: int,
|
||||
response_text: str,
|
||||
tokens_used: int,
|
||||
latency_ms: int,
|
||||
) -> None:
|
||||
"""更新为 success,记录响应、token 消耗和耗时。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE biz.ai_run_logs
|
||||
SET status = 'success',
|
||||
response_text = %s,
|
||||
tokens_used = %s,
|
||||
latency_ms = %s,
|
||||
finished_at = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(response_text, tokens_used, latency_ms, now, log_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
def update_failed(
|
||||
self,
|
||||
log_id: int,
|
||||
error_message: str,
|
||||
latency_ms: int,
|
||||
) -> None:
|
||||
"""更新为 failed,记录错误信息和耗时。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE biz.ai_run_logs
|
||||
SET status = 'failed',
|
||||
error_message = %s,
|
||||
latency_ms = %s,
|
||||
finished_at = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(error_message, latency_ms, now, log_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
def update_timeout(self, log_id: int, latency_ms: int) -> None:
|
||||
"""更新为 timeout。"""
|
||||
now = datetime.now(timezone.utc)
|
||||
conn = self._get_conn()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE biz.ai_run_logs
|
||||
SET status = 'timeout',
|
||||
latency_ms = %s,
|
||||
finished_at = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(latency_ms, now, log_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ── UsageProvider 协议实现 ────────────────────────────
|
||||
|
||||
def get_daily_usage(self) -> int:
|
||||
"""聚合今日 token 消耗(status='success',created_at 为今日)。"""
|
||||
conn = self._get_conn()
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COALESCE(SUM(tokens_used), 0)
|
||||
FROM biz.ai_run_logs
|
||||
WHERE status = 'success'
|
||||
AND created_at >= CURRENT_DATE
|
||||
AND created_at < CURRENT_DATE + INTERVAL '1 day'
|
||||
"""
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
def get_monthly_usage(self) -> int:
|
||||
"""聚合本月 token 消耗(status='success',created_at 为本月)。"""
|
||||
conn = self._get_conn()
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COALESCE(SUM(tokens_used), 0)
|
||||
FROM biz.ai_run_logs
|
||||
WHERE status = 'success'
|
||||
AND created_at >= date_trunc('month', CURRENT_DATE)
|
||||
AND created_at < date_trunc('month', CURRENT_DATE) + INTERVAL '1 month'
|
||||
"""
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
109
apps/backend/app/ai/test_rate_limiter.py
Normal file
109
apps/backend/app/ai/test_rate_limiter.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""RateLimiter 单元测试。
|
||||
|
||||
被测代码:apps/backend/app/ai/rate_limiter.py
|
||||
纯内存测试,不涉及 DB/网络。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.ai.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class TestCheckUserRate:
|
||||
"""App1 每用户每分钟限流。"""
|
||||
|
||||
def test_allows_under_limit(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(10):
|
||||
assert rl.check_user_rate("u1", limit=10) is True
|
||||
|
||||
def test_rejects_at_limit(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(10):
|
||||
rl.check_user_rate("u1", limit=10)
|
||||
assert rl.check_user_rate("u1", limit=10) is False
|
||||
|
||||
def test_different_users_independent(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(10):
|
||||
rl.check_user_rate("u1", limit=10)
|
||||
# u1 已满,u2 不受影响
|
||||
assert rl.check_user_rate("u1", limit=10) is False
|
||||
assert rl.check_user_rate("u2", limit=10) is True
|
||||
|
||||
def test_window_expiry_allows_again(self):
|
||||
"""窗口过期后,历史请求不影响当前判断。"""
|
||||
rl = RateLimiter()
|
||||
base = time.monotonic()
|
||||
|
||||
with patch("app.ai.rate_limiter.time.monotonic", return_value=base):
|
||||
for _ in range(10):
|
||||
rl.check_user_rate("u1", limit=10, window_seconds=60)
|
||||
|
||||
# 61 秒后,窗口内无请求
|
||||
with patch("app.ai.rate_limiter.time.monotonic", return_value=base + 61):
|
||||
assert rl.check_user_rate("u1", limit=10, window_seconds=60) is True
|
||||
|
||||
|
||||
class TestCheckStoreRate:
|
||||
"""App2~8 每门店每小时限流。"""
|
||||
|
||||
def test_allows_under_limit(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(5):
|
||||
assert rl.check_store_rate(123, limit=5) is True
|
||||
|
||||
def test_rejects_at_limit(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(5):
|
||||
rl.check_store_rate(123, limit=5)
|
||||
assert rl.check_store_rate(123, limit=5) is False
|
||||
|
||||
def test_different_stores_independent(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(5):
|
||||
rl.check_store_rate(100, limit=5)
|
||||
assert rl.check_store_rate(100, limit=5) is False
|
||||
assert rl.check_store_rate(200, limit=5) is True
|
||||
|
||||
def test_site_id_int_works(self):
|
||||
"""site_id 为 int,内部转 str 存储。"""
|
||||
rl = RateLimiter()
|
||||
assert rl.check_store_rate(2790685415443269, limit=100) is True
|
||||
|
||||
def test_window_expiry_allows_again(self):
|
||||
rl = RateLimiter()
|
||||
base = time.monotonic()
|
||||
|
||||
with patch("app.ai.rate_limiter.time.monotonic", return_value=base):
|
||||
for _ in range(100):
|
||||
rl.check_store_rate(123, limit=100, window_seconds=3600)
|
||||
|
||||
# 3601 秒后
|
||||
with patch("app.ai.rate_limiter.time.monotonic", return_value=base + 3601):
|
||||
assert rl.check_store_rate(123, limit=100, window_seconds=3600) is True
|
||||
|
||||
|
||||
class TestRejectedRequestNotRecorded:
|
||||
"""被拒绝的请求不应记录时间戳(不占用窗口配额)。"""
|
||||
|
||||
def test_rejected_user_request_not_counted(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(3):
|
||||
rl.check_user_rate("u1", limit=3)
|
||||
# 连续拒绝不应增加窗口内计数
|
||||
rl.check_user_rate("u1", limit=3)
|
||||
rl.check_user_rate("u1", limit=3)
|
||||
assert len(rl._user_windows["u1"]) == 3
|
||||
|
||||
def test_rejected_store_request_not_counted(self):
|
||||
rl = RateLimiter()
|
||||
for _ in range(3):
|
||||
rl.check_store_rate(1, limit=3)
|
||||
rl.check_store_rate(1, limit=3)
|
||||
rl.check_store_rate(1, limit=3)
|
||||
assert len(rl._store_windows["1"]) == 3
|
||||
Reference in New Issue
Block a user