This commit is contained in:
Neo
2026-03-15 10:15:02 +08:00
parent 2dd217522c
commit 72bb11b34f
916 changed files with 65306 additions and 16102803 deletions

View File

@@ -0,0 +1,222 @@
"""应用 1通用对话SSE 流式)。
每次进入 chat 页面新建 ai_conversations 记录(不复用),
首条消息注入页面上下文,流式返回 AI 回复。
app_id = "app1_chat"
"""
from __future__ import annotations
import json
import logging
from typing import AsyncGenerator
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import SSEEvent
logger = logging.getLogger(__name__)
APP_ID = "app1_chat"
async def chat_stream(
*,
message: str,
user_id: int | str,
nickname: str,
role: str,
site_id: int,
source_page: str | None = None,
page_context: dict | None = None,
screen_content: str | None = None,
bailian: BailianClient,
conv_svc: ConversationService,
) -> AsyncGenerator[SSEEvent, None]:
"""流式对话入口,返回 SSEEvent 异步生成器。
流程:
1. 创建 conversation 记录
2. 写入 user message
3. 构建 system prompt注入页面上下文
4. 调用 bailian.chat_stream 流式获取回复
5. 逐 chunk yield SSEEvent(type="chunk")
6. 完成后写入 assistant messageyield SSEEvent(type="done")
7. 异常时 yield SSEEvent(type="error")
"""
conversation_id: int | None = None
try:
# 1. 每次新建 conversation不复用
source_ctx = _build_source_context(
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
)
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_page=source_page,
source_context=source_ctx,
)
logger.info(
"App1 新建对话: conversation_id=%s user_id=%s site_id=%s",
conversation_id, user_id, site_id,
)
# 2. 立即写入 user message
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=message,
)
# 3. 构建消息列表system prompt + user message
messages = _build_messages(
message=message,
user_id=user_id,
nickname=nickname,
role=role,
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
)
# 4-5. 流式调用百炼,逐 chunk yield
full_reply_parts: list[str] = []
async for chunk in bailian.chat_stream(messages):
full_reply_parts.append(chunk)
yield SSEEvent(type="chunk", content=chunk)
# 6. 流式完成,拼接完整回复并写入 assistant message
full_reply = "".join(full_reply_parts)
# 百炼流式模式不返回 tokens_used按字符数估算粗略
estimated_tokens = len(full_reply)
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=full_reply,
tokens_used=estimated_tokens,
)
yield SSEEvent(
type="done",
conversation_id=conversation_id,
tokens_used=estimated_tokens,
)
except Exception as e:
logger.error(
"App1 对话异常: conversation_id=%s error=%s",
conversation_id, e,
exc_info=True,
)
yield SSEEvent(type="error", message=str(e))
def _build_messages(
*,
message: str,
user_id: int | str,
nickname: str,
role: str,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
) -> list[dict]:
"""构建发送给百炼的消息列表。
首条 system 消息注入页面上下文和用户信息。
"""
system_content = _build_system_prompt(
user_id=user_id,
nickname=nickname,
role=role,
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": message},
]
def _build_system_prompt(
*,
user_id: int | str,
nickname: str,
role: str,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
) -> dict:
"""构建 system prompt JSON。
通过 biz_params.user_prompt_params 传入用户信息,
注入页面上下文供 AI 理解当前场景。
"""
prompt: dict = {
"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。",
"biz_params": {
"user_prompt_params": {
"User_ID": str(user_id),
"Role": role,
"Nickname": nickname,
},
},
}
# 注入页面上下文(首条消息)
page_ctx = _build_page_context(
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
)
if page_ctx:
prompt["page_context"] = page_ctx
return prompt
def _build_page_context(
*,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
) -> dict:
"""构建页面上下文信息。
P5-A 阶段:直接透传前端传入的上下文字段。
P5-B 阶段:各页面逐步实现文本化工具,丰富 screen_content。
"""
# TODO: P5-B 各页面文本化工具细化
ctx: dict = {}
if source_page:
ctx["source_page"] = source_page
if page_context:
ctx["page_context"] = page_context
if screen_content:
ctx["screen_content"] = screen_content
return ctx
def _build_source_context(
*,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
) -> dict | None:
"""构建存入 ai_conversations.source_context 的 JSON。"""
ctx: dict = {}
if source_page:
ctx["source_page"] = source_page
if page_context:
ctx["page_context"] = page_context
if screen_content:
ctx["screen_content"] = screen_content
return ctx if ctx else None

View File

@@ -0,0 +1,210 @@
"""应用 2财务洞察。
8 个时间维度独立调用,每次调用结果写入 ai_cache
同时创建 ai_conversations + ai_messages 记录。
营业日分界点:每日 08:00BUSINESS_DAY_START_HOUR 环境变量,默认 8
app_id = "app2_finance"
"""
from __future__ import annotations
import json
import logging
import os
from datetime import date, datetime, timedelta
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.prompts.app2_finance_prompt import build_prompt
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app2_finance"
# 8 个时间维度编码
TIME_DIMENSIONS = (
"this_month",
"last_month",
"this_week",
"last_week",
"last_3_months",
"this_quarter",
"last_quarter",
"last_6_months",
)
def get_business_date() -> date:
"""根据营业日分界点计算当前营业日。
分界点前(如 07:59视为前一天营业日
分界点及之后(如 08:00视为当天营业日。
"""
hour = int(os.environ.get("BUSINESS_DAY_START_HOUR", "8"))
now = datetime.now()
if now.hour < hour:
return (now - timedelta(days=1)).date()
return now.date()
def compute_time_range(dimension: str, business_date: date) -> tuple[date, date]:
"""计算时间维度对应的日期范围 [start, end](闭区间)。
Args:
dimension: 时间维度编码
business_date: 当前营业日
Returns:
(start_date, end_date) 元组
"""
y, m, d = business_date.year, business_date.month, business_date.day
if dimension == "this_month":
start = date(y, m, 1)
return start, business_date
if dimension == "last_month":
prev = _month_offset(y, m, -1)
start = date(prev[0], prev[1], 1)
end = date(y, m, 1) - timedelta(days=1)
return start, end
if dimension == "this_week":
# 周一起算
weekday = business_date.weekday() # 0=周一
start = business_date - timedelta(days=weekday)
return start, business_date
if dimension == "last_week":
weekday = business_date.weekday()
this_monday = business_date - timedelta(days=weekday)
last_monday = this_monday - timedelta(days=7)
last_sunday = this_monday - timedelta(days=1)
return last_monday, last_sunday
if dimension == "last_3_months":
# 当前月 - 3 ~ 当前月 - 1
end_ym = _month_offset(y, m, -1)
start_ym = _month_offset(y, m, -3)
start = date(start_ym[0], start_ym[1], 1)
# end = 上月最后一天
end = date(y, m, 1) - timedelta(days=1)
return start, end
if dimension == "this_quarter":
q_start_month = ((m - 1) // 3) * 3 + 1
start = date(y, q_start_month, 1)
return start, business_date
if dimension == "last_quarter":
q_start_month = ((m - 1) // 3) * 3 + 1
# 上季度结束 = 本季度第一天 - 1
this_q_start = date(y, q_start_month, 1)
end = this_q_start - timedelta(days=1)
# 上季度开始
ly, lm = end.year, end.month
lq_start_month = ((lm - 1) // 3) * 3 + 1
start = date(ly, lq_start_month, 1)
return start, end
if dimension == "last_6_months":
# 当前月 - 6 ~ 当前月 - 1
end_ym = _month_offset(y, m, -1)
start_ym = _month_offset(y, m, -6)
start = date(start_ym[0], start_ym[1], 1)
end = date(y, m, 1) - timedelta(days=1)
return start, end
raise ValueError(f"未知时间维度: {dimension}")
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App2 财务洞察调用。
Args:
context: 包含 site_id, time_dimension, user_id(默认'system'), nickname(默认'')
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONinsights 数组)
"""
site_id = context["site_id"]
time_dimension = context["time_dimension"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 构建 Prompt
prompt_context = {
"site_id": site_id,
"time_dimension": time_dimension,
"current_data": context.get("current_data", {}),
"previous_data": context.get("previous_data", {}),
}
messages = build_prompt(prompt_context)
# 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"time_dimension": time_dimension},
)
# 写入 system prompt 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
# 写入 user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 写入缓存
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP2_FINANCE.value,
site_id=site_id,
target_id=time_dimension,
result_json=result,
triggered_by=f"user:{user_id}",
)
logger.info(
"App2 财务洞察完成: site_id=%s dimension=%s conversation_id=%s tokens=%d",
site_id, time_dimension, conversation_id, tokens_used,
)
return result
def _month_offset(year: int, month: int, offset: int) -> tuple[int, int]:
"""计算月份偏移,返回 (year, month)。"""
# 转为 0-based 计算
total = (year * 12 + (month - 1)) + offset
return total // 12, total % 12 + 1

View File

@@ -0,0 +1,213 @@
"""应用 3客户数据维客线索分析骨架
客户新增消费时自动触发,通过 AI 分析客户数据提取维客线索。
线索 category 限定 3 个枚举值:客户基础、消费习惯、玩法偏好。
线索提供者统一标记为"系统"
使用 items_sum 口径(= table_charge_money + goods_money
+ assistant_pd_money + assistant_cx_money + electricity_money
禁止使用 consume_money。
app_id = "app3_clue"
"""
from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app3_clue"
def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段。
P5-B 阶段P9-T1补充 consumption_records 等完整数据。
Args:
context: 包含 site_id, member_id, nickname 等
cache_svc: 缓存服务,用于获取 reference 历史数据
Returns:
消息列表 [{"role": "system", "content": ...}, {"role": "user", ...}]
"""
site_id = context["site_id"]
member_id = context["member_id"]
# 构建 referenceApp6 线索 + 最近 2 套 App8 历史(附 generated_at
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
"task": "分析客户消费数据,提取维客线索。",
"app_id": APP_ID,
"rules": {
"category_enum": ["客户基础", "消费习惯", "玩法偏好"],
"providers": "系统",
"amount_caliber": "items_sum = table_charge_money + goods_money + assistant_pd_money + assistant_cx_money + electricity_money",
"禁止使用": "consume_money",
},
"output_format": {
"clues": [
{
"category": "枚举值(客户基础/消费习惯/玩法偏好)",
"summary": "一句话摘要",
"detail": "详细说明",
"emoji": "表情符号",
}
]
},
# TODO: P9-T1 细化 - consumption_records 等客户消费数据
"data": {
"consumption_records": "待 P9-T1 补充",
"member_info": "待 P9-T1 补充",
},
"reference": reference,
}
user_content = (
f"请分析会员 {member_id} 的消费数据,提取维客线索。"
"每条线索包含 category、summary、detail、emoji 四个字段。"
"category 必须是:客户基础、消费习惯、玩法偏好 之一。"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]
def _build_reference(
site_id: int,
member_id: int,
cache_svc: AICacheService | None,
) -> dict:
"""构建 Prompt reference 字段。
包含:
- App6 备注分析线索(最新一条,如有)
- 最近 2 套 App8 维客线索整理历史(附 generated_at
缓存不存在时返回空对象 {}
"""
if cache_svc is None:
return {}
reference: dict = {}
target_id = str(member_id)
# App6 备注分析线索
app6_latest = cache_svc.get_latest(
CacheTypeEnum.APP6_NOTE_ANALYSIS.value, site_id, target_id,
)
if app6_latest:
reference["app6_note_clues"] = {
"result_json": app6_latest.get("result_json"),
"generated_at": app6_latest.get("created_at"),
}
# 最近 2 套 App8 历史
app8_history = cache_svc.get_history(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id, limit=2,
)
if app8_history:
reference["app8_history"] = [
{
"result_json": h.get("result_json"),
"generated_at": h.get("created_at"),
}
for h in app8_history
]
return reference
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App3 客户数据维客线索分析。
流程:
1. build_prompt 构建 Prompt
2. bailian.chat_json 调用百炼
3. 写入 conversation + messages
4. 写入 ai_cache
5. 返回结果
Args:
context: site_id, member_id, user_id(默认'system'), nickname(默认'')
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONclues 数组)
"""
site_id = context["site_id"]
member_id = context["member_id"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"member_id": member_id},
)
# 写入 system + user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 3. 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 4. 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 5. 写入缓存
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP3_CLUE.value,
site_id=site_id,
target_id=str(member_id),
result_json=result,
triggered_by=f"user:{user_id}",
)
logger.info(
"App3 线索分析完成: site_id=%s member_id=%s conversation_id=%s tokens=%d",
site_id, member_id, conversation_id, tokens_used,
)
return result

View File

@@ -0,0 +1,200 @@
"""应用 4关系分析/任务建议(骨架)。
助教参与新结算或被分配召回任务时自动触发,
生成关系分析和任务建议。
Prompt reference 包含 App8 最新 + 最近 2 套历史(附 generated_at
缓存不存在时 reference 传空对象,标注"暂无历史线索"
app_id = "app4_analysis"
"""
from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app4_analysis"
def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段。
P5-B 阶段P6-T4补充 service_history、assistant_info 等完整数据。
Args:
context: 包含 site_id, assistant_id, member_id
cache_svc: 缓存服务,用于获取 reference 历史数据
Returns:
消息列表
"""
site_id = context["site_id"]
assistant_id = context["assistant_id"]
member_id = context["member_id"]
# 构建 referenceApp8 最新 + 最近 2 套历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
"task": "分析助教与客户的关系,生成任务建议。",
"app_id": APP_ID,
"output_format": {
"task_description": "任务描述文本",
"action_suggestions": ["建议1", "建议2"],
"one_line_summary": "一句话总结",
},
# TODO: P6-T4 细化 - service_history、assistant_info
"data": {
"service_history": "待 P6-T4 补充",
"assistant_info": "待 P6-T4 补充",
},
"reference": reference,
}
# 缓存不存在时在 user prompt 中标注
no_history_hint = ""
if not reference:
no_history_hint = "(暂无历史线索,请基于现有信息分析)"
user_content = (
f"请分析助教 {assistant_id} 与会员 {member_id} 的关系,"
f"生成任务建议。{no_history_hint}"
"返回 task_description、action_suggestions、one_line_summary 三个字段。"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]
def _build_reference(
site_id: int,
member_id: int,
cache_svc: AICacheService | None,
) -> dict:
"""构建 Prompt reference 字段。
包含:
- App8 最新维客线索(如有)
- 最近 2 套 App8 历史(附 generated_at
缓存不存在时返回空对象 {}
"""
if cache_svc is None:
return {}
reference: dict = {}
target_id = str(member_id)
# App8 最新
app8_latest = cache_svc.get_latest(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id,
)
if app8_latest:
reference["app8_latest"] = {
"result_json": app8_latest.get("result_json"),
"generated_at": app8_latest.get("created_at"),
}
# 最近 2 套 App8 历史
app8_history = cache_svc.get_history(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id, limit=2,
)
if app8_history:
reference["app8_history"] = [
{
"result_json": h.get("result_json"),
"generated_at": h.get("created_at"),
}
for h in app8_history
]
return reference
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App4 关系分析。
Args:
context: site_id, assistant_id, member_id
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONtask_description, action_suggestions, one_line_summary
"""
site_id = context["site_id"]
assistant_id = context["assistant_id"]
member_id = context["member_id"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"assistant_id": assistant_id, "member_id": member_id},
)
# 写入 system + user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 3. 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 4. 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 5. 写入缓存target_id = {assistant_id}_{member_id}
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP4_ANALYSIS.value,
site_id=site_id,
target_id=f"{assistant_id}_{member_id}",
result_json=result,
triggered_by=f"user:{user_id}",
)
logger.info(
"App4 关系分析完成: site_id=%s assistant=%s member=%s conversation_id=%s tokens=%d",
site_id, assistant_id, member_id, conversation_id, tokens_used,
)
return result

View File

@@ -0,0 +1,182 @@
"""应用 5话术参考骨架
App4 完成后自动联动触发,接收 App4 完整返回结果
作为 Prompt 中的 task_suggestion 字段。
Prompt reference 包含最近 2 套 App8 历史(附 generated_at
app_id = "app5_tactics"
"""
from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app5_tactics"
def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段。
P5-B 阶段P6-T4补充 service_history、assistant_info随 App4 同步)。
Args:
context: 包含 site_id, assistant_id, member_id, app4_result(dict)
cache_svc: 缓存服务,用于获取 reference 历史数据
Returns:
消息列表
"""
site_id = context["site_id"]
assistant_id = context["assistant_id"]
member_id = context["member_id"]
app4_result = context.get("app4_result", {})
# 构建 reference最近 2 套 App8 历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
"task": "基于关系分析和任务建议,生成沟通话术参考。",
"app_id": APP_ID,
"task_suggestion": app4_result,
"output_format": {
"tactics": [
{"scenario": "场景描述", "script": "话术内容"}
]
},
# TODO: P6-T4 细化 - service_history、assistant_info随 App4 同步)
"data": {
"service_history": "待 P6-T4 补充",
"assistant_info": "待 P6-T4 补充",
},
"reference": reference,
}
user_content = (
f"请为助教 {assistant_id} 生成与会员 {member_id} 沟通的话术参考。"
"返回 tactics 数组,每条包含 scenario 和 script 字段。"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]
def _build_reference(
site_id: int,
member_id: int,
cache_svc: AICacheService | None,
) -> dict:
"""构建 Prompt reference 字段。
包含最近 2 套 App8 历史(附 generated_at
缓存不存在时返回空对象 {}
"""
if cache_svc is None:
return {}
reference: dict = {}
target_id = str(member_id)
# 最近 2 套 App8 历史
app8_history = cache_svc.get_history(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id, limit=2,
)
if app8_history:
reference["app8_history"] = [
{
"result_json": h.get("result_json"),
"generated_at": h.get("created_at"),
}
for h in app8_history
]
return reference
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App5 话术参考。
Args:
context: site_id, assistant_id, member_id, app4_result(dict)
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONtactics 数组)
"""
site_id = context["site_id"]
assistant_id = context["assistant_id"]
member_id = context["member_id"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"assistant_id": assistant_id, "member_id": member_id},
)
# 写入 system + user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 3. 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 4. 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 5. 写入缓存target_id = {assistant_id}_{member_id}
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP5_TACTICS.value,
site_id=site_id,
target_id=f"{assistant_id}_{member_id}",
result_json=result,
triggered_by=f"user:{user_id}",
)
logger.info(
"App5 话术参考完成: site_id=%s assistant=%s member=%s conversation_id=%s tokens=%d",
site_id, assistant_id, member_id, conversation_id, tokens_used,
)
return result

View File

@@ -0,0 +1,217 @@
"""应用 6备注分析骨架
助教提交备注后自动触发,通过 AI 分析备注内容,
提取维客线索并评分。
返回 score1-10+ clues 数组。
评分规则6 分为标准分,重复/低价值/时效性低酌情扣分,高价值信息酌情加分。
线索 category 限定 6 个枚举值。
线索提供者标记为当前备注提供人context.noted_by_name
app_id = "app6_note"
"""
from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app6_note"
def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段。
P5-B 阶段P9-T1补充 consumption_data 等完整数据。
Args:
context: 包含 site_id, member_id, note_content, noted_by_name
cache_svc: 缓存服务,用于获取 reference 历史数据
Returns:
消息列表
"""
site_id = context["site_id"]
member_id = context["member_id"]
note_content = context.get("note_content", "")
noted_by_name = context.get("noted_by_name", "")
# 构建 referenceApp3 线索 + 最近 2 套 App8 历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
"task": "分析备注内容,提取维客线索并评分。",
"app_id": APP_ID,
"rules": {
"category_enum": [
"客户基础", "消费习惯", "玩法偏好",
"促销偏好", "社交关系", "重要反馈",
],
"providers": noted_by_name,
"scoring": "6 分为标准分,重复/低价值/时效性低酌情扣分,高价值信息酌情加分",
"score_range": "1-10",
},
"output_format": {
"score": "1-10 整数",
"clues": [
{
"category": "枚举值6 选 1",
"summary": "一句话摘要",
"detail": "详细说明",
"emoji": "表情符号",
}
],
},
"note_content": note_content,
"noted_by_name": noted_by_name,
# TODO: P9-T1 细化 - consumption_data 等客户消费数据
"data": {
"consumption_data": "待 P9-T1 补充",
},
"reference": reference,
}
user_content = (
f"请分析以下备注内容,提取维客线索并评分。\n"
f"备注提供人:{noted_by_name}\n"
f"备注内容:{note_content}\n"
"返回 score1-10 整数)和 clues 数组。"
"category 必须是:客户基础、消费习惯、玩法偏好、促销偏好、社交关系、重要反馈 之一。"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]
def _build_reference(
site_id: int,
member_id: int,
cache_svc: AICacheService | None,
) -> dict:
"""构建 Prompt reference 字段。
包含:
- App3 客户数据线索(最新一条,如有)
- 最近 2 套 App8 维客线索整理历史(附 generated_at
缓存不存在时返回空对象 {}
"""
if cache_svc is None:
return {}
reference: dict = {}
target_id = str(member_id)
# App3 客户数据线索
app3_latest = cache_svc.get_latest(
CacheTypeEnum.APP3_CLUE.value, site_id, target_id,
)
if app3_latest:
reference["app3_clues"] = {
"result_json": app3_latest.get("result_json"),
"generated_at": app3_latest.get("created_at"),
}
# 最近 2 套 App8 历史
app8_history = cache_svc.get_history(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id, limit=2,
)
if app8_history:
reference["app8_history"] = [
{
"result_json": h.get("result_json"),
"generated_at": h.get("created_at"),
}
for h in app8_history
]
return reference
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App6 备注分析。
Args:
context: site_id, member_id, note_content, noted_by_name
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONscore + clues 数组)
"""
site_id = context["site_id"]
member_id = context["member_id"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"member_id": member_id},
)
# 写入 system + user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 3. 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 4. 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 5. 写入缓存score 存入 ai_cache.score
score = result.get("score")
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP6_NOTE_ANALYSIS.value,
site_id=site_id,
target_id=str(member_id),
result_json=result,
triggered_by=f"user:{user_id}",
score=score,
)
logger.info(
"App6 备注分析完成: site_id=%s member_id=%s score=%s conversation_id=%s tokens=%d",
site_id, member_id, score, conversation_id, tokens_used,
)
return result

View File

@@ -0,0 +1,200 @@
"""应用 7客户分析骨架
消费事件链中 App8 完成后串行触发,生成客户全量分析与运营建议。
使用 items_sum 口径(= table_charge_money + goods_money
+ assistant_pd_money + assistant_cx_money + electricity_money
禁止使用 consume_money。
对主观信息来自备注标注【来源XXX请甄别信息真实性】。
app_id = "app7_customer"
"""
from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app7_customer"
def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段。
P5-B 阶段P9-T1补充 objective_data 等完整数据。
Args:
context: 包含 site_id, member_id
cache_svc: 缓存服务,用于获取 reference 历史数据
Returns:
消息列表
"""
site_id = context["site_id"]
member_id = context["member_id"]
# 构建 reference最新 + 最近 2 套 App8 历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
"task": "综合分析客户数据,生成运营策略建议。",
"app_id": APP_ID,
"rules": {
"amount_caliber": "items_sum = table_charge_money + goods_money + assistant_pd_money + assistant_cx_money + electricity_money",
"禁止使用": "consume_money",
"subjective_info_label": "对主观信息来自备注标注【来源XXX请甄别信息真实性】",
},
"output_format": {
"strategies": [
{"title": "策略标题", "content": "策略内容"}
],
"summary": "一句话总结",
},
# TODO: P9-T1 细化 - objective_data 等客户消费数据
"data": {
"objective_data": "待 P9-T1 补充",
},
"reference": reference,
}
user_content = (
f"请综合分析会员 {member_id} 的客户数据,生成运营策略建议。"
"返回 strategies 数组(每条含 title 和 content和 summary 字段。"
"对来自备注的主观信息请标注【来源XXX请甄别信息真实性】。"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]
def _build_reference(
site_id: int,
member_id: int,
cache_svc: AICacheService | None,
) -> dict:
"""构建 Prompt reference 字段。
包含:
- App8 最新维客线索(如有)
- 最近 2 套 App8 历史(附 generated_at
缓存不存在时返回空对象 {}
"""
if cache_svc is None:
return {}
reference: dict = {}
target_id = str(member_id)
# App8 最新
app8_latest = cache_svc.get_latest(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id,
)
if app8_latest:
reference["app8_latest"] = {
"result_json": app8_latest.get("result_json"),
"generated_at": app8_latest.get("created_at"),
}
# 最近 2 套 App8 历史
app8_history = cache_svc.get_history(
CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value, site_id, target_id, limit=2,
)
if app8_history:
reference["app8_history"] = [
{
"result_json": h.get("result_json"),
"generated_at": h.get("created_at"),
}
for h in app8_history
]
return reference
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App7 客户分析。
Args:
context: site_id, member_id
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONstrategies 数组 + summary
"""
site_id = context["site_id"]
member_id = context["member_id"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"member_id": member_id},
)
# 写入 system + user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 3. 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 4. 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 5. 写入缓存
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP7_CUSTOMER_ANALYSIS.value,
site_id=site_id,
target_id=str(member_id),
result_json=result,
triggered_by=f"user:{user_id}",
)
logger.info(
"App7 客户分析完成: site_id=%s member_id=%s conversation_id=%s tokens=%d",
site_id, member_id, conversation_id, tokens_used,
)
return result

View File

@@ -0,0 +1,211 @@
"""应用 8维客线索整理。
接收 App3消费分析和 App6备注分析的线索
通过百炼 AI 整合去重,然后全量替换写入 member_retention_clue 表。
app_id = "app8_consolidation"
"""
from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.prompts.app8_consolidation_prompt import build_prompt
from app.ai.schemas import CacheTypeEnum
from app.database import get_connection
logger = logging.getLogger(__name__)
APP_ID = "app8_consolidation"
class ClueWriter:
"""维客线索全量替换写入器。
DELETE source IN ('ai_consumption', 'ai_note') → INSERT 新线索(事务)。
人工线索source='manual')不受影响。
"""
def replace_ai_clues(
self,
member_id: int,
site_id: int,
clues: list[dict],
) -> int:
"""全量替换该客户的 AI 来源线索,返回写入数量。
在单个事务中执行 DELETE + INSERT失败时回滚保留原有线索。
字段映射:
- category → category
- emoji + " " + summary → summary"📅 偏好周末下午时段消费"
- detail → detail
- providers → recorded_by_name
- source: 根据 providers 判断(见 _determine_source
- recorded_by_assistant_id: NULL系统触发
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 1. 删除该客户所有 AI 来源线索
cur.execute(
"""
DELETE FROM member_retention_clue
WHERE member_id = %s AND site_id = %s
AND source IN ('ai_consumption', 'ai_note')
""",
(member_id, site_id),
)
# 2. 插入新线索
for clue in clues:
emoji = clue.get("emoji", "")
raw_summary = clue.get("summary", "")
summary = f"{emoji} {raw_summary}" if emoji else raw_summary
source = _determine_source(clue.get("providers", ""))
cur.execute(
"""
INSERT INTO member_retention_clue
(member_id, site_id, category, summary, detail,
source, recorded_by_name, recorded_by_assistant_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, NULL)
""",
(
member_id,
site_id,
clue.get("category", ""),
summary,
clue.get("detail", ""),
source,
clue.get("providers", ""),
),
)
conn.commit()
return len(clues)
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _determine_source(providers: str) -> str:
"""根据 providers 判断 source 值。
- 纯 App3providers 仅含"系统")→ ai_consumption
- 纯 App6providers 不含"系统")→ ai_note
- 混合来源 → ai_consumption
"""
if not providers:
return "ai_consumption"
provider_list = [p.strip() for p in providers.split(",")]
has_system = "系统" in provider_list
has_human = any(p != "系统" for p in provider_list if p)
if has_system and not has_human:
# 纯 App3系统自动分析
return "ai_consumption"
elif has_human and not has_system:
# 纯 App6人工备注分析
return "ai_note"
else:
# 混合来源
return "ai_consumption"
async def run(
context: dict,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
"""执行 App8 维客线索整理。
流程:
1. build_prompt 构建 Prompt
2. bailian.chat_json 调用百炼
3. 写入 conversation + messages
4. 写入 ai_cache
5. ClueWriter 全量替换 member_retention_clue
6. 返回结果
Args:
context: site_id, member_id, app3_clues, app6_clues,
app3_generated_at, app6_generated_at
bailian: 百炼客户端
cache_svc: 缓存服务
conv_svc: 对话服务
Returns:
百炼返回的结构化 JSONclues 数组)
"""
site_id = context["site_id"]
member_id = context["member_id"]
user_id = context.get("user_id", "system")
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(
user_id=user_id,
nickname=nickname,
app_id=APP_ID,
site_id=site_id,
source_context={"member_id": member_id},
)
# 写入 system + user 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="system",
content=messages[0]["content"],
)
conv_svc.add_message(
conversation_id=conversation_id,
role="user",
content=messages[1]["content"],
)
# 3. 调用百炼 API
result, tokens_used = await bailian.chat_json(messages)
# 4. 写入 assistant 消息
conv_svc.add_message(
conversation_id=conversation_id,
role="assistant",
content=json.dumps(result, ensure_ascii=False),
tokens_used=tokens_used,
)
# 5. 写入缓存
cache_svc.write_cache(
cache_type=CacheTypeEnum.APP8_CLUE_CONSOLIDATED.value,
site_id=site_id,
target_id=str(member_id),
result_json=result,
triggered_by=f"user:{user_id}",
)
# 6. 全量替换 member_retention_clue
clues = result.get("clues", [])
if clues:
writer = ClueWriter()
written = writer.replace_ai_clues(member_id, site_id, clues)
logger.info(
"App8 线索写入完成: site_id=%s member_id=%s written=%d",
site_id, member_id, written,
)
logger.info(
"App8 线索整理完成: site_id=%s member_id=%s conversation_id=%s tokens=%d",
site_id, member_id, conversation_id, tokens_used,
)
return result

View File

@@ -0,0 +1,338 @@
"""AI 事件调度与调用链编排器。
根据业务事件(消费、备注、任务分配)编排 AI 应用调用链,
确保执行顺序和数据依赖正确。
调用链:
- 消费事件无助教App3 → App8 → App7
- 消费事件有助教App3 → App8 → App7 + App4 → App5
- 备注事件App6 → App8
- 任务分配事件App4 → App5读已有 App8 缓存)
容错策略:
- 某步失败记录错误日志,后续应用使用已有缓存继续
- 失败应用写入失败 conversation 记录
- 整条链后台异步执行,不阻塞业务请求
"""
from __future__ import annotations
import json
import logging
from typing import Any, Callable, Coroutine
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
class AIDispatcher:
"""AI 应用调用链编排器。"""
def __init__(
self,
bailian: BailianClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> None:
self.bailian = bailian
self.cache_svc = cache_svc
self.conv_svc = conv_svc
async def handle_consumption_event(
self,
member_id: int,
site_id: int,
settle_id: int,
assistant_id: int | None = None,
) -> None:
"""消费事件链App3 → App8 → App7+ App4 → App5 如有助教)。"""
from app.ai.apps.app3_clue import run as app3_run
from app.ai.apps.app4_analysis import run as app4_run
from app.ai.apps.app5_tactics import run as app5_run
from app.ai.apps.app7_customer import run as app7_run
from app.ai.apps.app8_consolidation import run as app8_run
context: dict[str, Any] = {
"member_id": member_id,
"site_id": site_id,
"settle_id": settle_id,
"user_id": "system",
"nickname": "",
}
# 步骤 1App3 线索分析
app3_result = await self._run_step("app3_clue", app3_run, context)
# 步骤 2App8 线索整理(需要 App3 的 clues
app8_context = {**context}
# 从 App3 结果提取 clues同时从缓存获取 App6 已有线索
if app3_result:
app8_context["app3_clues"] = app3_result.get("clues", [])
app8_context["app3_generated_at"] = None # 刚生成,无需时间戳
else:
app8_context["app3_clues"] = []
app8_context["app3_generated_at"] = None
# 从缓存获取 App6 已有线索
app6_cache = self.cache_svc.get_latest(
CacheTypeEnum.APP6_NOTE_ANALYSIS.value, site_id, str(member_id),
)
if app6_cache:
app6_result_json = app6_cache.get("result_json", {})
if isinstance(app6_result_json, str):
try:
app6_result_json = json.loads(app6_result_json)
except (json.JSONDecodeError, TypeError):
app6_result_json = {}
app8_context["app6_clues"] = app6_result_json.get("clues", [])
app8_context["app6_generated_at"] = app6_cache.get("created_at")
else:
app8_context["app6_clues"] = []
app8_context["app6_generated_at"] = None
await self._run_step("app8_consolidation", app8_run, app8_context)
# 步骤 3App7 客户分析
await self._run_step("app7_customer", app7_run, context)
# 步骤 4可选如有助教App4 → App5
if assistant_id is not None:
app4_context = {**context, "assistant_id": assistant_id}
app4_result = await self._run_step("app4_analysis", app4_run, app4_context)
app5_context = {
**context,
"assistant_id": assistant_id,
"app4_result": app4_result or {},
}
await self._run_step("app5_tactics", app5_run, app5_context)
async def handle_note_event(
self,
member_id: int,
site_id: int,
note_id: int,
note_content: str,
noted_by_name: str,
) -> None:
"""备注事件链App6 → App8。"""
from app.ai.apps.app6_note import run as app6_run
from app.ai.apps.app8_consolidation import run as app8_run
context: dict[str, Any] = {
"member_id": member_id,
"site_id": site_id,
"note_id": note_id,
"note_content": note_content,
"noted_by_name": noted_by_name,
"user_id": "system",
"nickname": "",
}
# 步骤 1App6 备注分析
app6_result = await self._run_step("app6_note", app6_run, context)
# 步骤 2App8 线索整理(需要 App6 的 clues
app8_context: dict[str, Any] = {
"member_id": member_id,
"site_id": site_id,
"user_id": "system",
"nickname": "",
}
if app6_result:
app8_context["app6_clues"] = app6_result.get("clues", [])
app8_context["app6_generated_at"] = None
else:
app8_context["app6_clues"] = []
app8_context["app6_generated_at"] = None
# 从缓存获取 App3 已有线索
app3_cache = self.cache_svc.get_latest(
CacheTypeEnum.APP3_CLUE.value, site_id, str(member_id),
)
if app3_cache:
app3_result_json = app3_cache.get("result_json", {})
if isinstance(app3_result_json, str):
try:
app3_result_json = json.loads(app3_result_json)
except (json.JSONDecodeError, TypeError):
app3_result_json = {}
app8_context["app3_clues"] = app3_result_json.get("clues", [])
app8_context["app3_generated_at"] = app3_cache.get("created_at")
else:
app8_context["app3_clues"] = []
app8_context["app3_generated_at"] = None
await self._run_step("app8_consolidation", app8_run, app8_context)
async def handle_task_assign_event(
self,
assistant_id: int,
member_id: int,
site_id: int,
task_type: str,
) -> None:
"""任务分配事件链App4 → App5读已有 App8 缓存)。"""
from app.ai.apps.app4_analysis import run as app4_run
from app.ai.apps.app5_tactics import run as app5_run
context: dict[str, Any] = {
"assistant_id": assistant_id,
"member_id": member_id,
"site_id": site_id,
"task_type": task_type,
"user_id": "system",
"nickname": "",
}
# 步骤 1App4 关系分析
app4_result = await self._run_step("app4_analysis", app4_run, context)
# 步骤 2App5 话术参考
app5_context = {
**context,
"app4_result": app4_result or {},
}
await self._run_step("app5_tactics", app5_run, app5_context)
async def _run_chain(
self,
chain: list[tuple[str, Callable[..., Coroutine], dict]],
) -> None:
"""串行执行调用链,某步失败记录日志后继续。
Args:
chain: [(app_name, run_func, context), ...] 的列表
"""
for app_name, run_func, ctx in chain:
await self._run_step(app_name, run_func, ctx)
async def _run_step(
self,
app_name: str,
run_func: Callable[..., Coroutine],
context: dict,
) -> dict | None:
"""执行单个应用步骤,失败时记录日志并写入失败 conversation。
Returns:
应用返回结果,失败时返回 None
"""
try:
result = await run_func(
context,
self.bailian,
self.cache_svc,
self.conv_svc,
)
logger.info("调用链步骤成功: %s", app_name)
return result
except Exception:
logger.exception("调用链步骤失败: %s", app_name)
# 写入失败 conversation 记录
try:
site_id = context.get("site_id", 0)
conv_id = self.conv_svc.create_conversation(
user_id="system",
nickname="",
app_id=app_name,
site_id=site_id,
source_context={"error": True, "chain_step": app_name},
)
self.conv_svc.add_message(
conversation_id=conv_id,
role="system",
content=f"调用链步骤 {app_name} 执行失败",
)
except Exception:
logger.exception("写入失败 conversation 记录也失败: %s", app_name)
return None
def _create_ai_event_handlers(dispatcher: AIDispatcher) -> dict[str, Callable]:
"""创建 AI 事件处理器,用于注册到 trigger_scheduler。
每个处理器从 payload 提取参数,通过 asyncio.create_task 后台执行,
不阻塞同步的 fire_event 调用。
Returns:
{event_job_type: handler_func} 映射
"""
import asyncio
def _get_or_create_loop() -> asyncio.AbstractEventLoop:
"""获取当前事件循环,兼容同步调用场景。"""
try:
return asyncio.get_running_loop()
except RuntimeError:
return asyncio.new_event_loop()
def handle_consumption_settled(payload: dict | None = None, **_kw: Any) -> None:
"""消费结算事件处理器(同步入口,内部异步执行)。"""
if not payload:
logger.warning("consumption_settled 事件缺少 payload")
return
loop = _get_or_create_loop()
loop.create_task(
dispatcher.handle_consumption_event(
member_id=payload["member_id"],
site_id=payload["site_id"],
settle_id=payload["settle_id"],
assistant_id=payload.get("assistant_id"),
)
)
def handle_note_created(payload: dict | None = None, **_kw: Any) -> None:
"""备注创建事件处理器。"""
if not payload:
logger.warning("note_created 事件缺少 payload")
return
loop = _get_or_create_loop()
loop.create_task(
dispatcher.handle_note_event(
member_id=payload["member_id"],
site_id=payload["site_id"],
note_id=payload["note_id"],
note_content=payload.get("note_content", ""),
noted_by_name=payload.get("noted_by_name", ""),
)
)
def handle_task_assigned(payload: dict | None = None, **_kw: Any) -> None:
"""任务分配事件处理器。"""
if not payload:
logger.warning("task_assigned 事件缺少 payload")
return
loop = _get_or_create_loop()
loop.create_task(
dispatcher.handle_task_assign_event(
assistant_id=payload["assistant_id"],
member_id=payload["member_id"],
site_id=payload["site_id"],
task_type=payload.get("task_type", ""),
)
)
return {
"ai_consumption_settled": handle_consumption_settled,
"ai_note_created": handle_note_created,
"ai_task_assigned": handle_task_assigned,
}
def register_ai_handlers(dispatcher: AIDispatcher) -> None:
"""将 AI 事件处理器注册到 trigger_scheduler。
在 FastAPI lifespan 中调用,将三个 AI 事件处理器
注册为 trigger_scheduler 的 job handler。
"""
from app.services.trigger_scheduler import register_job
handlers = _create_ai_event_handlers(dispatcher)
for job_type, handler in handlers.items():
register_job(job_type, handler)
logger.info("已注册 AI 事件处理器: %s", job_type)

View File

@@ -0,0 +1,145 @@
"""应用 2 财务洞察 Prompt 模板。
构建包含当期和上期收入结构的完整 Prompt供百炼 API 生成财务洞察。
收入字段映射(严格遵守 items_sum 口径):
- table_fee = table_charge_money台费
- assistant_pd = assistant_pd_money陪打费
- assistant_cx = assistant_cx_money超休费
- goods = goods_money商品收入
- recharge = 充值 pay_amount settle_type=5充值收入
禁止使用 consume_money统一使用
items_sum = table_charge_money + goods_money + assistant_pd_money
+ assistant_cx_money + electricity_money
"""
from __future__ import annotations
import json
def build_prompt(context: dict) -> list[dict]:
"""构建 App2 财务洞察 Prompt 消息列表。
Args:
context: 包含以下字段:
- site_id: int门店 ID
- time_dimension: str时间维度编码
- current_data: dict当期数据
- previous_data: dict上期数据
Returns:
messages 列表system + user供 BailianClient.chat_json 调用
"""
site_id = context.get("site_id", 0)
time_dimension = context.get("time_dimension", "")
current_data = context.get("current_data", {})
previous_data = context.get("previous_data", {})
system_content = _build_system_content(
site_id=site_id,
time_dimension=time_dimension,
current_data=current_data,
previous_data=previous_data,
)
user_content = (
f"请根据以上数据,为门店 {site_id} 生成 {_dimension_label(time_dimension)} 的财务洞察分析。"
"以 JSON 格式返回,包含 insights 数组,每项含 seq序号、title标题、body正文"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]
def _build_system_content(
*,
site_id: int,
time_dimension: str,
current_data: dict,
previous_data: dict,
) -> dict:
"""构建 system prompt JSON 结构。"""
return {
"task": (
"你是台球门店的财务分析 AI 助手。"
"根据提供的当期和上期经营数据,生成结构化的财务洞察。"
"分析维度包括:收入结构变化、各收入项占比、环比趋势、异常波动。"
"输出 JSON 格式:{\"insights\": [{\"seq\": 1, \"title\": \"...\", \"body\": \"...\"}]}"
),
"data": {
"site_id": site_id,
"time_dimension": time_dimension,
"time_dimension_label": _dimension_label(time_dimension),
"current_period": _build_period_data(current_data),
"previous_period": _build_period_data(previous_data),
},
"reference": {
"field_mapping": {
"items_sum": (
"table_charge_money + goods_money + assistant_pd_money"
" + assistant_cx_money + electricity_money"
),
"table_fee": "table_charge_money台费收入",
"assistant_pd": "assistant_pd_money陪打费",
"assistant_cx": "assistant_cx_money超休费",
"goods": "goods_money商品收入",
"recharge": "充值 pay_amountsettle_type=5充值收入",
"electricity": "electricity_money电费当前未启用全为 0",
},
"rules": [
"统一使用 items_sum 口径计算营收总额",
"助教费用必须拆分为 assistant_pd_money陪打和 assistant_cx_money超休",
"支付渠道恒等式balance_amount = recharge_card_amount + gift_card_amount",
"金额单位CNY保留两位小数",
],
},
}
def _build_period_data(data: dict) -> dict:
"""构建单期数据结构,确保字段名遵守 items_sum 口径。"""
return {
# 收入结构items_sum 口径)
"table_charge_money": data.get("table_charge_money", 0),
"goods_money": data.get("goods_money", 0),
"assistant_pd_money": data.get("assistant_pd_money", 0),
"assistant_cx_money": data.get("assistant_cx_money", 0),
"electricity_money": data.get("electricity_money", 0),
# 充值收入
"recharge_income": data.get("recharge_income", 0),
# 储值资产
"balance_pay": data.get("balance_pay", 0),
"recharge_card_pay": data.get("recharge_card_pay", 0),
"gift_card_pay": data.get("gift_card_pay", 0),
# 费用汇总
"discount_amount": data.get("discount_amount", 0),
"adjust_amount": data.get("adjust_amount", 0),
# 平台结算
"platform_settlement_amount": data.get("platform_settlement_amount", 0),
"groupbuy_pay_amount": data.get("groupbuy_pay_amount", 0),
# 汇总
"order_count": data.get("order_count", 0),
"member_count": data.get("member_count", 0),
}
# 时间维度编码 → 中文标签
_DIMENSION_LABELS: dict[str, str] = {
"this_month": "本月",
"last_month": "上月",
"this_week": "本周",
"last_week": "上周",
"last_3_months": "近三个月",
"this_quarter": "本季度",
"last_quarter": "上季度",
"last_6_months": "近六个月",
}
def _dimension_label(dimension: str) -> str:
"""将时间维度编码转为中文标签。"""
return _DIMENSION_LABELS.get(dimension, dimension)

View File

@@ -0,0 +1,93 @@
"""应用 8维客线索整理 Prompt 模板。
接收 App3消费分析和 App6备注分析的全部线索
整合去重后输出统一维客线索。
分类标签限定 6 个枚举值(与 member_retention_clue CHECK 约束一致):
客户基础、消费习惯、玩法偏好、促销偏好、社交关系、重要反馈。
合并规则:
- 相似线索合并providers 以逗号分隔
- 其余线索原文返回
- 最小改动原则
"""
from __future__ import annotations
import json
def build_prompt(context: dict) -> list[dict]:
"""构建 App8 维客线索整理 Prompt。
Args:
context: 包含以下字段:
- site_id: int
- member_id: int
- app3_clues: list[dict] — App3 产出的线索列表
- app6_clues: list[dict] — App6 产出的线索列表
- app3_generated_at: str | None — App3 线索生成时间
- app6_generated_at: str | None — App6 线索生成时间
Returns:
消息列表 [{"role": "system", ...}, {"role": "user", ...}]
"""
member_id = context["member_id"]
app3_clues = context.get("app3_clues", [])
app6_clues = context.get("app6_clues", [])
app3_generated_at = context.get("app3_generated_at")
app6_generated_at = context.get("app6_generated_at")
system_content = {
"task": "整合去重来自消费分析和备注分析的维客线索,输出统一线索列表。",
"app_id": "app8_consolidation",
"rules": {
"category_enum": [
"客户基础", "消费习惯", "玩法偏好",
"促销偏好", "社交关系", "重要反馈",
],
"merge_strategy": (
"相似线索合并为一条providers 以逗号分隔(如 '系统,张三'"
"不相似的线索原文保留,不做修改。最小改动原则。"
),
"output_format": {
"clues": [
{
"category": "枚举值6 选 1",
"summary": "一句话摘要",
"detail": "详细说明",
"emoji": "表情符号",
"providers": "提供者(逗号分隔)",
}
]
},
},
"input": {
"app3_clues": {
"source": "消费数据分析App3",
"generated_at": app3_generated_at,
"clues": app3_clues,
},
"app6_clues": {
"source": "备注分析App6",
"generated_at": app6_generated_at,
"clues": app6_clues,
},
},
}
user_content = (
f"请整合会员 {member_id} 的维客线索。\n"
"输入包含两个来源的线索App3消费数据分析和 App6备注分析\n"
"规则:\n"
"1. 相似线索合并为一条providers 字段以逗号分隔多个提供者\n"
"2. 不相似的线索原文保留\n"
"3. category 必须是:客户基础、消费习惯、玩法偏好、促销偏好、社交关系、重要反馈 之一\n"
"4. 每条线索包含 category、summary、detail、emoji、providers 五个字段\n"
"5. 最小改动原则,尽量保留原始表述"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "user", "content": user_content},
]

View File

@@ -17,7 +17,9 @@ from app import config
# CHANGE 2026-02-26 | member_birthday 路由替换为 member_retention_clue维客线索重构
# CHANGE 2026-02-26 | 新增 admin_applications 路由(管理端申请审核)
# CHANGE 2026-02-27 | 新增 xcx_tasks / xcx_notes 路由(小程序核心业务)
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback, member_retention_clue, ops_panel, xcx_auth, admin_applications, business_day, xcx_tasks, xcx_notes
# CHANGE 2026-03-09 | 新增 xcx_ai_chat 路由AI SSE 对话 + 历史对话)
# CHANGE 2026-03-09 | 新增 xcx_ai_cache 路由AI 缓存查询)
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback, member_retention_clue, ops_panel, xcx_auth, admin_applications, business_day, xcx_tasks, xcx_notes, xcx_ai_chat, xcx_ai_cache
from app.services.scheduler import scheduler
from app.services.task_queue import task_queue
from app.ws.logs import ws_router
@@ -57,6 +59,25 @@ async def lifespan(app: FastAPI):
register_job("recall_completion_check", recall_detector.run)
register_job("note_reclassify_backfill", note_reclassifier.run)
# CHANGE 2026-03-10 | 注册 AI 事件处理器(消费/备注/任务分配 → AI 调用链)
try:
import os
_api_key = os.environ.get("BAILIAN_API_KEY", "")
_base_url = os.environ.get("BAILIAN_BASE_URL", "")
_model = os.environ.get("BAILIAN_MODEL", "qwen-plus")
if _api_key and _base_url:
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.dispatcher import AIDispatcher, register_ai_handlers
_bailian = BailianClient(api_key=_api_key, base_url=_base_url, model=_model)
_dispatcher = AIDispatcher(_bailian, AICacheService(), ConversationService())
register_ai_handlers(_dispatcher)
except Exception:
import logging as _log
_log.getLogger(__name__).warning("AI 事件处理器注册失败AI 功能不可用", exc_info=True)
yield
# 关闭
await scheduler.stop()
@@ -100,6 +121,8 @@ app.include_router(admin_applications.router)
app.include_router(business_day.router)
app.include_router(xcx_tasks.router)
app.include_router(xcx_notes.router)
app.include_router(xcx_ai_chat.router)
app.include_router(xcx_ai_cache.router)
@app.get("/health", tags=["系统"])

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
小程序 AI 缓存查询路由 —— 查询各 AI 应用的最新缓存结果。
端点清单:
- GET /api/ai/cache/{cache_type}?target_id=xxx — 查询最新缓存
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.ai.cache_service import AICacheService
from app.ai.schemas import CacheTypeEnum
from app.auth.dependencies import CurrentUser, get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 缓存"])
@router.get("/cache/{cache_type}")
async def get_ai_cache(
cache_type: str,
target_id: str = Query(..., description="目标 IDmember_id / assistant_id_member_id / 时间维度编码)"),
user: CurrentUser = Depends(get_current_user),
):
"""查询指定类型的最新 AI 缓存结果。
site_id 从 JWT 提取,强制过滤,确保门店隔离。
"""
# 校验 cache_type 合法性
valid_types = {e.value for e in CacheTypeEnum}
if cache_type not in valid_types:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"无效的 cache_type: {cache_type},合法值: {sorted(valid_types)}",
)
cache_svc = AICacheService()
result = cache_svc.get_latest(
cache_type=cache_type,
site_id=user.site_id,
target_id=target_id,
)
if result is None:
return None
return {
"id": result.get("id"),
"cache_type": result.get("cache_type"),
"target_id": result.get("target_id"),
"result_json": result.get("result_json"),
"score": result.get("score"),
"created_at": result.get("created_at"),
}

View File

@@ -0,0 +1,223 @@
# -*- coding: utf-8 -*-
"""
小程序 AI 对话路由 —— SSE 流式对话、历史对话列表、消息查询。
端点清单:
- POST /api/ai/chat/stream — SSE 流式对话
- GET /api/ai/conversations — 历史对话列表(分页)
- GET /api/ai/conversations/{conversation_id}/messages — 对话消息列表
"""
from __future__ import annotations
import json
import logging
import os
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from app.ai.bailian_client import BailianClient
from app.ai.conversation_service import ConversationService
from app.ai.apps.app1_chat import chat_stream
from app.ai.schemas import ChatStreamRequest, SSEEvent
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["小程序 AI 对话"])
# ── 辅助:获取用户 nickname ──────────────────────────────────
def _get_user_nickname(user_id: int) -> str:
"""从 auth.users 查询用户 nickname查不到返回空字符串。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT nickname FROM auth.users WHERE id = %s",
(user_id,),
)
row = cur.fetchone()
return row[0] if row and row[0] else ""
finally:
conn.close()
# ── 辅助:获取用户主要角色 ───────────────────────────────────
def _get_user_role_label(roles: list[str]) -> str:
"""从角色列表提取主要角色标签,用于 AI 上下文。"""
if "store_manager" in roles or "owner" in roles:
return "管理者"
if "assistant" in roles or "coach" in roles:
return "助教"
return "用户"
# ── 辅助:构建 BailianClient 实例 ────────────────────────────
def _get_bailian_client() -> BailianClient:
"""从环境变量构建 BailianClient缺失时报错。"""
api_key = os.environ.get("BAILIAN_API_KEY")
base_url = os.environ.get("BAILIAN_BASE_URL")
model = os.environ.get("BAILIAN_MODEL")
if not api_key or not base_url or not model:
raise RuntimeError(
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
)
return BailianClient(api_key=api_key, base_url=base_url, model=model)
# ── SSE 流式对话 ─────────────────────────────────────────────
@router.post("/chat/stream")
async def ai_chat_stream(
body: ChatStreamRequest,
user: CurrentUser = Depends(get_current_user),
):
"""SSE 流式对话端点。
接收用户消息,通过百炼 API 流式返回 AI 回复。
每个 SSE 事件格式data: {json}\n\n
事件类型chunk文本片段/ done完成/ error错误
"""
if not body.message or not body.message.strip():
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="消息内容不能为空",
)
nickname = _get_user_nickname(user.user_id)
role_label = _get_user_role_label(user.roles)
bailian = _get_bailian_client()
conv_svc = ConversationService()
async def event_generator():
"""SSE 事件生成器,逐事件 yield data: {json}\n\n 格式。"""
try:
async for event in chat_stream(
message=body.message.strip(),
user_id=user.user_id,
nickname=nickname,
role=role_label,
site_id=user.site_id,
source_page=body.source_page,
page_context=body.page_context,
screen_content=body.screen_content,
bailian=bailian,
conv_svc=conv_svc,
):
yield f"data: {event.model_dump_json()}\n\n"
except Exception as e:
# 兜底:生成器内部异常也以 SSE error 事件返回
logger.error("SSE 生成器异常: %s", e, exc_info=True)
error_event = SSEEvent(type="error", message=str(e))
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # nginx 禁用缓冲
},
)
# ── 历史对话列表 ─────────────────────────────────────────────
@router.get("/conversations")
async def list_conversations(
page: int = 1,
page_size: int = 20,
user: CurrentUser = Depends(get_current_user),
):
"""查询当前用户的历史对话列表,按时间倒序,分页。"""
if page < 1:
page = 1
if page_size < 1 or page_size > 100:
page_size = 20
conv_svc = ConversationService()
conversations = conv_svc.get_conversations(
user_id=user.user_id,
site_id=user.site_id,
page=page,
page_size=page_size,
)
# 为每条对话附加首条消息预览
result = []
for conv in conversations:
item = {
"id": conv["id"],
"app_id": conv["app_id"],
"source_page": conv.get("source_page"),
"created_at": conv["created_at"],
"first_message_preview": None,
}
# 查询首条 user 消息作为预览
messages = conv_svc.get_messages(conv["id"])
for msg in messages:
if msg["role"] == "user":
content = msg["content"] or ""
item["first_message_preview"] = content[:50] if len(content) > 50 else content
break
result.append(item)
return result
# ── 对话消息列表 ─────────────────────────────────────────────
@router.get("/conversations/{conversation_id}/messages")
async def get_conversation_messages(
conversation_id: int,
user: CurrentUser = Depends(get_current_user),
):
"""查询指定对话的所有消息,按时间升序。
验证对话归属当前用户和 site_id防止越权访问。
"""
# 先验证对话归属
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id FROM biz.ai_conversations
WHERE id = %s AND user_id = %s AND site_id = %s
""",
(conversation_id, str(user.user_id), user.site_id),
)
if not cur.fetchone():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在或无权访问",
)
finally:
conn.close()
conv_svc = ConversationService()
messages = conv_svc.get_messages(conversation_id)
return [
{
"id": msg["id"],
"role": msg["role"],
"content": msg["content"],
"tokens_used": msg.get("tokens_used"),
"created_at": msg["created_at"],
}
for msg in messages
]

View File

@@ -26,6 +26,8 @@ class TaskListItem(BaseModel):
# RS 指数 + 爱心 icon
rs_score: float | None
heart_icon: str # 💖 / 🧡 / 💛 / 💙
# 放弃原因(仅 abandoned 任务有值)
abandon_reason: str | None = None
class AbandonRequest(BaseModel):

View File

@@ -2,9 +2,14 @@
"""
备注回溯重分类器Note Reclassifier
召回完成后,回溯检查是否有普通备注需重分类为回访备注。
查找 service_time 之后的第一条 normal 备注 → 更新为 follow_up →
触发 AI 应用 6 接口(占位)→ 根据 ai_score 生成 follow_up_visit 任务。
召回完成后,回溯检查是否有普通备注需重分类为回访备注,并创建回访任务
流程:
1. 查找 service_time 之后的第一条 normal 备注
2. 若找到 → 重分类为 follow_up任务状态 = completed回溯完成
3. 若未找到 → 任务状态 = active等待备注
4. 冲突检查:已有 completed → 跳过;已有 active → 顶替;否则正常创建
5. 保留 ai_analyze_note() 占位调用,返回值仅更新 ai_score 字段
由 trigger_jobs 中的 note_reclassify_backfill 配置驱动event: recall_completed
"""
@@ -62,21 +67,27 @@ def ai_analyze_note(note_id: int) -> int | None:
return None
def run(payload: dict | None = None) -> dict:
def run(payload: dict | None = None, job_id: int | None = None) -> dict:
"""
备注回溯主流程。
payload 包含: {site_id, assistant_id, member_id, service_time}
1. 查找 biz.notes 中该 (site_id, target_type='member', target_id=member_id)
service_time 之后提交的第一条 type='normal' 的备注
2. 将该备注 type 从 'normal' 更新为 'follow_up'
3. 触发 AI 应用 6 接口P5 实现,本 SPEC 仅定义触发接口):
- 调用 ai_analyze_note(note_id) → 返回 ai_score
4. 若 ai_score >= 6
- 生成 follow_up_visit 任务status='completed'(回溯完成)
5. 若 ai_score < 6
- 生成 follow_up_visit 任务status='active'(需助教重新备注)
流程:
1. 查找 service_time 之后的第一条 normal 备注 → note_id
2. 若 note_id 存在:重分类为 follow_uptask_status = 'completed'(回溯完成)
3. 若 note_id 不存在task_status = 'active'(等待备注)
4. 保留 ai_analyze_note() 占位调用,返回值仅更新 ai_score 字段
5. 冲突检查T3
- 已有 completed → 跳过创建
- 已有 active → 旧任务标记 inactive + superseded 历史,创建新任务
- 不存在(或仅 inactive/abandoned→ 正常创建
6. 创建 follow_up_visit 任务
参数:
payload: 事件载荷(由 trigger_scheduler 传入)
job_id: 触发器 job ID由 trigger_scheduler 传入),用于在最终事务中
更新 last_run_at保证 handler 数据变更与 last_run_at 原子提交
返回: {"reclassified_count": int, "tasks_created": int}
"""
@@ -119,84 +130,166 @@ def run(payload: dict | None = None) -> dict:
note_id = row[0]
conn.commit()
if note_id is None:
logger.info(
"未找到符合条件的 normal 备注: site_id=%s, member_id=%s",
site_id, member_id,
)
return {"reclassified_count": 0, "tasks_created": 0}
# ── 2. 将备注 type 从 'normal' 更新为 'follow_up' ──
with conn.cursor() as cur:
cur.execute("BEGIN")
cur.execute(
"""
UPDATE biz.notes
SET type = 'follow_up', updated_at = NOW()
WHERE id = %s AND type = 'normal'
""",
(note_id,),
)
conn.commit()
reclassified_count = 1
# ── 3. 触发 AI 应用 6 接口(占位,当前返回 None ──
ai_score = ai_analyze_note(note_id)
# ── 4/5. 根据 ai_score 生成 follow_up_visit 任务 ──
if ai_score is not None:
if ai_score >= 6:
# 回溯完成:生成 completed 任务
task_status = "completed"
else:
# 需助教重新备注:生成 active 任务
task_status = "active"
# ── 2. 根据是否找到备注确定任务状态T4 ──
if note_id is not None:
# 找到备注 → 重分类为 follow_up
with conn.cursor() as cur:
cur.execute("BEGIN")
cur.execute(
"""
INSERT INTO biz.coach_tasks
(site_id, assistant_id, member_id, task_type,
status, completed_at, completed_task_type)
VALUES (
%s, %s, %s, 'follow_up_visit',
%s,
CASE WHEN %s = 'completed' THEN NOW() ELSE NULL END,
CASE WHEN %s = 'completed' THEN 'follow_up_visit' ELSE NULL END
)
RETURNING id
UPDATE biz.notes
SET type = 'follow_up', updated_at = NOW()
WHERE id = %s AND type = 'normal'
""",
(
site_id, assistant_id, member_id,
task_status, task_status, task_status,
),
)
new_task_row = cur.fetchone()
new_task_id = new_task_row[0]
# 记录任务创建历史
_insert_history(
cur,
new_task_id,
action="created_by_reclassify",
old_status=None,
new_status=task_status,
old_task_type=None,
new_task_type="follow_up_visit",
detail={
"note_id": note_id,
"ai_score": ai_score,
"source": "note_reclassifier",
},
(note_id,),
)
conn.commit()
tasks_created = 1
reclassified_count = 1
# 保留 AI 占位调用,返回值仅用于更新 ai_score 字段
ai_score = ai_analyze_note(note_id)
if ai_score is not None:
with conn.cursor() as cur:
cur.execute("BEGIN")
cur.execute(
"""
UPDATE biz.notes
SET ai_score = %s, updated_at = NOW()
WHERE id = %s
""",
(ai_score, note_id),
)
conn.commit()
# 有备注 → 回溯完成
task_status = "completed"
else:
# AI 未就绪,跳过任务创建
# 未找到备注 → 等待备注
logger.info(
"AI 接口未就绪,跳过任务创建: note_id=%s", note_id
"未找到符合条件的 normal 备注: site_id=%s, member_id=%s",
site_id, member_id,
)
ai_score = None
task_status = "active"
# ── 3. 冲突检查T3查询已有 follow_up_visit 任务 ──
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, status
FROM biz.coach_tasks
WHERE site_id = %s AND assistant_id = %s AND member_id = %s
AND task_type = 'follow_up_visit'
AND status IN ('active', 'completed')
ORDER BY CASE WHEN status = 'completed' THEN 0 ELSE 1 END
LIMIT 1
""",
(site_id, assistant_id, member_id),
)
existing = cur.fetchone()
conn.commit()
if existing:
existing_id, existing_status = existing
if existing_status == "completed":
# 已完成 → 跳过创建(回访完成语义已满足)
logger.info(
"已存在 completed 回访任务 id=%s,跳过创建: "
"site_id=%s, assistant_id=%s, member_id=%s",
existing_id, site_id, assistant_id, member_id,
)
# 事务安全T5即使跳过创建handler 仍成功,更新 last_run_at
if job_id is not None:
from app.services.trigger_scheduler import (
update_job_last_run_at,
)
with conn.cursor() as cur:
cur.execute("BEGIN")
update_job_last_run_at(cur, job_id)
conn.commit()
return {
"reclassified_count": reclassified_count,
"tasks_created": 0,
}
elif existing_status == "active":
# 顶替:旧任务 → inactive + superseded 历史
with conn.cursor() as cur:
cur.execute("BEGIN")
cur.execute(
"""
UPDATE biz.coach_tasks
SET status = 'inactive', updated_at = NOW()
WHERE id = %s AND status = 'active'
""",
(existing_id,),
)
_insert_history(
cur,
existing_id,
action="superseded",
old_status="active",
new_status="inactive",
detail={
"reason": "new_reclassify_task_supersedes",
"source": "note_reclassifier",
},
)
conn.commit()
logger.info(
"顶替旧 active 回访任务 id=%s → inactive: "
"site_id=%s, assistant_id=%s, member_id=%s",
existing_id, site_id, assistant_id, member_id,
)
# ── 4. 创建 follow_up_visit 任务 ──
with conn.cursor() as cur:
cur.execute("BEGIN")
cur.execute(
"""
INSERT INTO biz.coach_tasks
(site_id, assistant_id, member_id, task_type,
status, completed_at, completed_task_type)
VALUES (
%s, %s, %s, 'follow_up_visit',
%s,
CASE WHEN %s = 'completed' THEN NOW() ELSE NULL END,
CASE WHEN %s = 'completed' THEN 'follow_up_visit' ELSE NULL END
)
RETURNING id
""",
(
site_id, assistant_id, member_id,
task_status, task_status, task_status,
),
)
new_task_row = cur.fetchone()
new_task_id = new_task_row[0]
# 记录任务创建历史
_insert_history(
cur,
new_task_id,
action="created_by_reclassify",
old_status=None,
new_status=task_status,
old_task_type=None,
new_task_type="follow_up_visit",
detail={
"note_id": note_id,
"ai_score": ai_score,
"source": "note_reclassifier",
},
)
# 事务安全T5在最终 commit 前更新 last_run_at
if job_id is not None:
from app.services.trigger_scheduler import update_job_last_run_at
update_job_last_run_at(cur, job_id)
conn.commit()
tasks_created = 1
except Exception:
logger.exception(
@@ -215,3 +308,4 @@ def run(payload: dict | None = None) -> dict:
"reclassified_count": reclassified_count,
"tasks_created": tasks_created,
}

View File

@@ -81,8 +81,8 @@ async def create_note(
- 否则 → type='normal'
3. INSERT INTO biz.notes
4. 若 type='follow_up'
- 触发 AI 应用 6 分析P5 实现)
- ai_score >= 6 且关联任务 status='active' → 标记任务 completed
- 保留 AI 占位调用P5 接入时调用链不变),返回值仅更新 ai_score
- 不论 ai_score 如何,有备注即标记关联 active 回访任务 completed
5. 返回创建的备注记录
注意:星星评分不参与回访完成判定,不参与 AI 分析,仅存储。
@@ -171,8 +171,9 @@ async def create_note(
"updated_at": row[13].isoformat() if row[13] else None,
}
# 若 type='follow_up',触发 AI 分析并可能标记任务完成
# 若 type='follow_up',触发 AI 分析并标记回访任务完成
if note_type == "follow_up" and task_id is not None:
# 保留 AI 占位调用P5 接入时调用链不变)
ai_score = ai_analyze_note(note["id"])
if ai_score is not None:
@@ -187,32 +188,32 @@ async def create_note(
)
note["ai_score"] = ai_score
# ai_score >= 6 且关联任务 status='active' → 标记任务 completed
if ai_score >= 6 and task_info and task_info["status"] == "active":
cur.execute(
"""
UPDATE biz.coach_tasks
SET status = 'completed',
completed_at = NOW(),
completed_task_type = task_type,
updated_at = NOW()
WHERE id = %s AND status = 'active'
""",
(task_id,),
)
_record_history(
cur,
task_id,
action="completed_by_note",
old_status="active",
new_status="completed",
old_task_type=task_info["task_type"],
new_task_type=task_info["task_type"],
detail={
"note_id": note["id"],
"ai_score": ai_score,
},
)
# 不论 ai_score 如何有备注即标记回访任务完成T4
if task_info and task_info["status"] == "active":
cur.execute(
"""
UPDATE biz.coach_tasks
SET status = 'completed',
completed_at = NOW(),
completed_task_type = task_type,
updated_at = NOW()
WHERE id = %s AND status = 'active'
""",
(task_id,),
)
_record_history(
cur,
task_id,
action="completed_by_note",
old_status="active",
new_status="completed",
old_task_type=task_info["task_type"],
new_task_type=task_info["task_type"],
detail={
"note_id": note["id"],
"ai_score": ai_score,
},
)
conn.commit()
return note

View File

@@ -52,7 +52,7 @@ def _insert_history(
)
def run(payload: dict | None = None) -> dict:
def run(payload: dict | None = None, job_id: int | None = None) -> dict:
"""
召回完成检测主流程。
@@ -69,6 +69,11 @@ def run(payload: dict | None = None) -> dict:
6. 记录 coach_task_history
7. 触发 fire_event('recall_completed', {site_id, assistant_id, member_id, service_time})
参数:
payload: 事件载荷event 触发时由 trigger_scheduler 传入)
job_id: 触发器 job ID由 trigger_scheduler 传入),用于在最终事务中
更新 last_run_at保证 handler 数据变更与 last_run_at 原子提交
返回: {"completed_count": int}
"""
completed_count = 0
@@ -111,6 +116,17 @@ def run(payload: dict | None = None) -> dict:
)
conn.rollback()
# ── 事务安全T5handler 成功后更新 last_run_at ──
# job_id 由 trigger_scheduler 传入,在 handler 最终事务中更新
# handler 异常时此处不会执行异常向上传播last_run_at 不变
if job_id is not None:
from app.services.trigger_scheduler import update_job_last_run_at
with conn.cursor() as cur:
cur.execute("BEGIN")
update_job_last_run_at(cur, job_id)
conn.commit()
finally:
conn.close()
@@ -193,7 +209,7 @@ def _process_service_record(
with conn.cursor() as cur:
cur.execute("BEGIN")
# 查找匹配的 active 任务
# 查找匹配的 active 召回类任务(仅完成召回任务,回访/关系构建不在此处理)
cur.execute(
"""
SELECT id, task_type
@@ -202,6 +218,7 @@ def _process_service_record(
AND assistant_id = %s
AND member_id = %s
AND status = 'active'
AND task_type IN ('high_priority_recall', 'priority_recall')
""",
(site_id, assistant_id, member_id),
)

View File

@@ -314,22 +314,55 @@ class TaskExecutor:
async def cancel(self, execution_id: str) -> bool:
"""向子进程发送终止信号。
如果进程仍在内存中,发送 terminate 信号;
如果进程已不在内存中(如后端重启后),但数据库中仍为 running
则直接将数据库状态标记为 cancelled幽灵记录兜底
Returns:
True 表示成功发送终止信号False 表示进程不存在或已退出
True 表示成功取消False 表示任务不存在或已完成
"""
proc = self._processes.get(execution_id)
if proc is None:
return False
# subprocess.Popen: poll() 返回 None 表示仍在运行
if proc.poll() is not None:
return False
if proc is not None:
# 进程仍在内存中
if proc.poll() is not None:
return False
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
try:
proc.terminate()
except ProcessLookupError:
return False
return True
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
# 进程不在内存中(后端重启等场景),尝试兜底修正数据库幽灵记录
try:
proc.terminate()
except ProcessLookupError:
return False
return True
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = 'cancelled',
finished_at = NOW(),
error_log = COALESCE(error_log, '')
|| E'\n[cancel 兜底] 进程已不在内存中,标记为 cancelled'
WHERE id = %s AND status = 'running'
""",
(execution_id,),
)
updated = cur.rowcount
conn.commit()
finally:
conn.close()
if updated:
logger.info(
"兜底取消 execution_log [%s]:数据库状态从 running → cancelled",
execution_id,
)
return True
except Exception:
logger.exception("兜底取消 execution_log [%s] 失败", execution_id)
return False
# ------------------------------------------------------------------
# 数据库操作(同步,在线程池中执行也可,此处简单直连)

View File

@@ -121,13 +121,13 @@ def _verify_task_ownership(
async def get_task_list(user_id: int, site_id: int) -> list[dict]:
"""
获取助教的活跃任务列表。
获取助教的任务列表(含有效 + 已放弃)
1. 通过 auth.user_assistant_binding 获取 assistant_id
2. 查询 biz.coach_tasks WHERE status='active'
2. 查询 biz.coach_tasks WHERE status IN ('active', 'abandoned')
3. 通过 FDW 读取客户基本信息dim_member和 RS 指数
4. 计算爱心 icon 档位
5. 排序is_pinned DESC, priority_score DESC, created_at ASC
5. 排序:abandoned 排最后 → is_pinned DESC priority_score DESC created_at ASC
FDW 查询需要 SET LOCAL app.current_site_id。
"""
@@ -135,17 +135,21 @@ async def get_task_list(user_id: int, site_id: int) -> list[dict]:
try:
assistant_id = _get_assistant_id(conn, user_id, site_id)
# 查询活跃任务
# 查询有效 + 已放弃任务abandoned 排最后)
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, task_type, status, priority_score, is_pinned,
expires_at, created_at, member_id
expires_at, created_at, member_id, abandon_reason
FROM biz.coach_tasks
WHERE site_id = %s
AND assistant_id = %s
AND status = 'active'
ORDER BY is_pinned DESC, priority_score DESC NULLS LAST, created_at ASC
AND status IN ('active', 'abandoned')
ORDER BY
CASE WHEN status = 'abandoned' THEN 1 ELSE 0 END ASC,
is_pinned DESC,
priority_score DESC NULLS LAST,
created_at ASC
""",
(site_id, assistant_id),
)
@@ -201,7 +205,7 @@ async def get_task_list(user_id: int, site_id: int) -> list[dict]:
result = []
for task_row in tasks:
(task_id, task_type, status, priority_score,
is_pinned, expires_at, created_at, member_id) = task_row
is_pinned, expires_at, created_at, member_id, abandon_reason) = task_row
info = member_info_map.get(member_id, {})
rs_score = rs_map.get(member_id, Decimal("0"))
@@ -220,6 +224,7 @@ async def get_task_list(user_id: int, site_id: int) -> list[dict]:
"member_phone": info.get("member_phone"),
"rs_score": float(rs_score),
"heart_icon": heart_icon,
"abandon_reason": abandon_reason,
})
return result
@@ -372,6 +377,7 @@ async def cancel_abandon(task_id: int, user_id: int, site_id: int) -> dict:
"""
UPDATE biz.coach_tasks
SET status = 'active',
is_pinned = FALSE,
abandon_reason = NULL,
updated_at = NOW()
WHERE id = %s
@@ -389,7 +395,7 @@ async def cancel_abandon(task_id: int, user_id: int, site_id: int) -> dict:
)
conn.commit()
return {"id": task_id, "status": "active"}
return {"id": task_id, "status": "active", "is_pinned": False}
finally:
conn.close()

View File

@@ -366,6 +366,9 @@ class TaskQueue:
async def _process_once(self, executor: Any) -> None:
"""单次处理:扫描所有门店的 pending 队列并执行。"""
# CHANGE 2026-03-09 | 每次轮询先回收僵尸 running 任务
self._recover_zombie_tasks()
site_ids = self._get_pending_site_ids()
for site_id in site_ids:
@@ -415,6 +418,13 @@ class TaskQueue:
except Exception:
logger.exception("队列任务执行异常 [%s]", queue_id)
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
finally:
# CHANGE 2026-03-09 | 兜底:确保 task_queue 不会卡在 running
# 背景_update_execution_log 内部异常(如 duration_ms integer 溢出)
# 被吞掉后_update_queue_status_from_log 读到的 execution_log 仍是
# running导致 task_queue 永远卡住,后续任务全部排队。
self._ensure_not_stuck_running(queue_id)
def _get_pending_site_ids(self) -> list[int]:
"""获取所有有 pending 任务的 site_id 列表(仅限本实例入队的)。"""
@@ -484,6 +494,84 @@ class TaskQueue:
finally:
conn.close()
def _ensure_not_stuck_running(self, queue_id: str) -> None:
"""兜底检查:如果 task_queue 仍是 running强制标记 failed。
CHANGE 2026-03-09 | 防止 _update_execution_log 内部异常导致
task_queue 永远卡在 running 状态。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT status FROM task_queue WHERE id = %s",
(queue_id,),
)
row = cur.fetchone()
if row and row[0] == "running":
logger.warning(
"兜底修正task_queue [%s] 执行完毕但仍为 running"
"强制标记 failed",
queue_id,
)
cur.execute(
"""
UPDATE task_queue
SET status = 'failed', finished_at = NOW(),
error_message = %s
WHERE id = %s AND status = 'running'
""",
(
"[兜底修正] 执行流程结束但状态未同步,"
"可能因 execution_log 更新失败",
queue_id,
),
)
conn.commit()
except Exception:
logger.exception("_ensure_not_stuck_running 异常 [%s]", queue_id)
finally:
conn.close()
def _recover_zombie_tasks(self, max_running_minutes: int = 180) -> None:
"""恢复僵尸 running 任务:超过阈值时间仍为 running 的任务强制标记 failed。
CHANGE 2026-03-09 | 在 process_loop 每次轮询时调用,作为最后防线。
场景:后端进程崩溃/重启后,之前的 running 任务永远不会被更新。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_queue
SET status = 'failed', finished_at = NOW(),
error_message = %s
WHERE status = 'running'
AND (enqueued_by = %s OR enqueued_by IS NULL)
AND started_at < NOW() - INTERVAL '%s minutes'
RETURNING id
""",
(
f"[僵尸回收] running 超过 {max_running_minutes} 分钟,"
"自动标记 failed",
_INSTANCE_ID,
max_running_minutes,
),
)
recovered = cur.fetchall()
if recovered:
ids = [r[0] for r in recovered]
logger.warning(
"僵尸回收:%d 个 running 任务超时,已标记 failed: %s",
len(ids), ids,
)
conn.commit()
except Exception:
logger.exception("_recover_zombie_tasks 异常")
finally:
conn.close()
# ------------------------------------------------------------------
# 生命周期
# ------------------------------------------------------------------

View File

@@ -86,6 +86,9 @@ DWS_TASKS: list[TaskDefinition] = [
TaskDefinition("DWS_ASSISTANT_FINANCE", "助教财务汇总", "汇总助教财务数据", "助教", "DWS"),
TaskDefinition("DWS_MEMBER_CONSUMPTION", "会员消费分析", "汇总会员消费数据", "会员", "DWS"),
TaskDefinition("DWS_MEMBER_VISIT", "会员到店分析", "汇总会员到店频次", "会员", "DWS"),
# CHANGE [2026-03-09] intent: 注册项目标签任务,与 ETL 侧 task_registry 同步;全量重建不依赖日期窗口
TaskDefinition("DWS_ASSISTANT_PROJECT_TAG", "助教项目标签", "按时间窗口计算助教各项目时长占比标签", "助教", "DWS", requires_window=False),
TaskDefinition("DWS_MEMBER_PROJECT_TAG", "客户项目标签", "按时间窗口计算客户各项目消费时长占比标签", "会员", "DWS", requires_window=False),
TaskDefinition("DWS_FINANCE_DAILY", "财务日报", "汇总每日财务数据", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_RECHARGE", "充值汇总", "汇总充值数据", "财务", "DWS"),
TaskDefinition("DWS_FINANCE_INCOME_STRUCTURE", "收入结构", "分析收入结构", "财务", "DWS"),

View File

@@ -31,6 +31,20 @@ def register_job(job_type: str, handler: Callable) -> None:
_JOB_REGISTRY[job_type] = handler
def update_job_last_run_at(cur, job_id: int) -> None:
"""
在 handler 的事务内更新 last_run_at。
handler 在最终 commit 前调用此函数,将 last_run_at 更新纳入同一事务。
handler 成功 → last_run_at 随事务一起 commit。
handler 失败 → last_run_at 随事务一起 rollback。
"""
cur.execute(
"UPDATE biz.trigger_jobs SET last_run_at = NOW() WHERE id = %s",
(job_id,),
)
def fire_event(event_name: str, payload: dict[str, Any] | None = None) -> int:
"""
触发事件驱动型任务。
@@ -38,6 +52,10 @@ def fire_event(event_name: str, payload: dict[str, Any] | None = None) -> int:
查找 trigger_condition='event' 且 trigger_config.event_name 匹配的 enabled job
立即执行对应的 handler。
事务安全:将 job_id 传入 handler由 handler 在最终 commit 前
更新 last_run_at保证 handler 数据变更与 last_run_at 在同一事务中。
handler 失败时整个事务回滚last_run_at 不更新。
返回: 执行的 job 数量
"""
conn = _get_connection()
@@ -55,6 +73,7 @@ def fire_event(event_name: str, payload: dict[str, Any] | None = None) -> int:
(event_name,),
)
rows = cur.fetchall()
conn.commit()
for job_id, job_type, job_name in rows:
handler = _JOB_REGISTRY.get(job_type)
@@ -64,18 +83,11 @@ def fire_event(event_name: str, payload: dict[str, Any] | None = None) -> int:
)
continue
try:
handler(payload=payload)
# 将 job_id 传入 handlerhandler 在最终 commit 前更新 last_run_at
handler(payload=payload, job_id=job_id)
executed += 1
# 更新 last_run_at
with conn.cursor() as cur:
cur.execute(
"UPDATE biz.trigger_jobs SET last_run_at = NOW() WHERE id = %s",
(job_id,),
)
conn.commit()
except Exception:
logger.exception("触发器 %s 执行失败", job_name)
conn.rollback()
finally:
conn.close()
@@ -87,6 +99,11 @@ def check_scheduled_jobs() -> int:
检查 cron/interval 类型的到期 job 并执行。
由 Scheduler 后台循环调用。
事务安全:将 conn 和 job_id 传入 handler由 handler 在最终 commit 前
更新 last_run_at 和 next_run_at保证 handler 数据变更与时间戳在同一事务中。
handler 失败时整个事务回滚。
返回: 执行的 job 数量
"""
conn = _get_connection()
@@ -104,6 +121,7 @@ def check_scheduled_jobs() -> int:
""",
)
rows = cur.fetchall()
conn.commit()
for job_id, job_type, job_name, trigger_condition, trigger_config in rows:
handler = _JOB_REGISTRY.get(job_type)
@@ -111,11 +129,12 @@ def check_scheduled_jobs() -> int:
logger.warning("未注册的 job_type: %s", job_type)
continue
try:
handler()
executed += 1
# 计算 next_run_at 并更新
# cron/interval handler 接受 conn + job_id在最终 commit 前更新时间戳
handler(conn=conn, job_id=job_id)
# 计算 next_run_at 并更新(在 handler commit 后的新事务中)
next_run = _calculate_next_run(trigger_condition, trigger_config)
with conn.cursor() as cur:
cur.execute("BEGIN")
cur.execute(
"""
UPDATE biz.trigger_jobs
@@ -125,6 +144,7 @@ def check_scheduled_jobs() -> int:
(next_run, job_id),
)
conn.commit()
executed += 1
except Exception:
logger.exception("触发器 %s 执行失败", job_name)
conn.rollback()
@@ -156,6 +176,6 @@ def _calculate_next_run(
from apps.backend.app.services.scheduler import _parse_simple_cron
return _parse_simple_cron(
trigger_config.get("cron_expression", "0 4 * * *"), now
trigger_config.get("cron_expression", "0 7 * * *"), now
)
return None # event 类型无 next_run_at

View File

@@ -34,3 +34,4 @@ dev = [
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
asyncio_mode = "auto"

View File

@@ -132,7 +132,7 @@ class TestDequeue:
config_dict = {"tasks": ["ODS_MEMBER"], "flow": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
None, None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
@@ -152,7 +152,7 @@ class TestDequeue:
config_dict = {"tasks": ["ODS_MEMBER"], "flow": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
None, None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
@@ -322,14 +322,17 @@ class TestProcessLoop:
@pytest.mark.asyncio
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
"""有 running 任务时不 dequeue"""
# _get_pending_site_ids 返回 [42]
# has_running(42) 返回 True
# 调用顺序_recover_zombie_tasks → _get_pending_site_ids → has_running
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _recover_zombie_tasks无僵尸任务
cur = _mock_cursor()
return _mock_conn(cur)
elif call_count == 2:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
@@ -372,10 +375,14 @@ class TestProcessLoop:
nonlocal call_count
call_count += 1
if call_count == 1:
# _recover_zombie_tasks无僵尸任务
cur = _mock_cursor()
return _mock_conn(cur)
elif call_count == 2:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
elif call_count == 2:
elif call_count == 3:
# has_running → False
cur = _mock_cursor(fetchone_val=(False,))
return _mock_conn(cur)
@@ -383,7 +390,7 @@ class TestProcessLoop:
# dequeue → 返回任务
row = (
task_id, 42, config_json, "pending", 1,
None, None, None, None, None,
None, None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
return _mock_conn(cur)
@@ -402,9 +409,21 @@ class TestProcessLoop:
@pytest.mark.asyncio
async def test_process_once_no_pending(self, mock_get_conn, queue):
"""无 pending 任务时什么都不做"""
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _recover_zombie_tasks无僵尸任务
cur = _mock_cursor()
return _mock_conn(cur)
else:
# _get_pending_site_ids → 空
cur = _mock_cursor(fetchall_val=[])
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)