feat: 累积功能变更 — 聊天集成、租户管理、小程序更新、ETL 增强、迁移脚本

包含多个会话的累积代码变更:
- backend: AI 聊天服务、触发器调度、认证增强、WebSocket、调度器最小间隔
- admin-web: ETL 状态页、任务管理、调度配置、登录优化
- miniprogram: 看板页面、聊天集成、UI 组件、导航更新
- etl: DWS 新任务(finance_area_daily/board_cache)、连接器增强
- tenant-admin: 项目初始化
- db: 19 个迁移脚本(etl_feiqiu 11 + zqyy_app 8)
- packages/shared: 枚举和工具函数更新
- tools: 数据库工具、报表生成、健康检查
- docs: PRD/架构/部署/合约文档更新

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Neo
2026-04-06 00:03:48 +08:00
parent 70324d8542
commit 6f8f12314f
515 changed files with 76604 additions and 7456 deletions

View File

@@ -21,10 +21,28 @@ apps/backend/
│ ├── auth/ # 认证模块
│ │ ├── dependencies.py # FastAPI 依赖注入CurrentUser
│ │ └── jwt.py # JWT 签发/验证/密码哈希
│ ├── routers/ # 17 个路由模块(详见 API 参考)
│ ├── routers/ # 18 个路由模块(详见 API 参考)
│ ├── schemas/ # Pydantic 请求/响应模型
│ ├── services/ # 业务逻辑层
│ ├── middleware/ # 中间件ResponseWrapper 全局响应包装)
│ ├── ai/ # AI 模块DashScope Application API + 8 个应用)
│ │ ├── config.py # AIConfig — 环境变量加载DASHSCOPE_*
│ │ ├── dashscope_client.py # DashScope Application API 统一封装
│ │ ├── dispatcher.py # AIDispatcher — 事件调度 + 调用链编排
│ │ ├── circuit_breaker.py # 熔断器(按 app_id 独立)
│ │ ├── rate_limiter.py # 限流器(用户/门店维度)
│ │ ├── budget_tracker.py # Token 预算追踪(日/月限额)
│ │ ├── run_log_service.py # AI 运行日志 CRUD
│ │ ├── exceptions.py # 异常层级DashScopeError 基类)
│ │ ├── cache_service.py # AI 缓存读写biz.ai_cache + status 状态控制)
│ │ ├── conversation_service.py # 对话管理session_id 双轨)
│ │ ├── schemas.py # AI 相关 SchemaSSEEvent 等)
│ │ ├── apps/ # 8 个 AI 应用app1_chat ~ app8_consolidation
│ │ ├── prompts/ # Prompt 模板app2/app8 独立模板)
│ │ └── data_fetchers/ # 共享数据获取层NS2 新增)
│ │ ├── member_data.py # 客户消费/会员卡/备注数据
│ │ ├── assistant_data.py # 助教信息/服务记录
│ │ └── page_context.py # 页面上下文文本化10 种入口)
│ └── ws/ # WebSocket实时日志
├── tests/ # 后端测试
├── pyproject.toml # 依赖声明
@@ -71,6 +89,7 @@ ETL 只读连接自动设置 `default_transaction_read_only = on` 和 RLS `app.c
1. 管理后台认证(`/api/auth/*`):用户名 + 密码 → JWT
2. 小程序认证(`/api/xcx-auth/*`):微信 code → openid → JWT
3. 租户管理后台认证(`/api/tenant/auth/*`):用户名 + 密码 → JWT`aud=tenant-admin`,与小程序完全隔离)
JWT 令牌分两种:
- 完整令牌:已审批用户,包含 `user_id` + `site_id` + `roles`
@@ -128,7 +147,20 @@ JWT 令牌分两种:
| `/api/xcx-test` | `xcx_test.py` | MVP 全链路验证 | 无 |
| `/api/wx-callback` | `wx_callback.py` | 微信消息推送回调 | 签名验证 |
| `/api/retention-clue` | `member_retention_clue.py` | 维客线索 CRUD | JWT |
| `/api/tenant/auth` | `tenant_auth.py` | 租户管理员登录/刷新令牌 | 无 |
| `/api/tenant` | `tenant_users.py` | 租户用户审核/管理(申请列表/关联建议/审核/用户编辑/绑定) | 租户JWT |
| `/api/tenant/excel` | `tenant_excel.py` | 租户 Excel 上传/校验/冲突/确认/记录/模板下载 | 租户JWT |
| `/api/tenant` | `tenant_clues.py` | 租户维客线索管理(客户搜索/线索CRUD/隐藏显示) | 租户JWT |
| `/api/tenant/site-admins` | `tenant_site_admins.py` | 店铺管理员 CRUD列表/创建/编辑/删除/重置密码,仅 tenant_admin | 租户JWT |
| `/api/admin` | `admin_tenant_admins.py` | 管理端租户管理员 CRUD列表/创建/编辑/删除/重置密码) | JWT+管理员 |
| `/api/admin` | `admin_registry.py` | 注册体系管理(租户列表/店铺列表/简写ID/店铺同步) | JWT+管理员 |
| `/api/admin/ai` | `admin_ai.py` | AI 监控后台Dashboard/调度状态/调用明细/缓存/预算/批量/告警13 端点) | JWT+管理员 |
| `/api/admin/dev-trace` | `admin_dev_trace.py` | 开发调试全链路日志(日期/请求列表/详情/清理/设置/覆盖率8 端点) | JWT+管理员 |
| `/api/admin/task-engine` | `admin_task_engine.py` | P18 任务引擎运营看板(转移日志分页+历史、待审核任务分页+重新分配+关闭、参数管理 CRUD9 端点) | JWT+管理员 |
| `/api/xcx/chat` | `xcx_chat.py` | 小程序 CHAT 对话/消息/发送/SSE 流式 | JWT |
| `/api/admin/db-health` | `admin_db_health.py` | 数据库健康监控4 库连接池/大小/慢查询) | JWT |
| `/api/admin/triggers` | `admin_triggers.py` | 触发器统一视图biz/ai/etl 三源聚合) | JWT |
| `/api/trigger-jobs` | `trigger_jobs.py` | 触发器任务管理(列表/详情/PATCH 配置编辑) | JWT |
| `/api/ops` | `ops_panel.py` | 运维面板(服务启停/Git/系统信息) | 无 |
| `/ws/logs` | `ws/logs.py` | WebSocket 实时日志推送 | — |
| `/health` | `main.py` | 健康检查 | 无 |
@@ -148,17 +180,57 @@ JWT 令牌分两种:
| `task_queue.py` | 任务队列管理(入队/消费/重排) |
| `task_registry.py` | ETL 任务/Flow/DWD 表静态注册表 |
| `cli_builder.py` | ETL CLI 命令构建器 |
| `task_generator.py` | 任务生成器(基于 WBI/NCI 指数 |
| `task_generator.py` | 任务生成器(四级漏斗 + 保底 relationship_building独立连接 |
| `task_manager.py` | 任务管理(置顶/放弃/状态变更) |
| `task_expiry.py` | 任务过期检查与处理 |
| `task_manager.py` | 任务管理CRUD + 列表扩展 + 详情) |
| `performance_service.py` | 绩效概览 + 明细ETL 直连查询) |
| `note_service.py` | 备注服务CRUD + 星星评分) |
| `fdw_queries.py` | ETL 查询集中封装(直连 ETL 库 + 门店隔离 RLS |
| `fdw_queries.py` | ETL 查询集中封装(直连 ETL 库 + 门店隔离 RLS,含区域日粒度查询(`get_finance_overview_area`/`get_finance_revenue_area`)和缓存读写(`get_finance_board_cache`/`set_finance_board_cache` |
| `note_reclassifier.py` | 备注重分类(召回完成后回填) |
| `recall_detector.py` | 召回完成检测ETL 数据更新触发) |
| `trigger_scheduler.py` | 触发器调度器cron/interval/event |
| `chat_service.py` | CHAT 模块业务逻辑(对话管理/消息持久化/referenceCard |
| `ai/admin_service.py` | AI 监控后台聚合服务Dashboard 统计/批量执行/告警管理) |
| `ai/cleanup_service.py` | AI 数据清理服务90 天保留 + 缓存上限 20000/App |
| `admin_task_engine.py` | P18 任务引擎运营看板路由(转移日志/待审核任务/参数管理9 端点) |
## AI 模块NS2 Prompt 细化)
8 个千问 AI 应用,通过百炼平台调用 Qwen3.5-Plus 模型。分三层架构:
```
应用层apps/app1_chat ~ app8_consolidation
↓ 调用
数据获取层data_fetchers/ ← NS2 新增
↓ 查询
基础设施层database.py / cache_service.py / dashscope_client.py
```
### 数据获取层(`app/ai/data_fetchers/`
NS2 新增的共享模块,封装 FDW 查询逻辑,供多个应用复用:
| 函数 | 数据来源 | 消费方 |
|------|---------|--------|
| `fetch_member_consumption_data()` | ETL FDW 视图(结算/商品/会员卡/到店) | App3/6/7 |
| `fetch_member_notes()` | `biz.notes` | App4/6 |
| `fetch_assistant_info()` | ETL FDW 视图(助教维度/月度汇总) | App4/5 |
| `fetch_service_history()` | ETL FDW 视图(服务日志/亲密度) | App4/5 |
| `build_page_text()` | 多数据源(按 contextType 路由) | App1 |
关键约束:
- 金额口径使用 `items_sum`,禁止 `consume_money`
- 所有 FDW 查询通过 `SET LOCAL app.current_site_id` 实现 RLS 隔离
- 部分数据获取失败不阻断 Prompt 生成(错误降级)
### 页面上下文App1
App1 通用对话支持 10 种页面入口,通过 `contextType` 路由到对应的文本化函数:
`task-detail` / `customer-detail` / `coach-detail` / `task-list` / `customer-service-records` / `board-finance` / `board-customer` / `board-coach` / `performance` / `my-profile`
每种入口自动获取页面数据并格式化为结构化中文文本,注入 system prompt。
## 依赖

View File

@@ -12,15 +12,19 @@ import json
import logging
from typing import AsyncGenerator
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.data_fetchers import build_page_text
from app.ai.schemas import SSEEvent
logger = logging.getLogger(__name__)
APP_ID = "app1_chat"
# system prompt 总字符数上限
_MAX_SYSTEM_PROMPT_LEN = 4000
async def chat_stream(
*,
@@ -32,7 +36,7 @@ async def chat_stream(
source_page: str | None = None,
page_context: dict | None = None,
screen_content: str | None = None,
bailian: BailianClient,
client: DashScopeClient,
conv_svc: ConversationService,
) -> AsyncGenerator[SSEEvent, None]:
"""流式对话入口,返回 SSEEvent 异步生成器。
@@ -76,11 +80,12 @@ async def chat_stream(
)
# 3. 构建消息列表system prompt + user message
messages = _build_messages(
messages = await _build_messages(
message=message,
user_id=user_id,
nickname=nickname,
role=role,
site_id=site_id,
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
@@ -118,12 +123,13 @@ async def chat_stream(
yield SSEEvent(type="error", message=str(e))
def _build_messages(
async def _build_messages(
*,
message: str,
user_id: int | str,
nickname: str,
role: str,
site_id: int,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
@@ -132,25 +138,38 @@ def _build_messages(
首条 system 消息注入页面上下文和用户信息。
"""
system_content = _build_system_prompt(
system_content = await _build_system_prompt(
user_id=user_id,
nickname=nickname,
role=role,
site_id=site_id,
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
)
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
# system prompt 总字符数控制
if len(content_str) > _MAX_SYSTEM_PROMPT_LEN:
# 截断 page_context 中的 data_text
pc = system_content.get("page_context", {})
dt = pc.get("data_text", "")
if dt and len(dt) > 500:
pc["data_text"] = dt[:500] + "…(已截断)"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "system", "content": content_str},
{"role": "user", "content": message},
]
def _build_system_prompt(
async def _build_system_prompt(
*,
user_id: int | str,
nickname: str,
role: str,
site_id: int,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
@@ -161,7 +180,12 @@ def _build_system_prompt(
注入页面上下文供 AI 理解当前场景。
"""
prompt: dict = {
"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。",
"task": (
"你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"
"当 page_context 中包含 memberNickname、contextId 或 data_text 时,"
"你必须直接使用这些信息回答问题,不要再向用户索要已有的信息。"
"例如用户在客户详情页提问时,直接基于该客户的数据回答,无需要求提供会员 ID。"
),
"biz_params": {
"user_prompt_params": {
"User_ID": str(user_id),
@@ -172,10 +196,11 @@ def _build_system_prompt(
}
# 注入页面上下文(首条消息)
page_ctx = _build_page_context(
page_ctx = await _build_page_context(
source_page=source_page,
page_context=page_context,
screen_content=screen_content,
site_id=site_id,
)
if page_ctx:
prompt["page_context"] = page_ctx
@@ -183,25 +208,52 @@ def _build_system_prompt(
return prompt
def _build_page_context(
async def _build_page_context(
*,
source_page: str | None,
page_context: dict | None,
screen_content: str | None,
site_id: int,
) -> dict:
"""构建页面上下文信息。
P5-A 阶段:直接透传前端传入的上下文字段。
P5-B 阶段:各页面逐步实现文本化工具,丰富 screen_content
根据 source_pagecontextType调用 build_page_text 获取结构化文本,
看板类页面从 page_context 提取筛选参数传入 filters
contextType 为空或未识别时返回空 dict跳过注入
"""
# TODO: P5-B 各页面文本化工具细化
ctx: dict = {}
if source_page:
ctx["source_page"] = source_page
# 从 page_context 提取 contextId 和筛选参数
context_id = None
filters: dict = {}
if page_context:
context_id = page_context.get("contextId")
# 看板类页面筛选参数透传
for key in ("timeDimension", "areaFilter", "dimension", "typeFilter", "projectFilter"):
if key in page_context:
filters[key] = page_context[key]
# 调用 data_fetcher 获取页面数据文本
try:
data_text = await build_page_text(
source_page=source_page,
context_id=context_id,
site_id=site_id,
filters=filters if filters else None,
)
if data_text:
ctx["data_text"] = data_text
except Exception:
logger.warning("页面上下文文本化失败: source_page=%s", source_page, exc_info=True)
if page_context:
ctx["page_context"] = page_context
if screen_content:
ctx["screen_content"] = screen_content
return ctx

View File

@@ -15,7 +15,7 @@ import logging
import os
from datetime import date, datetime, timedelta
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.prompts.app2_finance_prompt import build_prompt
@@ -124,7 +124,7 @@ def compute_time_range(dimension: str, business_date: date) -> tuple[date, date]
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:

View File

@@ -15,25 +15,42 @@ from __future__ import annotations
import json
import logging
from datetime import datetime
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.data_fetchers import fetch_member_consumption_data
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app3_clue"
# system message content 上限
_MAX_SYSTEM_CONTENT_LEN = 8000
def build_prompt(
def _default_member_data() -> dict:
"""数据获取失败时的默认空值。"""
return {
"member_nickname": "",
"consumption_records": [],
"member_cards": [],
"card_balance_total": 0,
"stored_value_balance_total": 0,
"expected_visit_date": None,
"days_since_last_visit": None,
}
async def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段
P5-B 阶段P9-T1补充 consumption_records 等完整数据。
从 data_fetchers 获取真实消费数据,失败时降级为空值
Args:
context: 包含 site_id, member_id, nickname 等
@@ -45,9 +62,28 @@ def build_prompt(
site_id = context["site_id"]
member_id = context["member_id"]
# 获取消费数据(失败时降级)
data_fetch_failed = False
try:
member_data = await fetch_member_consumption_data(site_id, member_id)
except Exception:
logger.warning("App3 消费数据获取失败,使用默认空值: site_id=%s member_id=%s", site_id, member_id, exc_info=True)
member_data = _default_member_data()
data_fetch_failed = True
# 构建 referenceApp6 线索 + 最近 2 套 App8 历史(附 generated_at
reference = _build_reference(site_id, member_id, cache_svc)
member_nickname = member_data.get("member_nickname", "")
consumption_records = member_data.get("consumption_records", [])
# 空数据标注
if not consumption_records:
if data_fetch_failed:
consumption_records = "⚠ 消费数据获取失败,该客户暂无消费记录可供分析"
else:
consumption_records = "该客户暂无消费记录"
system_content = {
"task": "分析客户消费数据,提取维客线索。",
"app_id": APP_ID,
@@ -67,14 +103,28 @@ def build_prompt(
}
]
},
# TODO: P9-T1 细化 - consumption_records 等客户消费数据
"data": {
"consumption_records": "待 P9-T1 补充",
"member_info": "待 P9-T1 补充",
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
"member_nickname": member_nickname,
"main_data": {
"consumption_records": consumption_records,
"member_cards": member_data.get("member_cards", []),
"card_balance_total": member_data.get("card_balance_total", 0),
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
"expected_visit_date": member_data.get("expected_visit_date"),
"days_since_last_visit": member_data.get("days_since_last_visit"),
},
"reference": reference,
}
# Token 预算控制:截断 consumption_records
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
records = system_content["main_data"].get("consumption_records")
if isinstance(records, list) and len(records) > 5:
system_content["main_data"]["consumption_records"] = records[:5]
system_content["main_data"]["_truncated"] = f"消费记录已截断,原始共 {len(records)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
user_content = (
f"请分析会员 {member_id} 的消费数据,提取维客线索。"
"每条线索包含 category、summary、detail、emoji 四个字段。"
@@ -82,7 +132,7 @@ def build_prompt(
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "system", "content": content_str},
{"role": "user", "content": user_content},
]
@@ -134,7 +184,7 @@ def _build_reference(
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
@@ -162,7 +212,7 @@ async def run(
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
messages = await build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(

View File

@@ -11,27 +11,50 @@ app_id = "app4_analysis"
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.data_fetchers import (
fetch_assistant_info,
fetch_member_consumption_data,
fetch_member_notes,
fetch_service_history,
)
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app4_analysis"
# system message content 上限
_MAX_SYSTEM_CONTENT_LEN = 8000
def build_prompt(
def _default_member_data() -> dict:
"""数据获取失败时的默认空值。"""
return {
"member_nickname": "",
"consumption_records": [],
"member_cards": [],
"card_balance_total": 0,
"stored_value_balance_total": 0,
"expected_visit_date": None,
"days_since_last_visit": None,
}
async def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段
P5-B 阶段P6-T4补充 service_history、assistant_info 等完整数据。
并发获取助教信息、服务历史、客户消费数据、备注,部分失败不阻断
Args:
context: 包含 site_id, assistant_id, member_id
@@ -44,10 +67,50 @@ def build_prompt(
assistant_id = context["assistant_id"]
member_id = context["member_id"]
# 并发获取 4 类数据,部分失败不阻断
results = await asyncio.gather(
fetch_assistant_info(site_id, assistant_id),
fetch_service_history(site_id, assistant_id, member_id),
fetch_member_consumption_data(site_id, member_id),
fetch_member_notes(site_id, member_id),
return_exceptions=True,
)
# 降级处理
fetch_errors: list[str] = []
if isinstance(results[0], Exception):
logger.warning("App4 助教信息获取失败: %s", results[0])
assistant_info = {}
fetch_errors.append("助教信息获取失败")
else:
assistant_info = results[0]
if isinstance(results[1], Exception):
logger.warning("App4 服务历史获取失败: %s", results[1])
service_history: list = []
fetch_errors.append("服务历史获取失败")
else:
service_history = results[1]
if isinstance(results[2], Exception):
logger.warning("App4 消费数据获取失败: %s", results[2])
member_data = _default_member_data()
fetch_errors.append("消费数据获取失败")
else:
member_data = results[2]
if isinstance(results[3], Exception):
logger.warning("App4 备注获取失败: %s", results[3])
notes: list = []
fetch_errors.append("备注获取失败")
else:
notes = results[3]
# 构建 referenceApp8 最新 + 最近 2 套历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
system_content: dict = {
"task": "分析助教与客户的关系,生成任务建议。",
"app_id": APP_ID,
"output_format": {
@@ -55,14 +118,51 @@ def build_prompt(
"action_suggestions": ["建议1", "建议2"],
"one_line_summary": "一句话总结",
},
# TODO: P6-T4 细化 - service_history、assistant_info
"data": {
"service_history": "待 P6-T4 补充",
"assistant_info": "待 P6-T4 补充",
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
"assistant_info": assistant_info if assistant_info else "⚠ 助教信息获取失败",
"service_history": service_history if service_history else "暂无服务记录",
"task_assignment_basis": {
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
"member_cards": member_data.get("member_cards", []),
"card_balance_total": member_data.get("card_balance_total", 0),
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
"expected_visit_date": member_data.get("expected_visit_date"),
"days_since_last_visit": member_data.get("days_since_last_visit"),
},
"customer_data": {
"system_data": {
"member_nickname": member_data.get("member_nickname", ""),
},
"notes": notes if notes else "暂无备注",
},
"reference": reference,
}
if fetch_errors:
system_content["_data_warnings"] = fetch_errors
# Token 预算控制
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
# 优先截断 service_history
sh = system_content.get("service_history")
if isinstance(sh, list) and len(sh) > 5:
system_content["service_history"] = sh[:5]
system_content["_truncated_service_history"] = f"服务记录已截断,原始共 {len(sh)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
records = system_content["task_assignment_basis"].get("consumption_records")
if isinstance(records, list) and len(records) > 5:
system_content["task_assignment_basis"]["consumption_records"] = records[:5]
system_content["task_assignment_basis"]["_truncated"] = f"消费记录已截断,原始共 {len(records)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
n = system_content["customer_data"].get("notes")
if isinstance(n, list) and len(n) > 10:
system_content["customer_data"]["notes"] = n[:10]
system_content["customer_data"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
# 缓存不存在时在 user prompt 中标注
no_history_hint = ""
if not reference:
@@ -75,7 +175,7 @@ def build_prompt(
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "system", "content": content_str},
{"role": "user", "content": user_content},
]
@@ -127,7 +227,7 @@ def _build_reference(
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
@@ -149,7 +249,7 @@ async def run(
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
messages = await build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(

View File

@@ -10,27 +10,51 @@ app_id = "app5_tactics"
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.data_fetchers import (
fetch_assistant_info,
fetch_member_consumption_data,
fetch_member_notes,
fetch_service_history,
)
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app5_tactics"
# system message content 上限
_MAX_SYSTEM_CONTENT_LEN = 8000
def build_prompt(
def _default_member_data() -> dict:
"""数据获取失败时的默认空值。"""
return {
"member_nickname": "",
"consumption_records": [],
"member_cards": [],
"card_balance_total": 0,
"stored_value_balance_total": 0,
"expected_visit_date": None,
"days_since_last_visit": None,
}
async def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段。
P5-B 阶段P6-T4补充 service_history、assistant_info随 App4 同步)
复用 App4 的数据获取逻辑(并发获取助教信息、服务历史、消费数据、备注),
额外从 context["app4_result"] 获取 task_suggestion
Args:
context: 包含 site_id, assistant_id, member_id, app4_result(dict)
@@ -42,35 +66,117 @@ def build_prompt(
site_id = context["site_id"]
assistant_id = context["assistant_id"]
member_id = context["member_id"]
app4_result = context.get("app4_result", {})
# App4 结果作为 task_suggestion缺失时设为空对象
task_suggestion = context.get("app4_result") or {}
# 并发获取 4 类数据,部分失败不阻断
results = await asyncio.gather(
fetch_assistant_info(site_id, assistant_id),
fetch_service_history(site_id, assistant_id, member_id),
fetch_member_consumption_data(site_id, member_id),
fetch_member_notes(site_id, member_id),
return_exceptions=True,
)
# 降级处理
fetch_errors: list[str] = []
if isinstance(results[0], Exception):
logger.warning("App5 助教信息获取失败: %s", results[0])
assistant_info = {}
fetch_errors.append("助教信息获取失败")
else:
assistant_info = results[0]
if isinstance(results[1], Exception):
logger.warning("App5 服务历史获取失败: %s", results[1])
service_history: list = []
fetch_errors.append("服务历史获取失败")
else:
service_history = results[1]
if isinstance(results[2], Exception):
logger.warning("App5 消费数据获取失败: %s", results[2])
member_data = _default_member_data()
fetch_errors.append("消费数据获取失败")
else:
member_data = results[2]
if isinstance(results[3], Exception):
logger.warning("App5 备注获取失败: %s", results[3])
notes: list = []
fetch_errors.append("备注获取失败")
else:
notes = results[3]
# 构建 reference最近 2 套 App8 历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
"task": "基于关系分析和任务建议,生成沟通话术参考。",
system_content: dict = {
"task": (
"基于关系分析和任务建议,生成沟通话术参考。"
"输出必须严格遵循 output_format 中定义的 JSON 结构,"
"每条话术必须包含 scenario场景描述和 script话术内容两个字段"
"禁止使用 content 或其他字段名替代。"
),
"app_id": APP_ID,
"task_suggestion": app4_result,
"task_suggestion": task_suggestion,
"output_format": {
"tactics": [
{"scenario": "场景描述", "script": "话术内容"}
]
},
# TODO: P6-T4 细化 - service_history、assistant_info随 App4 同步)
"data": {
"service_history": "待 P6-T4 补充",
"assistant_info": "待 P6-T4 补充",
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
"assistant_info": assistant_info if assistant_info else "⚠ 助教信息获取失败",
"service_history": service_history if service_history else "暂无服务记录",
"task_assignment_basis": {
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
"member_cards": member_data.get("member_cards", []),
"card_balance_total": member_data.get("card_balance_total", 0),
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
"expected_visit_date": member_data.get("expected_visit_date"),
"days_since_last_visit": member_data.get("days_since_last_visit"),
},
"customer_data": {
"system_data": {
"member_nickname": member_data.get("member_nickname", ""),
},
"notes": notes if notes else "暂无备注",
},
"reference": reference,
}
if fetch_errors:
system_content["_data_warnings"] = fetch_errors
# Token 预算控制
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
sh = system_content.get("service_history")
if isinstance(sh, list) and len(sh) > 5:
system_content["service_history"] = sh[:5]
system_content["_truncated_service_history"] = f"服务记录已截断,原始共 {len(sh)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
records = system_content["task_assignment_basis"].get("consumption_records")
if isinstance(records, list) and len(records) > 5:
system_content["task_assignment_basis"]["consumption_records"] = records[:5]
system_content["task_assignment_basis"]["_truncated"] = f"消费记录已截断,原始共 {len(records)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
n = system_content["customer_data"].get("notes")
if isinstance(n, list) and len(n) > 10:
system_content["customer_data"]["notes"] = n[:10]
system_content["customer_data"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
user_content = (
f"请为助教 {assistant_id} 生成与会员 {member_id} 沟通的话术参考。"
"返回 tactics 数组,每条包含 scenario 和 script 字段。"
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "system", "content": content_str},
{"role": "user", "content": user_content},
]
@@ -109,7 +215,7 @@ def _build_reference(
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
@@ -131,7 +237,7 @@ async def run(
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
messages = await build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(

View File

@@ -13,27 +13,45 @@ app_id = "app6_note"
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.data_fetchers import fetch_member_consumption_data, fetch_member_notes
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app6_note"
# system message content 上限
_MAX_SYSTEM_CONTENT_LEN = 8000
def build_prompt(
def _default_member_data() -> dict:
"""数据获取失败时的默认空值。"""
return {
"member_nickname": "",
"consumption_records": [],
"member_cards": [],
"card_balance_total": 0,
"stored_value_balance_total": 0,
"expected_visit_date": None,
"days_since_last_visit": None,
}
async def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段
P5-B 阶段P9-T1补充 consumption_data 等完整数据。
并发获取消费数据和备注,失败时降级为空值
Args:
context: 包含 site_id, member_id, note_content, noted_by_name
@@ -46,11 +64,47 @@ def build_prompt(
member_id = context["member_id"]
note_content = context.get("note_content", "")
noted_by_name = context.get("noted_by_name", "")
noted_by_created_at = context.get("noted_by_created_at", "")
# 并发获取消费数据和备注
results = await asyncio.gather(
fetch_member_consumption_data(site_id, member_id),
fetch_member_notes(site_id, member_id),
return_exceptions=True,
)
fetch_errors: list[str] = []
if isinstance(results[0], Exception):
logger.warning("App6 消费数据获取失败: %s", results[0])
member_data = _default_member_data()
fetch_errors.append("消费数据获取失败")
else:
member_data = results[0]
if isinstance(results[1], Exception):
logger.warning("App6 备注获取失败: %s", results[1])
all_notes: list = []
fetch_errors.append("备注获取失败")
else:
all_notes = results[1]
# 构建 referenceApp3 线索 + 最近 2 套 App8 历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
# 将消费数据和备注注入 reference
reference["member_nickname"] = member_data.get("member_nickname", "")
reference["consumption_data"] = {
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
"member_cards": member_data.get("member_cards", []),
"card_balance_total": member_data.get("card_balance_total", 0),
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
"expected_visit_date": member_data.get("expected_visit_date"),
"days_since_last_visit": member_data.get("days_since_last_visit"),
}
reference["all_notes"] = all_notes if all_notes else []
system_content: dict = {
"task": "分析备注内容,提取维客线索并评分。",
"app_id": APP_ID,
"rules": {
@@ -73,15 +127,33 @@ def build_prompt(
}
],
},
"note_content": note_content,
"noted_by_name": noted_by_name,
# TODO: P9-T1 细化 - consumption_data 等客户消费数据
"data": {
"consumption_data": "待 P9-T1 补充",
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
"current_note": {
"content": note_content,
"recorded_by": noted_by_name,
"created_at": noted_by_created_at,
},
"reference": reference,
}
if fetch_errors:
system_content["_data_warnings"] = fetch_errors
# Token 预算控制
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
records = system_content["reference"].get("consumption_data", {}).get("consumption_records")
if isinstance(records, list) and len(records) > 5:
system_content["reference"]["consumption_data"]["consumption_records"] = records[:5]
system_content["reference"]["consumption_data"]["_truncated"] = f"消费记录已截断,原始共 {len(records)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
n = system_content["reference"].get("all_notes")
if isinstance(n, list) and len(n) > 10:
system_content["reference"]["all_notes"] = n[:10]
system_content["reference"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
user_content = (
f"请分析以下备注内容,提取维客线索并评分。\n"
f"备注提供人:{noted_by_name}\n"
@@ -91,7 +163,7 @@ def build_prompt(
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "system", "content": content_str},
{"role": "user", "content": user_content},
]
@@ -143,7 +215,7 @@ def _build_reference(
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
@@ -164,7 +236,7 @@ async def run(
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
messages = await build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(

View File

@@ -13,27 +13,45 @@ app_id = "app7_customer"
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.data_fetchers import fetch_member_consumption_data, fetch_member_notes
from app.ai.schemas import CacheTypeEnum
logger = logging.getLogger(__name__)
APP_ID = "app7_customer"
# system message content 上限
_MAX_SYSTEM_CONTENT_LEN = 8000
def build_prompt(
def _default_member_data() -> dict:
"""数据获取失败时的默认空值。"""
return {
"member_nickname": "",
"consumption_records": [],
"member_cards": [],
"card_balance_total": 0,
"stored_value_balance_total": 0,
"expected_visit_date": None,
"days_since_last_visit": None,
}
async def build_prompt(
context: dict,
cache_svc: AICacheService | None = None,
) -> list[dict]:
"""构建 Prompt 消息列表。
P5-A 阶段:返回占位 Prompt标注待细化字段
P5-B 阶段P9-T1补充 objective_data 等完整数据。
并发获取消费数据和备注,备注标注来源信息
Args:
context: 包含 site_id, member_id
@@ -45,10 +63,46 @@ def build_prompt(
site_id = context["site_id"]
member_id = context["member_id"]
# 并发获取消费数据和备注
results = await asyncio.gather(
fetch_member_consumption_data(site_id, member_id),
fetch_member_notes(site_id, member_id),
return_exceptions=True,
)
fetch_errors: list[str] = []
if isinstance(results[0], Exception):
logger.warning("App7 消费数据获取失败: %s", results[0])
member_data = _default_member_data()
fetch_errors.append("消费数据获取失败")
else:
member_data = results[0]
if isinstance(results[1], Exception):
logger.warning("App7 备注获取失败: %s", results[1])
notes_raw: list = []
fetch_errors.append("备注获取失败")
else:
notes_raw = results[1]
# 备注标注来源信息
if notes_raw:
subjective_notes = []
for note in notes_raw:
recorded_by = note.get("recorded_by", "未知")
annotated = dict(note)
annotated["content"] = f"{note.get('content', '')}【来源:{recorded_by},请甄别信息真实性】"
subjective_notes.append(annotated)
else:
subjective_notes = "该客户暂无主观备注信息"
member_nickname = member_data.get("member_nickname", "")
# 构建 reference最新 + 最近 2 套 App8 历史
reference = _build_reference(site_id, member_id, cache_svc)
system_content = {
system_content: dict = {
"task": "综合分析客户数据,生成运营策略建议。",
"app_id": APP_ID,
"rules": {
@@ -62,13 +116,41 @@ def build_prompt(
],
"summary": "一句话总结",
},
# TODO: P9-T1 细化 - objective_data 等客户消费数据
"data": {
"objective_data": "待 P9-T1 补充",
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
"member_id": member_id,
"member_nickname": member_nickname,
"objective_data": {
"consumption_records": member_data.get("consumption_records", []) or "该客户暂无消费记录",
"member_cards": member_data.get("member_cards", []),
"card_balance_total": member_data.get("card_balance_total", 0),
"stored_value_balance_total": member_data.get("stored_value_balance_total", 0),
"expected_visit_date": member_data.get("expected_visit_date"),
"days_since_last_visit": member_data.get("days_since_last_visit"),
},
"subjective_data": {
"notes": subjective_notes,
},
"reference": reference,
}
if fetch_errors:
system_content["_data_warnings"] = fetch_errors
# Token 预算控制
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
records = system_content["objective_data"].get("consumption_records")
if isinstance(records, list) and len(records) > 5:
system_content["objective_data"]["consumption_records"] = records[:5]
system_content["objective_data"]["_truncated"] = f"消费记录已截断,原始共 {len(records)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
if len(content_str) > _MAX_SYSTEM_CONTENT_LEN:
n = system_content["subjective_data"].get("notes")
if isinstance(n, list) and len(n) > 10:
system_content["subjective_data"]["notes"] = n[:10]
system_content["subjective_data"]["_truncated_notes"] = f"备注已截断,原始共 {len(n)}"
content_str = json.dumps(system_content, ensure_ascii=False, default=str)
user_content = (
f"请综合分析会员 {member_id} 的客户数据,生成运营策略建议。"
"返回 strategies 数组(每条含 title 和 content和 summary 字段。"
@@ -76,7 +158,7 @@ def build_prompt(
)
return [
{"role": "system", "content": json.dumps(system_content, ensure_ascii=False)},
{"role": "system", "content": content_str},
{"role": "user", "content": user_content},
]
@@ -128,7 +210,7 @@ def _build_reference(
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:
@@ -149,7 +231,7 @@ async def run(
nickname = context.get("nickname", "")
# 1. 构建 Prompt
messages = build_prompt(context, cache_svc)
messages = await build_prompt(context, cache_svc)
# 2. 创建对话记录
conversation_id = conv_svc.create_conversation(

View File

@@ -11,7 +11,7 @@ from __future__ import annotations
import json
import logging
from app.ai.bailian_client import BailianClient
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.prompts.app8_consolidation_prompt import build_prompt
@@ -120,7 +120,7 @@ def _determine_source(providers: str) -> str:
async def run(
context: dict,
bailian: BailianClient,
client: DashScopeClient,
cache_svc: AICacheService,
conv_svc: ConversationService,
) -> dict:

View File

@@ -1,273 +0,0 @@
"""百炼 API 统一封装层。
使用 openai Python SDK百炼兼容 OpenAI 协议),提供流式和非流式两种调用模式。
所有 AI 应用通过此客户端统一调用阿里云通义千问。
"""
from __future__ import annotations
import asyncio
import copy
import json
import logging
from datetime import datetime
from typing import Any, AsyncGenerator
import openai
logger = logging.getLogger(__name__)
# ── 异常类 ──────────────────────────────────────────────────────────
class BailianApiError(Exception):
"""百炼 API 调用失败(重试耗尽后)。"""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code
class BailianJsonParseError(Exception):
"""百炼 API 返回的 JSON 解析失败。"""
def __init__(self, message: str, raw_content: str = ""):
super().__init__(message)
self.raw_content = raw_content
class BailianAuthError(BailianApiError):
"""百炼 API Key 无效HTTP 401"""
def __init__(self, message: str = "API Key 无效或已过期"):
super().__init__(message, status_code=401)
# ── 客户端 ──────────────────────────────────────────────────────────
class BailianClient:
"""百炼 API 统一封装层。
使用 openai.AsyncOpenAI 客户端base_url 指向百炼端点。
提供流式chat_stream和非流式chat_json两种调用模式。
"""
# 重试配置
MAX_RETRIES = 3
BASE_INTERVAL = 1 # 秒
def __init__(self, api_key: str, base_url: str, model: str):
"""初始化百炼客户端。
Args:
api_key: 百炼 API Key环境变量 BAILIAN_API_KEY
base_url: 百炼 API 端点(环境变量 BAILIAN_BASE_URL
model: 模型标识,如 qwen-plus环境变量 BAILIAN_MODEL
"""
self.model = model
self._client = openai.AsyncOpenAI(
api_key=api_key,
base_url=base_url,
)
async def chat_stream(
self,
messages: list[dict],
*,
temperature: float = 0.7,
max_tokens: int = 2000,
) -> AsyncGenerator[str, None]:
"""流式调用,逐 chunk yield 文本。用于应用 1 SSE。
Args:
messages: 消息列表
temperature: 温度参数,默认 0.7
max_tokens: 最大 token 数,默认 2000
Yields:
文本 chunk
"""
messages = self._inject_current_time(messages)
response = await self._call_with_retry(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
async for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
async def chat_json(
self,
messages: list[dict],
*,
temperature: float = 0.3,
max_tokens: int = 4000,
) -> tuple[dict, int]:
"""非流式调用,返回解析后的 JSON dict 和 tokens_used。
用于应用 2-8 的结构化输出。使用 response_format={"type": "json_object"}
确保返回合法 JSON。
Args:
messages: 消息列表
temperature: 温度参数,默认 0.3(结构化输出用低温度)
max_tokens: 最大 token 数,默认 4000
Returns:
(parsed_json_dict, tokens_used) 元组
Raises:
BailianJsonParseError: 响应内容无法解析为 JSON
BailianApiError: API 调用失败(重试耗尽后)
"""
messages = self._inject_current_time(messages)
response = await self._call_with_retry(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format={"type": "json_object"},
)
raw_content = response.choices[0].message.content or ""
tokens_used = response.usage.total_tokens if response.usage else 0
try:
parsed = json.loads(raw_content)
except (json.JSONDecodeError, TypeError) as e:
logger.error("百炼 API 返回非法 JSON: %s", raw_content[:500])
raise BailianJsonParseError(
f"JSON 解析失败: {e}",
raw_content=raw_content,
) from e
return parsed, tokens_used
def _inject_current_time(self, messages: list[dict]) -> list[dict]:
"""纯函数:在首条消息的 contentJSON 字符串)中注入 current_time 字段。
- 深拷贝输入,不修改原始 messages
- 首条消息 content 尝试解析为 JSON注入 current_time
- 如果首条消息 content 不是 JSON则包装为 JSON
- 其余消息不变
- current_time 格式ISO 8601 精确到秒,如 2026-03-08T14:30:00
Args:
messages: 原始消息列表
Returns:
注入 current_time 后的新消息列表
"""
if not messages:
return []
result = copy.deepcopy(messages)
first = result[0]
content = first.get("content", "")
now_str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
parsed["current_time"] = now_str
else:
# content 是合法 JSON 但不是 dict如数组、字符串包装为 dict
parsed = {"original_content": parsed, "current_time": now_str}
except (json.JSONDecodeError, TypeError):
# content 不是 JSON包装为 dict
parsed = {"content": content, "current_time": now_str}
first["content"] = json.dumps(parsed, ensure_ascii=False)
return result
async def _call_with_retry(self, **kwargs: Any) -> Any:
"""带指数退避的重试封装。
重试策略:
- 最多重试 MAX_RETRIES 次(默认 3 次)
- 间隔BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
- HTTP 4xx不重试直接抛出401 → BailianAuthError
- HTTP 5xx / 超时:重试
Args:
**kwargs: 传递给 openai client 的参数
Returns:
API 响应对象
Raises:
BailianAuthError: API Key 无效HTTP 401
BailianApiError: API 调用失败(重试耗尽后)
"""
is_stream = kwargs.get("stream", False)
last_error: Exception | None = None
for attempt in range(self.MAX_RETRIES):
try:
if is_stream:
# 流式调用:返回 async iterator
return await self._client.chat.completions.create(**kwargs)
else:
return await self._client.chat.completions.create(**kwargs)
except openai.AuthenticationError as e:
# 401API Key 无效,不重试
logger.error("百炼 API 认证失败: %s", e)
raise BailianAuthError(str(e)) from e
except openai.BadRequestError as e:
# 400请求参数错误不重试
logger.error("百炼 API 请求参数错误: %s", e)
raise BailianApiError(str(e), status_code=400) from e
except openai.RateLimitError as e:
# 429限流不重试属于 4xx
logger.error("百炼 API 限流: %s", e)
raise BailianApiError(str(e), status_code=429) from e
except openai.PermissionDeniedError as e:
# 403权限不足不重试
logger.error("百炼 API 权限不足: %s", e)
raise BailianApiError(str(e), status_code=403) from e
except openai.NotFoundError as e:
# 404资源不存在不重试
logger.error("百炼 API 资源不存在: %s", e)
raise BailianApiError(str(e), status_code=404) from e
except openai.UnprocessableEntityError as e:
# 422不可处理不重试
logger.error("百炼 API 不可处理的请求: %s", e)
raise BailianApiError(str(e), status_code=422) from e
except (openai.InternalServerError, openai.APIConnectionError, openai.APITimeoutError) as e:
# 5xx / 超时 / 连接错误:重试
last_error = e
if attempt < self.MAX_RETRIES - 1:
wait_time = self.BASE_INTERVAL * (2 ** attempt)
logger.warning(
"百炼 API 调用失败(第 %d/%d 次),%ds 后重试: %s",
attempt + 1,
self.MAX_RETRIES,
wait_time,
e,
)
await asyncio.sleep(wait_time)
else:
logger.error(
"百炼 API 调用失败,已达最大重试次数 %d: %s",
self.MAX_RETRIES,
e,
)
# 重试耗尽
status_code = getattr(last_error, "status_code", None)
raise BailianApiError(
f"百炼 API 调用失败(重试 {self.MAX_RETRIES} 次后): {last_error}",
status_code=status_code,
) from last_error

View File

@@ -0,0 +1,101 @@
"""Token 预算追踪器 — 从 ai_run_logs 聚合日/月 token 消耗。
每次 AI 调用前检查预算,超限时拒绝请求。
日预算默认 100,000 tokens月预算默认 2,000,000 tokens。
聚合数据通过构造函数注入的 callable 获取(解耦 AIRunLogService
callable 签名:() -> int分别返回当日/当月已消耗 token 数。
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Protocol
class UsageProvider(Protocol):
"""Token 用量数据提供者协议。"""
def get_daily_usage(self) -> int:
"""返回当日已消耗 token 数。"""
...
def get_monthly_usage(self) -> int:
"""返回当月已消耗 token 数。"""
...
@dataclass
class BudgetStatus:
"""预算检查结果。"""
allowed: bool
daily_used: int
monthly_used: int
reason: str | None = None # "daily_exceeded" / "monthly_exceeded" / None
class BudgetTracker:
"""Token 预算追踪器,从 ai_run_logs 聚合。
支持两种注入方式:
1. 传入 UsageProvider 实例(如 AIRunLogService
2. 传入两个 callableget_daily_usage / get_monthly_usage
"""
def __init__(
self,
daily_limit: int = 100_000,
monthly_limit: int = 2_000_000,
*,
get_daily_usage: Callable[[], int] | None = None,
get_monthly_usage: Callable[[], int] | None = None,
usage_provider: UsageProvider | None = None,
) -> None:
self.daily_limit = daily_limit
self.monthly_limit = monthly_limit
# 优先使用 usage_provider其次使用独立 callable
if usage_provider is not None:
self._get_daily_usage = usage_provider.get_daily_usage
self._get_monthly_usage = usage_provider.get_monthly_usage
elif get_daily_usage is not None and get_monthly_usage is not None:
self._get_daily_usage = get_daily_usage
self._get_monthly_usage = get_monthly_usage
else:
raise ValueError(
"必须提供 usage_provider 或同时提供 "
"get_daily_usage 和 get_monthly_usage callable"
)
def check_budget(self) -> BudgetStatus:
"""检查当前预算状态。
先检查日预算,再检查月预算。
任一超限即返回 allowed=False 并附带原因。
"""
daily_used = self._get_daily_usage()
monthly_used = self._get_monthly_usage()
if daily_used >= self.daily_limit:
return BudgetStatus(
allowed=False,
daily_used=daily_used,
monthly_used=monthly_used,
reason="daily_exceeded",
)
if monthly_used >= self.monthly_limit:
return BudgetStatus(
allowed=False,
daily_used=daily_used,
monthly_used=monthly_used,
reason="monthly_exceeded",
)
return BudgetStatus(
allowed=True,
daily_used=daily_used,
monthly_used=monthly_used,
reason=None,
)

View File

@@ -3,18 +3,38 @@ AI 缓存读写服务。
负责 biz.ai_cache 表的 CRUD 和保留策略管理。
所有查询和写入操作强制 site_id 隔离。
P14 改造:
- 新增 status 字段处理valid/expired/invalidated/generating
- 查询仅返回 status='valid' 且未过期的记录
- 按 App 类型设置过期时间
- 每 App 保留最新 20,000 条
"""
from __future__ import annotations
import json
import logging
from datetime import datetime
from datetime import datetime, timedelta, timezone
from app.database import get_connection
logger = logging.getLogger(__name__)
# 缓存过期策略cache_type → 过期天数0 表示当日 23:59:59
CACHE_EXPIRY_DAYS: dict[str, int] = {
"app2_finance": 0, # 当日 23:59:59
"app3_clue": 7,
"app4_analysis": 7,
"app5_tactics": 7,
"app6_note_analysis": 30,
"app7_customer_analysis": 7,
"app8_clue_consolidated": 7,
}
# 每 App 保留上限
CACHE_MAX_PER_APP = 20_000
class AICacheService:
"""AI 缓存读写服务。"""
@@ -25,9 +45,9 @@ class AICacheService:
site_id: int,
target_id: str,
) -> dict | None:
"""查询最新缓存记录。
"""查询最新有效缓存记录。
按 (cache_type, site_id, target_id) 查询 created_at 最新的一条
仅返回 status='valid' 且未过期的记录
无记录时返回 None。
"""
conn = get_connection()
@@ -37,9 +57,11 @@ class AICacheService:
"""
SELECT id, cache_type, site_id, target_id,
result_json, score, triggered_by,
created_at, expires_at
created_at, expires_at, status
FROM biz.ai_cache
WHERE cache_type = %s AND site_id = %s AND target_id = %s
AND (status = 'valid' OR status IS NULL)
AND (expires_at IS NULL OR expires_at > now())
ORDER BY created_at DESC
LIMIT 1
""",
@@ -95,7 +117,15 @@ class AICacheService:
score: int | None = None,
expires_at: datetime | None = None,
) -> int:
"""写入缓存记录,返回 id。写入后清理超限记录。"""
"""写入缓存记录,返回 id。
自动设置 status='valid' 和按 App 类型计算 expires_at。
写入后清理超限记录(每 App 保留 20,000 条)。
"""
# 自动计算过期时间(如果未显式指定)
if expires_at is None:
expires_at = self._calc_expires_at(cache_type)
conn = get_connection()
try:
with conn.cursor() as cur:
@@ -103,8 +133,8 @@ class AICacheService:
"""
INSERT INTO biz.ai_cache
(cache_type, site_id, target_id, result_json,
triggered_by, score, expires_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
triggered_by, score, expires_at, status)
VALUES (%s, %s, %s, %s, %s, %s, %s, 'valid')
RETURNING id
""",
(
@@ -126,7 +156,7 @@ class AICacheService:
finally:
conn.close()
# 写入成功后清理超限记录(失败仅记录警告,不影响写入结果)
# 写入成功后清理超限记录
try:
deleted = self._cleanup_excess(cache_type, site_id, target_id)
if deleted > 0:
@@ -143,12 +173,89 @@ class AICacheService:
return cache_id
def set_generating(
self,
cache_type: str,
site_id: int,
target_id: str,
triggered_by: str | None = None,
) -> int:
"""写入 generating 状态占位记录,返回 id。完成后调用 finalize_cache 更新。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.ai_cache
(cache_type, site_id, target_id, result_json, status, triggered_by)
VALUES (%s, %s, %s, '{}', 'generating', %s)
RETURNING id
""",
(cache_type, site_id, target_id, triggered_by),
)
row = cur.fetchone()
conn.commit()
return row[0]
except Exception:
conn.rollback()
raise
finally:
conn.close()
def finalize_cache(
self,
cache_id: int,
result_json: dict,
score: int | None = None,
cache_type: str | None = None,
) -> None:
"""将 generating 记录更新为 valid填充结果和过期时间。"""
expires_at = self._calc_expires_at(cache_type) if cache_type else None
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_cache
SET result_json = %s, score = %s, status = 'valid', expires_at = %s
WHERE id = %s AND status = 'generating'
""",
(
json.dumps(result_json, ensure_ascii=False),
score,
expires_at,
cache_id,
),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
@staticmethod
def _calc_expires_at(cache_type: str | None) -> datetime | None:
"""根据 cache_type 计算过期时间。未知类型返回 None。"""
if cache_type is None:
return None
days = CACHE_EXPIRY_DAYS.get(cache_type)
if days is None:
return None
now = datetime.now(timezone.utc)
if days == 0:
# 当日 23:59:59UTC+8
local_now = now + timedelta(hours=8)
end_of_day = local_now.replace(hour=23, minute=59, second=59, microsecond=0)
return end_of_day - timedelta(hours=8) # 转回 UTC
return now + timedelta(days=days)
def _cleanup_excess(
self,
cache_type: str,
site_id: int,
target_id: str,
max_count: int = 500,
max_count: int = CACHE_MAX_PER_APP,
) -> int:
"""清理超限记录,保留最近 max_count 条,返回删除数量。"""
conn = get_connection()

View File

@@ -0,0 +1,116 @@
"""熔断器 — 按 app_id 独立的断路保护。
状态机CLOSED → OPEN连续失败达阈值→ HALF_OPEN超时后探测→ CLOSED/OPEN。
内存实现,单实例部署,不依赖外部存储。
"""
from __future__ import annotations
import enum
import time
from dataclasses import dataclass, field
class CircuitState(enum.Enum):
"""熔断器状态。"""
CLOSED = "closed" # 正常放行
OPEN = "open" # 熔断中,拒绝请求
HALF_OPEN = "half_open" # 探测中,放行单个请求
@dataclass
class _BreakerState:
"""单个 app_id 的熔断内部状态。"""
state: CircuitState = CircuitState.CLOSED
failure_count: int = 0
last_failure_time: float = 0.0
last_state_change: float = field(default_factory=time.monotonic)
class CircuitBreaker:
"""按 app_id 独立的熔断器。
- check()检查当前状态OPEN 且超时自动转 HALF_OPEN
- record_success()HALF_OPEN→CLOSEDCLOSED 重置失败计数
- record_failure()连续达阈值→OPENHALF_OPEN 失败→重新 OPEN
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60,
) -> None:
self._failure_threshold = failure_threshold
self._recovery_timeout = recovery_timeout
self._breakers: dict[str, _BreakerState] = {}
def _get_state(self, app_id: str) -> _BreakerState:
"""获取或初始化指定 app_id 的状态。"""
if app_id not in self._breakers:
self._breakers[app_id] = _BreakerState()
return self._breakers[app_id]
def check(self, app_id: str) -> CircuitState:
"""检查当前熔断状态。
- CLOSED / HALF_OPEN允许通过返回对应状态
- OPEN 且未超时:返回 OPEN拒绝
- OPEN 且已超时:自动转 HALF_OPEN返回 HALF_OPEN允许探测
"""
breaker = self._get_state(app_id)
if breaker.state == CircuitState.CLOSED:
return CircuitState.CLOSED
if breaker.state == CircuitState.HALF_OPEN:
return CircuitState.HALF_OPEN
# OPEN 状态:检查是否超过恢复超时
elapsed = time.monotonic() - breaker.last_failure_time
if elapsed >= self._recovery_timeout:
# 超时,转为 HALF_OPEN 探测
breaker.state = CircuitState.HALF_OPEN
breaker.last_state_change = time.monotonic()
return CircuitState.HALF_OPEN
return CircuitState.OPEN
def record_success(self, app_id: str) -> None:
"""记录调用成功。
- HALF_OPEN→CLOSED探测成功恢复正常
- CLOSED 下重置失败计数
"""
breaker = self._get_state(app_id)
if breaker.state == CircuitState.HALF_OPEN:
breaker.state = CircuitState.CLOSED
breaker.failure_count = 0
breaker.last_state_change = time.monotonic()
elif breaker.state == CircuitState.CLOSED:
# CLOSED 状态下成功重置失败计数
breaker.failure_count = 0
def record_failure(self, app_id: str) -> None:
"""记录调用失败。
- CLOSED累加失败计数达阈值→OPEN
- HALF_OPEN探测失败→重新 OPEN
"""
breaker = self._get_state(app_id)
now = time.monotonic()
if breaker.state == CircuitState.HALF_OPEN:
# 探测失败,重新熔断
breaker.state = CircuitState.OPEN
breaker.failure_count = self._failure_threshold
breaker.last_failure_time = now
breaker.last_state_change = now
elif breaker.state == CircuitState.CLOSED:
breaker.failure_count += 1
breaker.last_failure_time = now
if breaker.failure_count >= self._failure_threshold:
breaker.state = CircuitState.OPEN
breaker.last_state_change = now

View File

@@ -0,0 +1,68 @@
"""AI 模块配置 — 从环境变量加载 DashScope 相关参数。
所有 DASHSCOPE_* 环境变量和 INTERNAL_API_TOKEN 统一在此管理,
启动时通过 from_env() 校验必需变量,缺失立即报错。
"""
from __future__ import annotations
import os
from dataclasses import dataclass
@dataclass(frozen=True)
class AIConfig:
"""AI 模块配置从环境变量加载。不可变frozen"""
api_key: str # DASHSCOPE_API_KEY
workspace_id: str | None # DASHSCOPE_WORKSPACE_ID可选
app_id_1_chat: str # DASHSCOPE_APP_ID_1_CHAT
app_id_2_finance: str # DASHSCOPE_APP_ID_2_FINANCE
app_id_3_clue: str # DASHSCOPE_APP_ID_3_CLUE
app_id_4_analysis: str # DASHSCOPE_APP_ID_4_ANALYSIS
app_id_5_tactics: str # DASHSCOPE_APP_ID_5_TACTICS
app_id_6_note: str # DASHSCOPE_APP_ID_6_NOTE
app_id_7_customer: str # DASHSCOPE_APP_ID_7_CUSTOMER
app_id_8_consolidate: str # DASHSCOPE_APP_ID_8_CONSOLIDATE
internal_api_token: str # INTERNAL_API_TOKEN
@classmethod
def from_env(cls) -> AIConfig:
"""从环境变量加载配置。
必需变量缺失时立即抛出 ValueError禁止静默回退空字符串。
可选变量DASHSCOPE_WORKSPACE_ID缺失时为 None。
"""
required_mapping: dict[str, str] = {
"DASHSCOPE_API_KEY": "api_key",
"DASHSCOPE_APP_ID_1_CHAT": "app_id_1_chat",
"DASHSCOPE_APP_ID_2_FINANCE": "app_id_2_finance",
"DASHSCOPE_APP_ID_3_CLUE": "app_id_3_clue",
"DASHSCOPE_APP_ID_4_ANALYSIS": "app_id_4_analysis",
"DASHSCOPE_APP_ID_5_TACTICS": "app_id_5_tactics",
"DASHSCOPE_APP_ID_6_NOTE": "app_id_6_note",
"DASHSCOPE_APP_ID_7_CUSTOMER": "app_id_7_customer",
"DASHSCOPE_APP_ID_8_CONSOLIDATE": "app_id_8_consolidate",
"INTERNAL_API_TOKEN": "internal_api_token",
}
# 收集所有缺失的必需变量,一次性报错
missing: list[str] = []
values: dict[str, str] = {}
for env_name, field_name in required_mapping.items():
val = os.environ.get(env_name)
if not val: # None 或空字符串均视为缺失
missing.append(env_name)
else:
values[field_name] = val
if missing:
raise ValueError(
f"AI 配置缺失必需环境变量: {', '.join(missing)}"
)
# 可选变量
workspace_id = os.environ.get("DASHSCOPE_WORKSPACE_ID") or None
return cls(workspace_id=workspace_id, **values)

View File

@@ -27,6 +27,7 @@ class ConversationService:
site_id: int,
source_page: str | None = None,
source_context: dict | None = None,
title: str | None = None,
) -> int:
"""创建对话记录,返回 conversation_id。
@@ -38,8 +39,8 @@ class ConversationService:
cur.execute(
"""
INSERT INTO biz.ai_conversations
(user_id, nickname, app_id, site_id, source_page, source_context)
VALUES (%s, %s, %s, %s, %s, %s)
(user_id, nickname, app_id, site_id, source_page, source_context, title)
VALUES (%s, %s, %s, %s, %s, %s, %s)
RETURNING id
""",
(
@@ -49,6 +50,7 @@ class ConversationService:
site_id,
source_page,
json.dumps(source_context, ensure_ascii=False) if source_context else None,
title,
),
)
row = cur.fetchone()
@@ -89,6 +91,22 @@ class ConversationService:
finally:
conn.close()
def update_title(self, conversation_id: int, title: str) -> None:
"""更新对话标题。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE biz.ai_conversations SET title = %s WHERE id = %s",
(title, conversation_id),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def get_conversations(
self,
user_id: int | str,
@@ -104,7 +122,7 @@ class ConversationService:
cur.execute(
"""
SELECT id, user_id, nickname, app_id, site_id,
source_page, source_context, created_at
source_page, source_context, title, created_at
FROM biz.ai_conversations
WHERE user_id = %s AND site_id = %s
ORDER BY created_at DESC

View File

@@ -0,0 +1,318 @@
"""DashScope Application API 统一封装层。
使用 dashscope.Application.call() 调用百炼智能体应用,
替代原 openai SDK 的通用模型 API。
- call_app_stream(): App1 流式调用asyncio.Queue 桥接 async generator
- call_app(): App2~8 单轮调用asyncio.to_thread() 包装
- _call_with_retry(): 指数退避重试1s→2s→4s
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, AsyncGenerator, Callable
import dashscope
from dashscope import Application
from app.ai.exceptions import (
DashScopeApiError,
DashScopeAuthError,
DashScopeJsonParseError,
DashScopeTimeoutError,
)
logger = logging.getLogger(__name__)
class DashScopeClient:
"""DashScope Application API 统一封装层。
通过 app_id 调用百炼控制台配置的智能体应用,
充分利用云端 System Prompt 和 MCP 工具。
"""
MAX_RETRIES = 3
BASE_INTERVAL = 1 # 秒
def __init__(self, api_key: str, workspace_id: str | None = None):
"""初始化。dashscope 通过全局变量设置密钥。
Args:
api_key: DashScope API Key
workspace_id: 百炼工作空间 ID可选
"""
dashscope.api_key = api_key
self._workspace_id = workspace_id
async def call_app_stream(
self,
app_id: str,
prompt: str,
session_id: str | None = None,
biz_params: dict | None = None,
) -> AsyncGenerator[str, None]:
"""App1 流式调用。
在线程中消费同步迭代器,通过 asyncio.Queue 桥接到 async generator。
错误通过 queue 传递给调用方。
Args:
app_id: 百炼应用 ID
prompt: 用户输入
session_id: 百炼 session_id多轮对话
biz_params: 业务参数(如 user_prompt_params
Yields:
文本 chunk
"""
queue: asyncio.Queue[str | BaseException | None] = asyncio.Queue()
loop = asyncio.get_running_loop()
def _consume_in_thread() -> None:
"""在线程中消费同步迭代器,逐 chunk 放入 queue。"""
try:
call_kwargs: dict[str, Any] = {
"app_id": app_id,
"prompt": prompt,
"stream": True,
"incremental_output": True,
}
if session_id is not None:
call_kwargs["session_id"] = session_id
if biz_params is not None:
call_kwargs["biz_params"] = biz_params
if self._workspace_id is not None:
call_kwargs["workspace"] = self._workspace_id
response = Application.call(**call_kwargs)
for chunk in response:
if chunk.status_code == 200:
text = chunk.output.get("text", "")
if text:
asyncio.run_coroutine_threadsafe(
queue.put(text), loop
)
else:
# 非 200 状态码,构造异常传递给调用方
status = chunk.status_code
msg = getattr(chunk, "message", "") or f"状态码 {status}"
if status == 401:
err = DashScopeAuthError(msg)
else:
err = DashScopeApiError(msg, status_code=status)
asyncio.run_coroutine_threadsafe(
queue.put(err), loop
)
return
# 正常结束信号
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
except Exception as exc:
# 线程内未预期异常,传递给调用方
asyncio.run_coroutine_threadsafe(
queue.put(exc), loop
)
loop.run_in_executor(None, _consume_in_thread)
while True:
item = await queue.get()
if item is None:
break
if isinstance(item, BaseException):
raise item
yield item
async def call_app(
self,
app_id: str,
prompt: str,
session_id: str | None = None,
biz_params: dict | None = None,
) -> tuple[dict, int, str | None]:
"""App2~8 单轮调用。
通过 asyncio.to_thread() 包装同步 Application.call()
解析 response.output.text 获取 JSON 内容。
非合法 JSON 触发重试(最多 3 次),不做本地修复。
Args:
app_id: 百炼应用 ID
prompt: 后端拼好的完整数据 JSON 字符串
session_id: 百炼 session_id可选
biz_params: 业务参数(可选)
Returns:
(parsed_json, tokens_used, new_session_id) 元组
Raises:
DashScopeApiError: API 调用失败(重试耗尽)
DashScopeJsonParseError: JSON 解析失败(重试耗尽)
"""
call_kwargs: dict[str, Any] = {
"app_id": app_id,
"prompt": prompt,
}
if session_id is not None:
call_kwargs["session_id"] = session_id
if biz_params is not None:
call_kwargs["biz_params"] = biz_params
if self._workspace_id is not None:
call_kwargs["workspace"] = self._workspace_id
# 非合法 JSON 纯重试,最多 MAX_RETRIES 次
last_json_error: DashScopeJsonParseError | None = None
for json_attempt in range(self.MAX_RETRIES):
response = await self._call_with_retry(
Application.call, **call_kwargs
)
# 提取 output.text
raw_text: str = ""
if hasattr(response, "output"):
output = response.output
if isinstance(output, dict):
raw_text = output.get("text", "")
elif hasattr(output, "text"):
raw_text = output.text or ""
# 提取 tokens_used
tokens_used = 0
if hasattr(response, "usage") and response.usage:
usage = response.usage
if isinstance(usage, dict):
# input_tokens + output_tokens
tokens_used = usage.get("input_tokens", 0) + usage.get(
"output_tokens", 0
)
elif hasattr(usage, "total_tokens"):
tokens_used = usage.total_tokens or 0
# 提取 new_session_id
new_session_id: str | None = None
if hasattr(response, "output") and isinstance(response.output, dict):
new_session_id = response.output.get("session_id")
# 解析 JSON
try:
parsed = json.loads(raw_text)
if isinstance(parsed, list):
# CHANGE 2026-03-23 | Prompt: App2 LLM 返回 list 而非 dict
# 百炼 LLM 有时直接返回 insights 数组而非包裹 dict
# 自动包装为 {"insights": list} 避免无意义重试
logger.info(
"LLM 返回 list长度 %d),自动包装为 {\"insights\": [...]}",
len(parsed),
)
parsed = {"insights": parsed}
if not isinstance(parsed, dict):
raise TypeError(f"期望 dict实际 {type(parsed).__name__}")
return parsed, tokens_used, new_session_id
except (json.JSONDecodeError, TypeError) as e:
last_json_error = DashScopeJsonParseError(
f"JSON 解析失败(第 {json_attempt + 1}/{self.MAX_RETRIES} 次): {e}",
raw_content=raw_text,
)
logger.warning(
"Application API 返回非法 JSON%d/%d 次): %s",
json_attempt + 1,
self.MAX_RETRIES,
raw_text[:500],
)
# 非合法 JSON 纯重试,不做本地修复
continue
# JSON 重试耗尽
raise last_json_error # type: ignore[misc]
async def _call_with_retry(self, func: Callable, **kwargs: Any) -> Any:
"""指数退避重试封装。
重试策略:
- 最多重试 MAX_RETRIES 次(默认 3 次)
- 间隔BASE_INTERVAL × 2^(n-1),即 1s → 2s → 4s
- HTTP 4xx → 不重试立即抛出401 → DashScopeAuthError
- HTTP 5xx / 超时 / 连接错误 → 重试
Args:
func: 同步调用函数(如 Application.call
**kwargs: 传递给 func 的参数
Returns:
API 响应对象status_code == 200
Raises:
DashScopeAuthError: API Key 无效HTTP 401
DashScopeTimeoutError: 调用超时(重试耗尽)
DashScopeApiError: API 调用失败(重试耗尽)
"""
last_error: Exception | None = None
for attempt in range(self.MAX_RETRIES):
try:
response = await asyncio.to_thread(func, **kwargs)
except Exception as exc:
# 网络/连接/超时等底层异常 → 可重试
last_error = exc
if attempt < self.MAX_RETRIES - 1:
wait_time = self.BASE_INTERVAL * (2**attempt)
logger.warning(
"DashScope API 底层异常(第 %d/%d 次),%ds 后重试: %s",
attempt + 1,
self.MAX_RETRIES,
wait_time,
exc,
)
await asyncio.sleep(wait_time)
continue
else:
logger.error(
"DashScope API 底层异常,已达最大重试次数 %d: %s",
self.MAX_RETRIES,
exc,
)
raise DashScopeApiError(
f"DashScope API 调用失败(重试 {self.MAX_RETRIES} 次后): {exc}",
) from exc
# Application.call() 返回 response 对象,通过 status_code 判断成功/失败
status_code = getattr(response, "status_code", None)
if status_code == 200:
return response
# 非 200根据状态码分类处理
message = getattr(response, "message", "") or f"状态码 {status_code}"
if status_code is not None and 400 <= status_code < 500:
# 4xx不重试立即抛出
if status_code == 401:
raise DashScopeAuthError(message)
raise DashScopeApiError(message, status_code=status_code)
# 5xx 或其他未知状态码 → 可重试
last_error = DashScopeApiError(message, status_code=status_code)
if attempt < self.MAX_RETRIES - 1:
wait_time = self.BASE_INTERVAL * (2**attempt)
logger.warning(
"DashScope API 调用失败(第 %d/%d 次,状态码 %s%ds 后重试: %s",
attempt + 1,
self.MAX_RETRIES,
status_code,
wait_time,
message,
)
await asyncio.sleep(wait_time)
else:
logger.error(
"DashScope API 调用失败,已达最大重试次数 %d(状态码 %s: %s",
self.MAX_RETRIES,
status_code,
message,
)
# 重试耗尽
raise last_error # type: ignore[misc]

View File

@@ -0,0 +1,29 @@
"""AI 数据获取层。
为 AI 应用提供共享的数据获取函数,封装 FDW 查询和业务库查询逻辑。
所有 FDW 查询通过 get_etl_readonly_connection(site_id) 获取只读连接,
自动设置 RLS 门店隔离。
模块:
- member_data: 客户消费数据获取(应用 3/6/7 共用)
- assistant_data: 助教数据获取(应用 4/5 共用)
- page_context: 页面上下文文本化(应用 1 专用)
"""
from app.ai.data_fetchers.member_data import (
fetch_member_consumption_data,
fetch_member_notes,
)
from app.ai.data_fetchers.assistant_data import (
fetch_assistant_info,
fetch_service_history,
)
from app.ai.data_fetchers.page_context import build_page_text
__all__ = [
"fetch_member_consumption_data",
"fetch_member_notes",
"fetch_assistant_info",
"fetch_service_history",
"build_page_text",
]

View File

@@ -0,0 +1,253 @@
"""助教数据获取模块(应用 4/5 共用)。
从 ETL 库 app.v_* RLS 视图获取助教基本信息和助教-客户服务历史。
使用 is_delete 字段排除废单is_delete=0 为正常),禁止使用已废弃的 dwd_assistant_trash_event 表。
"""
# CHANGE 2026-03-23 | Prompt: FDW 迁移——fdw_etl.* → app.* 直连 ETL 库
# intent: 将所有 fdw_etl.* 外部表引用改为 app.v_* RLS 视图(直连 ETL 库),列名同步修正
# 连接方式不变get_etl_readonly_connection仅改 SQL 表名和列名
from __future__ import annotations
import asyncio
import logging
from datetime import date, datetime
from decimal import Decimal
from functools import partial
from typing import Any
from app.database import get_etl_readonly_connection
logger = logging.getLogger(__name__)
FDW_QUERY_TIMEOUT_SEC = 5
async def fetch_assistant_info(
site_id: int,
assistant_id: int,
) -> dict[str, Any]:
"""获取助教基本信息。
返回:
{
"nickname": str,
"level": str,
"hire_date": str,
"tenure_months": int,
"monthly_customers": int,
"performance_tier": str,
}
Raises:
ValueError: 助教不存在
TimeoutError: FDW 查询超时
ConnectionError: FDW 连接失败
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(_fetch_assistant_info_sync, site_id, assistant_id),
)
def _fetch_assistant_info_sync(site_id: int, assistant_id: int) -> dict[str, Any]:
"""同步实现。"""
conn = None
try:
conn = get_etl_readonly_connection(site_id)
# RLS 隔离 + 语句超时get_etl_readonly_connection 的 SET LOCAL 在 commit 后失效,
# 需在查询事务中重新设置)
with conn.cursor() as cur:
cur.execute(
"SET LOCAL app.current_site_id = %s", (str(site_id),)
)
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
# 基本信息
# ⚠️ v_dim_assistant 列名: hire_date→entry_time
cur.execute(
"""
SELECT nickname, level, entry_time AS hire_date
FROM app.v_dim_assistant
WHERE assistant_id = %s AND scd2_is_current = 1
LIMIT 1
""",
(assistant_id,),
)
row = cur.fetchone()
if not row:
raise ValueError(f"assistant not found: assistant_id={assistant_id}")
nickname = row[0] or ""
level = row[1] or ""
hire_date = row[2]
# 计算工龄
tenure_months = 0
if hire_date and isinstance(hire_date, date):
today = date.today()
tenure_months = (today.year - hire_date.year) * 12 + (today.month - hire_date.month)
# 绩效数据
# ⚠️ 列名映射: monthly_customers 不存在(用 0 占位performance_tier→tier_name
# ⚠️ salary_month 是 date 类型YYYY-MM-01按月降序取最新
cur.execute(
"""
SELECT
0 AS monthly_customers,
COALESCE(tier_name, '') AS performance_tier
FROM app.v_dws_assistant_salary_calc
WHERE assistant_id = %s
ORDER BY salary_month DESC
LIMIT 1
""",
(assistant_id,),
)
perf_row = cur.fetchone()
monthly_customers = perf_row[0] if perf_row else 0
performance_tier = perf_row[1] if perf_row else ""
conn.commit()
return {
"nickname": nickname,
"level": level,
"hire_date": hire_date.isoformat() if isinstance(hire_date, date) else "",
"tenure_months": tenure_months,
"monthly_customers": monthly_customers,
"performance_tier": performance_tier,
}
except (ValueError, TimeoutError, ConnectionError):
raise
except Exception as e:
err_msg = str(e).lower()
if "statement timeout" in err_msg or "timeout" in err_msg:
raise TimeoutError(
f"FDW 查询超时: assistant_id={assistant_id}"
) from e
if "connection" in err_msg or "connect" in err_msg:
raise ConnectionError(
f"FDW 连接失败: assistant_id={assistant_id}"
) from e
raise
finally:
if conn:
conn.close()
async def fetch_service_history(
site_id: int,
assistant_id: int,
member_id: int,
months: int = 3,
) -> list[dict[str, Any]]:
"""获取助教服务该客户的历史记录。
使用 is_delete 排除废单WHERE is_delete = 0
返回:
[
{
"service_date": str,
"duration_minutes": int,
"items_sum": float,
"room_name": str,
"is_pd": bool,
},
...
]
Raises:
TimeoutError: FDW 查询超时
ConnectionError: FDW 连接失败
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(_fetch_service_history_sync, site_id, assistant_id, member_id, months),
)
def _fetch_service_history_sync(
site_id: int,
assistant_id: int,
member_id: int,
months: int,
) -> list[dict[str, Any]]:
"""同步实现。"""
conn = None
try:
conn = get_etl_readonly_connection(site_id)
# RLS 隔离 + 语句超时get_etl_readonly_connection 的 SET LOCAL 在 commit 后失效,
# 需在查询事务中重新设置)
with conn.cursor() as cur:
cur.execute(
"SET LOCAL app.current_site_id = %s", (str(site_id),)
)
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
# ⚠️ 列名映射: assistant_id→site_assistant_id, member_id→tenant_member_id,
# is_trash=false→is_delete=0, service_date→create_time,
# duration_minutes→real_use_seconds/60, items_sum→ledger_amount,
# room_name→site_table_id, is_pd→(order_assistant_type=1)
cur.execute(
"""
SELECT
create_time AS service_date,
COALESCE(real_use_seconds / 60, 0) AS duration_minutes,
ledger_amount AS items_sum,
site_table_id AS room_name,
(order_assistant_type = 1) AS is_pd
FROM app.v_dwd_assistant_service_log
WHERE site_assistant_id = %s
AND tenant_member_id = %s
AND is_delete = 0
AND create_time >= (CURRENT_DATE - INTERVAL '%s months')
ORDER BY create_time DESC
""",
(assistant_id, member_id, months),
)
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
conn.commit()
records = []
for row in rows:
record = {}
for col, val in zip(columns, row):
if isinstance(val, (date, datetime)):
record[col] = val.isoformat()
elif isinstance(val, Decimal):
record[col] = float(val)
elif isinstance(val, bool):
record[col] = val
else:
record[col] = val
records.append(record)
return records
except (TimeoutError, ConnectionError):
raise
except Exception as e:
err_msg = str(e).lower()
if "statement timeout" in err_msg or "timeout" in err_msg:
raise TimeoutError(
f"FDW 查询超时: assistant_id={assistant_id}, member_id={member_id}"
) from e
if "connection" in err_msg or "connect" in err_msg:
raise ConnectionError(
f"FDW 连接失败: assistant_id={assistant_id}, member_id={member_id}"
) from e
raise
finally:
if conn:
conn.close()

View File

@@ -0,0 +1,402 @@
"""客户消费数据获取模块(应用 3/6/7 共用)。
从 ETL 库 app.v_* RLS 视图获取客户近 N 个月消费数据,从业务库获取备注。
金额口径统一使用拆分字段table_charge_money + assistant_pd/cx_money + goods_money禁止 consume_money。
会员信息通过 member_id JOIN v_dim_member (scd2_is_current=1) 获取。
"""
# CHANGE 2026-03-23 | Prompt: FDW 迁移——fdw_etl.* → app.* 直连 ETL 库
# intent: 将所有 fdw_etl.* 外部表引用改为 app.v_* RLS 视图(直连 ETL 库),列名同步修正
# 连接方式不变get_etl_readonly_connection仅改 SQL 表名和列名
from __future__ import annotations
import asyncio
import logging
from datetime import date, datetime
from decimal import Decimal
from functools import partial
from typing import Any
from app.database import get_connection, get_etl_readonly_connection
logger = logging.getLogger(__name__)
# 消费记录最大返回数
MAX_CONSUMPTION_RECORDS = 100
# 备注最大返回数
MAX_NOTES = 50
# 备注单条最大字符数
MAX_NOTE_LENGTH = 500
# FDW 查询超时(秒)
FDW_QUERY_TIMEOUT_SEC = 5
async def fetch_member_consumption_data(
site_id: int,
member_id: int,
months: int = 3,
) -> dict[str, Any]:
"""获取客户近 N 个月消费数据。
返回结构对应 NS2 设计文档中 main_data
- consumption_records: 消费记录列表(最多 100 条settle_date DESC
- member_cards: 会员卡明细列表
- card_balance_total: 储值卡余额合计
- stored_value_balance_total: 储值余额合计
- expected_visit_date: 预计到店日期
- days_since_last_visit: 距上次到店天数
- member_nickname: 会员昵称
Raises:
TimeoutError: FDW 查询超时(>5s
ConnectionError: FDW 连接失败
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(_fetch_member_consumption_data_sync, site_id, member_id, months),
)
def _fetch_member_consumption_data_sync(
site_id: int,
member_id: int,
months: int,
) -> dict[str, Any]:
"""同步实现:在单个 FDW 连接上串行执行多个查询。"""
conn = None
try:
conn = get_etl_readonly_connection(site_id)
# RLS 隔离 + 语句超时get_etl_readonly_connection 的 SET LOCAL 在 commit 后失效,
# 需在查询事务中重新设置)
with conn.cursor() as cur:
cur.execute(
"SET LOCAL app.current_site_id = %s", (str(site_id),)
)
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",), # 毫秒
)
# 1. 会员昵称
nickname = _query_member_nickname(conn, member_id)
# 2. 消费记录(台桌结账 + 商城订单)
records, total_count = _query_consumption_records(conn, member_id, months)
# 3. 会员卡明细
cards = _query_member_cards(conn, member_id)
# 4. 余额汇总
balance_info = _query_balance_summary(conn, member_id)
# 5. 到店数据
visit_info = _query_visit_info(conn, member_id)
result: dict[str, Any] = {
"member_nickname": nickname,
"consumption_records": records,
"member_cards": cards,
"card_balance_total": balance_info.get("card_balance_total", Decimal("0")),
"stored_value_balance_total": balance_info.get(
"stored_value_balance_total", Decimal("0")
),
"expected_visit_date": visit_info.get("expected_visit_date"),
"days_since_last_visit": visit_info.get("days_since_last_visit"),
}
if total_count > MAX_CONSUMPTION_RECORDS:
result["truncated"] = True
result["total_count"] = total_count
conn.commit()
return result
except Exception as e:
# psycopg2 超时异常包含 "statement timeout"
err_msg = str(e).lower()
if "statement timeout" in err_msg or "timeout" in err_msg:
raise TimeoutError(
f"FDW 查询超时(>{FDW_QUERY_TIMEOUT_SEC}s: member_id={member_id}"
) from e
if "connection" in err_msg or "connect" in err_msg:
raise ConnectionError(
f"FDW 连接失败: member_id={member_id}, error={e}"
) from e
raise
finally:
if conn:
conn.close()
def _query_member_nickname(conn: Any, member_id: int) -> str:
"""从 app.v_dim_member 获取会员昵称scd2_is_current=1"""
with conn.cursor() as cur:
cur.execute(
"""
SELECT nickname
FROM app.v_dim_member
WHERE member_id = %s AND scd2_is_current = 1
LIMIT 1
""",
(member_id,),
)
row = cur.fetchone()
return row[0] if row and row[0] else ""
def _query_consumption_records(
conn: Any, member_id: int, months: int
) -> tuple[list[dict], int]:
"""从 app.v_dwd_settlement_head + app.v_dwd_table_fee_log 获取消费记录。
仅包含正向交易settle_type IN (1, 3))。
⚠️ 费用拆分字段table_charge_money, assistant_pd/cx_money在 settlement_head 上。
⚠️ table_fee_log 提供台桌时长real_table_use_seconds和桌台IDsite_table_id
⚠️ 列名映射: settle_date→create_time, settle_id→order_settle_id, sale_amount→ledger_amount。
返回 (records, total_count)。
"""
with conn.cursor() as cur:
# 先查总数
cur.execute(
"""
SELECT COUNT(*)
FROM app.v_dwd_settlement_head sh
WHERE sh.member_id = %s
AND sh.settle_type IN (1, 3)
AND sh.create_time >= (CURRENT_DATE - INTERVAL '%s months')
""",
(member_id, months),
)
total_count = cur.fetchone()[0]
# 查询消费记录(限制 100 条)
# table_charge_money/assistant_pd_money/assistant_cx_money 直接从 settlement_head 取
# 台桌信息从 table_fee_log 取site_table_id, real_table_use_seconds
# 商品金额从 store_goods_sale 聚合
# 助教姓名从 service_log JOIN dim_assistant 获取
cur.execute(
"""
SELECT
sh.create_time AS settle_date,
sh.settle_type,
sh.table_charge_money + sh.assistant_pd_money + sh.assistant_cx_money
+ COALESCE(sg.goods_money, 0) AS items_sum,
COALESCE(sh.table_charge_money, 0) AS table_charge_money,
COALESCE(sh.assistant_pd_money, 0) AS assistant_pd_money,
COALESCE(sh.assistant_cx_money, 0) AS assistant_cx_money,
COALESCE(sg.goods_money, 0) AS goods_money,
tfl.site_table_id AS room_name,
COALESCE(tfl.real_table_use_seconds / 60, 0) AS duration_minutes,
coaches.assistant_names
FROM app.v_dwd_settlement_head sh
LEFT JOIN app.v_dwd_table_fee_log tfl
ON sh.order_settle_id = tfl.order_settle_id
LEFT JOIN LATERAL (
SELECT SUM(sgs.ledger_amount) AS goods_money
FROM app.v_dwd_store_goods_sale sgs
WHERE sgs.order_settle_id = sh.order_settle_id
) sg ON true
LEFT JOIN LATERAL (
SELECT string_agg(DISTINCT COALESCE(da.nickname, da.real_name, ''), ', ')
AS assistant_names
FROM app.v_dwd_assistant_service_log sl
LEFT JOIN app.v_dim_assistant da
ON sl.site_assistant_id = da.assistant_id
AND da.scd2_is_current = 1
WHERE sl.order_settle_id = sh.order_settle_id
AND sl.is_delete = 0
) coaches ON true
WHERE sh.member_id = %s
AND sh.settle_type IN (1, 3)
AND sh.create_time >= (CURRENT_DATE - INTERVAL '%s months')
ORDER BY sh.create_time DESC
LIMIT %s
""",
(member_id, months, MAX_CONSUMPTION_RECORDS),
)
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
records = []
for row in rows:
record = {}
for col, val in zip(columns, row):
if isinstance(val, (date, datetime)):
record[col] = val.isoformat()
elif isinstance(val, Decimal):
record[col] = float(val)
else:
record[col] = val
# assistant_names: 确保是列表
names = record.get("assistant_names")
if names and isinstance(names, str):
record["assistant_names"] = [n.strip() for n in names.split(",") if n.strip()]
elif not names:
record["assistant_names"] = []
records.append(record)
return records, total_count
def _query_member_cards(conn: Any, member_id: int) -> list[dict]:
"""从 app.v_dim_member_card_account 获取会员卡明细。
⚠️ 列名映射: member_id→tenant_member_id, gift_balance 不存在(用 balance - principal_balance 近似)。
"""
with conn.cursor() as cur:
cur.execute(
"""
SELECT member_card_type_name AS card_type,
COALESCE(balance, 0) AS balance,
COALESCE(balance, 0) - COALESCE(principal_balance, 0) AS gift_balance
FROM app.v_dim_member_card_account
WHERE tenant_member_id = %s AND scd2_is_current = 1
""",
(member_id,),
)
rows = cur.fetchall()
return [
{
"card_type": row[0] or "",
"balance": float(row[1]) if row[1] else 0.0,
"gift_balance": float(row[2]) if row[2] else 0.0,
}
for row in rows
]
def _query_balance_summary(conn: Any, member_id: int) -> dict:
"""从 app.v_dws_member_consumption_summary 获取余额汇总。
⚠️ 列名映射: recharge_card_amount→cash_card_balance, balance_amount→total_card_balance。
"""
with conn.cursor() as cur:
cur.execute(
"""
SELECT
COALESCE(cash_card_balance, 0) AS card_balance_total,
COALESCE(total_card_balance, 0) AS stored_value_balance_total
FROM app.v_dws_member_consumption_summary
WHERE member_id = %s
LIMIT 1
""",
(member_id,),
)
row = cur.fetchone()
if not row:
return {
"card_balance_total": Decimal("0"),
"stored_value_balance_total": Decimal("0"),
}
return {
"card_balance_total": row[0],
"stored_value_balance_total": row[1],
}
def _query_visit_info(conn: Any, member_id: int) -> dict:
"""从 app.v_dws_member_visit_detail 获取到店数据,推算预计到店日期。
⚠️ 列名映射: last_visit_date→MAX(visit_date), avg_visit_interval_days 需从明细计算。
"""
with conn.cursor() as cur:
# 获取最近到店日期和平均到店间隔
cur.execute(
"""
WITH visits AS (
SELECT visit_date,
LAG(visit_date) OVER (ORDER BY visit_date) AS prev_visit
FROM app.v_dws_member_visit_detail
WHERE member_id = %s
)
SELECT
MAX(visit_date) AS last_visit_date,
AVG(visit_date - prev_visit) AS avg_visit_interval_days
FROM visits
WHERE prev_visit IS NOT NULL
""",
(member_id,),
)
row = cur.fetchone()
if not row or not row[0]:
return {"expected_visit_date": None, "days_since_last_visit": None}
last_visit = row[0]
avg_interval = row[1]
today = date.today()
days_since = (today - last_visit).days if isinstance(last_visit, date) else None
expected = None
if avg_interval and last_visit:
from datetime import timedelta
expected_date = last_visit + timedelta(days=int(avg_interval))
expected = expected_date.isoformat()
return {
"expected_visit_date": expected,
"days_since_last_visit": days_since,
}
async def fetch_member_notes(
site_id: int,
member_id: int,
limit: int = MAX_NOTES,
) -> list[dict]:
"""获取客户的全部备注(按 created_at DESC最多 limit 条)。
从业务库 biz.notes 查询。
单条备注内容截断 500 字符,超出附加"…(已截断)"
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
partial(_fetch_member_notes_sync, site_id, member_id, limit),
)
def _fetch_member_notes_sync(
site_id: int,
member_id: int,
limit: int,
) -> list[dict]:
"""同步实现:从业务库查询备注。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
n.content,
u.nickname AS recorded_by,
n.created_at
FROM biz.notes n
LEFT JOIN biz.coach_tasks ct ON ct.id = n.task_id
LEFT JOIN public.users u ON u.id = n.user_id
WHERE n.target_id = %s AND n.site_id = %s
ORDER BY n.created_at DESC
LIMIT %s
""",
(member_id, site_id, limit),
)
rows = cur.fetchall()
notes = []
for row in rows:
content = row[0] or ""
recorded_by = row[1] or ""
created_at = row[2]
# 截断处理
if len(content) > MAX_NOTE_LENGTH:
content = content[:MAX_NOTE_LENGTH] + "…(已截断)"
notes.append({
"recorded_by": recorded_by,
"content": content,
"created_at": created_at.isoformat() if isinstance(created_at, (date, datetime)) else str(created_at) if created_at else "",
})
return notes
finally:
conn.close()

View File

@@ -0,0 +1,645 @@
"""页面上下文文本化模块(应用 1 专用)。
根据 contextType 从数据库获取对应页面数据,
格式化为结构化中文文本(≤ 2000 字符),供 AI 理解当前场景。
不传入 member_phone 等断档敏感字段。
"""
from __future__ import annotations
import asyncio
import logging
from datetime import date, datetime
from decimal import Decimal
from functools import partial
from typing import Any
from app.database import get_connection, get_etl_readonly_connection
logger = logging.getLogger(__name__)
MAX_PAGE_CONTEXT_LENGTH = 2000
FDW_QUERY_TIMEOUT_SEC = 5
# 支持的 10 种页面类型
SUPPORTED_PAGE_TYPES = {
"task-detail",
"customer-detail",
"coach-detail",
"board-finance",
"board-customer",
"board-coach",
"performance",
"my-profile",
"task-list",
"customer-service-records",
}
async def build_page_text(
source_page: str,
context_id: int | str | None,
site_id: int,
filters: dict | None = None,
) -> str:
"""将页面数据转换为 AI 可读的结构化中文文本。
Args:
source_page: 页面类型contextType
context_id: 实体 IDcontextId
site_id: 门店 ID
filters: 看板类页面的筛选参数
Returns:
结构化中文文本(≤ 2000 字符),失败时返回降级文本
"""
if not source_page or source_page not in SUPPORTED_PAGE_TYPES:
return ""
try:
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(
None,
partial(_build_page_text_sync, source_page, context_id, site_id, filters or {}),
)
# 截断保护
if len(text) > MAX_PAGE_CONTEXT_LENGTH:
text = text[:MAX_PAGE_CONTEXT_LENGTH - 20] + "\n…(上下文已截断)"
return text
except Exception:
logger.exception("页面上下文获取失败: source_page=%s", source_page)
return "页面上下文获取失败,请直接描述您的问题"
def _build_page_text_sync(
source_page: str,
context_id: int | str | None,
site_id: int,
filters: dict,
) -> str:
"""同步路由到对应页面文本化函数。"""
handlers = {
"task-detail": _text_task_detail,
"customer-detail": _text_customer_detail,
"coach-detail": _text_coach_detail,
"board-finance": _text_board_finance,
"board-customer": _text_board_customer,
"board-coach": _text_board_coach,
"performance": _text_performance,
"my-profile": _text_my_profile,
"task-list": _text_task_list,
"customer-service-records": _text_customer_service_records,
}
handler = handlers.get(source_page)
if not handler:
return ""
return handler(context_id, site_id, filters)
# ── 详情类页面 ──────────────────────────────────────────────────
def _text_task_detail(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""任务详情页文本化。"""
if not context_id:
return ""
task_id = int(context_id)
conn = get_connection()
try:
with conn.cursor() as cur:
# 任务信息
cur.execute(
"""
SELECT ct.task_type, ct.status, ct.deadline,
ct.member_id, ct.assistant_id,
dm.nickname AS member_nickname,
da.nickname AS assistant_nickname
FROM biz.coach_tasks ct
LEFT JOIN biz.coach_tasks_member_view dm
ON dm.member_id = ct.member_id AND dm.site_id = ct.site_id
LEFT JOIN biz.coach_tasks_assistant_view da
ON da.assistant_id = ct.assistant_id AND da.site_id = ct.site_id
WHERE ct.id = %s AND ct.site_id = %s
""",
(task_id, site_id),
)
task = cur.fetchone()
if not task:
return f"任务 {task_id} 不存在"
task_type, status, deadline, member_id, assistant_id, member_nick, asst_nick = task
# 最近备注(最多 3 条)
cur.execute(
"""
SELECT content, created_at
FROM biz.notes
WHERE task_id = %s AND site_id = %s
ORDER BY created_at DESC LIMIT 3
""",
(task_id, site_id),
)
notes = cur.fetchall()
# AI 缓存(最新分析)
cur.execute(
"""
SELECT result_json, created_at
FROM biz.ai_cache
WHERE cache_type = 'app4_analysis'
AND site_id = %s
AND target_id = %s
ORDER BY created_at DESC LIMIT 1
""",
(site_id, f"{assistant_id}_{member_id}"),
)
ai_row = cur.fetchone()
lines = [
"【任务详情】",
f" 任务类型:{task_type or '未知'}",
f" 状态:{status or '未知'}",
f" 截止日期:{_fmt_date(deadline)}",
f" 客户:{member_nick or f'ID:{member_id}'}",
f" 助教:{asst_nick or f'ID:{assistant_id}'}",
]
if notes:
lines.append("【最近备注】")
for content, created_at in notes:
short = (content or "")[:100]
lines.append(f" {_fmt_date(created_at)} {short}")
if ai_row:
lines.append(f"【AI 分析】最近更新于 {_fmt_date(ai_row[1])}")
return "\n".join(lines)
finally:
conn.close()
def _text_customer_detail(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""客户详情页文本化。"""
if not context_id:
return ""
member_id = int(context_id)
# 复用 member_data 的同步查询(避免循环导入,直接查询)
etl_conn = None
biz_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
# CHANGE 2026-03-23 | Prompt: FDW 迁移——fdw_etl.* → app.* 直连 ETL 库
# 会员信息
cur.execute(
"""
SELECT nickname
FROM app.v_dim_member
WHERE member_id = %s AND scd2_is_current = 1
""",
(member_id,),
)
m_row = cur.fetchone()
nickname = m_row[0] if m_row else f"ID:{member_id}"
# 最近 5 条消费
cur.execute(
"""
SELECT settle_date, items_sum, room_name
FROM app.v_dwd_settlement_head
WHERE member_id = %s AND settle_type IN (1, 3)
ORDER BY settle_date DESC LIMIT 5
""",
(member_id,),
)
recent = cur.fetchall()
# 余额
cur.execute(
"""
SELECT balance_amount
FROM app.v_dws_member_consumption_summary
WHERE member_id = %s LIMIT 1
""",
(member_id,),
)
bal_row = cur.fetchone()
etl_conn.commit()
# 维客线索
biz_conn = get_connection()
with biz_conn.cursor() as cur:
cur.execute(
"""
SELECT summary FROM member_retention_clue
WHERE member_id = %s AND site_id = %s
ORDER BY created_at DESC LIMIT 5
""",
(member_id, site_id),
)
clues = cur.fetchall()
lines = [
"【客户详情】",
f" 昵称:{nickname}",
f" 储值余额:{_fmt_decimal(bal_row[0]) if bal_row else '未知'}",
]
if recent:
lines.append("【近期消费】")
for sd, amt, room in recent:
lines.append(f" {_fmt_date(sd)} ¥{_fmt_decimal(amt)} {room or ''}")
if clues:
lines.append("【维客线索】")
for (summary,) in clues:
lines.append(f" {summary}")
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
if biz_conn:
biz_conn.close()
def _text_coach_detail(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""助教详情页文本化。"""
if not context_id:
return ""
assistant_id = int(context_id)
etl_conn = None
biz_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
cur.execute(
"""
SELECT nickname, level, hire_date
FROM app.v_dim_assistant
WHERE assistant_id = %s LIMIT 1
""",
(assistant_id,),
)
row = cur.fetchone()
etl_conn.commit()
if not row:
return f"助教 {assistant_id} 不存在"
nickname, level, hire_date = row
biz_conn = get_connection()
with biz_conn.cursor() as cur:
# 任务统计
cur.execute(
"""
SELECT status, COUNT(*)
FROM biz.coach_tasks
WHERE assistant_id = %s AND site_id = %s
GROUP BY status
""",
(assistant_id, site_id),
)
task_stats = cur.fetchall()
lines = [
"【助教详情】",
f" 花名:{nickname or ''}",
f" 级别:{level or ''}",
f" 入职日期:{_fmt_date(hire_date)}",
]
if task_stats:
lines.append("【任务统计】")
for status, cnt in task_stats:
lines.append(f" {status}: {cnt}")
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
if biz_conn:
biz_conn.close()
# ── 看板类页面 ──────────────────────────────────────────────────
def _text_board_finance(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""财务看板文本化。"""
time_dim = filters.get("timeDimension", "this_month")
area = filters.get("areaFilter", "")
etl_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
# 简化查询:获取汇总数据
cur.execute(
"""
SELECT
COUNT(*) AS settle_count,
COALESCE(SUM(items_sum), 0) AS total_revenue,
COALESCE(AVG(items_sum), 0) AS avg_revenue
FROM app.v_dwd_settlement_head
WHERE settle_type IN (1, 3)
AND settle_date >= (CURRENT_DATE - INTERVAL '1 month')
""",
)
row = cur.fetchone()
etl_conn.commit()
lines = [
"【财务看板】",
f" 时间维度:{time_dim}",
]
if area:
lines.append(f" 区域筛选:{area}")
if row:
lines.append(f" 结算笔数:{row[0]}")
lines.append(f" 总营收:¥{_fmt_decimal(row[1])}")
lines.append(f" 笔均:¥{_fmt_decimal(row[2])}")
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
def _text_board_customer(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""客户看板文本化。"""
dimension = filters.get("dimension", "consumption")
type_filter = filters.get("typeFilter", "")
etl_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
# Top 10 客户
cur.execute(
"""
SELECT
dm.nickname,
COALESCE(SUM(sh.items_sum), 0) AS total_consumption
FROM app.v_dwd_settlement_head sh
JOIN app.v_dim_member dm
ON dm.member_id = sh.member_id AND dm.scd2_is_current = 1
WHERE sh.settle_type IN (1, 3)
AND sh.member_id > 0
AND sh.settle_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY dm.nickname
ORDER BY total_consumption DESC
LIMIT 10
""",
)
rows = cur.fetchall()
etl_conn.commit()
lines = [
"【客户看板】",
f" 排序维度:{dimension}",
]
if type_filter:
lines.append(f" 类型筛选:{type_filter}")
if rows:
lines.append(" Top 10 客户:")
for i, (nick, amt) in enumerate(rows, 1):
lines.append(f" {i}. {nick or '未知'} ¥{_fmt_decimal(amt)}")
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
def _text_board_coach(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""助教看板文本化。"""
dimension = filters.get("dimension", "service_count")
project = filters.get("projectFilter", "")
time_dim = filters.get("timeDimension", "this_month")
etl_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
cur.execute(
"""
SELECT
da.nickname,
COUNT(*) AS service_count,
COALESCE(SUM(sl.ledger_amount), 0) AS total_revenue
FROM app.v_dwd_assistant_service_log sl
JOIN app.v_dim_assistant da
ON da.assistant_id = sl.site_assistant_id
WHERE sl.is_delete = 0
AND sl.create_time >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY da.nickname
ORDER BY service_count DESC
LIMIT 10
""",
)
rows = cur.fetchall()
etl_conn.commit()
lines = [
"【助教看板】",
f" 排序维度:{dimension}",
f" 时间维度:{time_dim}",
]
if project:
lines.append(f" 技能筛选:{project}")
if rows:
lines.append(" Top 10 助教:")
for i, (nick, cnt, amt) in enumerate(rows, 1):
lines.append(f" {i}. {nick or '未知'} 服务{cnt}次 ¥{_fmt_decimal(amt)}")
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
# ── 其他页面 ──────────────────────────────────────────────────
def _text_performance(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""绩效页面文本化。"""
time_dim = filters.get("timeDimension", "this_month")
etl_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
cur.execute(
"""
SELECT
da.nickname,
sc.performance_tier,
sc.monthly_customers
FROM app.v_dws_assistant_salary_calc sc
JOIN app.v_dim_assistant da
ON da.assistant_id = sc.assistant_id
ORDER BY sc.calc_month DESC, sc.monthly_customers DESC
LIMIT 10
""",
)
rows = cur.fetchall()
etl_conn.commit()
lines = [
"【绩效数据】",
f" 时间维度:{time_dim}",
]
if rows:
for nick, tier, customers in rows:
lines.append(f" {nick or '未知'} {tier or ''} 服务{customers or 0}")
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
def _text_my_profile(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""个人信息页文本化。"""
return "【个人信息】\n 当前为个人信息页面,可查询个人绩效和任务情况。"
def _text_task_list(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""任务列表页文本化。"""
if not context_id:
# 无特定任务,返回概要
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT status, COUNT(*)
FROM biz.coach_tasks
WHERE site_id = %s
GROUP BY status
""",
(site_id,),
)
stats = cur.fetchall()
lines = ["【任务列表】"]
for status, cnt in stats:
lines.append(f" {status}: {cnt}")
return "\n".join(lines)
finally:
conn.close()
# 有特定任务 ID复用 task-detail
return _text_task_detail(context_id, site_id, filters)
def _text_customer_service_records(
context_id: int | str | None, site_id: int, filters: dict
) -> str:
"""客户服务记录页文本化。"""
if not context_id:
return ""
member_id = int(context_id)
etl_conn = None
try:
etl_conn = get_etl_readonly_connection(site_id)
with etl_conn.cursor() as cur:
cur.execute(
"SET LOCAL statement_timeout = %s",
(f"{FDW_QUERY_TIMEOUT_SEC * 1000}",),
)
cur.execute(
"""
SELECT
create_time,
real_use_seconds / 60 AS duration_minutes,
ledger_amount,
site_table_id
FROM app.v_dwd_assistant_service_log
WHERE tenant_member_id = %s AND is_delete = 0
ORDER BY create_time DESC
LIMIT 10
""",
(member_id,),
)
rows = cur.fetchall()
etl_conn.commit()
lines = ["【服务记录】"]
if not rows:
lines.append(" 暂无服务记录")
else:
for sd, dur, amt, room in rows:
lines.append(
f" {_fmt_date(sd)} {dur or 0}分钟 ¥{_fmt_decimal(amt)} {room or ''}"
)
return "\n".join(lines)
finally:
if etl_conn:
etl_conn.close()
# ── 工具函数 ──────────────────────────────────────────────────
def _fmt_date(val: Any) -> str:
"""格式化日期值。"""
if isinstance(val, datetime):
return val.strftime("%Y-%m-%d %H:%M")
if isinstance(val, date):
return val.isoformat()
return str(val) if val else "未知"
def _fmt_decimal(val: Any) -> str:
"""格式化金额值。"""
if val is None:
return "0.00"
if isinstance(val, Decimal):
return f"{val:.2f}"
if isinstance(val, float):
return f"{val:.2f}"
return str(val)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,53 @@
"""AI 模块异常层级。
所有 DashScope 相关异常继承自 DashScopeError 基类,
便于上层统一捕获和分类处理。
"""
from __future__ import annotations
class DashScopeError(Exception):
"""DashScope 异常基类。"""
class DashScopeApiError(DashScopeError):
"""Application API 调用失败(重试耗尽后)。"""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code
class DashScopeAuthError(DashScopeApiError):
"""API Key 无效HTTP 401"""
def __init__(self, message: str = "API Key 无效或已过期"):
super().__init__(message, status_code=401)
class DashScopeTimeoutError(DashScopeApiError):
"""调用超时。"""
def __init__(self, message: str = "DashScope API 调用超时"):
super().__init__(message, status_code=None)
class DashScopeJsonParseError(DashScopeError):
"""响应 JSON 解析失败(重试耗尽后)。"""
def __init__(self, message: str, raw_content: str = ""):
super().__init__(message)
self.raw_content = raw_content
class CircuitOpenError(DashScopeError):
"""熔断器处于 OPEN 状态,拒绝请求。"""
class RateLimitExceededError(DashScopeError):
"""限流阈值超限。"""
class BudgetExceededError(DashScopeError):
"""Token 预算超限。"""

View File

@@ -0,0 +1,73 @@
"""限流器 — 滑动窗口内存计数器。
App1 按 user_id 限流(每用户每分钟 10 次),
App2~8 按 site_id 限流(每门店每小时 100 次)。
内存实现,单实例部署,不依赖外部存储。
"""
from __future__ import annotations
import time
from collections import deque
class RateLimiter:
"""滑动窗口内存限流器。
- check_user_rate()App1 每用户每分钟限流
- check_store_rate()App2~8 每门店每小时限流
每个 key 维护一个时间戳 deque检查时先清除过期条目
再判断窗口内请求数是否低于阈值。
"""
def __init__(self) -> None:
self._user_windows: dict[str, deque[float]] = {} # App1: user_id → 时间戳队列
self._store_windows: dict[str, deque[float]] = {} # App2~8: site_id → 时间戳队列
def _check(
self,
windows: dict[str, deque[float]],
key: str,
limit: int,
window_seconds: int,
) -> bool:
"""通用滑动窗口检查。返回 True 表示允许。"""
now = time.monotonic()
if key not in windows:
windows[key] = deque()
window = windows[key]
# 清除窗口外的过期时间戳
cutoff = now - window_seconds
while window and window[0] <= cutoff:
window.popleft()
# 判断是否超限
if len(window) >= limit:
return False
# 未超限,记录本次请求时间戳
window.append(now)
return True
def check_user_rate(
self,
user_id: str,
limit: int = 10,
window_seconds: int = 60,
) -> bool:
"""App1 每用户每分钟限流。返回 True 表示允许。"""
return self._check(self._user_windows, user_id, limit, window_seconds)
def check_store_rate(
self,
site_id: int,
limit: int = 100,
window_seconds: int = 3600,
) -> bool:
"""App2~8 每门店每小时限流。返回 True 表示允许。"""
# site_id 为 int转为 str 作为 dict key
return self._check(self._store_windows, str(site_id), limit, window_seconds)

View File

@@ -0,0 +1,207 @@
"""AI 运行日志服务 — biz.ai_run_logs 表的 CRUD 操作。
每次 Application API 调用前创建 pending 记录,调用过程中更新状态,
调用结束后记录结果。同时提供日/月 token 聚合查询,实现 UsageProvider 协议
以便注入 BudgetTracker。
request_prompt 写入前截断为前 2000 字符,避免大 prompt 占用过多存储。
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Callable
import psycopg2.extensions
# prompt 最大存储长度
_MAX_PROMPT_LENGTH = 2000
def _truncate_prompt(prompt: str | None) -> str | None:
"""截断 prompt 为前 2000 字符。None 原样返回。"""
if prompt is None:
return None
return prompt[:_MAX_PROMPT_LENGTH]
class AIRunLogService:
"""AI 运行日志 CRUD实现 UsageProvider 协议。
构造函数接受 get_conn callable每次操作时获取数据库连接
避免长期持有连接导致超时或连接池耗尽。
"""
def __init__(self, get_conn: Callable[[], psycopg2.extensions.connection]) -> None:
self._get_conn = get_conn
# ── 创建 ──────────────────────────────────────────────
def create_log(
self,
site_id: int,
app_type: str,
trigger_type: str,
*,
member_id: int | None = None,
request_prompt: str | None = None,
session_id: str | None = None,
) -> int:
"""创建日志记录status: pending返回 log_id。
request_prompt 自动截断为前 2000 字符。
"""
truncated = _truncate_prompt(request_prompt)
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.ai_run_logs
(site_id, app_type, trigger_type, member_id,
request_prompt, session_id, status)
VALUES (%s, %s, %s, %s, %s, %s, 'pending')
RETURNING id
""",
(site_id, app_type, trigger_type, member_id,
truncated, session_id),
)
row = cur.fetchone()
assert row is not None, "INSERT RETURNING 应返回 id"
log_id: int = row[0]
conn.commit()
return log_id
except Exception:
conn.rollback()
raise
# ── 状态转换 ──────────────────────────────────────────
def update_running(self, log_id: int) -> None:
"""更新为 running。"""
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_run_logs
SET status = 'running'
WHERE id = %s
""",
(log_id,),
)
conn.commit()
except Exception:
conn.rollback()
raise
def update_success(
self,
log_id: int,
response_text: str,
tokens_used: int,
latency_ms: int,
) -> None:
"""更新为 success记录响应、token 消耗和耗时。"""
now = datetime.now(timezone.utc)
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_run_logs
SET status = 'success',
response_text = %s,
tokens_used = %s,
latency_ms = %s,
finished_at = %s
WHERE id = %s
""",
(response_text, tokens_used, latency_ms, now, log_id),
)
conn.commit()
except Exception:
conn.rollback()
raise
def update_failed(
self,
log_id: int,
error_message: str,
latency_ms: int,
) -> None:
"""更新为 failed记录错误信息和耗时。"""
now = datetime.now(timezone.utc)
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_run_logs
SET status = 'failed',
error_message = %s,
latency_ms = %s,
finished_at = %s
WHERE id = %s
""",
(error_message, latency_ms, now, log_id),
)
conn.commit()
except Exception:
conn.rollback()
raise
def update_timeout(self, log_id: int, latency_ms: int) -> None:
"""更新为 timeout。"""
now = datetime.now(timezone.utc)
conn = self._get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_run_logs
SET status = 'timeout',
latency_ms = %s,
finished_at = %s
WHERE id = %s
""",
(latency_ms, now, log_id),
)
conn.commit()
except Exception:
conn.rollback()
raise
# ── UsageProvider 协议实现 ────────────────────────────
def get_daily_usage(self) -> int:
"""聚合今日 token 消耗status='success'created_at 为今日)。"""
conn = self._get_conn()
with conn.cursor() as cur:
cur.execute(
"""
SELECT COALESCE(SUM(tokens_used), 0)
FROM biz.ai_run_logs
WHERE status = 'success'
AND created_at >= CURRENT_DATE
AND created_at < CURRENT_DATE + INTERVAL '1 day'
"""
)
row = cur.fetchone()
return int(row[0]) if row else 0
def get_monthly_usage(self) -> int:
"""聚合本月 token 消耗status='success'created_at 为本月)。"""
conn = self._get_conn()
with conn.cursor() as cur:
cur.execute(
"""
SELECT COALESCE(SUM(tokens_used), 0)
FROM biz.ai_run_logs
WHERE status = 'success'
AND created_at >= date_trunc('month', CURRENT_DATE)
AND created_at < date_trunc('month', CURRENT_DATE) + INTERVAL '1 month'
"""
)
row = cur.fetchone()
return int(row[0]) if row else 0

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""RateLimiter 单元测试。
被测代码apps/backend/app/ai/rate_limiter.py
纯内存测试,不涉及 DB/网络。
"""
from __future__ import annotations
import time
from unittest.mock import patch
from app.ai.rate_limiter import RateLimiter
class TestCheckUserRate:
"""App1 每用户每分钟限流。"""
def test_allows_under_limit(self):
rl = RateLimiter()
for _ in range(10):
assert rl.check_user_rate("u1", limit=10) is True
def test_rejects_at_limit(self):
rl = RateLimiter()
for _ in range(10):
rl.check_user_rate("u1", limit=10)
assert rl.check_user_rate("u1", limit=10) is False
def test_different_users_independent(self):
rl = RateLimiter()
for _ in range(10):
rl.check_user_rate("u1", limit=10)
# u1 已满u2 不受影响
assert rl.check_user_rate("u1", limit=10) is False
assert rl.check_user_rate("u2", limit=10) is True
def test_window_expiry_allows_again(self):
"""窗口过期后,历史请求不影响当前判断。"""
rl = RateLimiter()
base = time.monotonic()
with patch("app.ai.rate_limiter.time.monotonic", return_value=base):
for _ in range(10):
rl.check_user_rate("u1", limit=10, window_seconds=60)
# 61 秒后,窗口内无请求
with patch("app.ai.rate_limiter.time.monotonic", return_value=base + 61):
assert rl.check_user_rate("u1", limit=10, window_seconds=60) is True
class TestCheckStoreRate:
"""App2~8 每门店每小时限流。"""
def test_allows_under_limit(self):
rl = RateLimiter()
for _ in range(5):
assert rl.check_store_rate(123, limit=5) is True
def test_rejects_at_limit(self):
rl = RateLimiter()
for _ in range(5):
rl.check_store_rate(123, limit=5)
assert rl.check_store_rate(123, limit=5) is False
def test_different_stores_independent(self):
rl = RateLimiter()
for _ in range(5):
rl.check_store_rate(100, limit=5)
assert rl.check_store_rate(100, limit=5) is False
assert rl.check_store_rate(200, limit=5) is True
def test_site_id_int_works(self):
"""site_id 为 int内部转 str 存储。"""
rl = RateLimiter()
assert rl.check_store_rate(2790685415443269, limit=100) is True
def test_window_expiry_allows_again(self):
rl = RateLimiter()
base = time.monotonic()
with patch("app.ai.rate_limiter.time.monotonic", return_value=base):
for _ in range(100):
rl.check_store_rate(123, limit=100, window_seconds=3600)
# 3601 秒后
with patch("app.ai.rate_limiter.time.monotonic", return_value=base + 3601):
assert rl.check_store_rate(123, limit=100, window_seconds=3600) is True
class TestRejectedRequestNotRecorded:
"""被拒绝的请求不应记录时间戳(不占用窗口配额)。"""
def test_rejected_user_request_not_counted(self):
rl = RateLimiter()
for _ in range(3):
rl.check_user_rate("u1", limit=3)
# 连续拒绝不应增加窗口内计数
rl.check_user_rate("u1", limit=3)
rl.check_user_rate("u1", limit=3)
assert len(rl._user_windows["u1"]) == 3
def test_rejected_store_request_not_counted(self):
rl = RateLimiter()
for _ in range(3):
rl.check_store_rate(1, limit=3)
rl.check_store_rate(1, limit=3)
rl.check_store_rate(1, limit=3)
assert len(rl._store_windows["1"]) == 3

View File

@@ -13,18 +13,92 @@ FastAPI 依赖注入:从 JWT 提取当前用户信息。
... # 受限逻辑
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from datetime import datetime
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError
from app.auth.jwt import decode_access_token
from app.trace.context import SpanType, TraceSpan, get_current_trace
from app.trace.decorators import truncate_token
# Bearer token 提取器
_bearer_scheme = HTTPBearer(auto_error=True)
# ── 鉴权失败原因分类常量 ──
AUTH_EXPIRED = "AUTH_EXPIRED"
AUTH_INVALID = "AUTH_INVALID"
AUTH_MALFORMED = "AUTH_MALFORMED"
AUTH_LIMITED = "AUTH_LIMITED"
AUTH_FORBIDDEN = "AUTH_FORBIDDEN"
def _record_auth_span(
*,
token: str,
success: bool,
user_id: int | None = None,
site_id: int | None = None,
roles: list[str] | None = None,
user_status: str = "",
failure_reason: str = "",
detail: str = "",
duration_ms: float = 0.0,
) -> None:
"""向当前 TraceContext 添加 AUTH span无 trace 时静默跳过)。"""
ctx = get_current_trace()
if ctx is None:
return
token_prefix = truncate_token(token)
if success:
desc_zh = f"JWT 鉴权通过user_id={user_id}, site_id={site_id}, roles={roles}"
desc_en = f"JWT auth passed: user_id={user_id}, site_id={site_id}, roles={roles}"
result_summary = "approved"
else:
desc_zh = f"JWT 鉴权失败:{failure_reason}{detail}"
desc_en = f"JWT auth failed: {failure_reason}{detail}"
result_summary = failure_reason
extra: dict = {}
if failure_reason:
extra["failure_reason"] = failure_reason
ctx.add_span(TraceSpan(
span_type=SpanType.AUTH,
module="auth.dependencies",
function="get_current_user",
description_zh=desc_zh,
description_en=desc_en,
params={"token_prefix": token_prefix},
result_summary=result_summary,
duration_ms=duration_ms,
timestamp=datetime.now().isoformat(),
extra=extra,
))
# 鉴权成功时将 user_id / site_id 写入 TraceContext
if success and user_id is not None:
ctx.user_id = user_id
if site_id is not None:
ctx.site_id = site_id
def _classify_jwt_error(exc: JWTError) -> str:
"""根据 JWTError 消息分类失败原因。"""
msg = str(exc).lower()
if "expired" in msg or "exp" in msg:
return AUTH_EXPIRED
return AUTH_INVALID
@dataclass(frozen=True)
class CurrentUser:
"""从 JWT 解析出的当前用户上下文。"""
@@ -45,9 +119,17 @@ async def get_current_user(
要求完整令牌(非 limited失败时抛出 401。
"""
token = credentials.credentials
start = time.perf_counter()
try:
payload = decode_access_token(token)
except JWTError:
except JWTError as exc:
elapsed = (time.perf_counter() - start) * 1000
reason = _classify_jwt_error(exc)
_record_auth_span(
token=token, success=False,
failure_reason=reason, detail="无效的令牌",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌",
@@ -56,6 +138,12 @@ async def get_current_user(
# 受限令牌不允许通过此依赖
if payload.get("limited"):
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=False,
failure_reason=AUTH_LIMITED, detail="受限令牌无法访问此端点",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="受限令牌无法访问此端点",
@@ -66,6 +154,12 @@ async def get_current_user(
site_id = payload.get("site_id")
if user_id_raw is None or site_id is None:
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=False,
failure_reason=AUTH_MALFORMED, detail="令牌缺少必要字段",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌缺少必要字段",
@@ -75,6 +169,12 @@ async def get_current_user(
try:
user_id = int(user_id_raw)
except (TypeError, ValueError):
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=False,
failure_reason=AUTH_MALFORMED, detail="令牌中 user_id 格式无效",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中 user_id 格式无效",
@@ -82,6 +182,13 @@ async def get_current_user(
)
roles = payload.get("roles", [])
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=True,
user_id=user_id, site_id=site_id, roles=roles,
user_status="approved", duration_ms=elapsed,
)
return CurrentUser(
user_id=user_id,
@@ -102,9 +209,17 @@ async def get_current_user_or_limited(
- 完整令牌:正常返回 CurrentUser
"""
token = credentials.credentials
start = time.perf_counter()
try:
payload = decode_access_token(token)
except JWTError:
except JWTError as exc:
elapsed = (time.perf_counter() - start) * 1000
reason = _classify_jwt_error(exc)
_record_auth_span(
token=token, success=False,
failure_reason=reason, detail="无效的令牌",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌",
@@ -113,6 +228,12 @@ async def get_current_user_or_limited(
user_id_raw = payload.get("sub")
if user_id_raw is None:
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=False,
failure_reason=AUTH_MALFORMED, detail="令牌缺少必要字段",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌缺少必要字段",
@@ -122,6 +243,12 @@ async def get_current_user_or_limited(
try:
user_id = int(user_id_raw)
except (TypeError, ValueError):
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=False,
failure_reason=AUTH_MALFORMED, detail="令牌中 user_id 格式无效",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中 user_id 格式无效",
@@ -130,6 +257,12 @@ async def get_current_user_or_limited(
# 受限令牌pending 用户
if payload.get("limited"):
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=True,
user_id=user_id, site_id=0, roles=[],
user_status="pending", duration_ms=elapsed,
)
return CurrentUser(
user_id=user_id,
site_id=0,
@@ -141,6 +274,12 @@ async def get_current_user_or_limited(
# 完整令牌:要求 site_id
site_id = payload.get("site_id")
if site_id is None:
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=False,
failure_reason=AUTH_MALFORMED, detail="令牌缺少必要字段",
duration_ms=elapsed,
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌缺少必要字段",
@@ -148,6 +287,13 @@ async def get_current_user_or_limited(
)
roles = payload.get("roles", [])
elapsed = (time.perf_counter() - start) * 1000
_record_auth_span(
token=token, success=True,
user_id=user_id, site_id=site_id, roles=roles,
user_status="approved", duration_ms=elapsed,
)
return CurrentUser(
user_id=user_id,

View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
"""
通用 Internal-Token 认证依赖。
从环境变量 INTERNAL_API_TOKEN 读取期望 token
供 /api/internal/* 端点使用(不依赖 AIConfig
"""
from __future__ import annotations
import os
from fastapi import Header, HTTPException, status
def verify_internal_token(authorization: str = Header(...)) -> str:
"""校验 Internal-Token 认证。
Header 格式Authorization: Internal-Token {token}
"""
prefix = "Internal-Token "
if not authorization.startswith(prefix):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证格式,需要 Internal-Token",
)
token = authorization[len(prefix):]
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 不能为空",
)
expected = os.environ.get("INTERNAL_API_TOKEN", "")
if not expected:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="INTERNAL_API_TOKEN 未配置",
)
if token != expected:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 不匹配",
)
return token

View File

@@ -0,0 +1,208 @@
# -*- coding: utf-8 -*-
"""
租户管理员认证依赖注入。
提供 require_tenant_admin() 依赖,验证 JWT aud=tenant-admin
与小程序端 get_current_user()aud 隐含为 xcx完全隔离。
用法:
@router.get("/protected")
async def endpoint(admin: CurrentTenantAdmin = Depends(require_tenant_admin)):
print(admin.admin_id, admin.managed_site_ids)
"""
from __future__ import annotations
# AI_CHANGELOG
# - 2026-03-23 21:00:00 | Prompt: P20260323-210000根治 tenant_admin managed_site_ids 限制)| Direct causeJWT managed_site_ids 静态签发,新建店铺后所有端点受限 | Summary新增 get_tenant_site_ids(tenant_id) 和 get_effective_site_ids(admin) 函数;改造 site_filter_clause 和 verify_site_access 支持 admin= keyword-only 参数(向后兼容旧签名)| Verifytenant_admin 新建店铺后无需重新登录即可访问所有端点
from dataclasses import dataclass, field
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError
from app.auth.jwt import decode_access_token
from app import config as _config
from jose import jwt as _jose_jwt
# 复用与 dependencies.py 相同的 Bearer 提取器
_bearer_scheme = HTTPBearer(auto_error=True)
@dataclass(frozen=True)
class CurrentTenantAdmin:
"""从 JWT 解析出的租户管理员上下文。"""
admin_id: int
tenant_id: int
managed_site_ids: list[int] = field(default_factory=list)
display_name: str | None = None
admin_type: str = "tenant_admin" # tenant_admin / site_admin
async def require_tenant_admin(
credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme),
) -> CurrentTenantAdmin:
"""
FastAPI 依赖:验证 JWT aud=tenant-admin提取管理员信息。
拒绝小程序 JWTaud 不匹配)及任何无效/过期令牌。
"""
token = credentials.credentials
try:
# 直接解码并验证 aud=tenant-admin + type=access
# 不能复用 decode_access_token(),因为它不传 audience 参数,
# jose 遇到 aud claim 但无 audience 参数时会直接拒绝。
payload = _jose_jwt.decode(
token,
_config.JWT_SECRET_KEY,
algorithms=[_config.JWT_ALGORITHM],
audience="tenant-admin",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的令牌",
headers={"WWW-Authenticate": "Bearer"},
)
# 验证 token type 为 access与 decode_access_token 一致)
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌类型不匹配",
headers={"WWW-Authenticate": "Bearer"},
)
# jose 在 aud claim 缺失时不会拒绝,需要显式检查
if payload.get("aud") != "tenant-admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌类型不匹配",
headers={"WWW-Authenticate": "Bearer"},
)
# 提取必要字段
sub = payload.get("sub")
tenant_id = payload.get("tenant_id")
managed_site_ids = payload.get("managed_site_ids")
if sub is None or tenant_id is None or managed_site_ids is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌缺少必要字段",
headers={"WWW-Authenticate": "Bearer"},
)
try:
admin_id = int(sub)
except (TypeError, ValueError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌中 admin_id 格式无效",
headers={"WWW-Authenticate": "Bearer"},
)
return CurrentTenantAdmin(
admin_id=admin_id,
tenant_id=tenant_id,
managed_site_ids=managed_site_ids,
display_name=payload.get("display_name"),
admin_type=payload.get("admin_type", "tenant_admin"),
)
# ── 数据隔离工具函数 ─────────────────────────────────────────
# [CHANGE P20260323-210000] intent: 根治 tenant_admin 的 managed_site_ids 限制,
# tenant_admin 按 tenant_id 查 biz.sites 获取有效 site_ids
# site_admin 仍用 JWT 中的 managed_site_ids。
# assumptions: biz.sites 数据量极小(几条),无需缓存
# verify: tenant_admin 新建店铺后无需重新登录即可访问
def get_tenant_site_ids(tenant_id: int) -> list[int]:
"""查询租户下所有活跃店铺的 site_id 列表。
通过 biz.tenants.tenant_id外部租户标识→ biz.tenants.id内部 PK
→ biz.sites.tenant_id 关联查询。
"""
from app.database import get_connection
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT s.site_id
FROM biz.sites s
JOIN biz.tenants t ON t.id = s.tenant_id
WHERE t.tenant_id = %s AND t.is_active = true
AND s.is_active = true
""",
(tenant_id,),
)
return [row[0] for row in cur.fetchall()]
finally:
conn.close()
def get_effective_site_ids(admin: CurrentTenantAdmin) -> list[int]:
"""获取管理员的有效 site_id 列表。
- tenant_admin实时查 biz.sites覆盖新建店铺
- site_admin使用 JWT 中的 managed_site_ids精确控制
"""
if admin.admin_type == "tenant_admin":
return get_tenant_site_ids(admin.tenant_id)
return admin.managed_site_ids
def site_filter_clause(
managed_site_ids: list[int] | None = None,
*,
admin: CurrentTenantAdmin | None = None,
) -> tuple[str, tuple]:
"""生成 site_id IN (...) SQL 片段,用于数据隔离查询。
优先使用 admin 参数(自动区分 tenant_admin/site_admin
也兼容旧的 managed_site_ids 直传方式。
返回 (sql_fragment, params_tuple),可直接拼入 WHERE 子句。
"""
if admin is not None:
site_ids = get_effective_site_ids(admin)
elif managed_site_ids is not None:
site_ids = managed_site_ids
else:
return "1 = 0", ()
if not site_ids:
return "1 = 0", ()
placeholders = ", ".join(["%s"] * len(site_ids))
return f"site_id IN ({placeholders})", tuple(site_ids)
def verify_site_access(
site_id: int,
managed_site_ids: list[int] | None = None,
*,
admin: CurrentTenantAdmin | None = None,
) -> None:
"""校验 site_id 是否在管辖范围内,不在则抛 403。
优先使用 admin 参数(自动区分 tenant_admin/site_admin
也兼容旧的 managed_site_ids 直传方式。
"""
if admin is not None:
effective_ids = get_effective_site_ids(admin)
elif managed_site_ids is not None:
effective_ids = managed_site_ids
else:
effective_ids = []
if site_id not in effective_ids:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问该门店数据",
)

View File

@@ -111,7 +111,8 @@ APP_DB_NAME: str = get("APP_DB_NAME", "test_zqyy_app")
JWT_SECRET_KEY: str = get("JWT_SECRET_KEY", "") # 生产环境必须设置
JWT_ALGORITHM: str = get("JWT_ALGORITHM", "HS256")
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int(get("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int(get("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# CHANGE 2026-03-27 | 权限改造 W1refresh_token 有效期 7天→30天配合滑动窗口续期
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int(get("JWT_REFRESH_TOKEN_EXPIRE_DAYS", "30"))
# ---- ETL 数据库连接参数(可独立配置,缺省时复用 zqyy_app 的连接参数) ----
ETL_DB_HOST: str = get("ETL_DB_HOST") or DB_HOST
@@ -177,6 +178,10 @@ WX_SECRET: str = get("WX_SECRET", "")
# 开发模式WX_DEV_MODE=true 时启用 mock 登录端点,跳过微信 code2Session
WX_DEV_MODE: bool = get("WX_DEV_MODE", "false").lower() in ("true", "1", "yes")
# ---- 用户头像存储 ----
# chooseAvatar 上传后保存到此目录,文件名 {user_id}.jpg
AVATAR_EXPORT_PATH: str = get("AVATAR_EXPORT_PATH", "")
# ---- 营业日分割点 ----
BUSINESS_DAY_START_HOUR: int = int(get("BUSINESS_DAY_START_HOUR", "8"))

View File

@@ -8,8 +8,13 @@
- get_connection()zqyy_app 读写连接(用户/队列/调度等业务数据)
- get_etl_readonly_connection(site_id)etl_feiqiu 只读连接(数据库查看器),
自动设置 RLS site_id 隔离
当 DEV_TRACE_ENABLED=true 且存在活跃 TraceContext 时,
get_connection() 返回 TracedConnection 包装,自动记录 DB_CONN / DB_QUERY / DB_CONN_RELEASE span。
"""
import time
import psycopg2
from psycopg2.extensions import connection as PgConnection
@@ -32,8 +37,19 @@ def get_connection() -> PgConnection:
获取 zqyy_app 数据库连接。
调用方负责关闭连接(推荐配合 contextmanager 或 try/finally 使用)。
当 trace 启用且有活跃 TraceContext 时,返回 TracedConnection 包装,
自动记录 DB_CONN span连接获取耗时并拦截后续 SQL 执行。
"""
return psycopg2.connect(
# CHANGE 2026-03-22 | task 8.2 | 集成 trace db_wrapper仅 trace 启用时包装
from app.trace.config import get_trace_config
from app.trace.context import SpanType, TraceSpan, get_current_trace
config = get_trace_config()
should_trace = config.enabled and get_current_trace() is not None
start = time.perf_counter() if should_trace else 0.0
conn = psycopg2.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USER,
@@ -41,6 +57,52 @@ def get_connection() -> PgConnection:
dbname=APP_DB_NAME,
)
if should_trace:
from datetime import datetime
from app.trace.db_wrapper import traced_connection
elapsed_ms = (time.perf_counter() - start) * 1000
ctx = get_current_trace()
# ctx 不为 None上面已检查
ctx.add_span(TraceSpan(
span_type=SpanType.DB_CONN,
module="app.database",
function="get_connection",
description_zh=f"获取数据库连接,耗时 {elapsed_ms:.1f}ms",
description_en=f"Acquired database connection in {elapsed_ms:.1f}ms",
params={},
result_summary=f"{elapsed_ms:.1f}ms",
duration_ms=elapsed_ms,
timestamp=datetime.now().isoformat(),
))
return traced_connection(conn)
return conn
def get_etl_global_readonly_connection() -> PgConnection:
"""
获取 ETL 数据库的全局只读连接(不设 RLS
用于系统管理后台等不需要门店隔离的场景(如 ETL 状态监控)。
"""
conn = psycopg2.connect(
host=ETL_DB_HOST,
port=ETL_DB_PORT,
user=ETL_DB_USER,
password=ETL_DB_PASSWORD,
dbname=ETL_DB_NAME,
)
try:
conn.autocommit = False
with conn.cursor() as cur:
cur.execute("SET default_transaction_read_only = on")
conn.commit()
except Exception:
conn.close()
raise
return conn
def get_etl_readonly_connection(site_id: int | str) -> PgConnection:
"""

View File

@@ -16,6 +16,7 @@ from app.middleware.response_wrapper import (
http_exception_handler,
unhandled_exception_handler,
)
from app.trace.middleware import TraceMiddleware
from app import config
# CHANGE 2026-02-19 | 新增 xcx_test 路由MVP 验证)+ wx_callback 路由(微信消息推送)
@@ -29,9 +30,16 @@ from app import config
# CHANGE 2026-03-18 | 新增 xcx_customers 路由CUST-1 客户详情、CUST-2 客户服务记录)
# CHANGE 2026-03-19 | 新增 xcx_coaches 路由COACH-1 助教详情)
# CHANGE 2026-03-19 | 新增 xcx_board / xcx_config 路由RNS1.3 三看板 + 技能类型配置)
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback, member_retention_clue, ops_panel, xcx_auth, admin_applications, business_day, xcx_tasks, xcx_notes, xcx_chat, xcx_ai_cache, xcx_performance, xcx_customers, xcx_coaches, xcx_board, xcx_config
# CHANGE 2026-03-22 | 新增 admin_registry 路由NS4.1 注册体系:租户/店铺/简写ID 管理)
# CHANGE 2026-03-23 | 新增 admin_ai 路由P15 AI 监控后台Dashboard/调度/调用/缓存/预算/批量/告警)
# CHANGE 2026-03-24 | 新增 admin_dev_trace 路由dev-trace-log: 开发调试日志管理 API
# CHANGE 2026-03-23 | 新增 trigger_jobs 路由(定时任务管理页面 API
# CHANGE 2026-03-24 | P18 任务引擎运营看板:新增 admin_task_engine 路由
# CHANGE 2026-03-29 | DWS_TASK_ENGINE新增 internal_events 路由(按 job_name 执行任务)
from app.routers import auth, execution, schedules, tasks, env_config, db_viewer, etl_status, xcx_test, wx_callback, member_retention_clue, ops_panel, xcx_auth, xcx_avatar, admin_applications, business_day, xcx_tasks, xcx_notes, xcx_chat, xcx_ai_cache, xcx_performance, xcx_customers, xcx_coaches, xcx_board, xcx_config, tenant_auth, tenant_users, tenant_excel, tenant_clues, tenant_site_admins, admin_tenant_admins, admin_registry, internal_ai, admin_ai, admin_dev_trace, trigger_jobs, admin_task_engine, admin_db_health, admin_triggers, internal_events
from app.services.scheduler import scheduler
from app.services.task_queue import task_queue
from app.services.task_executor import task_executor
from app.ws.logs import ws_router
@@ -56,40 +64,76 @@ async def lifespan(app: FastAPI):
)
print(_banner, flush=True)
# CHANGE 2026-03-22 | 启动时清理本机僵尸任务(上次非正常关闭遗留的 running 记录)
task_executor.recover_stale()
# 启动
task_queue.start()
scheduler.start()
# CHANGE 2026-02-27 | 注册触发器 job handler核心业务模块
# CHANGE 2026-03-24 | dev-trace-log: 用 trace_job 包装 job handler追踪后台任务执行
from app.services.trigger_scheduler import register_job
from app.services import task_generator, task_expiry, recall_detector, note_reclassifier
from app.trace.job_wrapper import trace_job
register_job("task_generator", lambda **_kw: task_generator.run())
register_job("task_expiry_check", lambda **_kw: task_expiry.run())
register_job("recall_completion_check", recall_detector.run)
register_job("note_reclassify_backfill", note_reclassifier.run)
register_job("task_generator", trace_job("task_generator")(lambda **_kw: task_generator.run()))
register_job("task_expiry_check", trace_job("task_expiry_check")(lambda **_kw: task_expiry.run()))
register_job("recall_completion_check", trace_job("recall_completion_check")(recall_detector.run))
register_job("note_reclassify_backfill", trace_job("note_reclassify_backfill")(note_reclassifier.run))
# CHANGE 2026-03-23 | 启动时检查定时任务是否今天执行过,打印提示
from app.services.trigger_scheduler import check_startup_jobs
try:
pending_jobs = check_startup_jobs()
if pending_jobs:
_lines = ["╔══ 定时任务提醒 ══════════════════════════════════════╗"]
for j in pending_jobs:
_lines.append(f"║ ⚠ {j['description']}{j['job_name']})— {j['last_run_at']}")
_lines.append("║ → 请在管理后台「定时任务」页面手动执行")
_lines.append("╚══════════════════════════════════════════════════════╝")
print("\n".join(_lines), flush=True)
else:
print("✓ 所有定时任务今天已执行过", flush=True)
except Exception:
import logging as _log
_log.getLogger(__name__).warning("启动检查定时任务失败", exc_info=True)
# CHANGE 2026-03-10 | 注册 AI 事件处理器(消费/备注/任务分配 → AI 调用链)
# CHANGE 2026-03-22 | P14 迁移BailianClient → DashScopeClient + AIConfig + 防护层
try:
import os
_api_key = os.environ.get("BAILIAN_API_KEY", "")
_base_url = os.environ.get("BAILIAN_BASE_URL", "")
_model = os.environ.get("BAILIAN_MODEL", "qwen-plus")
if _api_key and _base_url:
from app.ai.bailian_client import BailianClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.dispatcher import AIDispatcher, register_ai_handlers
from app.ai.config import AIConfig
from app.ai.dashscope_client import DashScopeClient
from app.ai.cache_service import AICacheService
from app.ai.conversation_service import ConversationService
from app.ai.circuit_breaker import CircuitBreaker
from app.ai.rate_limiter import RateLimiter
from app.ai.budget_tracker import BudgetTracker
from app.ai.run_log_service import AIRunLogService
from app.ai.dispatcher import AIDispatcher, register_ai_handlers
from app.database import get_connection
_bailian = BailianClient(api_key=_api_key, base_url=_base_url, model=_model)
_dispatcher = AIDispatcher(_bailian, AICacheService(), ConversationService())
register_ai_handlers(_dispatcher)
_ai_config = AIConfig.from_env()
_client = DashScopeClient(api_key=_ai_config.api_key, workspace_id=_ai_config.workspace_id)
_run_log_svc = AIRunLogService(get_conn=get_connection)
_dispatcher = AIDispatcher(
client=_client,
cache_svc=AICacheService(),
conv_svc=ConversationService(),
circuit_breaker=CircuitBreaker(),
rate_limiter=RateLimiter(),
budget_tracker=BudgetTracker(usage_provider=_run_log_svc),
run_log_svc=_run_log_svc,
config=_ai_config,
)
register_ai_handlers(_dispatcher)
except Exception:
import logging as _log
_log.getLogger(__name__).warning("AI 事件处理器注册失败AI 功能不可用", exc_info=True)
yield
# 关闭
# CHANGE 2026-03-22 | 优雅关闭先终止所有运行中的子进程3s 超时),再停调度和队列
await task_executor.shutdown(timeout=3.0)
await scheduler.stop()
await task_queue.stop()
@@ -117,6 +161,10 @@ app.add_middleware(
# CHANGE 2026-03-16 | RNS1.0 T0-1: 全局响应包装 + 异常处理器
app.add_middleware(ResponseWrapperMiddleware)
# ---- 全链路追踪中间件(最后添加 = 最先执行 = 最外层) ----
# CHANGE 2026-03-24 | dev-trace-log: TraceMiddleware 包裹所有中间件,仅拦截 /api/xcx/ 路由
app.add_middleware(TraceMiddleware)
# ---- 全局异常处理器 ----
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
@@ -135,6 +183,7 @@ app.include_router(wx_callback.router)
app.include_router(member_retention_clue.router)
app.include_router(ops_panel.router)
app.include_router(xcx_auth.router)
app.include_router(xcx_avatar.router)
app.include_router(admin_applications.router)
app.include_router(business_day.router)
app.include_router(xcx_tasks.router)
@@ -146,6 +195,21 @@ app.include_router(xcx_customers.router)
app.include_router(xcx_coaches.router)
app.include_router(xcx_board.router)
app.include_router(xcx_config.router)
app.include_router(tenant_auth.router)
app.include_router(tenant_users.router)
app.include_router(tenant_excel.router)
app.include_router(tenant_clues.router)
app.include_router(tenant_site_admins.router)
app.include_router(admin_tenant_admins.router)
app.include_router(admin_registry.router)
app.include_router(internal_ai.router)
app.include_router(internal_events.router)
app.include_router(admin_ai.router)
app.include_router(admin_dev_trace.router)
app.include_router(trigger_jobs.router)
app.include_router(admin_task_engine.router)
app.include_router(admin_db_health.router)
app.include_router(admin_triggers.router)
@app.get("/health", tags=["系统"])

View File

@@ -180,6 +180,10 @@ def _update_content_length(
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""HTTPException → { code: <status_code>, message: <detail> }"""
# 记录 ERROR spantrace 未激活时静默跳过)
from app.trace.error_handler import record_http_exception
record_http_exception(exc)
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": exc.detail},
@@ -189,6 +193,10 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""未捕获异常 → { code: 500, message: "Internal Server Error" }
完整堆栈写入服务端日志。"""
# 记录 ERROR spantrace 未激活时静默跳过)
from app.trace.error_handler import record_unhandled_exception
record_unhandled_exception(exc)
logger.exception("未捕获异常: %s", exc)
return JSONResponse(
status_code=500,

View File

@@ -0,0 +1,294 @@
# -*- coding: utf-8 -*-
"""
管理端 — AI 监控后台路由。
端点清单13 个,全部需要 JWT + admin 角色):
- GET /api/admin/ai/dashboard — 总览统计
- GET /api/admin/ai/trigger-jobs — 调度任务分页列表
- GET /api/admin/ai/trigger-jobs/{job_id} — 调度任务详情
- POST /api/admin/ai/trigger-jobs/{job_id}/retry — 手动重跑
- GET /api/admin/ai/run-logs — 调用记录分页列表
- GET /api/admin/ai/run-logs/{log_id} — 调用记录详情
- POST /api/admin/ai/cache/invalidate — 缓存失效
- GET /api/admin/ai/budget — Token 预算
- POST /api/admin/ai/batch-run — 创建批量执行(返回预估)
- POST /api/admin/ai/batch-run/confirm — 确认批量执行
- GET /api/admin/ai/alerts — 告警列表
- POST /api/admin/ai/alerts/{log_id}/ack — 确认告警
- POST /api/admin/ai/alerts/{log_id}/ignore — 忽略告警
需求: A1.1, A2.1, A2.4, A3.1, A4.1, A4.3, A5.1, A6.1, A7.1, A7.3, A8.1, A8.2, A8.3, A9.1, A9.2, A9.3
"""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_permission
from app.schemas.admin_ai import (
AlertActionResponse,
AlertListResponse,
BatchRunConfirm,
BatchRunConfirmResponse,
BatchRunEstimate,
BatchRunRequest,
BudgetResponse,
CacheInvalidateRequest,
CacheInvalidateResponse,
DashboardResponse,
RetryResponse,
RunLogDetailResponse,
RunLogListResponse,
TriggerJobDetailResponse,
TriggerJobListResponse,
)
from app.services.ai.admin_service import AdminAIService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin/ai", tags=["admin-ai"])
# ── 模块级服务实例 ────────────────────────────────────────
_admin_svc = AdminAIService()
# ── 权限依赖 ──────────────────────────────────────────────
def _require_admin():
"""
管理端依赖:要求 JWT status=approved 且角色包含 site_admin 或 tenant_admin。
"""
async def _dependency(
user: CurrentUser = Depends(require_permission()),
) -> CurrentUser:
admin_roles = {"site_admin", "tenant_admin"}
if not admin_roles.intersection(user.roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限site_admin 或 tenant_admin",
)
return user
return _dependency
# ── Dashboard ─────────────────────────────────────────────
@router.get("/dashboard", response_model=DashboardResponse)
async def get_dashboard(
site_id: Optional[int] = Query(None, description="门店 ID 筛选"),
user: CurrentUser = Depends(_require_admin()),
) -> DashboardResponse:
"""总览统计(支持 site_id 筛选)。"""
data = await _admin_svc.get_dashboard(site_id=site_id)
return DashboardResponse(**data)
# ── 调度任务 ──────────────────────────────────────────────
@router.get("/trigger-jobs", response_model=TriggerJobListResponse)
async def list_trigger_jobs(
event_type: Optional[str] = Query(None),
status_filter: Optional[str] = Query(None, alias="status"),
site_id: Optional[int] = Query(None),
date_from: Optional[str] = Query(None, description="起始日期 YYYY-MM-DD"),
date_to: Optional[str] = Query(None, description="截止日期 YYYY-MM-DD"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(_require_admin()),
) -> TriggerJobListResponse:
"""调度任务分页列表。"""
filters: dict = {}
if event_type is not None:
filters["event_type"] = event_type
if status_filter is not None:
filters["status"] = status_filter
if site_id is not None:
filters["site_id"] = site_id
if date_from is not None:
filters["date_from"] = date_from
if date_to is not None:
filters["date_to"] = date_to
data = await _admin_svc.list_trigger_jobs(filters, page=page, page_size=page_size)
return TriggerJobListResponse(**data)
@router.get("/trigger-jobs/{job_id}", response_model=TriggerJobDetailResponse)
async def get_trigger_job(
job_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> TriggerJobDetailResponse:
"""调度任务详情。"""
data = await _admin_svc.get_trigger_job(job_id)
if data is None:
raise HTTPException(status_code=404, detail="调度任务不存在")
return TriggerJobDetailResponse(**data)
@router.post("/trigger-jobs/{job_id}/retry", response_model=RetryResponse)
async def retry_trigger_job(
job_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> RetryResponse:
"""手动重跑:创建新 trigger_jobis_forced=true异步执行。"""
try:
new_job_id = await _admin_svc.retry_trigger_job(job_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return RetryResponse(trigger_job_id=new_job_id, status="pending")
# ── 调用记录 ──────────────────────────────────────────────
@router.get("/run-logs", response_model=RunLogListResponse)
async def list_run_logs(
app_type: Optional[str] = Query(None),
status_filter: Optional[str] = Query(None, alias="status"),
trigger_type: Optional[str] = Query(None),
site_id: Optional[int] = Query(None),
date_from: Optional[str] = Query(None, description="起始日期 YYYY-MM-DD"),
date_to: Optional[str] = Query(None, description="截止日期 YYYY-MM-DD"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(_require_admin()),
) -> RunLogListResponse:
"""调用记录分页列表。"""
filters: dict = {}
if app_type is not None:
filters["app_type"] = app_type
if status_filter is not None:
filters["status"] = status_filter
if trigger_type is not None:
filters["trigger_type"] = trigger_type
if site_id is not None:
filters["site_id"] = site_id
if date_from is not None:
filters["date_from"] = date_from
if date_to is not None:
filters["date_to"] = date_to
data = await _admin_svc.list_run_logs(filters, page=page, page_size=page_size)
return RunLogListResponse(**data)
@router.get("/run-logs/{log_id}", response_model=RunLogDetailResponse)
async def get_run_log(
log_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> RunLogDetailResponse:
"""调用记录详情(含完整 prompt/response/error不脱敏"""
data = await _admin_svc.get_run_log(log_id)
if data is None:
raise HTTPException(status_code=404, detail="调用记录不存在")
return RunLogDetailResponse(**data)
# ── 缓存管理 ──────────────────────────────────────────────
@router.post("/cache/invalidate", response_model=CacheInvalidateResponse)
async def invalidate_cache(
body: CacheInvalidateRequest,
user: CurrentUser = Depends(_require_admin()),
) -> CacheInvalidateResponse:
"""批量缓存失效:将匹配条件的 ai_cache.status 设为 invalidated。"""
affected = await _admin_svc.invalidate_cache(
site_id=body.site_id,
app_type=body.app_type,
member_id=body.member_id,
)
return CacheInvalidateResponse(affected_count=affected)
# ── Token 预算 ────────────────────────────────────────────
@router.get("/budget", response_model=BudgetResponse)
async def get_budget(
user: CurrentUser = Depends(_require_admin()),
) -> BudgetResponse:
"""Token 预算使用情况:日/月已用量、上限、百分比。"""
data = await _admin_svc.get_budget()
return BudgetResponse(**data)
# ── 批量执行 ──────────────────────────────────────────────
@router.post("/batch-run", response_model=BatchRunEstimate)
async def create_batch_run(
body: BatchRunRequest,
user: CurrentUser = Depends(_require_admin()),
) -> BatchRunEstimate:
"""创建批量执行请求,返回预估(不立即执行)。"""
data = await _admin_svc.estimate_batch(
app_types=body.app_types,
member_ids=body.member_ids,
site_id=body.site_id,
)
return BatchRunEstimate(**data)
@router.post("/batch-run/confirm", response_model=BatchRunConfirmResponse)
async def confirm_batch_run(
body: BatchRunConfirm,
user: CurrentUser = Depends(_require_admin()),
) -> BatchRunConfirmResponse:
"""确认批量执行,后台异步执行。"""
try:
await _admin_svc.confirm_batch(batch_id=body.batch_id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return BatchRunConfirmResponse(status="started")
# ── 告警管理 ──────────────────────────────────────────────
@router.get("/alerts", response_model=AlertListResponse)
async def list_alerts(
alert_status: Optional[str] = Query(None, description="pending / acknowledged / ignored"),
site_id: Optional[int] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(_require_admin()),
) -> AlertListResponse:
"""告警列表ai_run_logs WHERE status IN ('failed','timeout','circuit_open'))。"""
data = await _admin_svc.list_alerts(
alert_status=alert_status,
site_id=site_id,
page=page,
page_size=page_size,
)
return AlertListResponse(**data)
@router.post("/alerts/{log_id}/ack", response_model=AlertActionResponse)
async def ack_alert(
log_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> AlertActionResponse:
"""确认告警alert_status → acknowledged。"""
new_status = await _admin_svc.ack_alert(log_id)
return AlertActionResponse(id=log_id, alert_status=new_status)
@router.post("/alerts/{log_id}/ignore", response_model=AlertActionResponse)
async def ignore_alert(
log_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> AlertActionResponse:
"""忽略告警alert_status → ignored。"""
new_status = await _admin_svc.ignore_alert(log_id)
return AlertActionResponse(id=log_id, alert_status=new_status)

View File

@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
"""管理端 — 数据库健康监控 API
提供 1 个端点:
- GET /api/admin/db-health — 返回 4 个数据库的健康状态
遍历 etl_feiqiu / test_etl_feiqiu / zqyy_app / test_zqyy_app
对每个库执行诊断 SQL连接池、大小、慢查询
连接失败时返回 disconnected 状态,不抛出 HTTP 错误。
需求: 6.1, 6.2, 6.3, 6.4
"""
from __future__ import annotations
import logging
import os
import psycopg2
from fastapi import APIRouter, Depends
from app.auth.dependencies import CurrentUser, get_current_user
from app.config import (
DB_HOST,
DB_PASSWORD,
DB_PORT,
DB_USER,
ETL_DB_HOST,
ETL_DB_PASSWORD,
ETL_DB_PORT,
ETL_DB_USER,
)
from app.schemas.admin_db_health import DbHealthItem
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin/db-health", tags=["系统管理"])
# 4 个数据库的连接参数:业务库正式/测试 + ETL 库正式/测试
DB_CONFIGS: list[dict] = [
{
"db_name": "zqyy_app",
"host": DB_HOST,
"port": DB_PORT,
"user": DB_USER,
"password": DB_PASSWORD,
"dbname": "zqyy_app",
},
{
"db_name": "test_zqyy_app",
"host": DB_HOST,
"port": DB_PORT,
"user": DB_USER,
"password": DB_PASSWORD,
"dbname": "test_zqyy_app",
},
{
"db_name": "etl_feiqiu",
"host": ETL_DB_HOST,
"port": ETL_DB_PORT,
"user": ETL_DB_USER,
"password": ETL_DB_PASSWORD,
"dbname": "etl_feiqiu",
},
{
"db_name": "test_etl_feiqiu",
"host": ETL_DB_HOST,
"port": ETL_DB_PORT,
"user": ETL_DB_USER,
"password": ETL_DB_PASSWORD,
"dbname": "test_etl_feiqiu",
},
]
# 诊断 SQL连接池状态
_SQL_CONNECTIONS = """
SELECT
count(*) FILTER (WHERE state = 'active') AS active_connections,
count(*) FILTER (WHERE state = 'idle') AS idle_connections
FROM pg_stat_activity
WHERE datname = current_database();
"""
# 诊断 SQL数据库大小MB
_SQL_DB_SIZE = """
SELECT pg_database_size(current_database()) / (1024.0 * 1024.0) AS db_size_mb;
"""
# 诊断 SQL慢查询最近 1 小时内执行时间超过 1 秒)
_SQL_SLOW_QUERIES = """
SELECT count(*) AS slow_query_count
FROM pg_stat_activity
WHERE datname = current_database()
AND state = 'active'
AND query_start < now() - interval '1 second'
AND query_start > now() - interval '1 hour';
"""
def _check_single_db(cfg: dict) -> DbHealthItem:
"""对单个数据库执行诊断,连接失败时返回 disconnected。"""
db_name = cfg["db_name"]
try:
# CHANGE 2026-03-29 | Windows GBK 环境下 psycopg2/libpq 构建连接字符串时
# 会读取系统用户名/计算机名,含中文时触发 UnicodeDecodeError0xd6 是 GBK 首字节)。
# 用显式 DSN 字符串连接,避免 libpq 自动拼接时混入系统 locale 信息。
dsn = (
f"host={cfg['host']} port={cfg['port']} "
f"dbname={cfg['dbname']} user={cfg['user']} "
f"password={cfg['password']} "
f"connect_timeout=5 client_encoding=UTF8 "
f"application_name=neozqyy_health"
)
os.environ.setdefault("PGCLIENTENCODING", "UTF8")
conn = psycopg2.connect(dsn)
except Exception:
logger.warning("数据库 %s 连接失败", db_name, exc_info=True)
return DbHealthItem(db_name=db_name, status="disconnected")
try:
with conn.cursor() as cur:
# 连接池状态
cur.execute(_SQL_CONNECTIONS)
row = cur.fetchone()
active_connections = row[0] if row else 0
idle_connections = row[1] if row else 0
# 数据库大小
cur.execute(_SQL_DB_SIZE)
row = cur.fetchone()
db_size_mb = round(float(row[0]), 2) if row else 0.0
# 慢查询
cur.execute(_SQL_SLOW_QUERIES)
row = cur.fetchone()
slow_query_count = row[0] if row else 0
return DbHealthItem(
db_name=db_name,
status="connected",
active_connections=active_connections,
idle_connections=idle_connections,
db_size_mb=db_size_mb,
slow_query_count=slow_query_count,
)
except Exception:
logger.warning("数据库 %s 诊断 SQL 执行失败", db_name, exc_info=True)
return DbHealthItem(db_name=db_name, status="disconnected")
finally:
conn.close()
@router.get("", response_model=list[DbHealthItem])
async def get_db_health(
user: CurrentUser = Depends(get_current_user),
) -> list[DbHealthItem]:
"""返回 4 个数据库的健康状态。
遍历 DB_CONFIGS 中的 4 个库,逐一执行诊断 SQL。
连接失败时返回 disconnected 状态,不抛出 HTTP 错误。
即使所有库都连接失败,仍返回 HTTP 200。
"""
return [_check_single_db(cfg) for cfg in DB_CONFIGS]

View File

@@ -0,0 +1,374 @@
# -*- coding: utf-8 -*-
"""
管理端 — 开发调试全链路日志路由。
端点清单8 个,全部需要 JWT + admin 角色):
- GET /api/admin/dev-trace/dates — 有日志数据的日期列表
- GET /api/admin/dev-trace/requests — 按条件分页查询请求列表
- GET /api/admin/dev-trace/request/{id} — 指定 request_id 的完整 trace
- POST /api/admin/dev-trace/cleanup — 按日期范围手动清理日志
- GET /api/admin/dev-trace/settings — 当前设置
- PUT /api/admin/dev-trace/settings — 更新运行时设置
- GET /api/admin/dev-trace/coverage — 最近一次覆盖率扫描结果
- POST /api/admin/dev-trace/coverage/scan — 手动触发覆盖率扫描
"""
from __future__ import annotations
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from app.auth.dependencies import CurrentUser, get_current_user
from app.trace.cleanup import cleanup_date_range
from app.trace.config import get_trace_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin/dev-trace", tags=["开发调试日志"])
# 日期目录名格式
_DATE_FORMAT = "%Y-%m-%d"
# ── 权限依赖 ──────────────────────────────────────────────
def _require_admin():
"""管理端依赖:仅要求 JWT 认证通过。
dev-trace 是开发调试工具,不涉及业务数据,无需检查业务角色
site_admin / tenant_admin。只要是 admin-web 的已认证用户即可访问。
"""
async def _dependency(
user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
return user
return _dependency
# ── Pydantic 请求/响应模型 ────────────────────────────────
class CleanupRequest(BaseModel):
"""手动清理请求体。"""
start_date: str
end_date: str
class SettingsUpdate(BaseModel):
"""运行时设置更新请求体(所有字段可选)。"""
enabled: Optional[bool] = None
retention_days: Optional[int] = None
log_sql: Optional[bool] = None
log_params: Optional[bool] = None
# ── 辅助函数 ──────────────────────────────────────────────
def _get_base_dir() -> Path:
"""获取日志根目录 Path 对象。"""
return Path(get_trace_config().log_dir)
def _is_date_dir(name: str) -> bool:
"""判断目录名是否为 YYYY-MM-DD 格式。"""
try:
datetime.strptime(name, _DATE_FORMAT)
return True
except ValueError:
return False
def _read_jsonl_file(filepath: Path) -> list[dict[str, Any]]:
"""逐行读取 .jsonl 文件,跳过解析失败的行。"""
records: list[dict[str, Any]] = []
if not filepath.exists():
return records
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
records.append(json.loads(line))
except json.JSONDecodeError:
continue
return records
def _match_filter(record: dict[str, Any], **filters: Any) -> bool:
"""检查单条 trace 记录是否满足所有筛选条件。"""
# trace_type
if filters.get("trace_type") and record.get("trace_type") != filters["trace_type"]:
return False
# method
if filters.get("method") and record.get("method", "").upper() != filters["method"].upper():
return False
# path_contains
if filters.get("path_contains") and filters["path_contains"].lower() not in (record.get("path") or "").lower():
return False
# status_code
if filters.get("status_code") is not None and record.get("status_code") != filters["status_code"]:
return False
# min_duration
if filters.get("min_duration") is not None and (record.get("total_duration_ms") or 0) < filters["min_duration"]:
return False
# has_error
if filters.get("has_error") is not None:
has_err = record.get("error") is not None
if filters["has_error"] != has_err:
return False
# span_type — 检查 spans 列表中是否包含指定类型
if filters.get("span_type"):
span_types = {s.get("span_type") for s in record.get("spans", [])}
if filters["span_type"] not in span_types:
return False
# start_time / end_time — 基于 record.timestamp
ts_str = record.get("timestamp", "")
if ts_str and (filters.get("start_time") or filters.get("end_time")):
try:
rec_dt = datetime.fromisoformat(ts_str)
rec_time = rec_dt.time()
if filters.get("start_time") and rec_time < filters["start_time"]:
return False
if filters.get("end_time") and rec_time > filters["end_time"]:
return False
except (ValueError, TypeError):
pass
return True
# ── 1. GET /dates — 有日志数据的日期列表 ─────────────────
@router.get("/dates")
async def list_dates(
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, list[str]]:
"""返回有日志数据的日期列表(降序排列)。"""
base = _get_base_dir()
if not base.exists():
return {"dates": []}
dates = sorted(
[d.name for d in base.iterdir() if d.is_dir() and _is_date_dir(d.name)],
reverse=True,
)
return {"dates": dates}
# ── 2. GET /requests — 按条件分页查询请求列表 ─────────────
@router.get("/requests")
async def list_requests(
date: str = Query(..., description="日期,格式 YYYY-MM-DD"),
start_time: Optional[str] = Query(None, description="起始时间 HH:MM:SS"),
end_time: Optional[str] = Query(None, description="结束时间 HH:MM:SS"),
trace_type: Optional[str] = Query(None, description="trace 类型http/sse/ws/job"),
method: Optional[str] = Query(None, description="HTTP 方法GET/POST/PUT/DELETE"),
path_contains: Optional[str] = Query(None, description="路径关键词"),
status_code: Optional[int] = Query(None, description="HTTP 状态码"),
min_duration: Optional[float] = Query(None, description="最小耗时ms"),
has_error: Optional[bool] = Query(None, description="是否有错误"),
span_type: Optional[str] = Query(None, description="包含的 span 类型"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(50, ge=1, le=200, description="每页条数"),
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""按条件分页查询指定日期的请求列表。"""
# 解析时间参数
parsed_start = None
parsed_end = None
if start_time:
try:
parsed_start = datetime.strptime(start_time, "%H:%M:%S").time()
except ValueError:
raise HTTPException(status_code=400, detail="start_time 格式无效,应为 HH:MM:SS")
if end_time:
try:
parsed_end = datetime.strptime(end_time, "%H:%M:%S").time()
except ValueError:
raise HTTPException(status_code=400, detail="end_time 格式无效,应为 HH:MM:SS")
# 读取指定日期目录下所有 .jsonl 文件
date_dir = _get_base_dir() / date
if not date_dir.exists() or not date_dir.is_dir():
return {"items": [], "total": 0, "page": page, "page_size": page_size}
all_records: list[dict[str, Any]] = []
for f in sorted(date_dir.glob("*.jsonl")):
all_records.extend(_read_jsonl_file(f))
# 过滤
filtered = [
r for r in all_records
if _match_filter(
r,
trace_type=trace_type,
method=method,
path_contains=path_contains,
status_code=status_code,
min_duration=min_duration,
has_error=has_error,
span_type=span_type,
start_time=parsed_start,
end_time=parsed_end,
)
]
# 按时间降序排列
filtered.sort(key=lambda r: r.get("timestamp", ""), reverse=True)
total = len(filtered)
start_idx = (page - 1) * page_size
items = filtered[start_idx : start_idx + page_size]
# 列表项不返回完整 spans只返回摘要字段
summary_items = []
for item in items:
summary_items.append({
"request_id": item.get("request_id"),
"trace_type": item.get("trace_type"),
"timestamp": item.get("timestamp"),
"method": item.get("method"),
"path": item.get("path"),
"status_code": item.get("status_code"),
"total_duration_ms": item.get("total_duration_ms"),
"user_id": item.get("user_id"),
"site_id": item.get("site_id"),
"db_query_count": item.get("db_query_count"),
"db_total_ms": item.get("db_total_ms"),
"error": item.get("error"),
"span_count": len(item.get("spans", [])),
})
return {"items": summary_items, "total": total, "page": page, "page_size": page_size}
# ── 3. GET /request/{request_id} — 完整 trace 记录 ──────
@router.get("/request/{request_id}")
async def get_request_detail(
request_id: str,
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""返回指定 request_id 的完整 trace 记录(含所有 spans"""
base = _get_base_dir()
if not base.exists():
raise HTTPException(status_code=404, detail="未找到该 request_id 的 trace 记录")
# 搜索所有日期目录下的 .jsonl 文件
for date_dir in sorted(base.iterdir(), reverse=True):
if not date_dir.is_dir() or not _is_date_dir(date_dir.name):
continue
for f in date_dir.glob("*.jsonl"):
for record in _read_jsonl_file(f):
if record.get("request_id") == request_id:
return record
raise HTTPException(status_code=404, detail="未找到该 request_id 的 trace 记录")
# ── 4. POST /cleanup — 按日期范围手动清理 ────────────────
@router.post("/cleanup")
async def cleanup_logs(
body: CleanupRequest,
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""按日期范围手动清理日志目录。"""
# 校验日期格式
try:
datetime.strptime(body.start_date, _DATE_FORMAT)
datetime.strptime(body.end_date, _DATE_FORMAT)
except ValueError:
raise HTTPException(status_code=400, detail="日期格式无效,应为 YYYY-MM-DD")
if body.start_date > body.end_date:
raise HTTPException(status_code=400, detail="start_date 不能晚于 end_date")
result = cleanup_date_range(body.start_date, body.end_date)
return {
"deleted_dates": result["deleted_dirs"],
"deleted_files": result["deleted_count"],
}
# ── 5. GET /settings — 当前设置 ──────────────────────────
@router.get("/settings")
async def get_settings(
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""返回当前 trace 运行时设置。"""
return get_trace_config().get_settings()
# ── 6. PUT /settings — 更新运行时设置 ────────────────────
@router.put("/settings")
async def update_settings(
body: SettingsUpdate,
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""更新 trace 运行时设置(不需重启,重启后回退到 .env 值)。"""
cfg = get_trace_config()
cfg.update_settings(
enabled=body.enabled,
retention_days=body.retention_days,
log_sql=body.log_sql,
log_params=body.log_params,
)
return cfg.get_settings()
# ── 7. GET /coverage — 最近一次覆盖率扫描结果 ────────────
@router.get("/coverage")
async def get_coverage(
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""返回最近一次覆盖率扫描结果(缓存)。"""
from app.trace.coverage import get_cached_coverage, run_coverage_scan
result = get_cached_coverage()
if result is None:
# 首次访问时自动扫描一次
result = run_coverage_scan()
return result
# ── 8. POST /coverage/scan — 手动触发覆盖率扫描 ──────────
@router.post("/coverage/scan")
async def trigger_coverage_scan(
user: CurrentUser = Depends(_require_admin()),
) -> dict[str, Any]:
"""手动触发覆盖率扫描,返回最新结果。"""
from app.trace.coverage import run_coverage_scan
return run_coverage_scan()

View File

@@ -0,0 +1,673 @@
# -*- coding: utf-8 -*-
"""
管理端路由 — 注册体系(连接器/租户/店铺/简写ID/店铺同步)。
端点清单:
- GET /api/admin/tenants — 所有活跃租户列表
- GET /api/admin/tenants/{tenant_id}/sites — 指定租户下所有活跃店铺
- PUT /api/admin/sites/{site_id}/site-code — 设置/修改简写ID
- GET /api/admin/sites/{site_id}/site-code-history — 简写ID 变更历史
- POST /api/admin/sites/sync — 手动触发店铺同步
- POST /api/admin/sites/sync/internal — 内部 APIETL DWD 完成后触发同步(无认证,隐藏)
除 /sites/sync/internal 外,所有端点要求 JWT + site_admin 或 tenant_admin 角色。
需求: A2.1, A2.2, A2.4, A2.5, A3.1, A3.2, A3.3, A3.4, A5.1, A5.2, A5.3, A5.4
"""
from __future__ import annotations
import logging
import re
import psycopg2
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2.extensions import connection as PgConnection
from app.auth.dependencies import CurrentUser, get_current_user
from app.config import (
ETL_DB_HOST,
ETL_DB_NAME,
ETL_DB_PASSWORD,
ETL_DB_PORT,
ETL_DB_USER,
)
from app.database import get_connection
from app.schemas.admin_registry import (
CreateSiteRequest,
SiteCodeHistoryItem,
SiteCodeResult,
SiteItem,
SiteSyncResult,
TenantItem,
UpdateSiteCodeRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin", tags=["admin-registry"])
# 简写ID 格式:前 3 位字母/数字 + 后 3 位数字(共 6 位)
_SITE_CODE_PATTERN = re.compile(r"^[A-Z0-9]{3}\d{3}$")
# ── ETL 库直连(无 RLS管理端同步专用 ─────────────────
def _get_etl_admin_connection() -> PgConnection:
"""获取 ETL 库只读连接(无 RLS 隔离),用于管理端跨站点同步。
与 database.get_etl_readonly_connection 不同:
- 不设置 app.current_site_id需要读取所有站点数据
- 仍设置 read_only 防止误写
"""
conn = psycopg2.connect(
host=ETL_DB_HOST,
port=ETL_DB_PORT,
user=ETL_DB_USER,
password=ETL_DB_PASSWORD,
dbname=ETL_DB_NAME,
)
try:
conn.autocommit = False
with conn.cursor() as cur:
cur.execute("SET default_transaction_read_only = on")
conn.commit()
except Exception:
conn.close()
raise
return conn
# ── 店铺同步核心逻辑 ─────────────────────────────────────
def sync_sites_from_etl() -> SiteSyncResult:
"""从 ETL 库 dwd.dim_site 同步店铺到 biz.sites。
逻辑:
1. 读取 dwd.dim_sitescd2_is_current=1获取当前有效店铺
2. 对比 biz.sites
- 新 site_id → INSERTsite_code 留空tenant_id 通过 dim_site.tenant_id 映射 biz.tenants
- 已有 site_id 且 shop_name/site_label 变更 → UPDATE
3. 不删除已有记录
需求: A5.1, A5.2
"""
# 1. 从 ETL 库读取当前有效店铺
etl_conn = _get_etl_admin_connection()
try:
with etl_conn.cursor() as cur:
cur.execute(
"""
SELECT site_id, tenant_id, shop_name, site_label
FROM dwd.dim_site
WHERE scd2_is_current = 1
"""
)
etl_sites = cur.fetchall()
finally:
etl_conn.close()
if not etl_sites:
return SiteSyncResult(inserted=0, updated=0)
# 2. 在 app 库中执行对比和写入
app_conn = get_connection()
inserted = 0
updated = 0
try:
with app_conn.cursor() as cur:
# 构建 tenant_id → biz.tenants.id 映射
cur.execute("SELECT tenant_id, id FROM biz.tenants WHERE is_active = true")
tenant_map: dict[int, int] = {row[0]: row[1] for row in cur.fetchall()}
# 获取 biz.sites 现有数据site_id → (biz_id, site_name, site_label)
cur.execute(
"SELECT site_id, id, site_name, site_label FROM biz.sites"
)
existing: dict[int, tuple[int, str | None, str | None]] = {
row[0]: (row[1], row[2], row[3]) for row in cur.fetchall()
}
for etl_site_id, etl_tenant_id, etl_shop_name, etl_site_label in etl_sites:
biz_tenant_id = tenant_map.get(etl_tenant_id)
if biz_tenant_id is None:
# 租户未注册,跳过
logger.warning(
"同步跳过: site_id=%s 的 tenant_id=%s 在 biz.tenants 中不存在",
etl_site_id, etl_tenant_id,
)
continue
if etl_site_id not in existing:
# 新增店铺site_code 留空
cur.execute(
"""
INSERT INTO biz.sites (tenant_id, site_id, site_name, site_label)
VALUES (%s, %s, %s, %s)
""",
(biz_tenant_id, etl_site_id, etl_shop_name, etl_site_label),
)
inserted += 1
else:
# 已有店铺:检查名称/标签是否变更
_, cur_name, cur_label = existing[etl_site_id]
if cur_name != etl_shop_name or cur_label != etl_site_label:
cur.execute(
"""
UPDATE biz.sites
SET site_name = %s, site_label = %s, updated_at = NOW()
WHERE site_id = %s
""",
(etl_shop_name, etl_site_label, etl_site_id),
)
updated += 1
app_conn.commit()
except Exception:
app_conn.rollback()
raise
finally:
app_conn.close()
logger.info("店铺同步完成: 新增 %d, 更新 %d", inserted, updated)
return SiteSyncResult(inserted=inserted, updated=updated)
# ── 管理端权限依赖 ──────────────────────────────────────
def _require_admin():
"""管理端依赖:要求 JWT 中角色包含 site_admin 或 tenant_admin。"""
async def _dependency(
user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
admin_roles = {"site_admin", "tenant_admin"}
if not admin_roles.intersection(user.roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限site_admin 或 tenant_admin",
)
return user
return _dependency
# ── GET /api/admin/tenants ────────────────────────────────
@router.get("/tenants")
async def list_tenants(
user: CurrentUser = Depends(_require_admin()),
) -> list[TenantItem]:
"""
所有活跃租户列表(含连接器名称)。
JOIN biz.connectors 获取 connector_name。
需求 A2.1
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT t.id, t.tenant_id, t.tenant_name,
c.display_name AS connector_name, t.is_active
FROM biz.tenants t
JOIN biz.connectors c ON c.id = t.connector_id
WHERE t.is_active = true
ORDER BY t.id
"""
)
rows = cur.fetchall()
finally:
conn.close()
return [
TenantItem(
id=r[0],
tenant_id=r[1],
tenant_name=r[2],
connector_name=r[3],
is_active=r[4],
)
for r in rows
]
# ── GET /api/admin/tenants/{tenant_id}/sites ──────────────
@router.get("/tenants/{tenant_id}/sites")
async def list_tenant_sites(
tenant_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> list[SiteItem]:
"""
指定租户下所有活跃店铺。
tenant_id 参数支持两种格式:
- 上游系统租户 IDBIGINT如 2790683160709957
- 内部主键SERIAL如 1, 2, 3...
自动判断:> 10000 视为上游 ID否则视为内部 PK。
需求 A2.2
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-22 | Prompt: 管辖门店下拉为空 | 兼容上游 tenant_id 和内部 PK
# 自动判断:上游 ID 是 BIGINT远大于内部 SERIAL阈值 10000 足够区分
if tenant_id > 10000:
cur.execute(
"SELECT id FROM biz.tenants WHERE tenant_id = %s AND is_active = true",
(tenant_id,),
)
else:
cur.execute(
"SELECT id FROM biz.tenants WHERE id = %s AND is_active = true",
(tenant_id,),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
internal_tenant_id = row[0]
cur.execute(
"""
SELECT id, site_id, site_name, site_code, site_label, is_active
FROM biz.sites
WHERE tenant_id = %s AND is_active = true
ORDER BY site_id
""",
(internal_tenant_id,),
)
rows = cur.fetchall()
finally:
conn.close()
return [
SiteItem(
id=r[0],
site_id=r[1],
site_name=r[2],
site_code=r[3],
site_label=r[4],
is_active=r[5],
)
for r in rows
]
# ── PUT /api/admin/sites/{site_id}/site-code ──────────────
@router.put("/sites/{site_id}/site-code")
async def update_site_code(
site_id: int,
body: UpdateSiteCodeRequest,
user: CurrentUser = Depends(_require_admin()),
) -> SiteCodeResult:
"""
设置/修改店铺简写ID事务内执行历史记录管理。
校验规则:
- 格式6 位,前 3 位字母/数字 + 后 3 位数字,统一大写
- 全局唯一biz.sites.site_code + biz.site_code_history.site_code
事务步骤:
a. 旧 code 在 site_code_history 中标记 is_current=false, retired_at=NOW()
b. 新 code 插入 site_code_historyis_current=true
c. 更新 biz.sites.site_code
d. 检查旧 code 是否有未审核申请引用,无引用则从 history 中删除旧记录
需求 A2.4, A3.1, A3.2, A3.3, A3.4
"""
new_code = body.new_code.strip().upper()
# ── 格式校验 ──
if not _SITE_CODE_PATTERN.match(new_code):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="简写ID 格式错误,需 6 位3+3 模式:前 3 位字母/数字 + 后 3 位数字)",
)
conn = get_connection()
try:
with conn.cursor() as cur:
# ── 校验店铺存在 ──
cur.execute(
"SELECT site_id, site_code FROM biz.sites WHERE id = %s",
(site_id,),
)
site_row = cur.fetchone()
if site_row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="店铺不存在",
)
db_site_id = site_row[0] # biz.sites.site_id上游 ID
old_code = site_row[1]
# ── 全局唯一性校验biz.sites + biz.site_code_history ──
cur.execute(
"""
SELECT 1 FROM biz.sites
WHERE site_code = %s AND id != %s
LIMIT 1
""",
(new_code, site_id),
)
if cur.fetchone():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"简写ID '{new_code}' 已被使用",
)
cur.execute(
"""
SELECT 1 FROM biz.site_code_history
WHERE site_code = %s AND site_id != %s
LIMIT 1
""",
(new_code, db_site_id),
)
if cur.fetchone():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"简写ID '{new_code}' 已被使用",
)
# ── 事务内执行变更 ──
# a. 旧 code 标记 retired
history_cleaned = False
if old_code:
cur.execute(
"""
UPDATE biz.site_code_history
SET is_current = false, retired_at = NOW()
WHERE site_id = %s AND site_code = %s AND is_current = true
""",
(db_site_id, old_code),
)
# b. 新 code 插入 history
cur.execute(
"""
INSERT INTO biz.site_code_history (site_id, site_code, is_current)
VALUES (%s, %s, true)
""",
(db_site_id, new_code),
)
# c. 更新 biz.sites.site_code
cur.execute(
"""
UPDATE biz.sites SET site_code = %s, updated_at = NOW()
WHERE id = %s
""",
(new_code, site_id),
)
# d. 检查旧 code 是否有未审核申请引用,无引用则清理历史
if old_code:
cur.execute(
"""
SELECT 1 FROM auth.user_applications
WHERE site_code = %s AND status = 'pending'
LIMIT 1
""",
(old_code,),
)
has_pending = cur.fetchone() is not None
if not has_pending:
cur.execute(
"""
DELETE FROM biz.site_code_history
WHERE site_id = %s AND site_code = %s AND is_current = false
""",
(db_site_id, old_code),
)
history_cleaned = True
conn.commit()
except HTTPException:
conn.rollback()
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return SiteCodeResult(
site_id=db_site_id,
old_code=old_code,
new_code=new_code,
history_cleaned=history_cleaned,
)
# ── GET /api/admin/sites/{site_id}/site-code-history ──────
@router.get("/sites/{site_id}/site-code-history")
async def get_site_code_history(
site_id: int,
user: CurrentUser = Depends(_require_admin()),
) -> list[SiteCodeHistoryItem]:
"""
查看简写ID 变更历史。
需求 A2.5
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 校验店铺存在,获取上游 site_id
cur.execute(
"SELECT site_id FROM biz.sites WHERE id = %s",
(site_id,),
)
site_row = cur.fetchone()
if site_row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="店铺不存在",
)
db_site_id = site_row[0]
cur.execute(
"""
SELECT id, site_code, is_current, created_at, retired_at
FROM biz.site_code_history
WHERE site_id = %s
ORDER BY created_at DESC
""",
(db_site_id,),
)
rows = cur.fetchall()
finally:
conn.close()
return [
SiteCodeHistoryItem(
id=r[0],
site_code=r[1],
is_current=r[2],
created_at=r[3],
retired_at=r[4],
)
for r in rows
]
# ── POST /api/admin/sites/sync ────────────────────────────
@router.post("/sites/sync")
async def sync_sites(
user: CurrentUser = Depends(_require_admin()),
) -> SiteSyncResult:
"""
手动触发店铺同步:从 ETL 库 dwd.dim_site 同步到 biz.sites。
返回同步结果(新增数/更新数)。
需求 A5.3
"""
return sync_sites_from_etl()
# ── POST /api/admin/sites/sync/internal ───────────────────
@router.post("/sites/sync/internal", include_in_schema=False)
async def sync_sites_internal() -> SiteSyncResult:
"""内部 APIETL DWD 完成后触发店铺同步。
不需要 JWT 认证(内部调用),通过 include_in_schema=False 隐藏。
后续可添加 API key 或 IP 白名单认证。
需求 A5.4
"""
return sync_sites_from_etl()
# ── POST /api/admin/sites测试功能手动创建店铺 ────────
@router.post("/sites", status_code=status.HTTP_201_CREATED)
async def create_site(
body: CreateSiteRequest,
user: CurrentUser = Depends(_require_admin()),
):
"""
手动创建店铺(测试功能)。
向 biz.sites 插入一条记录,可选指定 site_code。
site_id 和 site_code 需全局唯一,冲突返回 409。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 校验 tenant_id 存在
cur.execute(
"SELECT id FROM biz.tenants WHERE tenant_id = %s AND is_active = true",
(body.tenant_id,),
)
tenant_row = cur.fetchone()
if tenant_row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
internal_tenant_id = tenant_row[0]
# site_code 格式校验(如果提供)
site_code = None
if body.site_code:
site_code = body.site_code.strip().upper()
if not _SITE_CODE_PATTERN.match(site_code):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="简写ID 格式错误,需 6 位3+3 格式)",
)
cur.execute(
"""
INSERT INTO biz.sites (tenant_id, site_id, site_name, site_code)
VALUES (%s, %s, %s, %s)
RETURNING id, site_id, site_name, site_code
""",
(internal_tenant_id, body.site_id, body.site_name, site_code),
)
row = cur.fetchone()
# 如果有 site_code同步插入 history
if site_code:
cur.execute(
"""
INSERT INTO biz.site_code_history (site_id, site_code, is_current)
VALUES (%s, %s, true)
""",
(body.site_id, site_code),
)
conn.commit()
except HTTPException:
conn.rollback()
raise
except psycopg2.errors.UniqueViolation as e:
conn.rollback()
detail = str(e)
if "site_id" in detail:
raise HTTPException(status_code=409, detail="site_id 已存在")
if "site_code" in detail:
raise HTTPException(status_code=409, detail="简写ID 已被占用")
raise HTTPException(status_code=409, detail="唯一约束冲突")
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": row[0], "siteId": row[1], "siteName": row[2], "siteCode": row[3]}
# ── DELETE /api/admin/sites/{site_id}(测试功能:删除店铺) ─
@router.delete("/sites/{site_id}")
async def delete_site(
site_id: int,
user: CurrentUser = Depends(_require_admin()),
):
"""
删除店铺(测试功能,硬删除)。
同时清理 site_code_history 中的关联记录。
site_id 参数为 biz.sites.id内部主键
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 获取上游 site_id 用于清理 history
cur.execute(
"SELECT site_id FROM biz.sites WHERE id = %s",
(site_id,),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="店铺不存在",
)
upstream_site_id = row[0]
# 清理 site_code_history
cur.execute(
"DELETE FROM biz.site_code_history WHERE site_id = %s",
(upstream_site_id,),
)
# 删除店铺
cur.execute("DELETE FROM biz.sites WHERE id = %s", (site_id,))
conn.commit()
except HTTPException:
conn.rollback()
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": site_id}

View File

@@ -0,0 +1,640 @@
# -*- coding: utf-8 -*-
"""P18 任务引擎运营看板 API
提供转移日志查看、待审核任务管理、参数配置等端点。
所有端点需要 JWT 认证;写操作仅限 super_admin。
"""
from __future__ import annotations
import logging
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from psycopg2.extras import RealDictCursor
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.admin_task_engine import (
CandidateAssistant,
CandidateListResponse,
CloseRequest,
CloseResponse,
ConfigParam,
ConfigParamCreate,
ConfigParamList,
ConfigParamResponse,
ConfigParamUpdate,
PendingReviewItem,
PendingReviewPage,
ReassignRequest,
ReassignResponse,
TransferLogItem,
TransferLogPage,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin/task-engine", tags=["任务引擎管理"])
# ---- 任务类型中文映射 ----
TASK_TYPE_LABELS = {
"high_priority_recall": "高优先召回",
"priority_recall": "优先召回",
"follow_up_visit": "客户回访",
"relationship_building": "关系构建",
}
# ---- 权限辅助函数 ----
def _require_super_admin(user: CurrentUser) -> None:
"""写操作权限校验:仅超级管理员可执行。"""
if "super_admin" not in user.roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅超级管理员可执行此操作",
)
def _filter_site_id(user: CurrentUser, query_site_id: int | None) -> int | None:
"""读操作门店过滤:门店管理员强制按自身 site_id 过滤。"""
if "super_admin" in user.roles:
return query_site_id
return user.site_id
# =====================================================================
# 1. 转移日志
# =====================================================================
@router.get("/transfer-log", response_model=TransferLogPage)
async def list_transfer_logs(
site_id: int | None = Query(None, description="门店 ID"),
from_date: date | None = Query(None, description="起始日期"),
to_date: date | None = Query(None, description="截止日期"),
assistant_id: int | None = Query(None, description="助教 ID"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(get_current_user),
) -> TransferLogPage:
"""转移日志分页列表。"""
effective_site_id = _filter_site_id(user, site_id)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
conditions = []
params: list = []
if effective_site_id is not None:
conditions.append("t.site_id = %s")
params.append(effective_site_id)
if from_date is not None:
conditions.append("t.created_at >= %s")
params.append(from_date)
if to_date is not None:
conditions.append("t.created_at < %s::date + interval '1 day'")
params.append(to_date)
if assistant_id is not None:
conditions.append("(t.from_assistant_id = %s OR t.to_assistant_id = %s)")
params.extend([assistant_id, assistant_id])
where_clause = " AND ".join(conditions) if conditions else "1=1"
# 总数
cur.execute(
f"SELECT count(*) AS cnt FROM biz.coach_task_transfer_log t WHERE {where_clause}",
params,
)
total = cur.fetchone()["cnt"]
# 分页数据
offset = (page - 1) * page_size
cur.execute(
f"""
SELECT t.*, s.site_name
FROM biz.coach_task_transfer_log t
LEFT JOIN biz.sites s ON s.site_id = t.site_id
WHERE {where_clause}
ORDER BY t.created_at DESC
LIMIT %s OFFSET %s
""",
params + [page_size, offset],
)
rows = cur.fetchall()
items = [TransferLogItem(**row) for row in rows]
return TransferLogPage(items=items, total=total)
except HTTPException:
raise
except Exception as exc:
logger.exception("查询转移日志失败")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询转移日志失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.get("/transfer-log/{member_id}/history", response_model=list[TransferLogItem])
async def get_member_transfer_history(
member_id: int,
user: CurrentUser = Depends(get_current_user),
) -> list[TransferLogItem]:
"""某客户全部转移历史。"""
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
# 门店管理员只能看自己门店的记录
effective_site_id = _filter_site_id(user, None)
if effective_site_id is not None:
cur.execute(
"""
SELECT t.*, s.site_name
FROM biz.coach_task_transfer_log t
LEFT JOIN biz.sites s ON s.site_id = t.site_id
WHERE t.member_id = %s AND t.site_id = %s
ORDER BY t.created_at DESC
""",
[member_id, effective_site_id],
)
else:
cur.execute(
"""
SELECT t.*, s.site_name
FROM biz.coach_task_transfer_log t
LEFT JOIN biz.sites s ON s.site_id = t.site_id
WHERE t.member_id = %s
ORDER BY t.created_at DESC
""",
[member_id],
)
rows = cur.fetchall()
return [TransferLogItem(**row) for row in rows]
except HTTPException:
raise
except Exception as exc:
logger.exception("查询客户转移历史失败: member_id=%s", member_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询客户转移历史失败: {str(exc)[:200]}",
)
finally:
conn.close()
# =====================================================================
# 2. 待审核任务
# =====================================================================
@router.get("/pending-review", response_model=PendingReviewPage)
async def list_pending_reviews(
site_id: int | None = Query(None, description="门店 ID"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(get_current_user),
) -> PendingReviewPage:
"""待审核任务列表。"""
effective_site_id = _filter_site_id(user, site_id)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
conditions = ["ct.status = 'pending_review'"]
params: list = []
if effective_site_id is not None:
conditions.append("ct.site_id = %s")
params.append(effective_site_id)
where_clause = " AND ".join(conditions)
# 总数
cur.execute(
f"SELECT count(*) AS cnt FROM biz.coach_tasks ct WHERE {where_clause}",
params,
)
total = cur.fetchone()["cnt"]
# 分页数据
offset = (page - 1) * page_size
cur.execute(
f"""
SELECT ct.*, s.site_name
FROM biz.coach_tasks ct
LEFT JOIN biz.sites s ON s.site_id = ct.site_id
WHERE {where_clause}
ORDER BY ct.created_at DESC
LIMIT %s OFFSET %s
""",
params + [page_size, offset],
)
rows = cur.fetchall()
items = []
for row in rows:
row["task_type_label"] = TASK_TYPE_LABELS.get(row.get("task_type", ""), "")
items.append(PendingReviewItem(**row))
return PendingReviewPage(items=items, total=total)
except HTTPException:
raise
except Exception as exc:
logger.exception("查询待审核任务失败")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询待审核任务失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.post("/pending-review/{task_id}/reassign", response_model=ReassignResponse)
async def reassign_task(
task_id: int,
body: ReassignRequest,
user: CurrentUser = Depends(get_current_user),
) -> ReassignResponse:
"""重新分配待审核任务(仅超级管理员)。
逻辑:原任务 status → 'transferred',新建 active 任务,写 transfer_log。
"""
_require_super_admin(user)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
# 查询原任务
cur.execute(
"SELECT * FROM biz.coach_tasks WHERE id = %s FOR UPDATE",
[task_id],
)
task = cur.fetchone()
if task is None:
raise HTTPException(status_code=404, detail="任务不存在")
if task["status"] != "pending_review":
raise HTTPException(status_code=400, detail="任务状态不是待审核,无法重新分配")
# 原任务标记为 transferred
cur.execute(
"UPDATE biz.coach_tasks SET status = 'transferred', updated_at = now() WHERE id = %s",
[task_id],
)
# 新建 active 任务
cur.execute(
"""
INSERT INTO biz.coach_tasks
(site_id, member_id, assistant_id, task_type, priority_score, status, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, 'active', now(), now())
RETURNING id
""",
[task["site_id"], task["member_id"], body.to_assistant_id,
task["task_type"], task.get("priority_score")],
)
new_task_id = cur.fetchone()["id"]
# 写转移日志
cur.execute(
"""
INSERT INTO biz.coach_task_transfer_log
(site_id, member_id, from_assistant_id, to_assistant_id,
transfer_reason, transfer_score, created_at)
VALUES (%s, %s, %s, %s, %s, %s, now())
""",
[task["site_id"], task["member_id"], task["assistant_id"],
body.to_assistant_id, "manual_reassign", task.get("priority_score")],
)
conn.commit()
return ReassignResponse(success=True, new_task_id=new_task_id)
except HTTPException:
conn.rollback()
raise
except Exception as exc:
conn.rollback()
logger.exception("重新分配任务失败: task_id=%s", task_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"重新分配任务失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.post("/pending-review/{task_id}/close", response_model=CloseResponse)
async def close_task(
task_id: int,
body: CloseRequest,
user: CurrentUser = Depends(get_current_user),
) -> CloseResponse:
"""关闭待审核任务(仅超级管理员)。
逻辑:任务 status → 'inactive',记录 abandon_reason。
"""
_require_super_admin(user)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"SELECT id, status FROM biz.coach_tasks WHERE id = %s FOR UPDATE",
[task_id],
)
task = cur.fetchone()
if task is None:
raise HTTPException(status_code=404, detail="任务不存在")
if task["status"] != "pending_review":
raise HTTPException(status_code=400, detail="任务状态不是待审核,无法关闭")
cur.execute(
"""
UPDATE biz.coach_tasks
SET status = 'inactive', abandon_reason = %s, updated_at = now()
WHERE id = %s
""",
[body.reason, task_id],
)
conn.commit()
return CloseResponse(success=True)
except HTTPException:
conn.rollback()
raise
except Exception as exc:
conn.rollback()
logger.exception("关闭任务失败: task_id=%s", task_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"关闭任务失败: {str(exc)[:200]}",
)
finally:
conn.close()
# =====================================================================
# 3. 参数管理
# =====================================================================
# 权重参数 key 列表(联合校验用)
_WEIGHT_KEYS = {"w_rs", "w_ms", "w_ml"}
@router.get("/config", response_model=ConfigParamList)
async def list_config_params(
site_id: int | None = Query(None, description="门店 ID不传则返回全部"),
user: CurrentUser = Depends(get_current_user),
) -> ConfigParamList:
"""参数列表。"""
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
conditions: list[str] = []
params: list = []
if site_id is not None:
# 返回指定门店覆盖 + 全局默认
conditions.append("(p.site_id = %s OR p.site_id IS NULL)")
params.append(site_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
cur.execute(
f"""
SELECT p.*, s.site_name
FROM biz.cfg_task_generator_params p
LEFT JOIN biz.sites s ON s.site_id = p.site_id
WHERE {where_clause}
ORDER BY p.site_id NULLS FIRST, p.param_key
""",
params,
)
rows = cur.fetchall()
return ConfigParamList(params=[ConfigParam(**row) for row in rows])
except HTTPException:
raise
except Exception as exc:
logger.exception("查询参数配置失败")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询参数配置失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.put("/config/{param_id}", response_model=ConfigParamResponse)
async def update_config_param(
param_id: int,
body: ConfigParamUpdate,
user: CurrentUser = Depends(get_current_user),
) -> ConfigParamResponse:
"""更新参数值(仅超级管理员)。
权重参数w_rs / w_ms / w_ml更新后会校验三者之和是否为 1.0。
"""
_require_super_admin(user)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
# 查询当前参数
cur.execute(
"SELECT * FROM biz.cfg_task_generator_params WHERE id = %s FOR UPDATE",
[param_id],
)
param = cur.fetchone()
if param is None:
raise HTTPException(status_code=404, detail="参数不存在")
# 更新
cur.execute(
"""
UPDATE biz.cfg_task_generator_params
SET param_value = %s, updated_at = now()
WHERE id = %s
""",
[body.param_value, param_id],
)
# 权重参数联合校验w_rs + w_ms + w_ml = 1.0
if param["param_key"] in _WEIGHT_KEYS:
cur.execute(
"""
SELECT param_key, param_value
FROM biz.cfg_task_generator_params
WHERE site_id IS NOT DISTINCT FROM %s
AND param_key = ANY(%s)
""",
[param["site_id"], list(_WEIGHT_KEYS)],
)
weight_rows = cur.fetchall()
weight_sum = sum(r["param_value"] for r in weight_rows)
if abs(weight_sum - 1.0) > 0.001:
conn.rollback()
raise HTTPException(
status_code=400,
detail=f"权重参数之和必须为 1.0,当前为 {weight_sum:.4f}",
)
conn.commit()
return ConfigParamResponse(success=True, id=param_id)
except HTTPException:
conn.rollback()
raise
except Exception as exc:
conn.rollback()
logger.exception("更新参数失败: param_id=%s", param_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新参数失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.post("/config", response_model=ConfigParamResponse)
async def create_config_param(
body: ConfigParamCreate,
user: CurrentUser = Depends(get_current_user),
) -> ConfigParamResponse:
"""新增门店覆盖参数(仅超级管理员)。"""
_require_super_admin(user)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
# 检查是否已存在同 site_id + param_key 的记录
cur.execute(
"""
SELECT id FROM biz.cfg_task_generator_params
WHERE site_id = %s AND param_key = %s
""",
[body.site_id, body.param_key],
)
if cur.fetchone() is not None:
raise HTTPException(
status_code=400,
detail=f"门店 {body.site_id} 已存在参数 {body.param_key} 的覆盖配置",
)
cur.execute(
"""
INSERT INTO biz.cfg_task_generator_params
(site_id, param_key, param_value, updated_at)
VALUES (%s, %s, %s, now())
RETURNING id
""",
[body.site_id, body.param_key, body.param_value],
)
new_id = cur.fetchone()["id"]
conn.commit()
return ConfigParamResponse(success=True, id=new_id)
except HTTPException:
conn.rollback()
raise
except Exception as exc:
conn.rollback()
logger.exception("新增参数失败")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"新增参数失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.delete("/clear-all-tasks")
async def clear_all_tasks(
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""【测试用】清空所有 coach_tasks 及关联数据(仅超级管理员)。
用于开发/测试阶段重置任务数据,让 task_generator 重新生成。
按外键依赖顺序删除transfer_log → notes → history → tasks。
"""
_require_super_admin(user)
conn = get_connection()
try:
with conn.cursor() as cur:
# 按外键依赖顺序:先删引用表,再删主表
cur.execute("DELETE FROM biz.coach_task_transfer_log")
transfer_count = cur.rowcount
cur.execute("DELETE FROM biz.notes WHERE task_id IS NOT NULL")
notes_count = cur.rowcount
cur.execute("DELETE FROM biz.coach_task_history")
history_count = cur.rowcount
# coach_tasks 有自引用 FK先清 parent_task_id 和 transferred_from
cur.execute("UPDATE biz.coach_tasks SET parent_task_id = NULL, transferred_from = NULL")
cur.execute("DELETE FROM biz.coach_tasks")
task_count = cur.rowcount
conn.commit()
return {
"success": True,
"message": f"已清空 {task_count} 条任务 + {history_count} 条历史 + {transfer_count} 条转移日志 + {notes_count} 条备注",
"deleted_tasks": task_count,
"deleted_history": history_count,
"deleted_transfers": transfer_count,
"deleted_notes": notes_count,
}
except Exception as exc:
conn.rollback()
logger.exception("清空任务失败")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"清空任务失败: {str(exc)[:200]}",
)
finally:
conn.close()
@router.delete("/config/{param_id}", response_model=ConfigParamResponse)
async def delete_config_param(
param_id: int,
user: CurrentUser = Depends(get_current_user),
) -> ConfigParamResponse:
"""删除门店覆盖参数(仅超级管理员)。
不允许删除 site_id IS NULL 的全局默认参数。
"""
_require_super_admin(user)
conn = get_connection()
try:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(
"SELECT id, site_id FROM biz.cfg_task_generator_params WHERE id = %s",
[param_id],
)
param = cur.fetchone()
if param is None:
raise HTTPException(status_code=404, detail="参数不存在")
if param["site_id"] is None:
raise HTTPException(status_code=400, detail="不允许删除全局默认参数")
cur.execute(
"DELETE FROM biz.cfg_task_generator_params WHERE id = %s",
[param_id],
)
conn.commit()
return ConfigParamResponse(success=True, id=param_id)
except HTTPException:
conn.rollback()
raise
except Exception as exc:
conn.rollback()
logger.exception("删除参数失败: param_id=%s", param_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"删除参数失败: {str(exc)[:200]}",
)
finally:
conn.close()

View File

@@ -0,0 +1,405 @@
# -*- coding: utf-8 -*-
"""
管理端路由 — 租户管理员 CRUD。
端点清单:
- GET /api/admin/tenant-admins — 管理员列表(分页 + 关键词搜索)
- POST /api/admin/tenant-admins — 创建管理员
- PATCH /api/admin/tenant-admins/{id} — 编辑管理员
- DELETE /api/admin/tenant-admins/{id} — 软删除管理员
- POST /api/admin/tenant-admins/{id}/reset-password — 重置密码
所有端点要求 JWT + site_admin 或 tenant_admin 角色。
需求: 14.1-14.7, A2.3, A2.6, A2.7, A2.8, A4.1
"""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from psycopg2 import errors as pg_errors
from app.auth.dependencies import CurrentUser, get_current_user
from app.auth.jwt import hash_password
from app.database import get_connection
from app.schemas.admin_tenant_admins import (
ResetPasswordRequest,
TenantAdminCreateRequest,
TenantAdminEditRequest,
TenantAdminListItem,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin", tags=["管理端租户管理员"])
# ── 管理端权限依赖:要求 site_admin 或 tenant_admin 角色 ──
def _require_admin():
"""
管理端依赖:要求 JWT 中角色包含 site_admin 或 tenant_admin。
直接从 JWT 校验角色,不查 auth.users 表(管理员在 admin_users 表,
不在 auth.users 中require_permission 会报"用户不存在")。
"""
async def _dependency(
user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
admin_roles = {"site_admin", "tenant_admin"}
if not admin_roles.intersection(user.roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限site_admin 或 tenant_admin",
)
return user
return _dependency
# ── GET /api/admin/tenant-admins ──────────────────────────
@router.get("/tenant-admins")
async def list_tenant_admins(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页条数"),
keyword: Optional[str] = Query(None, description="关键词搜索(用户名/显示名称)"),
include_inactive: bool = Query(False, description="是否包含已禁用的管理员"),
user: CurrentUser = Depends(_require_admin()),
):
"""
查询租户管理员列表,支持分页和关键词搜索。
默认只返回 is_active=true 的记录include_inactive=true 时返回所有记录。
JOIN biz.tenants 获取 tenant_name。
需求 14.1, A2.7, A2.6
"""
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
# 构建查询
where_clauses: list[str] = []
params: list = []
# CHANGE 2026-03-22 | Prompt: 删除与禁用分离 | 始终过滤已删除记录
where_clauses.append("ta.deleted_at IS NULL")
# CHANGE 2026-03-23 | Prompt: 任务5.1 A2.7 | 默认过滤 is_active
if not include_inactive:
where_clauses.append("ta.is_active = true")
if keyword:
where_clauses.append(
"(ta.username ILIKE %s OR ta.display_name ILIKE %s)"
)
like_val = f"%{keyword}%"
params.extend([like_val, like_val])
where_sql = ("WHERE " + " AND ".join(where_clauses)) if where_clauses else ""
# 查询总数
cur.execute(
f"""
SELECT COUNT(*)
FROM auth.tenant_admins ta
{where_sql}
""",
params,
)
total = cur.fetchone()[0]
# CHANGE 2026-03-23 | Prompt: 任务5.1 A2.6 | JOIN biz.tenants 获取 tenant_name
# CHANGE 2026-03-23 | Prompt: 角色体系隔离 | 加入 admin_type 列
cur.execute(
f"""
SELECT ta.id, ta.username, ta.display_name, ta.tenant_id,
ta.managed_site_ids,
ta.is_active, ta.created_at, ta.last_login_at,
bt.tenant_name, ta.admin_type
FROM auth.tenant_admins ta
LEFT JOIN biz.tenants bt ON bt.tenant_id = ta.tenant_id
{where_sql}
ORDER BY ta.created_at DESC
LIMIT %s OFFSET %s
""",
params + [page_size, offset],
)
rows = cur.fetchall()
finally:
conn.close()
items = [
TenantAdminListItem(
id=r[0],
username=r[1],
display_name=r[2],
tenant_id=r[3],
managed_site_ids=list(r[4]) if r[4] else [],
is_active=r[5],
created_at=r[6].isoformat() if r[6] else None,
last_login_at=r[7].isoformat() if r[7] else None,
tenant_name=r[8],
admin_type=r[9],
)
for r in rows
]
# 返回分页格式(由 ResponseWrapperMiddleware 包装为 {code:0, data:...}
return {"items": items, "total": total, "page": page, "page_size": page_size}
# ── POST /api/admin/tenant-admins ─────────────────────────
@router.post("/tenant-admins", status_code=status.HTTP_201_CREATED)
async def create_tenant_admin(
body: TenantAdminCreateRequest,
user: CurrentUser = Depends(_require_admin()),
):
"""
创建租户管理员。
密码 bcrypt 哈希username UNIQUE 冲突返回 409记录 created_by。
创建时校验 tenant_id 在 biz.tenants 中存在且 is_active=true。
需求 14.2, 14.3, A2.6
"""
password_hash = hash_password(body.password)
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-23 | Prompt: 任务5.1 A2.6 | 校验 tenant_id 存在性
cur.execute(
"SELECT id FROM biz.tenants WHERE tenant_id = %s AND is_active = true",
(body.tenant_id,),
)
if cur.fetchone() is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# CHANGE 2026-03-23 | Prompt: 登录用户名大小写不敏感 | 存储时统一小写
cur.execute(
"""
INSERT INTO auth.tenant_admins
(username, password_hash, display_name, tenant_id, managed_site_ids, created_by)
VALUES (LOWER(%s), %s, %s, %s, %s, %s)
RETURNING id, created_at
""",
(
body.username,
password_hash,
body.display_name,
body.tenant_id,
body.managed_site_ids,
user.user_id,
),
)
row = cur.fetchone()
conn.commit()
except HTTPException:
raise
except pg_errors.UniqueViolation:
conn.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="用户名已存在",
)
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": row[0], "created_at": row[1].isoformat() if row[1] else None}
# ── PATCH /api/admin/tenant-admins/{id} ───────────────────
@router.patch("/tenant-admins/{admin_id}")
async def edit_tenant_admin(
admin_id: int,
body: TenantAdminEditRequest,
user: CurrentUser = Depends(_require_admin()),
):
"""
编辑租户管理员信息username / display_name / managed_site_ids / is_active
管理员 ID 不存在返回 404。
修改 username 时校验全局唯一性(排除自身),冲突返回 409。
需求 14.4, 14.6, A2.8
"""
# 构建动态 SET 子句
set_clauses: list[str] = []
params: list = []
# CHANGE 2026-03-23 | Prompt: 任务5.1 A2.8 | 支持修改 username
# CHANGE 2026-03-23 | Prompt: 登录用户名大小写不敏感 | 存储时统一小写
if body.username is not None:
set_clauses.append("username = LOWER(%s)")
params.append(body.username)
if body.display_name is not None:
set_clauses.append("display_name = %s")
params.append(body.display_name)
if body.managed_site_ids is not None:
set_clauses.append("managed_site_ids = %s")
params.append(body.managed_site_ids)
if body.is_active is not None:
set_clauses.append("is_active = %s")
params.append(body.is_active)
if not set_clauses:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="至少需要提供一个修改字段",
)
params.append(admin_id)
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-23 | Prompt: 任务5.1 A2.8 | username 唯一性校验(排除自身)
# CHANGE 2026-03-22 | Prompt: 删除与禁用分离 | 只在未删除记录中校验唯一性
# CHANGE 2026-03-23 | Prompt: 登录用户名大小写不敏感 | LOWER() 比较 + 存储小写
if body.username is not None:
cur.execute(
"SELECT id FROM auth.tenant_admins WHERE LOWER(username) = LOWER(%s) AND id != %s AND deleted_at IS NULL",
(body.username, admin_id),
)
if cur.fetchone() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="用户名已存在",
)
cur.execute(
f"""
UPDATE auth.tenant_admins
SET {', '.join(set_clauses)}
WHERE id = %s AND deleted_at IS NULL
RETURNING id
""",
params,
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户管理员不存在",
)
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": admin_id}
# ── DELETE /api/admin/tenant-admins/{id} ──────────────────
@router.delete("/tenant-admins/{admin_id}")
async def delete_tenant_admin(
admin_id: int,
user: CurrentUser = Depends(_require_admin()),
):
"""
软删除租户管理员(设置 deleted_at=NOW())。
无论 is_active 状态如何,均可删除。
管理员不存在或已删除返回 404重复删除幂等返回 404。
需求 A2.3
"""
# CHANGE 2026-03-22 | Prompt: 删除与禁用分离 | deleted_at 软删除,不再检查 is_active
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE auth.tenant_admins SET deleted_at = NOW() "
"WHERE id = %s AND deleted_at IS NULL "
"RETURNING id",
(admin_id,),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="管理员不存在",
)
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": admin_id}
# ── POST /api/admin/tenant-admins/{id}/reset-password ─────
@router.post("/tenant-admins/{admin_id}/reset-password")
async def reset_password(
admin_id: int,
body: ResetPasswordRequest,
user: CurrentUser = Depends(_require_admin()),
):
"""
重置租户管理员密码。
新密码 bcrypt 哈希后更新 password_hash。管理员 ID 不存在返回 404。
需求 14.5
"""
new_hash = hash_password(body.new_password)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE auth.tenant_admins
SET password_hash = %s
WHERE id = %s AND deleted_at IS NULL
RETURNING id
""",
(new_hash, admin_id),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户管理员不存在",
)
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": admin_id}

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
"""管理端 — 触发器统一视图 API
提供 1 个端点:
- GET /api/admin/triggers/unified — 聚合三张表的触发器数据
数据源:
- biz.trigger_jobs业务触发器→ source="biz"
- biz.ai_trigger_jobsAI 事件链,最近 100 条)→ source="ai"
- public.scheduled_tasksETL 调度)→ source="etl"
某数据源查询失败时记录日志,返回其他数据源数据。
需求: 4.1, 4.2, 4.3
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.admin_triggers import UnifiedTriggerItem
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/admin/triggers", tags=["系统管理"])
def _fetch_biz_triggers(conn) -> list[UnifiedTriggerItem]:
"""查询 biz.trigger_jobs映射 source='biz'"""
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, job_name, trigger_condition, status,
last_run_at, next_run_at, last_error
FROM biz.trigger_jobs
ORDER BY id
"""
)
rows = cur.fetchall()
return [
UnifiedTriggerItem(
id=row[0],
name=row[1],
source="biz",
trigger_condition=row[2] or "",
status=row[3] or "",
last_run_at=str(row[4]) if row[4] is not None else None,
next_run_at=str(row[5]) if row[5] is not None else None,
last_error=row[6],
)
for row in rows
]
def _fetch_ai_triggers(conn) -> list[UnifiedTriggerItem]:
"""查询 biz.ai_trigger_jobs最近 100 条),映射 source='ai'
字段映射DDL 实际列 → UnifiedTriggerItem
- id → id
- event_type → nameai_trigger_jobs 无 job_name 列)
- 'event' → trigger_conditionAI 触发器均为事件驱动)
- status → status
- started_at → last_run_atai_trigger_jobs 无 last_run_at 列)
- None → next_run_at事件驱动无预定下次执行时间
- error_message → last_errorai_trigger_jobs 列名为 error_message
"""
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, event_type, status,
started_at, error_message
FROM biz.ai_trigger_jobs
ORDER BY id DESC
LIMIT 100
"""
)
rows = cur.fetchall()
return [
UnifiedTriggerItem(
id=row[0],
name=row[1] or "",
source="ai",
trigger_condition="event",
status=row[2] or "",
last_run_at=str(row[3]) if row[3] is not None else None,
next_run_at=None,
last_error=row[4],
)
for row in rows
]
def _fetch_etl_triggers(conn) -> list[UnifiedTriggerItem]:
"""查询 public.scheduled_tasks映射 source='etl'
字段映射DDL 实际列 → UnifiedTriggerItem
- id → idUUID转为字符串后取 hashcode 作为 int 不合适,改用 row_number
- name → name
- schedule_config->>'schedule_type' → trigger_condition
- last_status / enabled → status组合判断
- last_run_at → last_run_at
- next_run_at → next_run_at
- None → last_errorscheduled_tasks 无 last_error 列)
注意scheduled_tasks.id 是 UUID 类型UnifiedTriggerItem.id 是 int。
使用 ROW_NUMBER() 生成临时整数 ID加 100000 偏移避免与其他数据源冲突。
"""
with conn.cursor() as cur:
cur.execute(
"""
SELECT ROW_NUMBER() OVER (ORDER BY created_at) + 100000 AS row_id,
name,
schedule_config->>'schedule_type' AS schedule_type,
CASE
WHEN enabled = FALSE THEN 'disabled'
WHEN last_status IS NOT NULL THEN last_status
ELSE 'idle'
END AS status,
last_run_at,
next_run_at
FROM scheduled_tasks
ORDER BY created_at
"""
)
rows = cur.fetchall()
return [
UnifiedTriggerItem(
id=int(row[0]),
name=row[1] or "",
source="etl",
trigger_condition=row[2] or "unknown",
status=row[3] or "idle",
last_run_at=str(row[4]) if row[4] is not None else None,
next_run_at=str(row[5]) if row[5] is not None else None,
last_error=None,
)
for row in rows
]
@router.get("/unified", response_model=list[UnifiedTriggerItem])
async def get_unified_triggers(
user: CurrentUser = Depends(get_current_user),
) -> list[UnifiedTriggerItem]:
"""聚合三张表的触发器数据。
依次查询 biz.trigger_jobs、biz.ai_trigger_jobs、scheduled_tasks
某数据源查询失败时记录日志并跳过,返回其他数据源的数据。
"""
results: list[UnifiedTriggerItem] = []
conn = get_connection()
try:
# 数据源 1biz.trigger_jobs
try:
results.extend(_fetch_biz_triggers(conn))
except Exception:
logger.warning("查询 biz.trigger_jobs 失败", exc_info=True)
# 数据源 2biz.ai_trigger_jobs
try:
results.extend(_fetch_ai_triggers(conn))
except Exception:
logger.warning("查询 biz.ai_trigger_jobs 失败", exc_info=True)
# 数据源 3public.scheduled_tasks
try:
results.extend(_fetch_etl_triggers(conn))
except Exception:
logger.warning("查询 scheduled_tasks 失败", exc_info=True)
return results
finally:
conn.close()

View File

@@ -37,7 +37,7 @@ async def login(body: LoginRequest):
try:
with conn.cursor() as cur:
cur.execute(
"SELECT id, password_hash, site_id, is_active "
"SELECT id, password_hash, site_id, is_active, roles "
"FROM admin_users WHERE username = %s",
(body.username,),
)
@@ -51,7 +51,7 @@ async def login(body: LoginRequest):
detail="用户名或密码错误",
)
user_id, password_hash, site_id, is_active = row
user_id, password_hash, site_id, is_active, roles = row
if not is_active:
raise HTTPException(
@@ -65,7 +65,7 @@ async def login(body: LoginRequest):
detail="用户名或密码错误",
)
tokens = create_token_pair(user_id, site_id)
tokens = create_token_pair(user_id, site_id, roles=roles or [])
return TokenResponse(**tokens)
@@ -88,8 +88,22 @@ async def refresh(body: RefreshRequest):
user_id = int(payload["sub"])
site_id = payload["site_id"]
# CHANGE 2026-03-24 | Prompt: 修复 refresh 丢失 roles | 刷新前查询数据库获取最新 roles
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT roles FROM admin_users WHERE id = %s",
(user_id,),
)
row = cur.fetchone()
finally:
conn.close()
roles = row[0] if row else []
# 生成新的 access_tokenrefresh_token 原样返回
new_access = create_access_token(user_id, site_id)
new_access = create_access_token(user_id, site_id, roles=roles or [])
return TokenResponse(
access_token=new_access,
refresh_token=body.refresh_token,

View File

@@ -18,7 +18,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from psycopg2 import OperationalError
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection, get_etl_readonly_connection
from app.database import get_connection, get_etl_global_readonly_connection
from app.schemas.etl_status import CursorInfo, RecentRun
logger = logging.getLogger(__name__)
@@ -40,7 +40,8 @@ async def list_cursors(
查询 ETL 数据库中的 meta.etl_cursor 表。
如果该表不存在,返回空列表而非报错。
"""
conn = get_etl_readonly_connection(user.site_id)
# CHANGE 2026-03-23 | 系统管理后台全局视角,不按门店过滤
conn = get_etl_global_readonly_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构etl_admin → meta
@@ -60,9 +61,10 @@ async def list_cursors(
cur.execute(
"""
SELECT task_code, last_fetch_time, record_count
FROM meta.etl_cursor
ORDER BY task_code
SELECT t.task_code, c.last_start, c.last_end
FROM meta.etl_cursor c
JOIN meta.etl_task t ON c.task_id = t.task_id
ORDER BY t.task_code
"""
)
rows = cur.fetchall()
@@ -70,8 +72,8 @@ async def list_cursors(
return [
CursorInfo(
task_code=row[0],
last_fetch_time=str(row[1]) if row[1] is not None else None,
record_count=row[2],
last_start=str(row[1]) if row[1] is not None else None,
last_end=str(row[2]) if row[2] is not None else None,
)
for row in rows
]
@@ -99,16 +101,16 @@ async def list_recent_runs(
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-23 | 系统管理后台全局视角,不按门店过滤
cur.execute(
"""
SELECT id, task_codes, status, started_at,
finished_at, duration_ms, exit_code
FROM task_execution_log
WHERE site_id = %s
ORDER BY started_at DESC
LIMIT %s
""",
(user.site_id, _RECENT_RUNS_LIMIT),
(_RECENT_RUNS_LIMIT,),
)
rows = cur.fetchall()

View File

@@ -35,6 +35,7 @@ from app.schemas.execution import (
from app.schemas.tasks import TaskConfigSchema
from app.services.task_executor import task_executor
from app.services.task_queue import task_queue
from app.services.output_cleanup import cleanup_output_dirs
logger = logging.getLogger(__name__)
@@ -188,6 +189,142 @@ async def cancel_execution(
return {"message": "已发送取消信号"}
import re as _re
def _parse_config_from_command(
task_codes: list[str],
command: str | None,
site_id: int,
) -> TaskConfigSchema:
"""从旧记录的 command 字符串解析出原始 CLI 参数,构建 TaskConfigSchema。
旧记录没有 config JSONB 列,但 command 包含完整的 CLI 参数。
"""
kwargs: dict = {
"tasks": task_codes or [],
"store_id": site_id,
}
if command:
# 解析 --flow
m = _re.search(r"--flow\s+(\S+)", command)
if m:
kwargs["flow"] = m.group(1)
# 解析 --processing-mode
m = _re.search(r"--processing-mode\s+(\S+)", command)
if m:
kwargs["processing_mode"] = m.group(1)
# 解析 --lookback-hours
m = _re.search(r"--lookback-hours\s+(\d+)", command)
if m:
kwargs["lookback_hours"] = int(m.group(1))
# 解析 --overlap-seconds
m = _re.search(r"--overlap-seconds\s+(\d+)", command)
if m:
kwargs["overlap_seconds"] = int(m.group(1))
# 解析 --window-start / --window-end
m = _re.search(r"--window-start\s+(\S+)", command)
if m:
kwargs["window_start"] = m.group(1)
kwargs["window_mode"] = "custom"
m = _re.search(r"--window-end\s+(\S+)", command)
if m:
kwargs["window_end"] = m.group(1)
# 解析 --dry-run
if "--dry-run" in command:
kwargs["dry_run"] = True
# 解析 --force-full
if "--force-full" in command:
kwargs["force_full"] = True
# 解析 --fetch-before-verify
if "--fetch-before-verify" in command:
kwargs["fetch_before_verify"] = True
return TaskConfigSchema(**kwargs)
# ── POST /api/execution/{id}/rerun — 重新执行 ───────────────
# CHANGE 2026-03-22 | 支持对任意历史任务重新执行
@router.post("/{execution_id}/rerun", response_model=ExecutionRunResponse)
async def rerun_execution(
execution_id: str,
user: CurrentUser = Depends(get_current_user),
) -> ExecutionRunResponse:
"""根据历史执行记录重新执行相同的任务。
优先从 config JSONB 列还原完整配置;若旧记录无 config 列,
回退到 task_codes + 默认配置。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT task_codes, site_id, config, command
FROM task_execution_log
WHERE id = %s AND site_id = %s
""",
(execution_id, user.site_id),
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="执行记录不存在",
)
task_codes = row[0] or []
config_json = row[2] # JSONB可能为 None旧记录
command_str = row[3] # command 字符串,用于旧记录回退解析
if not task_codes and not config_json:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="原执行记录无任务代码,无法重新执行",
)
# CHANGE 2026-03-22 | 优先从存储的完整 config 还原,保留原始 processing_mode/lookback 等参数
if config_json and isinstance(config_json, dict):
# 覆盖 store_id 为当前用户的(安全)
config_json["store_id"] = user.site_id
config = TaskConfigSchema(**config_json)
else:
# 旧记录无 config 列,尝试从 command 字符串解析原始参数
config = _parse_config_from_command(task_codes, command_str, user.site_id)
new_execution_id = str(uuid.uuid4())
asyncio.create_task(
task_executor.execute(
config=config,
execution_id=new_execution_id,
site_id=user.site_id,
)
)
logger.info(
"重新执行 [%s] → [%s], tasks=%s",
execution_id, new_execution_id, task_codes,
)
return ExecutionRunResponse(
execution_id=new_execution_id,
message=f"已基于 {execution_id[:8]}… 重新执行",
)
# ── GET /api/execution/history — 执行历史 ────────────────────
@router.get("/history", response_model=list[ExecutionHistoryItem])
@@ -281,3 +418,21 @@ async def get_execution_logs(
output_log=row[0],
error_log=row[1],
)
# ── POST /api/execution/cleanup-output — 清理输出目录 ────────
# CHANGE 2026-03-27 | 新增:执行前清理 EXPORT_ROOT 下旧运行记录,每类任务只保留最近 10 个
@router.post("/cleanup-output")
async def cleanup_output(
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""清理 EXPORT_ROOT 下每个任务文件夹的旧运行记录,只保留最近 10 个。"""
try:
result = cleanup_output_dirs(keep=10)
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(exc),
)
return result

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""
内部 AI 触发 API — ETL/内部服务调用入口。
端点:
- POST /api/internal/ai/trigger — 接收事件触发请求,异步执行 AI 调用链
认证方式Authorization: Internal-Token {token}
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, Field
from app.ai.config import AIConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/internal/ai", tags=["internal-ai"])
# ── 请求/响应模型 ────────────────────────────────────────────
class TriggerRequest(BaseModel):
"""内部触发请求体。"""
event_type: str = Field(..., description="事件类型: consumption / dws_completed / note_created / task_assigned")
connector_type: str = Field("feiqiu", description="连接器类型")
site_id: int = Field(..., description="门店 ID")
member_id: int | None = Field(None, description="会员 ID可选")
payload: dict | None = Field(None, description="附加数据")
is_forced: bool = Field(False, description="是否强制执行(跳过去重)")
class TriggerResponse(BaseModel):
"""触发响应。"""
trigger_job_id: int
status: str = "pending"
# ── 认证依赖 ─────────────────────────────────────────────────
def verify_internal_token(authorization: str = Header(...)) -> str:
"""校验 Internal-Token 认证。
Header 格式Authorization: Internal-Token {token}
token 不匹配或缺失时返回 HTTP 401。
"""
prefix = "Internal-Token "
if not authorization.startswith(prefix):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证格式,需要 Internal-Token",
)
token = authorization[len(prefix):]
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 不能为空",
)
# 从环境变量加载期望 token
try:
config = AIConfig.from_env()
except ValueError:
logger.error("AIConfig 加载失败,无法校验 internal token")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="AI 配置异常",
)
if token != config.internal_api_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 不匹配",
)
return token
# ── 端点 ─────────────────────────────────────────────────────
@router.post("/trigger", response_model=TriggerResponse)
async def trigger_ai_event(
body: TriggerRequest,
_token: str = Depends(verify_internal_token),
) -> TriggerResponse:
"""接收 ETL/内部事件,写 ai_trigger_jobs 后异步执行。
立即返回 trigger_job_id调用链在后台异步执行。
"""
from app.ai.dispatcher import AIDispatcher, TriggerEvent
# 构建触发事件
event = TriggerEvent(
event_type=body.event_type,
site_id=body.site_id,
member_id=body.member_id,
connector_type=body.connector_type,
payload=body.payload or {},
is_forced=body.is_forced,
)
# 获取 dispatcher 实例并触发
# 延迟导入避免循环依赖dispatcher 实例由应用启动时创建
dispatcher = _get_dispatcher()
job_id = await dispatcher.handle_trigger(event)
return TriggerResponse(trigger_job_id=job_id, status="pending")
# ── 辅助函数 ─────────────────────────────────────────────────
# 全局 dispatcher 实例(应用启动时初始化)
_dispatcher_instance: AIDispatcher | None = None
def set_dispatcher(dispatcher: "AIDispatcher") -> None:
"""设置全局 dispatcher 实例(应用启动时调用)。"""
global _dispatcher_instance
_dispatcher_instance = dispatcher
def _get_dispatcher() -> "AIDispatcher":
"""获取全局 dispatcher 实例。"""
if _dispatcher_instance is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AI Dispatcher 尚未初始化",
)
return _dispatcher_instance

View File

@@ -0,0 +1,83 @@
# AI_CHANGELOG
# - 2026-03-29 | Prompt: DWS_TASK_ENGINE ETL 任务 | 新建文件。
# 提供 POST /api/internal/run-job 端点,供 ETL 按 job_name 执行
# biz.trigger_jobs 中的任务。Internal-Token 认证。
# -*- coding: utf-8 -*-
"""
内部任务执行 API — ETL/内部服务调用入口。
端点:
- POST /api/internal/run-job — 按 job_name 执行 biz.trigger_jobs 中的任务
认证方式Authorization: Internal-Token {token}
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from app.auth.internal_token import verify_internal_token
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/internal", tags=["internal-events"])
class RunJobByNameRequest(BaseModel):
"""按 job_name 执行任务的请求体。"""
job_name: str = Field(..., description="任务名称,如 recall_completion_check")
class RunJobByNameResponse(BaseModel):
"""执行结果。"""
success: bool
message: str
job_name: str
@router.post("/run-job", response_model=RunJobByNameResponse)
async def run_job_by_name_endpoint(
body: RunJobByNameRequest,
_token: str = Depends(verify_internal_token),
) -> RunJobByNameResponse:
"""按 job_name 查找并执行 biz.trigger_jobs 中的任务。
ETL DWS_TASK_ENGINE 任务通过此端点按顺序执行后端任务引擎的各个步骤。
"""
from app.database import get_connection
from app.services.trigger_scheduler import run_job_by_id
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT id FROM biz.trigger_jobs WHERE job_name = %s",
(body.job_name,),
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if not row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"任务 '{body.job_name}' 不存在",
)
job_id = row[0]
result = run_job_by_id(job_id)
logger.info(
"内部任务执行: job_name=%s, success=%s",
body.job_name, result.get("success"),
)
return RunJobByNameResponse(
success=result.get("success", False),
message=result.get("message", ""),
job_name=body.job_name,
)

View File

@@ -67,6 +67,7 @@ async def get_retention_clues(member_id: int, site_id: int):
recorded_by_assistant_id, recorded_by_name, recorded_at, site_id, source
FROM member_retention_clue
WHERE member_id = %s AND site_id = %s
AND is_hidden = false
ORDER BY recorded_at DESC
"""
conn = get_connection()

View File

@@ -43,6 +43,7 @@ def _row_to_response(row) -> ScheduleResponse:
"""将数据库行转换为 ScheduleResponse。"""
task_config = row[4] if isinstance(row[4], dict) else json.loads(row[4])
schedule_config = row[5] if isinstance(row[5], dict) else json.loads(row[5])
min_run_intervals = row[14] if isinstance(row[14], dict) else json.loads(row[14]) if row[14] else {}
return ScheduleResponse(
id=str(row[0]),
site_id=row[1],
@@ -55,8 +56,12 @@ def _row_to_response(row) -> ScheduleResponse:
next_run_at=row[8],
run_count=row[9],
last_status=row[10],
created_at=row[11],
updated_at=row[12],
min_run_interval_value=row[11] or 0,
min_run_interval_unit=row[12] or "minutes",
last_success_at=row[13],
min_run_intervals=min_run_intervals,
created_at=row[15],
updated_at=row[16],
)
@@ -64,6 +69,8 @@ def _row_to_response(row) -> ScheduleResponse:
_SELECT_COLS = """
id, site_id, name, task_codes, task_config, schedule_config,
enabled, last_run_at, next_run_at, run_count, last_status,
min_run_interval_value, min_run_interval_unit, last_success_at,
min_run_intervals,
created_at, updated_at
"""
@@ -107,8 +114,9 @@ async def create_schedule(
cur.execute(
f"""
INSERT INTO scheduled_tasks
(site_id, name, task_codes, task_config, schedule_config, enabled, next_run_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
(site_id, name, task_codes, task_config, schedule_config, enabled, next_run_at,
min_run_interval_value, min_run_interval_unit, min_run_intervals)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING {_SELECT_COLS}
""",
(
@@ -119,6 +127,9 @@ async def create_schedule(
body.schedule_config.model_dump_json(),
body.schedule_config.enabled,
next_run,
body.min_run_interval_value,
body.min_run_interval_unit,
json.dumps({k: v.model_dump() for k, v in body.min_run_intervals.items()}) if body.min_run_intervals else "{}",
),
)
row = cur.fetchone()
@@ -174,6 +185,16 @@ async def update_schedule(
set_parts.append("next_run_at = %s")
params.append(next_run)
if body.min_run_interval_value is not None:
set_parts.append("min_run_interval_value = %s")
params.append(body.min_run_interval_value)
if body.min_run_interval_unit is not None:
set_parts.append("min_run_interval_unit = %s")
params.append(body.min_run_interval_unit)
if body.min_run_intervals is not None:
set_parts.append("min_run_intervals = %s")
params.append(json.dumps({k: v.model_dump() for k, v in body.min_run_intervals.items()}))
if not set_parts:
raise HTTPException(
status_code=422,
@@ -314,17 +335,22 @@ async def toggle_schedule(
@router.post("/{schedule_id}/run")
async def run_schedule_now(
schedule_id: str,
force: bool = Query(False),
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""手动触发调度任务执行一次,不更新 last_run_at / next_run_at / run_count。
读取调度任务的 task_config构造 TaskConfigSchema 后入队执行。
force=true 时绕过并发和间隔检查,直接入队。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT task_config, site_id FROM scheduled_tasks WHERE id = %s AND site_id = %s",
"""SELECT task_config, site_id,
min_run_interval_value, min_run_interval_unit,
last_run_at, last_status, min_run_intervals
FROM scheduled_tasks WHERE id = %s AND site_id = %s""",
(schedule_id, user.site_id),
)
row = cur.fetchone()
@@ -338,6 +364,42 @@ async def run_schedule_now(
detail="调度任务不存在",
)
# force=false 时执行并发和间隔检查
if not force:
last_status = row[5]
if last_status == "running":
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="任务正在运行中,无法重复执行",
)
min_interval_value = row[2] or 0
min_interval_unit = row[3] or "minutes"
last_run_at = row[4]
min_run_intervals_raw = row[6] if isinstance(row[6], dict) else json.loads(row[6]) if row[6] else {}
# 计算有效间隔per-task 最大值 vs schedule 级别,取较大者
effective_interval_seconds = 0
multipliers = {"minutes": 60, "hours": 3600, "days": 86400}
if min_interval_value > 0:
effective_interval_seconds = min_interval_value * multipliers.get(min_interval_unit, 60)
for _task_code, interval_cfg in min_run_intervals_raw.items():
if isinstance(interval_cfg, dict):
v = interval_cfg.get("value", 0) or 0
u = interval_cfg.get("unit", "minutes")
task_seconds = v * multipliers.get(u, 60) if v > 0 else 0
if task_seconds > effective_interval_seconds:
effective_interval_seconds = task_seconds
if effective_interval_seconds > 0 and last_run_at is not None:
now = datetime.now(timezone.utc)
elapsed = (now - last_run_at).total_seconds()
if elapsed < effective_interval_seconds:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="未达到最小运行间隔,请稍后再试",
)
task_config_raw = row[0] if isinstance(row[0], dict) else json.loads(row[0])
config = TaskConfigSchema(**task_config_raw)
config = config.model_copy(update={"store_id": user.site_id})

View File

@@ -0,0 +1,247 @@
# -*- coding: utf-8 -*-
"""
租户管理员认证路由:登录与令牌刷新。
- POST /api/tenant/auth/login — 用户名+密码验证,签发 JWTaud=tenant-admin
- POST /api/tenant/auth/refresh — 刷新令牌换取新令牌对
JWT payload 包含sub=admin_id, tenant_id, managed_site_ids, aud=tenant-admin, type
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, HTTPException, status
from jose import JWTError, jwt as jose_jwt
from pydantic import BaseModel, Field
from app import config
from app.auth.jwt import verify_password
from app.database import get_connection
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/tenant/auth", tags=["租户认证"])
# ── Pydantic 模型 ────────────────────────────────────────────
class TenantLoginRequest(BaseModel):
"""租户管理员登录请求。"""
username: str = Field(..., min_length=1, max_length=100, description="用户名")
password: str = Field(..., min_length=1, description="密码")
class TenantRefreshRequest(BaseModel):
"""刷新令牌请求。"""
refresh_token: str = Field(..., min_length=1, description="刷新令牌")
class TenantTokenResponse(BaseModel):
"""令牌响应。"""
access_token: str
refresh_token: str
token_type: str = "bearer"
# ── JWT 签发(租户管理员专用,含 aud=tenant-admin ──────────
def _create_tenant_access_token(
admin_id: int,
tenant_id: int,
managed_site_ids: list[int],
admin_type: str = "tenant_admin",
display_name: str | None = None,
) -> str:
"""签发租户管理员 access_tokenaud=tenant-admin"""
expire = datetime.now(timezone.utc) + timedelta(
minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES
)
payload: dict = {
"sub": str(admin_id),
"tenant_id": tenant_id,
"managed_site_ids": managed_site_ids,
"admin_type": admin_type,
"aud": "tenant-admin",
"type": "access",
"exp": expire,
}
if display_name is not None:
payload["display_name"] = display_name
return jose_jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
def _create_tenant_refresh_token(
admin_id: int,
tenant_id: int,
managed_site_ids: list[int],
admin_type: str = "tenant_admin",
) -> str:
"""签发租户管理员 refresh_tokenaud=tenant-admin"""
expire = datetime.now(timezone.utc) + timedelta(
days=config.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
payload: dict = {
"sub": str(admin_id),
"tenant_id": tenant_id,
"managed_site_ids": managed_site_ids,
"admin_type": admin_type,
"aud": "tenant-admin",
"type": "refresh",
"exp": expire,
}
return jose_jwt.encode(payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM)
# ── 路由端点 ─────────────────────────────────────────────────
@router.post("/login", response_model=TenantTokenResponse)
async def tenant_login(body: TenantLoginRequest):
"""
租户管理员登录。
查询 auth.tenant_admins 表验证用户名密码,成功后签发 JWT 令牌对。
- 用户不存在或密码错误401统一消息不区分
- 账号已禁用is_active=false403
- 登录成功:更新 last_login_at
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-22 | Prompt: 删除与禁用分离 | 过滤已删除记录
# CHANGE 2026-03-23 | Prompt: 登录用户名大小写不敏感 | LOWER() 比较
cur.execute(
"SELECT id, password_hash, display_name, tenant_id, "
"managed_site_ids, is_active, admin_type "
"FROM auth.tenant_admins "
"WHERE LOWER(username) = LOWER(%s) AND deleted_at IS NULL",
(body.username,),
)
row = cur.fetchone()
finally:
conn.close()
# 用户不存在 → 401统一消息
if row is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
admin_id, password_hash, display_name, tenant_id, managed_site_ids, is_active, admin_type = row
# 账号禁用 → 403
if not is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号已被禁用",
)
# 密码错误 → 401统一消息
if not verify_password(body.password, password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 登录成功:更新 last_login_at
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE auth.tenant_admins SET last_login_at = NOW() WHERE id = %s",
(admin_id,),
)
conn.commit()
except Exception:
logger.warning("更新 last_login_at 失败admin_id=%s", admin_id, exc_info=True)
finally:
conn.close()
# 签发令牌对
access_token = _create_tenant_access_token(
admin_id=admin_id,
tenant_id=tenant_id,
managed_site_ids=managed_site_ids,
admin_type=admin_type,
display_name=display_name,
)
refresh_token = _create_tenant_refresh_token(
admin_id=admin_id,
tenant_id=tenant_id,
managed_site_ids=managed_site_ids,
admin_type=admin_type,
)
return TenantTokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
)
@router.post("/refresh", response_model=TenantTokenResponse)
async def tenant_refresh(body: TenantRefreshRequest):
"""
刷新租户管理员令牌。
验证 refresh_tokenaud=tenant-admin, type=refresh签发新令牌对。
"""
try:
payload = jose_jwt.decode(
body.refresh_token,
config.JWT_SECRET_KEY,
algorithms=[config.JWT_ALGORITHM],
audience="tenant-admin",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的刷新令牌",
)
# 验证 token type
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌类型不匹配",
)
# 验证 audjose 在 aud 缺失时不会拒绝,需显式检查)
if payload.get("aud") != "tenant-admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌类型不匹配",
)
# 提取字段
admin_id = int(payload["sub"])
tenant_id = payload["tenant_id"]
managed_site_ids = payload["managed_site_ids"]
admin_type = payload.get("admin_type", "tenant_admin")
# 签发新令牌对
access_token = _create_tenant_access_token(
admin_id=admin_id,
tenant_id=tenant_id,
managed_site_ids=managed_site_ids,
admin_type=admin_type,
display_name=payload.get("display_name"),
)
refresh_token = _create_tenant_refresh_token(
admin_id=admin_id,
tenant_id=tenant_id,
managed_site_ids=managed_site_ids,
admin_type=admin_type,
)
return TenantTokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
)

View File

@@ -0,0 +1,312 @@
# -*- coding: utf-8 -*-
"""
租户管理后台 — 维客线索管理路由。
端点清单:
- GET /api/tenant/customers/search — 客户搜索keyword + site_id
- GET /api/tenant/customers/{member_id}/clues — 线索列表source / is_hidden 筛选)
- PATCH /api/tenant/clues/{id} — 编辑线索
- DELETE /api/tenant/clues/{id} — 物理删除线索
- PATCH /api/tenant/clues/{id}/visibility — 切换隐藏/显示
需求: 9.1-9.4, 10.1, 11.1-11.3, 12.2-12.3, 13.1-13.4
AI_CHANGELOG
- 2026-03-23 21:00:00 | Prompt: P20260323-210000根治 tenant_admin managed_site_ids 限制)| Direct causeJWT managed_site_ids 静态签发,新建店铺后所有端点受限 | Summary_get_clue_with_site_check 签名改为接受 admin: CurrentTenantAdminsearch_customers 用 get_effective_site_idslist_customer_clues 用 site_filter_clause(admin=admin);三个调用点改传 admin | Verify维客线索管理覆盖新建店铺
"""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.auth.tenant_admins import (
CurrentTenantAdmin,
get_effective_site_ids,
require_tenant_admin,
site_filter_clause,
)
from app.database import get_connection, get_etl_readonly_connection
from app.schemas.tenant_clues import (
ClueEditRequest,
ClueListItem,
ClueVisibilityRequest,
CustomerSearchItem,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/tenant", tags=["维客线索管理"])
def _mask_mobile(mobile: str | None) -> str | None:
"""手机号脱敏:中间 4 位替换为 ****,如 138****1234。"""
if not mobile or len(mobile) < 7:
return mobile
return mobile[:3] + "****" + mobile[7:]
def _get_clue_with_site_check(clue_id: int, admin: CurrentTenantAdmin):
"""
查询线索并校验 site_id 是否在管辖范围内。
不在管辖范围或不存在均返回 404避免泄露线索存在性
返回 (id, site_id, member_id, category, summary, detail,
recorded_by_name, source, recorded_at, is_hidden)。
"""
site_sql, site_params = site_filter_clause(admin=admin)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT id, site_id, member_id, category, summary, detail,
recorded_by_name, source, recorded_at::text, is_hidden
FROM public.member_retention_clue
WHERE id = %s AND {site_sql}
""",
(clue_id, *site_params),
)
row = cur.fetchone()
finally:
conn.close()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="线索不存在")
return row
# ── GET /api/tenant/customers/search ──────────────────────
@router.get("/customers/search")
async def search_customers(
keyword: str = Query(..., min_length=1, description="搜索关键词(姓名模糊/手机号精确)"),
site_id: Optional[int] = Query(None, description="指定门店 ID 筛选"),
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""
客户搜索:在管辖门店范围内搜索 v_dim_member。
nickname 模糊匹配 OR mobile 精确匹配scd2_is_current=1。
手机号脱敏返回。
"""
# 确定要搜索的门店列表
# [CHANGE P20260323-210000] intent: 使用 get_effective_site_ids 统一获取有效 site_ids
effective_ids = get_effective_site_ids(admin)
if site_id is not None:
if site_id not in effective_ids:
return {"items": []}
search_site_ids = [site_id]
else:
search_site_ids = effective_ids
if not search_site_ids:
return {"items": []}
# 逐 site_id 查询 FDWRLS 要求逐个设置 current_site_id
all_items: list[dict] = []
for sid in search_site_ids:
try:
etl_conn = get_etl_readonly_connection(sid)
try:
with etl_conn.cursor() as cur:
cur.execute(
"""
SELECT member_id, nickname, mobile
FROM fdw_etl.v_dim_member
WHERE scd2_is_current = 1
AND (nickname ILIKE %s OR mobile = %s)
LIMIT 50
""",
(f"%{keyword}%", keyword),
)
for row in cur.fetchall():
all_items.append(
CustomerSearchItem(
member_id=row[0],
nickname=row[1],
mobile_masked=_mask_mobile(row[2]),
site_id=sid,
).model_dump(by_alias=True)
)
finally:
etl_conn.close()
except Exception:
logger.warning("v_dim_member 搜索失败site_id=%s", sid, exc_info=True)
# 补充 site_name
if all_items:
site_ids_set = list({item.get("siteId") for item in all_items if item.get("siteId")})
if site_ids_set:
conn = get_connection()
try:
with conn.cursor() as cur:
placeholders = ", ".join(["%s"] * len(site_ids_set))
cur.execute(
f"SELECT site_id, site_name FROM biz.sites WHERE site_id IN ({placeholders})",
tuple(site_ids_set),
)
site_name_map = {r[0]: r[1] for r in cur.fetchall()}
finally:
conn.close()
for item in all_items:
sid_val = item.get("siteId")
if sid_val and sid_val in site_name_map:
item["siteName"] = site_name_map[sid_val]
return {"items": all_items}
# ── GET /api/tenant/customers/{member_id}/clues ───────────
@router.get("/customers/{member_id}/clues")
async def list_customer_clues(
member_id: int,
source: Optional[str] = Query(None, description="按来源筛选manual/ai_consumption/ai_note"),
is_hidden: Optional[bool] = Query(None, description="按隐藏状态筛选"),
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""返回该客户在管辖门店范围内的全部线索,支持 source 和 is_hidden 筛选。"""
site_sql, site_params = site_filter_clause(admin=admin)
conn = get_connection()
try:
with conn.cursor() as cur:
where_parts = [f"{site_sql}", "member_id = %s"]
params: list = list(site_params) + [member_id]
if source is not None:
where_parts.append("source = %s")
params.append(source)
if is_hidden is not None:
where_parts.append("is_hidden = %s")
params.append(is_hidden)
where_clause = " AND ".join(where_parts)
cur.execute(
f"""
SELECT id, category, summary, detail,
recorded_by_name, source, recorded_at::text, is_hidden
FROM public.member_retention_clue
WHERE {where_clause}
ORDER BY recorded_at DESC
""",
tuple(params),
)
rows = cur.fetchall()
finally:
conn.close()
items = [
ClueListItem(
id=r[0], category=r[1], summary=r[2], detail=r[3],
recorded_by_name=r[4], source=r[5], recorded_at=r[6], is_hidden=r[7],
).model_dump(by_alias=True)
for r in rows
]
return {"items": items}
# ── PATCH /api/tenant/clues/{id} ──────────────────────────
@router.patch("/clues/{clue_id}")
async def edit_clue(
clue_id: int,
body: ClueEditRequest,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""编辑线索 category/summary/detail。校验 category 枚举和 summary 长度。"""
# 先校验线索存在且在管辖范围内
_get_clue_with_site_check(clue_id, admin)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE public.member_retention_clue
SET category = %s, summary = %s, detail = %s
WHERE id = %s
""",
(body.category.value, body.summary, body.detail, clue_id),
)
conn.commit()
except Exception:
conn.rollback()
logger.error("编辑线索失败clue_id=%s", clue_id, exc_info=True)
raise HTTPException(status_code=500, detail="编辑操作失败")
finally:
conn.close()
return {"message": "更新成功"}
# ── DELETE /api/tenant/clues/{id} ─────────────────────────
@router.delete("/clues/{clue_id}")
async def delete_clue(
clue_id: int,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""物理删除线索。线索不存在或不在管辖范围返回 404。"""
_get_clue_with_site_check(clue_id, admin)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM public.member_retention_clue WHERE id = %s",
(clue_id,),
)
conn.commit()
except Exception:
conn.rollback()
logger.error("删除线索失败clue_id=%s", clue_id, exc_info=True)
raise HTTPException(status_code=500, detail="删除操作失败")
finally:
conn.close()
return {"message": "删除成功"}
# ── PATCH /api/tenant/clues/{id}/visibility ───────────────
@router.patch("/clues/{clue_id}/visibility")
async def toggle_clue_visibility(
clue_id: int,
body: ClueVisibilityRequest,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""切换线索 is_hidden 状态。"""
_get_clue_with_site_check(clue_id, admin)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE public.member_retention_clue
SET is_hidden = %s
WHERE id = %s
""",
(body.is_hidden, clue_id),
)
conn.commit()
except Exception:
conn.rollback()
logger.error("切换线索可见性失败clue_id=%s", clue_id, exc_info=True)
raise HTTPException(status_code=500, detail="操作失败")
finally:
conn.close()
return {"message": "更新成功"}

View File

@@ -0,0 +1,996 @@
# -*- coding: utf-8 -*-
"""
租户管理后台 — Excel 上传/校验/冲突/写入路由。
端点清单:
- POST /api/tenant/excel/upload — 上传解析 + 格式校验 + 人员匹配 + 冲突检测
- POST /api/tenant/excel/confirm — 确认写入(单事务)
- GET /api/tenant/excel/logs — 上传记录列表(分页)
- GET /api/tenant/excel/template/{type} — 下载空白 Excel 模板
需求: 5.1-5.5, 6.1-6.5, 7.1-7.5, 8.1-8.5
AI_CHANGELOG
- 2026-03-23 21:00:00 | Prompt: P20260323-210000根治 tenant_admin managed_site_ids 限制)| Direct causeJWT managed_site_ids 静态签发,新建店铺后所有端点受限 | Summary两个 verify_site_access 改用 admin=adminlist_upload_logs 的 site_filter_clause 改用 admin=admin | VerifyExcel 上传/确认/日志覆盖新建店铺
"""
from __future__ import annotations
import io
import json
import logging
import re
from datetime import date, datetime, timezone
from decimal import Decimal, InvalidOperation
from typing import Any, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
from fastapi.responses import StreamingResponse
from app.auth.tenant_admins import (
CurrentTenantAdmin,
require_tenant_admin,
site_filter_clause,
verify_site_access,
)
from app.database import get_connection, get_etl_readonly_connection
from app.schemas.tenant_excel import (
ConfirmRequest,
ConflictDiff,
FieldDiff,
UploadLogItem,
ValidationError as VError,
ValidationResult,
ValidationWarning as VWarning,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/tenant/excel", tags=["租户Excel上传"])
# ── 常量 ──────────────────────────────────────────────────
VALID_UPLOAD_TYPES = {"expense", "platform_income", "salary_adj", "recharge_commission"}
EXPENSE_CATEGORIES = [
"房租", "水电", "物业", "食品饮料进货", "耗材", "报销", "固定人员工资", "其他费用",
]
SALARY_ADJ_TYPES = {"扣款": "deduction", "奖金": "bonus"}
# 模板列定义(中文表头 → 内部字段名)
TEMPLATE_COLUMNS: dict[str, list[tuple[str, str]]] = {
"expense": [
("月份", "expense_month"),
("支出类别", "category"),
("金额", "amount"),
("备注", "remark"),
],
"platform_income": [
("月份", "income_month"),
("平台名称", "platform_name"),
("金额", "amount"),
("备注", "remark"),
],
"salary_adj": [
("月份", "salary_month"),
("助教姓名", "assistant_name"),
("助教编号", "assistant_number"),
("类型", "adjustment_type"),
("金额", "amount"),
("原因", "reason"),
],
"recharge_commission": [
("充值日期", "recharge_date"),
("会员名称", "member_name"),
("充值金额", "recharge_amount"),
("归属助教", "assigned_assistant"),
("奖励金额", "reward_amount"),
],
}
# 冲突检测主键规则(不含 site_idsite_id 在查询时自动附加)
CONFLICT_KEYS: dict[str, list[str]] = {
"expense": ["expense_month", "category"],
"platform_income": ["income_month", "platform_name"],
"salary_adj": ["salary_month", "assistant_name", "assistant_number", "adjustment_type", "reason"],
"recharge_commission": ["recharge_date", "member_name", "assigned_assistant"],
}
# 目标表映射
TARGET_TABLES: dict[str, str] = {
"expense": "biz.stg_finance_expense",
"platform_income": "biz.stg_platform_income",
"salary_adj": "biz.salary_adjustments",
"recharge_commission": "biz.stg_recharge_commission",
}
# 各表写入字段(不含 id, upload_batch_id, created_at, synced_at 等自动字段)
TABLE_WRITE_FIELDS: dict[str, list[str]] = {
"expense": ["site_id", "expense_month", "category", "amount", "remark", "upload_batch_id", "created_at"],
"platform_income": ["site_id", "income_month", "platform_name", "amount", "remark", "upload_batch_id", "created_at"],
"salary_adj": ["site_id", "assistant_id", "assistant_name", "assistant_number", "salary_month", "adjustment_type", "amount", "reason", "upload_batch_id", "created_at", "created_by"],
"recharge_commission": ["site_id", "recharge_date", "member_name", "recharge_amount", "assigned_assistant", "reward_amount", "upload_batch_id", "created_at"],
}
# ── 校验工具函数 ──────────────────────────────────────────
_MONTH_RE = re.compile(r"^\d{4}-(0[1-9]|1[0-2])$")
_DATE_RE = re.compile(r"^\d{4}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$")
def _validate_month(value: str, current_month: str | None = None) -> str | None:
"""校验月份格式 YYYY-MM返回错误描述或 None。"""
if not value or not _MONTH_RE.match(str(value).strip()):
return "月份格式应为 YYYY-MM"
return None
def _validate_date(value: str) -> str | None:
"""校验日期格式 YYYY-MM-DD。"""
if not value or not _DATE_RE.match(str(value).strip()):
return "日期格式应为 YYYY-MM-DD"
# 额外验证日期合法性
try:
datetime.strptime(str(value).strip(), "%Y-%m-%d")
except ValueError:
return "无效的日期"
return None
def _validate_positive_amount(value: Any) -> str | None:
"""校验金额 > 0精度 2 位小数。"""
try:
d = Decimal(str(value))
except (InvalidOperation, TypeError, ValueError):
return "金额必须为有效数字"
if d <= 0:
return "金额必须大于 0"
if d.as_tuple().exponent is not None and abs(d.as_tuple().exponent) > 2:
return "金额精度不能超过 2 位小数"
return None
def _validate_non_negative_amount(value: Any) -> str | None:
"""校验金额 ≥ 0精度 2 位小数。"""
try:
d = Decimal(str(value))
except (InvalidOperation, TypeError, ValueError):
return "金额必须为有效数字"
if d < 0:
return "金额不能为负数"
if d.as_tuple().exponent is not None and abs(d.as_tuple().exponent) > 2:
return "金额精度不能超过 2 位小数"
return None
def _validate_not_empty(value: Any, max_len: int | None = None) -> str | None:
"""校验非空字符串,可选最大长度。"""
s = str(value).strip() if value is not None else ""
if not s:
return "不能为空"
if max_len and len(s) > max_len:
return f"长度不能超过 {max_len} 字符"
return None
def validate_rows(upload_type: str, rows: list[dict]) -> tuple[list[VError], list[dict]]:
"""
按模板类型校验数据行。
返回 (errors, passed_rows)。
passed_rows 中的字段值已做类型转换(如金额转 float
"""
errors: list[VError] = []
passed: list[dict] = []
for row in rows:
row_idx = row.get("row_index", 0)
row_errors: list[VError] = []
if upload_type == "expense":
_validate_expense_row(row, row_idx, row_errors)
elif upload_type == "platform_income":
_validate_platform_income_row(row, row_idx, row_errors)
elif upload_type == "salary_adj":
_validate_salary_adj_row(row, row_idx, row_errors)
elif upload_type == "recharge_commission":
_validate_recharge_commission_row(row, row_idx, row_errors)
if row_errors:
errors.extend(row_errors)
else:
passed.append(row)
return errors, passed
def _validate_expense_row(row: dict, row_idx: int, errors: list[VError]):
"""校验财务支出行。"""
# 月份
err = _validate_month(row.get("expense_month", ""))
if err:
errors.append(VError(row_index=row_idx, column="月份", message=err))
# 支出类别
cat = str(row.get("category", "")).strip()
if cat not in EXPENSE_CATEGORIES:
errors.append(VError(
row_index=row_idx, column="支出类别",
message=f"无效的支出类别,可选值:{''.join(EXPENSE_CATEGORIES)}",
))
# 金额
err = _validate_positive_amount(row.get("amount"))
if err:
errors.append(VError(row_index=row_idx, column="金额", message=err))
# 备注(可选,最长 500
remark = row.get("remark")
if remark is not None and str(remark).strip():
if len(str(remark).strip()) > 500:
errors.append(VError(row_index=row_idx, column="备注", message="备注长度不能超过 500 字符"))
def _validate_platform_income_row(row: dict, row_idx: int, errors: list[VError]):
"""校验团购收入行。"""
err = _validate_month(row.get("income_month", ""))
if err:
errors.append(VError(row_index=row_idx, column="月份", message=err))
err = _validate_not_empty(row.get("platform_name"))
if err:
errors.append(VError(row_index=row_idx, column="平台名称", message=err))
err = _validate_positive_amount(row.get("amount"))
if err:
errors.append(VError(row_index=row_idx, column="金额", message=err))
remark = row.get("remark")
if remark is not None and str(remark).strip():
if len(str(remark).strip()) > 500:
errors.append(VError(row_index=row_idx, column="备注", message="备注长度不能超过 500 字符"))
def _validate_salary_adj_row(row: dict, row_idx: int, errors: list[VError]):
"""校验助教奖罚行。"""
err = _validate_month(row.get("salary_month", ""))
if err:
errors.append(VError(row_index=row_idx, column="月份", message=err))
err = _validate_not_empty(row.get("assistant_name"))
if err:
errors.append(VError(row_index=row_idx, column="助教姓名", message=err))
err = _validate_not_empty(row.get("assistant_number"))
if err:
errors.append(VError(row_index=row_idx, column="助教编号", message=err))
adj_type = str(row.get("adjustment_type", "")).strip()
if adj_type not in SALARY_ADJ_TYPES:
errors.append(VError(
row_index=row_idx, column="类型",
message=f"无效的类型,可选值:{''.join(SALARY_ADJ_TYPES.keys())}",
))
err = _validate_positive_amount(row.get("amount"))
if err:
errors.append(VError(row_index=row_idx, column="金额", message=err))
err = _validate_not_empty(row.get("reason"), max_len=200)
if err:
errors.append(VError(row_index=row_idx, column="原因", message=err))
def _validate_recharge_commission_row(row: dict, row_idx: int, errors: list[VError]):
"""校验充值业绩归属行。"""
err = _validate_date(row.get("recharge_date", ""))
if err:
errors.append(VError(row_index=row_idx, column="充值日期", message=err))
err = _validate_not_empty(row.get("member_name"))
if err:
errors.append(VError(row_index=row_idx, column="会员名称", message=err))
err = _validate_positive_amount(row.get("recharge_amount"))
if err:
errors.append(VError(row_index=row_idx, column="充值金额", message=err))
err = _validate_not_empty(row.get("assigned_assistant"))
if err:
errors.append(VError(row_index=row_idx, column="归属助教", message=err))
err = _validate_non_negative_amount(row.get("reward_amount"))
if err:
errors.append(VError(row_index=row_idx, column="奖励金额", message=err))
# ── Excel 解析 ────────────────────────────────────────────
def parse_excel(file_bytes: bytes, upload_type: str) -> list[dict]:
"""
解析 Excel 文件,返回行数据列表。
每行为 dict包含 row_index从 1 开始)和各字段值。
"""
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(file_bytes), read_only=True, data_only=True)
ws = wb.active
if ws is None:
return []
columns = TEMPLATE_COLUMNS.get(upload_type, [])
if not columns:
return []
rows_data: list[dict] = []
header_row = True
for row in ws.iter_rows(values_only=True):
if header_row:
header_row = False
continue # 跳过表头行
# 跳过全空行
if all(cell is None or str(cell).strip() == "" for cell in row):
continue
row_dict: dict[str, Any] = {"row_index": len(rows_data) + 1}
for i, (_, field_name) in enumerate(columns):
val = row[i] if i < len(row) else None
# 将值转为字符串(保留 None
if val is not None:
row_dict[field_name] = str(val).strip()
else:
row_dict[field_name] = ""
rows_data.append(row_dict)
wb.close()
return rows_data
# ── 人员匹配 ─────────────────────────────────────────────
def match_personnel(
rows: list[dict],
site_id: int,
upload_type: str,
) -> list[VWarning]:
"""
对 salary_adj / recharge_commission 模板执行人员匹配校验。
优先 v_dim_assistantnickname + assistant_number
未匹配再查 v_dim_staff + v_dim_staff_exname + staff_number
匹配成功填充 assistant_id失败标记 warning 不阻断。
"""
if upload_type not in ("salary_adj", "recharge_commission"):
return []
warnings: list[VWarning] = []
# 提取需要匹配的姓名+编号对
if upload_type == "salary_adj":
name_field = "assistant_name"
number_field = "assistant_number"
else:
name_field = "assigned_assistant"
number_field = None # recharge_commission 没有编号字段
# 批量查询 v_dim_assistant
assistant_map: dict[str, int] = {}
staff_map: dict[str, int] = {}
try:
etl_conn = get_etl_readonly_connection(site_id)
try:
with etl_conn.cursor() as cur:
cur.execute(
"SELECT assistant_id, nickname, number FROM fdw_etl.v_dim_assistant WHERE scd2_is_current = 1",
)
for aid, nickname, number in cur.fetchall():
if nickname and number:
assistant_map[f"{nickname}|{number}"] = aid
if nickname:
assistant_map[f"{nickname}|"] = aid
finally:
etl_conn.close()
except Exception:
logger.warning("v_dim_assistant 查询失败site_id=%s", site_id, exc_info=True)
try:
etl_conn = get_etl_readonly_connection(site_id)
try:
with etl_conn.cursor() as cur:
cur.execute(
"SELECT staff_id, name, number FROM fdw_etl.v_dim_staff",
)
for sid, name, number in cur.fetchall():
if name and number:
staff_map[f"{name}|{number}"] = sid
if name:
staff_map[f"{name}|"] = sid
finally:
etl_conn.close()
except Exception:
logger.warning("v_dim_staff 查询失败site_id=%s", site_id, exc_info=True)
for row in rows:
name = str(row.get(name_field, "")).strip()
number = str(row.get(number_field, "")).strip() if number_field else ""
row_idx = row.get("row_index", 0)
# 优先 v_dim_assistant 匹配
key_full = f"{name}|{number}"
key_name = f"{name}|"
matched_id = assistant_map.get(key_full) or assistant_map.get(key_name)
if not matched_id:
matched_id = staff_map.get(key_full) or staff_map.get(key_name)
if matched_id:
row["assistant_id"] = matched_id
else:
row["assistant_id"] = None
warnings.append(VWarning(
row_index=row_idx,
column="助教姓名",
message=f"未匹配到助教/员工:{name}" + (f"(编号 {number}" if number else ""),
))
return warnings
# ── 冲突检测 ──────────────────────────────────────────────
def detect_conflicts(
upload_type: str,
rows: list[dict],
site_id: int,
) -> tuple[list[ConflictDiff], list[dict], list[dict]]:
"""
按模板主键规则检测冲突。
返回 (conflicts, new_rows, conflict_rows_with_existing)。
- conflicts: 冲突 diff 列表
- new_rows: 无冲突的新增行
- conflict_rows_with_existing: 冲突行(附带已有数据用于 confirm 时 UPDATE
"""
keys = CONFLICT_KEYS.get(upload_type, [])
table = TARGET_TABLES.get(upload_type, "")
if not keys or not table:
return [], rows, []
# 查询已有数据
existing_map: dict[tuple, dict] = {}
conn = get_connection()
try:
with conn.cursor() as cur:
key_cols = ", ".join(keys)
# 获取所有字段用于 diff
cur.execute(f"SELECT * FROM {table} WHERE site_id = %s LIMIT 0", (site_id,))
col_names = [desc[0] for desc in cur.description] if cur.description else []
cur.execute(
f"SELECT * FROM {table} WHERE site_id = %s",
(site_id,),
)
for row_data in cur.fetchall():
row_dict = dict(zip(col_names, row_data))
pk = tuple(str(row_dict.get(k, "")).strip() for k in keys)
existing_map[pk] = row_dict
finally:
conn.close()
conflicts: list[ConflictDiff] = []
new_rows: list[dict] = []
conflict_rows: list[dict] = []
# 对 salary_adj 类型,需要将中文类型映射为英文
for row in rows:
pk_values = []
for k in keys:
val = str(row.get(k, "")).strip()
# salary_adj 的 adjustment_type 需要映射
if upload_type == "salary_adj" and k == "adjustment_type":
val = SALARY_ADJ_TYPES.get(val, val)
pk_values.append(val)
pk = tuple(pk_values)
if pk in existing_map:
existing = existing_map[pk]
# 生成逐字段 diff
field_diffs: list[FieldDiff] = []
# 比较可变字段(排除主键和系统字段)
compare_fields = _get_compare_fields(upload_type)
for field_name, display_name in compare_fields:
old_val = str(existing.get(field_name, "")) if existing.get(field_name) is not None else ""
new_val = str(row.get(field_name, "")).strip()
# salary_adj 的 adjustment_type 需要映射
if upload_type == "salary_adj" and field_name == "adjustment_type":
new_val = SALARY_ADJ_TYPES.get(new_val, new_val)
if old_val != new_val:
field_diffs.append(FieldDiff(
field=display_name, old_value=old_val, new_value=new_val,
))
if field_diffs:
conflicts.append(ConflictDiff(
row_index=row.get("row_index", 0),
field_diffs=field_diffs,
))
row["_existing_id"] = existing.get("id")
conflict_rows.append(row)
else:
# 主键匹配但所有字段相同,视为无变化,跳过
conflict_rows.append(row)
row["_existing_id"] = existing.get("id")
else:
new_rows.append(row)
return conflicts, new_rows, conflict_rows
def _get_compare_fields(upload_type: str) -> list[tuple[str, str]]:
"""获取用于 diff 比较的字段列表 [(db_field, display_name)]。"""
if upload_type == "expense":
return [("amount", "金额"), ("remark", "备注")]
elif upload_type == "platform_income":
return [("amount", "金额"), ("remark", "备注")]
elif upload_type == "salary_adj":
return [("amount", "金额"), ("reason", "原因")]
elif upload_type == "recharge_commission":
return [("recharge_amount", "充值金额"), ("reward_amount", "奖励金额")]
return []
# ── POST /api/tenant/excel/upload ─────────────────────────
@router.post("/upload")
async def upload_excel(
file: UploadFile = File(...),
upload_type: str = Form(...),
site_id: int = Form(...),
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""
上传 Excel 文件:解析 → 格式校验 → 人员匹配 → 冲突检测。
返回 upload_id + 校验结果 + 冲突 diff。
"""
# 校验 upload_type
if upload_type not in VALID_UPLOAD_TYPES:
raise HTTPException(status_code=400, detail=f"无效的模板类型,可选值:{', '.join(VALID_UPLOAD_TYPES)}")
# 校验门店权限
verify_site_access(site_id, admin=admin)
# 校验文件格式
filename = file.filename or ""
if not filename.lower().endswith((".xlsx", ".xls")):
raise HTTPException(status_code=400, detail="请上传有效的 Excel 文件(.xlsx/.xls")
# 读取文件内容
file_bytes = await file.read()
if not file_bytes:
raise HTTPException(status_code=400, detail="文件内容为空")
# 解析 Excel
try:
rows = parse_excel(file_bytes, upload_type)
except Exception as e:
logger.warning("Excel 解析失败:%s", e, exc_info=True)
raise HTTPException(status_code=400, detail="Excel 文件解析失败,请检查文件格式")
if not rows:
raise HTTPException(status_code=400, detail="Excel 文件中没有数据行")
# 格式校验
errors, passed_rows = validate_rows(upload_type, rows)
# 如果有格式错误,直接返回(不创建 upload_log
if errors:
return ValidationResult(
errors=errors,
warnings=[],
passed_rows=[],
upload_id=None,
).model_dump(by_alias=True)
# 人员匹配校验(仅 salary_adj / recharge_commission
warnings = match_personnel(passed_rows, site_id, upload_type)
# 冲突检测
conflicts, new_rows, conflict_rows = detect_conflicts(upload_type, passed_rows, site_id)
# 创建 excel_upload_log 记录
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.excel_upload_log
(site_id, upload_type, file_name, uploaded_by, row_count, conflict_count, status)
VALUES (%s, %s, %s, %s, %s, %s, 'pending')
RETURNING id
""",
(site_id, upload_type, filename, admin.admin_id, len(passed_rows), len(conflicts)),
)
upload_id = cur.fetchone()[0]
conn.commit()
except Exception:
conn.rollback()
logger.error("创建 upload_log 失败", exc_info=True)
raise HTTPException(status_code=500, detail="创建上传记录失败")
finally:
conn.close()
# 将通过的行数据临时存储到 upload_log 的 error_detail 字段JSON
# 用于 confirm 时读取(避免二次上传)
_cache_upload_data(upload_id, {
"upload_type": upload_type,
"site_id": site_id,
"new_rows": new_rows,
"conflict_rows": conflict_rows,
})
return {
**ValidationResult(
errors=[],
warnings=warnings,
passed_rows=passed_rows,
upload_id=upload_id,
).model_dump(by_alias=True),
"conflicts": [c.model_dump(by_alias=True) for c in conflicts],
}
def _cache_upload_data(upload_id: int, data: dict):
"""将上传数据缓存到 upload_log.error_detailJSON供 confirm 时使用。"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 序列化时处理 Decimal 等特殊类型
json_str = json.dumps(data, ensure_ascii=False, default=str)
cur.execute(
"UPDATE biz.excel_upload_log SET error_detail = %s::jsonb WHERE id = %s",
(json_str, upload_id),
)
conn.commit()
except Exception:
conn.rollback()
logger.warning("缓存上传数据失败upload_id=%s", upload_id, exc_info=True)
finally:
conn.close()
# ── POST /api/tenant/excel/confirm ────────────────────────
@router.post("/confirm")
async def confirm_upload(
body: ConfirmRequest,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""
确认写入:单事务写入目标表。
替换行执行 UPDATE新增行执行 INSERT。
写入失败回滚整批log status=failed。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 获取 upload_log 记录
cur.execute(
"SELECT site_id, upload_type, status, error_detail FROM biz.excel_upload_log WHERE id = %s",
(body.upload_id,),
)
log_row = cur.fetchone()
if log_row is None:
raise HTTPException(status_code=404, detail="上传记录不存在")
site_id, upload_type, log_status, cached_data = log_row
if log_status != "pending":
raise HTTPException(status_code=409, detail="该上传批次已被处理")
verify_site_access(site_id, admin=admin)
# 从缓存中读取数据
if not cached_data:
raise HTTPException(status_code=400, detail="上传数据已过期,请重新上传")
if isinstance(cached_data, str):
cached_data = json.loads(cached_data)
new_rows = cached_data.get("new_rows", [])
conflict_rows = cached_data.get("conflict_rows", [])
# 构建 resolution 映射
resolution_map: dict[int, str] = {}
for r in body.resolutions:
resolution_map[r.row_index] = r.action
table = TARGET_TABLES[upload_type]
write_fields = TABLE_WRITE_FIELDS[upload_type]
inserted_count = 0
updated_count = 0
resolved_count = 0
# 写入新增行
for row in new_rows:
_insert_row(cur, table, write_fields, row, upload_type, site_id, body.upload_id, admin.admin_id)
inserted_count += 1
# 处理冲突行
for row in conflict_rows:
row_idx = row.get("row_index", 0)
action = resolution_map.get(row_idx, "keep")
existing_id = row.get("_existing_id")
if action == "replace" and existing_id:
_update_row(cur, table, write_fields, row, upload_type, existing_id, site_id, body.upload_id, admin.admin_id)
updated_count += 1
resolved_count += 1
elif action == "keep":
resolved_count += 1
else:
# 无 existing_id 的冲突行按新增处理
_insert_row(cur, table, write_fields, row, upload_type, site_id, body.upload_id, admin.admin_id)
inserted_count += 1
# 更新 upload_log
cur.execute(
"""
UPDATE biz.excel_upload_log
SET status = 'confirmed',
row_count = %s,
resolved_count = %s,
confirmed_at = NOW(),
error_detail = NULL
WHERE id = %s
""",
(inserted_count + updated_count, resolved_count, body.upload_id),
)
conn.commit()
except HTTPException:
conn.rollback()
raise
except Exception as e:
conn.rollback()
# 记录失败状态
_mark_upload_failed(body.upload_id, str(e))
logger.error("写入失败upload_id=%s", body.upload_id, exc_info=True)
raise HTTPException(status_code=500, detail="数据写入失败,已回滚整批")
finally:
conn.close()
return {
"message": "写入成功",
"inserted": inserted_count,
"updated": updated_count,
"resolved": resolved_count,
}
def _insert_row(cur, table: str, fields: list[str], row: dict, upload_type: str, site_id: int, upload_id: int, admin_id: int):
"""插入一行数据到目标表。"""
values = _build_row_values(fields, row, upload_type, site_id, upload_id, admin_id)
placeholders = ", ".join(["%s"] * len(fields))
cols = ", ".join(fields)
cur.execute(f"INSERT INTO {table} ({cols}) VALUES ({placeholders})", tuple(values))
def _update_row(cur, table: str, fields: list[str], row: dict, upload_type: str, existing_id: int, site_id: int, upload_id: int, admin_id: int):
"""更新已有行。"""
values = _build_row_values(fields, row, upload_type, site_id, upload_id, admin_id)
set_parts = [f"{f} = %s" for f in fields]
cur.execute(
f"UPDATE {table} SET {', '.join(set_parts)} WHERE id = %s",
(*values, existing_id),
)
def _build_row_values(fields: list[str], row: dict, upload_type: str, site_id: int, upload_id: int, admin_id: int) -> list:
"""根据字段列表构建值列表。"""
values = []
for f in fields:
if f == "site_id":
values.append(site_id)
elif f == "upload_batch_id":
values.append(upload_id)
elif f == "created_at":
values.append(datetime.now(timezone.utc))
elif f == "created_by":
values.append(admin_id)
elif f == "adjustment_type":
# 中文 → 英文映射
raw = str(row.get(f, "")).strip()
values.append(SALARY_ADJ_TYPES.get(raw, raw))
elif f in ("amount", "recharge_amount", "reward_amount"):
try:
values.append(float(row.get(f, 0)))
except (ValueError, TypeError):
values.append(0.0)
elif f == "assistant_id":
values.append(row.get("assistant_id"))
else:
values.append(str(row.get(f, "")).strip() if row.get(f) is not None else None)
return values
def _mark_upload_failed(upload_id: int, error_msg: str):
"""标记上传批次为失败状态。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.excel_upload_log
SET status = 'failed',
error_detail = %s::jsonb
WHERE id = %s
""",
(json.dumps({"error": error_msg}, ensure_ascii=False), upload_id),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.warning("标记上传失败状态失败upload_id=%s", upload_id, exc_info=True)
# ── GET /api/tenant/excel/logs ────────────────────────────
@router.get("/logs")
async def list_upload_logs(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页条数"),
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""上传记录列表,分页,附加 site_id IN 条件。"""
site_sql, site_params = site_filter_clause(admin=admin)
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"SELECT COUNT(*) FROM biz.excel_upload_log WHERE {site_sql}",
site_params,
)
total = cur.fetchone()[0]
cur.execute(
f"""
SELECT id, site_id, upload_type, file_name, uploaded_by,
row_count, conflict_count, resolved_count, status,
created_at::text, confirmed_at::text
FROM biz.excel_upload_log
WHERE {site_sql}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(*site_params, page_size, offset),
)
rows = cur.fetchall()
finally:
conn.close()
items = [
UploadLogItem(
id=r[0], site_id=r[1], upload_type=r[2], file_name=r[3],
uploaded_by=r[4], row_count=r[5], conflict_count=r[6],
resolved_count=r[7], status=r[8], created_at=r[9], confirmed_at=r[10],
).model_dump(by_alias=True)
for r in rows
]
return {"items": items, "total": total, "page": page, "pageSize": page_size}
# ── GET /api/tenant/excel/template/{type} ─────────────────
@router.get("/template/{template_type}")
async def download_template(template_type: str):
"""返回空白 Excel 模板文件(含表头和格式说明)。"""
if template_type not in VALID_UPLOAD_TYPES:
raise HTTPException(status_code=400, detail=f"无效的模板类型,可选值:{', '.join(VALID_UPLOAD_TYPES)}")
import openpyxl
from openpyxl.styles import Font, PatternFill, Alignment
wb = openpyxl.Workbook()
ws = wb.active
ws.title = "数据模板"
columns = TEMPLATE_COLUMNS[template_type]
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
# 写入表头
for col_idx, (header_name, _) in enumerate(columns, 1):
cell = ws.cell(row=1, column=col_idx, value=header_name)
cell.font = header_font
cell.fill = header_fill
cell.alignment = Alignment(horizontal="center")
ws.column_dimensions[cell.column_letter].width = 18
# 写入格式说明行(第 2 行,灰色字体)
hint_font = Font(color="808080", italic=True)
hints = _get_template_hints(template_type)
for col_idx, hint in enumerate(hints, 1):
cell = ws.cell(row=2, column=col_idx, value=hint)
cell.font = hint_font
# 输出为字节流
output = io.BytesIO()
wb.save(output)
output.seek(0)
from urllib.parse import quote as _url_quote
filename_map = {
"expense": "财务支出模板.xlsx",
"platform_income": "团购收入模板.xlsx",
"salary_adj": "助教奖罚模板.xlsx",
"recharge_commission": "充值业绩归属模板.xlsx",
}
raw_name = filename_map.get(template_type, "template.xlsx")
encoded_name = _url_quote(raw_name)
return StreamingResponse(
output,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_name}",
},
)
def _get_template_hints(template_type: str) -> list[str]:
"""获取模板格式说明。"""
if template_type == "expense":
return [
"格式YYYY-MM",
f"可选值:{''.join(EXPENSE_CATEGORIES)}",
"大于0保留2位小数",
"可选最长500字符",
]
elif template_type == "platform_income":
return [
"格式YYYY-MM",
"必填",
"大于0保留2位小数",
"可选最长500字符",
]
elif template_type == "salary_adj":
return [
"格式YYYY-MM",
"必填",
"必填",
"可选值:扣款、奖金",
"大于0保留2位小数",
"必填最长200字符",
]
elif template_type == "recharge_commission":
return [
"格式YYYY-MM-DD",
"必填",
"大于0保留2位小数",
"必填",
"≥0保留2位小数",
]
return []

View File

@@ -0,0 +1,354 @@
# -*- coding: utf-8 -*-
"""
租户管理后台 — 店铺管理员 CRUD 路由。
仅 admin_type='tenant_admin' 的管理员可调用。
店铺管理员复用 auth.tenant_admins 表admin_type='site_admin'
端点清单:
- GET /api/tenant/site-admins — 店铺管理员列表
- POST /api/tenant/site-admins — 创建店铺管理员
- PATCH /api/tenant/site-admins/{id} — 编辑店铺管理员
- DELETE /api/tenant/site-admins/{id} — 软删除店铺管理员
- POST /api/tenant/site-admins/{id}/reset-password — 重置密码
AI_CHANGELOG
- 2026-03-23 21:00:00 | Prompt: P20260323-210000根治 tenant_admin managed_site_ids 限制)| Direct causeJWT managed_site_ids 静态签发,新建店铺后所有端点受限 | Summarycreate_site_admin 和 edit_site_admin 的权限子集校验改用 get_effective_site_ids(admin)(覆盖新建店铺)| Verify创建/编辑店铺管理员时可选新建店铺
"""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from psycopg2 import errors as pg_errors
from pydantic import Field
from app.auth.jwt import hash_password
from app.auth.tenant_admins import CurrentTenantAdmin, get_effective_site_ids, require_tenant_admin
from app.database import get_connection
from app.schemas.base import CamelModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/tenant", tags=["租户店铺管理员"])
# ── 权限守卫:仅 tenant_admin 可调用 ─────────────────────
def _require_tenant_admin_type(admin: CurrentTenantAdmin) -> CurrentTenantAdmin:
"""校验当前登录者为租户管理员(非店铺管理员)。"""
if admin.admin_type != "tenant_admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅租户管理员可执行此操作",
)
return admin
# ── Schema ────────────────────────────────────────────────
class SiteAdminCreateRequest(CamelModel):
"""创建店铺管理员请求。"""
username: str = Field(..., max_length=56, description="用户名site_code 前缀 + 最长 50 字符)")
password: str = Field(..., min_length=6, description="初始密码")
display_name: str | None = Field(None, max_length=100, description="显示名称")
managed_site_ids: list[int] = Field(..., min_length=1, description="管辖门店 ID 列表")
class SiteAdminEditRequest(CamelModel):
"""编辑店铺管理员请求。"""
display_name: str | None = Field(None, max_length=100)
managed_site_ids: list[int] | None = Field(None, min_length=1)
is_active: bool | None = None
class SiteAdminResetPasswordRequest(CamelModel):
"""重置密码请求。"""
new_password: str = Field(..., min_length=6)
# ── GET /api/tenant/site-admins ───────────────────────────
@router.get("/site-admins")
async def list_site_admins(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
keyword: Optional[str] = Query(None, description="搜索用户名/显示名称"),
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""列出当前租户下的店铺管理员。"""
_require_tenant_admin_type(admin)
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
where_parts = [
"ta.tenant_id = %s",
"ta.admin_type = 'site_admin'",
"ta.deleted_at IS NULL",
]
params: list = [admin.tenant_id]
if keyword:
where_parts.append("(ta.username ILIKE %s OR ta.display_name ILIKE %s)")
like_val = f"%{keyword}%"
params.extend([like_val, like_val])
where_sql = " AND ".join(where_parts)
cur.execute(f"SELECT COUNT(*) FROM auth.tenant_admins ta WHERE {where_sql}", params)
total = cur.fetchone()[0]
cur.execute(
f"""
SELECT ta.id, ta.username, ta.display_name, ta.managed_site_ids,
ta.is_active, ta.created_at, ta.last_login_at
FROM auth.tenant_admins ta
WHERE {where_sql}
ORDER BY ta.created_at DESC
LIMIT %s OFFSET %s
""",
params + [page_size, offset],
)
rows = cur.fetchall()
finally:
conn.close()
items = [
{
"id": r[0], "username": r[1], "displayName": r[2],
"managedSiteIds": list(r[3]) if r[3] else [],
"isActive": r[4],
"createdAt": r[5].isoformat() if r[5] else None,
"lastLoginAt": r[6].isoformat() if r[6] else None,
}
for r in rows
]
return {"items": items, "total": total, "page": page, "pageSize": page_size}
# ── POST /api/tenant/site-admins ──────────────────────────
@router.post("/site-admins", status_code=status.HTTP_201_CREATED)
async def create_site_admin(
body: SiteAdminCreateRequest,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""
创建店铺管理员。
用户名校验:必须以管辖店铺的 site_code 开头。
managed_site_ids 必须是当前租户管理员管辖范围的子集。
"""
_require_tenant_admin_type(admin)
# 校验 managed_site_ids 是当前管理员有效管辖范围的子集
# [CHANGE P20260323-210000] intent: 使用 get_effective_site_ids 替代 JWT managed_site_ids
# tenant_admin 按 tenant_id 查库获取有效范围(覆盖新建店铺)
effective_ids = get_effective_site_ids(admin)
if not set(body.managed_site_ids).issubset(set(effective_ids)):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="所选门店超出您的管辖范围",
)
# 校验用户名以 site_code 开头
conn = get_connection()
try:
with conn.cursor() as cur:
# 查询第一个管辖店铺的 site_code
cur.execute(
"SELECT site_code FROM biz.sites WHERE site_id = %s AND is_active = true",
(body.managed_site_ids[0],),
)
site_row = cur.fetchone()
if site_row is None or site_row[0] is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="第一个管辖店铺未设置简写ID",
)
expected_prefix = site_row[0].upper()
if not body.username.upper().startswith(expected_prefix):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"用户名必须以简写ID '{expected_prefix}' 开头",
)
# 插入记录
password_hash = hash_password(body.password)
cur.execute(
"""
INSERT INTO auth.tenant_admins
(username, password_hash, display_name, tenant_id,
managed_site_ids, admin_type, created_by)
VALUES (LOWER(%s), %s, %s, %s, %s, 'site_admin', %s)
RETURNING id, created_at
""",
(
body.username,
password_hash,
body.display_name,
admin.tenant_id,
body.managed_site_ids,
admin.admin_id,
),
)
row = cur.fetchone()
conn.commit()
except HTTPException:
raise
except pg_errors.UniqueViolation:
conn.rollback()
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="用户名已存在")
except Exception:
conn.rollback()
logger.error("创建店铺管理员失败", exc_info=True)
raise HTTPException(status_code=500, detail="创建失败")
finally:
conn.close()
return {"id": row[0], "createdAt": row[1].isoformat() if row[1] else None}
# ── PATCH /api/tenant/site-admins/{id} ────────────────────
@router.patch("/site-admins/{admin_id}")
async def edit_site_admin(
admin_id: int,
body: SiteAdminEditRequest,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""编辑店铺管理员(显示名称/管辖门店/启用状态)。"""
_require_tenant_admin_type(admin)
if body.managed_site_ids is not None:
effective_ids = get_effective_site_ids(admin)
if not set(body.managed_site_ids).issubset(set(effective_ids)):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="所选门店超出您的管辖范围")
set_clauses: list[str] = []
params: list = []
if body.display_name is not None:
set_clauses.append("display_name = %s")
params.append(body.display_name)
if body.managed_site_ids is not None:
set_clauses.append("managed_site_ids = %s")
params.append(body.managed_site_ids)
if body.is_active is not None:
set_clauses.append("is_active = %s")
params.append(body.is_active)
if not set_clauses:
raise HTTPException(status_code=422, detail="至少需要提供一个修改字段")
params.extend([admin_id, admin.tenant_id])
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
UPDATE auth.tenant_admins
SET {', '.join(set_clauses)}
WHERE id = %s AND tenant_id = %s
AND admin_type = 'site_admin' AND deleted_at IS NULL
RETURNING id
""",
params,
)
if cur.fetchone() is None:
raise HTTPException(status_code=404, detail="店铺管理员不存在")
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": admin_id}
# ── DELETE /api/tenant/site-admins/{id} ───────────────────
@router.delete("/site-admins/{admin_id}")
async def delete_site_admin(
admin_id: int,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""软删除店铺管理员。"""
_require_tenant_admin_type(admin)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE auth.tenant_admins SET deleted_at = NOW()
WHERE id = %s AND tenant_id = %s
AND admin_type = 'site_admin' AND deleted_at IS NULL
RETURNING id
""",
(admin_id, admin.tenant_id),
)
if cur.fetchone() is None:
raise HTTPException(status_code=404, detail="店铺管理员不存在")
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": admin_id}
# ── POST /api/tenant/site-admins/{id}/reset-password ──────
@router.post("/site-admins/{admin_id}/reset-password")
async def reset_site_admin_password(
admin_id: int,
body: SiteAdminResetPasswordRequest,
admin: CurrentTenantAdmin = Depends(require_tenant_admin),
):
"""重置店铺管理员密码。"""
_require_tenant_admin_type(admin)
new_hash = hash_password(body.new_password)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE auth.tenant_admins SET password_hash = %s
WHERE id = %s AND tenant_id = %s
AND admin_type = 'site_admin' AND deleted_at IS NULL
RETURNING id
""",
(new_hash, admin_id, admin.tenant_id),
)
if cur.fetchone() is None:
raise HTTPException(status_code=404, detail="店铺管理员不存在")
conn.commit()
except HTTPException:
raise
except Exception:
conn.rollback()
raise
finally:
conn.close()
return {"id": admin_id}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
"""定时任务管理 API
提供 3 个端点:
- GET /api/trigger-jobs — 列出所有定时任务
- POST /api/trigger-jobs/{id}/run — 手动执行指定任务
- PATCH /api/trigger-jobs/{id}/config — 编辑触发器配置
所有端点需要 JWT 认证(系统管理后台使用)。
"""
from __future__ import annotations
import json
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from app.auth.dependencies import CurrentUser, get_current_user
from app.database import get_connection
from app.schemas.trigger_jobs import TriggerJobItem, RunJobResult, UpdateTriggerConfigRequest
from app.services.trigger_scheduler import list_trigger_jobs, run_job_by_id, _calculate_next_run
from app.utils.cron_validator import validate_cron_expression
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/trigger-jobs", tags=["定时任务"])
@router.get("", response_model=list[TriggerJobItem])
async def get_trigger_jobs(
user: CurrentUser = Depends(get_current_user),
) -> list[TriggerJobItem]:
"""返回所有定时任务列表。"""
try:
jobs = list_trigger_jobs()
return [TriggerJobItem(**j) for j in jobs]
except Exception as exc:
logger.exception("获取定时任务列表失败")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取定时任务列表失败: {str(exc)[:200]}",
)
@router.post("/{job_id}/run", response_model=RunJobResult)
async def run_trigger_job(
job_id: int,
user: CurrentUser = Depends(get_current_user),
) -> RunJobResult:
"""手动执行指定定时任务。"""
try:
result = run_job_by_id(job_id)
return RunJobResult(**result)
except Exception as exc:
logger.exception("手动执行任务 %s 失败", job_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"执行失败: {str(exc)[:200]}",
)
@router.patch("/{job_id}/config", response_model=TriggerJobItem)
async def update_trigger_config(
job_id: int,
body: UpdateTriggerConfigRequest,
user: CurrentUser = Depends(get_current_user),
) -> TriggerJobItem:
"""编辑触发器的 cron_expression 或 interval_seconds。
仅 merge 请求中非 None 的字段到 trigger_config JSONB
不覆盖其他已有字段。更新后重新计算 next_run_at。
"""
# --- 校验 cron_expression 格式 ---
if body.cron_expression is not None and not validate_cron_expression(body.cron_expression):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="cron 表达式格式无效,需要 5 字段格式",
)
# --- 校验 interval_seconds >= 1 ---
if body.interval_seconds is not None and body.interval_seconds < 1:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="interval_seconds 必须 >= 1",
)
conn = get_connection()
try:
with conn.cursor() as cur:
# 查询 trigger_job 是否存在,同时获取当前 trigger_condition 和 trigger_config
cur.execute(
"SELECT trigger_condition, trigger_config FROM biz.trigger_jobs WHERE id = %s",
(job_id,),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"任务 {job_id} 不存在",
)
trigger_condition, current_config = row
current_config = current_config or {}
# 构建 config_updates仅包含非 None 字段
config_updates: dict = {}
if body.cron_expression is not None:
config_updates["cron_expression"] = body.cron_expression
if body.interval_seconds is not None:
config_updates["interval_seconds"] = body.interval_seconds
# 合并后的 trigger_config 用于计算 next_run_at
merged_config = {**current_config, **config_updates}
next_run_at = _calculate_next_run(trigger_condition, merged_config)
# 使用 || 操作符 merge JSONB避免覆盖其他字段
cur.execute(
"""
UPDATE biz.trigger_jobs
SET trigger_config = trigger_config || %s::jsonb,
next_run_at = %s
WHERE id = %s
RETURNING id, job_type, job_name, trigger_condition, trigger_config,
last_run_at, next_run_at, status, description, last_error, created_at
""",
(json.dumps(config_updates), next_run_at, job_id),
)
updated = cur.fetchone()
conn.commit()
col_names = [
"id", "job_type", "job_name", "trigger_condition", "trigger_config",
"last_run_at", "next_run_at", "status", "description", "last_error", "created_at",
]
result = dict(zip(col_names, updated))
# 日期时间字段转字符串psycopg2 返回 datetime 对象)
for dt_field in ("last_run_at", "next_run_at", "created_at"):
if result[dt_field] is not None:
result[dt_field] = str(result[dt_field])
return TriggerJobItem(**result)
except HTTPException:
raise
except Exception as exc:
conn.rollback()
logger.exception("更新任务 %s 触发器配置失败", job_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新触发器配置失败: {str(exc)[:200]}",
)
finally:
conn.close()

View File

@@ -15,6 +15,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.ai.cache_service import AICacheService
from app.ai.schemas import CacheTypeEnum
from app.auth.dependencies import CurrentUser, get_current_user
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -22,6 +23,7 @@ router = APIRouter(prefix="/api/ai", tags=["小程序 AI 缓存"])
@router.get("/cache/{cache_type}")
@trace_service("查询 AI 缓存", "Get AI cache")
async def get_ai_cache(
cache_type: str,
target_id: str = Query(..., description="目标 IDmember_id / assistant_id_member_id / 时间维度编码)"),

View File

@@ -1,3 +1,9 @@
# AI_CHANGELOG
# | 日期 | Prompt | 变更 |
# |------|--------|------|
# | 2026-03-23 | P20260323-190012 禁用→移除+鉴权两层模型 | login/refresh 移除 disabled 403 拦截disabled 签发受限令牌由前端路由cancel-application 接口;角色列表更新 |
# | 2026-03-23 | 角色路由+页面权限守卫 | /api/xcx/me、/api/xcx/login、/api/xcx/dev-login 返回用户角色 |
# -*- coding: utf-8 -*-
"""
小程序认证路由 —— 微信登录、申请提交、状态查询、店铺切换、令牌刷新。
@@ -37,17 +43,20 @@ from app.auth.jwt import (
from app import config
from app.database import get_connection
from app.services.application import (
cancel_application,
create_application,
get_user_applications,
)
from app.schemas.xcx_auth import (
ApplicationRequest,
ApplicationResponse,
CancelApplicationResponse,
DevLoginRequest,
DevSwitchBindingRequest,
DevSwitchRoleRequest,
DevSwitchStatusRequest,
DevContextResponse,
LatestApplicationDetail,
RefreshTokenRequest,
SiteInfo,
SwitchSiteRequest,
@@ -57,6 +66,7 @@ from app.schemas.xcx_auth import (
)
from app.services.wechat import WeChatAuthError, code2session
from app.services.role import get_user_permissions
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -74,6 +84,7 @@ def _get_user_roles_at_site(conn, user_id: int, site_id: int) -> list[str]:
FROM auth.user_site_roles usr
JOIN auth.roles r ON usr.role_id = r.id
WHERE usr.user_id = %s AND usr.site_id = %s
AND usr.is_removed = false
""",
(user_id, site_id),
)
@@ -88,6 +99,7 @@ def _get_user_default_site(conn, user_id: int) -> int | None:
SELECT DISTINCT site_id
FROM auth.user_site_roles
WHERE user_id = %s
AND is_removed = false
ORDER BY site_id
LIMIT 1
""",
@@ -100,12 +112,13 @@ def _get_user_default_site(conn, user_id: int) -> int | None:
# ── POST /api/xcx/login ──────────────────────────────────
@router.post("/login", response_model=WxLoginResponse)
@trace_service("微信登录", "WeChat login")
async def wx_login(body: WxLoginRequest):
"""
微信登录。
流程code → code2session(openid) → 查找/创建 auth.users → 签发 JWT。
- disabled 用户返回 403
- disabled 用户签发受限令牌,由前端状态路由处理
- 新用户自动创建status=new前端引导至申请页
- approved 用户签发包含 site_id + roles 的完整令牌
- new/pending/rejected 用户签发受限令牌
@@ -157,23 +170,38 @@ async def wx_login(body: WxLoginRequest):
(openid,),
)
row = cur.fetchone()
else:
# CHANGE 2026-03-22 | #8: 已有用户登录时更新 wx_union_id幂等保护
# intent: unionid 可能在首次登录时为空(未绑定开放平台),后续登录补全
if unionid:
cur.execute(
"""
UPDATE auth.users
SET wx_union_id = %s
WHERE id = %s
AND (wx_union_id IS NULL OR wx_union_id <> %s)
""",
(unionid, row[0], unionid),
)
if cur.rowcount > 0:
conn.commit()
user_id, user_status = row
# 3. disabled 用户拒绝登录
if user_status == "disabled":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号已被禁用",
)
# CHANGE 2026-03-23 | disabled 不再拒绝登录
# 第一层微信身份始终有效disabled 只影响第二层(业务状态路由)
# disabled/new/pending/rejected 统一签发受限令牌,由前端状态路由处理
# 4. 签发令牌
# CHANGE 2026-03-23 | 角色路由:登录时查询角色并返回
login_role: str | None = None
if user_status == "approved":
# 查找默认 site_id 和角色
default_site_id = _get_user_default_site(conn, user_id)
if default_site_id is not None:
roles = _get_user_roles_at_site(conn, user_id, default_site_id)
tokens = create_token_pair(user_id, default_site_id, roles=roles)
login_role = roles[0] if roles else None
else:
# approved 但无 site 绑定(异常边界),签发受限令牌
tokens = create_limited_token_pair(user_id)
@@ -190,12 +218,14 @@ async def wx_login(body: WxLoginRequest):
token_type=tokens["token_type"],
user_status=user_status,
user_id=user_id,
role=login_role,
)
# ── POST /api/xcx/apply ──────────────────────────────────
@router.post("/apply", response_model=ApplicationResponse)
@trace_service("提交入驻申请", "Submit application")
async def submit_application(
body: ApplicationRequest,
user: CurrentUser = Depends(get_current_user_or_limited),
@@ -217,9 +247,27 @@ async def submit_application(
return ApplicationResponse(**result)
# ── POST /api/xcx/cancel-application ─────────────────────
@router.post("/cancel-application", response_model=CancelApplicationResponse)
@trace_service("取消申请", "Cancel application")
async def cancel_my_application(
user: CurrentUser = Depends(get_current_user_or_limited),
):
"""
用户主动取消当前 pending 申请。
将申请 status 改为 cancelled用户 status 回退 new。
返回被取消申请的信息(用于前端预填重新申请表单)。
"""
result = await cancel_application(user_id=user.user_id)
return CancelApplicationResponse(**result)
# ── GET /api/xcx/me ───────────────────────────────────────
@router.get("/me", response_model=UserStatusResponse)
@trace_service("查询自身状态", "Get my status")
async def get_my_status(
user: CurrentUser = Depends(get_current_user_or_limited),
):
@@ -232,8 +280,9 @@ async def get_my_status(
try:
with conn.cursor() as cur:
# 查询用户基本信息
# CHANGE 2026-03-24 | 头像:新增 avatar_url 字段查询
cur.execute(
"SELECT id, status, nickname FROM auth.users WHERE id = %s",
"SELECT id, status, nickname, avatar_url FROM auth.users WHERE id = %s",
(user.user_id,),
)
user_row = cur.fetchone()
@@ -243,25 +292,110 @@ async def get_my_status(
detail="用户不存在",
)
user_id, user_status, nickname = user_row
user_id, user_status, nickname, avatar_url = user_row
# CHANGE 2026-03-23 | 角色路由approved 用户查询当前门店角色
role: str | None = None
store_name: str | None = None
coach_level: str | None = None
if user_status == "approved":
site_id = getattr(user, "site_id", None)
# CHANGE 2026-03-24 | 受限 token 兼容token 无 site_id 时从数据库查默认 site
# 场景:用户从 pending→approved旧的受限 token 不含 site_id
if not site_id:
site_id = _get_user_default_site(conn, user_id)
if site_id:
roles = _get_user_roles_at_site(conn, user_id, site_id)
# 用户在一个门店下仅一个角色
role = roles[0] if roles else None
# CHANGE 2026-03-23 | banner 数据修复:查询门店名
cur.execute(
"SELECT site_name FROM biz.sites WHERE site_id = %s",
(site_id,),
)
sn_row = cur.fetchone()
store_name = sn_row[0] if sn_row else None
# CHANGE 2026-03-23 | banner 数据修复查询助教等级coach_level
cur.execute(
"""
SELECT assistant_id
FROM auth.user_assistant_binding
WHERE user_id = %s AND site_id = %s AND assistant_id IS NOT NULL
AND is_removed = false
LIMIT 1
""",
(user_id, site_id),
)
bind_row = cur.fetchone()
if bind_row:
try:
from datetime import datetime as _dt
from app.services import fdw_queries
_now = _dt.now()
# CHANGE 2026-03-24 | coach_level 回退链salary_calc → monthly_summary
# salary_calc 月初结算前可能无数据monthly_summary 每日更新更可靠
salary = fdw_queries.get_salary_calc(
conn, site_id, bind_row[0], _now.year, _now.month,
)
if salary:
coach_level = salary.get("coach_level") or None
if not coach_level:
ms = fdw_queries.get_monthly_summary(
conn, site_id, bind_row[0], _now.year, _now.month,
)
if ms:
coach_level = ms.get("coach_level") or None
except Exception:
pass # 优雅降级FDW 查询失败不影响主流程
finally:
conn.close()
# 委托 service 查询申请列表
# CHANGE 2026-03-27 | 权限改造 W2查询权限码列表
# get_user_permissions 内部自行获取连接,无需外部 conn
permissions: list[str] = []
if user_status == "approved" and role:
_perm_site_id = getattr(user, "site_id", None) or site_id
if _perm_site_id:
permissions = await get_user_permissions(user_id, _perm_site_id)
# 委托 service 查询申请列表(排除 cancelled
app_list = await get_user_applications(user_id)
applications = [ApplicationResponse(**a) for a in app_list]
applications = [ApplicationResponse(**a) for a in app_list if a["status"] != "cancelled"]
# 最新申请(含 phone/employee_number用于前端展示和预填
latest = None
if app_list:
la = app_list[0] # 已按 created_at DESC 排序
latest = LatestApplicationDetail(
id=la["id"],
site_code=la["site_code"],
applied_role_text=la["applied_role_text"],
phone=la.get("phone", ""),
employee_number=la.get("employee_number"),
status=la["status"],
review_note=la.get("review_note"),
created_at=la["created_at"],
reviewed_at=la.get("reviewed_at"),
)
return UserStatusResponse(
user_id=user_id,
status=user_status,
nickname=nickname,
avatar_url=avatar_url,
role=role,
permissions=permissions,
store_name=store_name,
coach_level=coach_level,
applications=applications,
latest_application=latest,
)
# ── GET /api/xcx/me/sites ────────────────────────────────
@router.get("/me/sites", response_model=list[SiteInfo])
@trace_service("查询关联店铺", "Get my sites")
async def get_my_sites(
user: CurrentUser = Depends(get_current_user),
):
@@ -281,8 +415,9 @@ async def get_my_sites(
r.name AS role_name
FROM auth.user_site_roles usr
JOIN auth.roles r ON usr.role_id = r.id
LEFT JOIN auth.site_code_mapping scm ON usr.site_id = scm.site_id
LEFT JOIN biz.sites scm ON scm.site_id = usr.site_id
WHERE usr.user_id = %s
AND usr.is_removed = false
ORDER BY usr.site_id, r.code
""",
(user.user_id,),
@@ -306,6 +441,7 @@ async def get_my_sites(
# ── POST /api/xcx/switch-site ────────────────────────────
@router.post("/switch-site", response_model=WxLoginResponse)
@trace_service("切换当前店铺", "Switch site")
async def switch_site(
body: SwitchSiteRequest,
user: CurrentUser = Depends(get_current_user),
@@ -323,6 +459,7 @@ async def switch_site(
"""
SELECT 1 FROM auth.user_site_roles
WHERE user_id = %s AND site_id = %s
AND is_removed = false
LIMIT 1
""",
(user.user_id, body.site_id),
@@ -360,13 +497,14 @@ async def switch_site(
# ── POST /api/xcx/refresh ────────────────────────────────
@router.post("/refresh", response_model=WxLoginResponse)
@trace_service("刷新令牌", "Refresh token")
async def refresh_token(body: RefreshTokenRequest):
"""
刷新令牌。
解码 refresh_token → 根据用户当前状态签发新的令牌对。
- 受限 refresh_tokenlimited=True→ 签发新的受限令牌对
- 完整 refresh_token → 签发新的完整令牌对(保持原 site_id
解码 refresh_token → 根据用户当前数据库状态签发新的令牌对。
- approved 用户 → 签发完整令牌(即使旧 token 是受限的,也自动升级)
- 其他状态 → 签发受限令牌
"""
try:
payload = decode_refresh_token(body.refresh_token)
@@ -396,26 +534,28 @@ async def refresh_token(body: RefreshTokenRequest):
_, user_status = user_row
if user_status == "disabled":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号已被禁用",
)
if is_limited or user_status != "approved":
# 受限令牌刷新 → 仍签发受限令牌
tokens = create_limited_token_pair(user_id)
else:
# 完整令牌刷新 → 使用原 site_id 签发
site_id = payload.get("site_id")
if site_id is None:
# 回退到默认 site
# CHANGE 2026-03-23 | 令牌升级:根据数据库当前状态决定签发类型
# 旧的受限 token 不应锁死用户——审核通过后 refresh 应自动升级为完整 token
if user_status == "approved":
# approved 用户:签发完整令牌(无论旧 token 是否 limited
if is_limited:
# 受限 token 升级:查默认 site
site_id = _get_user_default_site(conn, user_id)
else:
# 完整 token 刷新:优先保持原 site_id
site_id = payload.get("site_id")
if site_id is None:
site_id = _get_user_default_site(conn, user_id)
if site_id is not None:
roles = _get_user_roles_at_site(conn, user_id, site_id)
tokens = create_token_pair(user_id, site_id, roles=roles)
else:
# approved 但无 site 绑定(异常边界)
tokens = create_limited_token_pair(user_id)
else:
# new / pending / rejected / disabled → 受限令牌
tokens = create_limited_token_pair(user_id)
finally:
conn.close()
@@ -433,6 +573,7 @@ async def refresh_token(body: RefreshTokenRequest):
if config.WX_DEV_MODE:
@router.post("/dev-login", response_model=WxLoginResponse)
@trace_service("开发模式登录", "Dev mode login")
async def dev_login(body: DevLoginRequest):
"""
开发模式 mock 登录。
@@ -482,11 +623,14 @@ if config.WX_DEV_MODE:
user_id, user_status = row
# 签发令牌(逻辑与正常登录一致)
# CHANGE 2026-03-23 | 角色路由dev-login 也返回角色
dev_login_role: str | None = None
if user_status == "approved":
default_site_id = _get_user_default_site(conn, user_id)
if default_site_id is not None:
roles = _get_user_roles_at_site(conn, user_id, default_site_id)
tokens = create_token_pair(user_id, default_site_id, roles=roles)
dev_login_role = roles[0] if roles else None
else:
tokens = create_limited_token_pair(user_id)
else:
@@ -501,11 +645,13 @@ if config.WX_DEV_MODE:
token_type=tokens["token_type"],
user_status=user_status,
user_id=user_id,
role=dev_login_role,
)
# ── GET /api/xcx/dev-context仅开发模式 ────────────────
@router.get("/dev-context", response_model=DevContextResponse)
@trace_service("查询调试上下文", "Get dev context")
async def dev_context(
user: CurrentUser = Depends(get_current_user_or_limited),
):
@@ -532,7 +678,7 @@ if config.WX_DEV_MODE:
site_name = None
if user.site_id:
cur.execute(
"SELECT site_name FROM auth.site_code_mapping WHERE site_id = %s",
"SELECT site_name FROM biz.sites WHERE site_id = %s",
(user.site_id,),
)
sn_row = cur.fetchone()
@@ -552,6 +698,7 @@ if config.WX_DEV_MODE:
SELECT assistant_id, staff_id, binding_type
FROM auth.user_assistant_binding
WHERE user_id = %s AND site_id = %s
AND is_removed = false
LIMIT 1
""",
(user.user_id, user.site_id),
@@ -572,8 +719,9 @@ if config.WX_DEV_MODE:
r.code, r.name
FROM auth.user_site_roles usr
JOIN auth.roles r ON usr.role_id = r.id
LEFT JOIN auth.site_code_mapping scm ON usr.site_id = scm.site_id
LEFT JOIN biz.sites scm ON scm.site_id = usr.site_id
WHERE usr.user_id = %s
AND usr.is_removed = false
ORDER BY usr.site_id, r.code
""",
(user.user_id,),
@@ -604,6 +752,7 @@ if config.WX_DEV_MODE:
# ── POST /api/xcx/dev-switch-role仅开发模式 ───────────
@router.post("/dev-switch-role", response_model=WxLoginResponse)
@trace_service("切换角色", "Dev switch role")
async def dev_switch_role(
body: DevSwitchRoleRequest,
user: CurrentUser = Depends(get_current_user),
@@ -613,7 +762,8 @@ if config.WX_DEV_MODE:
删除旧角色绑定,插入新角色绑定,重签 token。
"""
valid_roles = ("coach", "staff", "site_admin", "tenant_admin")
# CHANGE 2026-03-23 | 角色体系隔离:小程序端只有 4 个角色site_admin/tenant_admin 已移至租户管理后台
valid_roles = ("coach", "staff", "head_coach", "manager")
if body.role_code not in valid_roles:
raise HTTPException(
status_code=400,
@@ -669,6 +819,7 @@ if config.WX_DEV_MODE:
# ── POST /api/xcx/dev-switch-status仅开发模式 ─────────
@router.post("/dev-switch-status", response_model=WxLoginResponse)
@trace_service("切换用户状态", "Dev switch status")
async def dev_switch_status(
body: DevSwitchStatusRequest,
user: CurrentUser = Depends(get_current_user_or_limited),
@@ -718,6 +869,7 @@ if config.WX_DEV_MODE:
# ── POST /api/xcx/dev-switch-binding仅开发模式 ────────
@router.post("/dev-switch-binding")
@trace_service("切换人员绑定", "Dev switch binding")
async def dev_switch_binding(
body: DevSwitchBindingRequest,
user: CurrentUser = Depends(get_current_user),

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
"""
小程序头像上传路由。
端点清单:
- POST /api/xcx/avatar/upload — 上传头像chooseAvatar 临时文件 → 服务器持久化)
- GET /api/xcx/avatar/{user_id} — 获取头像文件(静态文件服务)
"""
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
from fastapi.responses import FileResponse
from app import config
from app.auth.dependencies import CurrentUser, get_current_user_or_limited
from app.database import get_connection
from app.schemas.base import CamelModel
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/xcx/avatar", tags=["小程序头像"])
def _get_avatar_dir() -> Path:
"""获取头像存储目录,不存在则创建。"""
avatar_path = config.AVATAR_EXPORT_PATH
if not avatar_path:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="AVATAR_EXPORT_PATH 未配置",
)
p = Path(avatar_path)
p.mkdir(parents=True, exist_ok=True)
return p
class AvatarUploadResponse(CamelModel):
"""头像上传响应。"""
avatar_url: str
# ── POST /api/xcx/avatar/upload ──────────────────────────
@router.post("/upload", response_model=AvatarUploadResponse)
@trace_service("上传头像", "Upload avatar")
async def upload_avatar(
file: UploadFile = File(...),
user: CurrentUser = Depends(get_current_user_or_limited),
):
"""
接收小程序 chooseAvatar 上传的头像文件。
流程:
1. 读取上传文件内容
2. 保存到 AVATAR_EXPORT_PATH/{user_id}.jpg覆盖式幂等
3. 更新 auth.users.avatar_url
4. 返回相对路径
"""
avatar_dir = _get_avatar_dir()
# 固定 jpg 后缀,覆盖式保存
filename = f"{user.user_id}.jpg"
filepath = avatar_dir / filename
relative_url = f"avatars/{filename}"
# 读取并保存文件
content = await file.read()
if len(content) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="上传文件为空",
)
# 限制 2MB
if len(content) > 2 * 1024 * 1024:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="头像文件不能超过 2MB",
)
filepath.write_bytes(content)
logger.info("头像已保存: user_id=%s, path=%s", user.user_id, filepath)
# 更新数据库
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE auth.users SET avatar_url = %s WHERE id = %s",
(relative_url, user.user_id),
)
conn.commit()
finally:
conn.close()
return AvatarUploadResponse(avatar_url=relative_url)
# ── GET /api/xcx/avatar/{user_id} ────────────────────────
@router.get("/{user_id}")
async def get_avatar(user_id: int):
"""
获取用户头像文件。
无需鉴权(头像为公开资源,通过 user_id 访问)。
文件不存在时返回 404。
"""
avatar_dir = _get_avatar_dir()
filepath = avatar_dir / f"{user_id}.jpg"
if not filepath.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="头像不存在",
)
return FileResponse(
path=str(filepath),
media_type="image/jpeg",
headers={"Cache-Control": "public, max-age=3600"},
)

View File

@@ -27,25 +27,30 @@ from app.schemas.xcx_board import (
SkillFilterEnum,
)
from app.services import board_service
from app.trace.decorators import trace_service
router = APIRouter(prefix="/api/xcx/board", tags=["xcx-board"])
@router.get("/coaches", response_model=CoachBoardResponse)
@trace_service("获取助教看板", "Get coach board")
async def get_coach_board(
sort: CoachSortEnum = Query(default=CoachSortEnum.perf_desc),
skill: SkillFilterEnum = Query(default=SkillFilterEnum.ALL),
time: BoardTimeEnum = Query(default=BoardTimeEnum.month),
page: int = Query(default=1, ge=1),
page_size: int = Query(default=20, ge=1, le=100),
user: CurrentUser = Depends(require_permission("view_board_coach")),
):
"""助教看板BOARD-1"""
return await board_service.get_coach_board(
sort=sort.value, skill=skill.value, time=time.value,
site_id=user.site_id,
page=page, page_size=page_size, site_id=user.site_id,
)
@router.get("/customers", response_model=CustomerBoardResponse)
@trace_service("获取客户看板", "Get customer board")
async def get_customer_board(
dimension: CustomerDimensionEnum = Query(default=CustomerDimensionEnum.recall),
project: ProjectFilterEnum = Query(default=ProjectFilterEnum.ALL),
@@ -65,6 +70,7 @@ async def get_customer_board(
response_model=FinanceBoardResponse,
response_model_exclude_none=True,
)
@trace_service("获取财务看板", "Get finance board")
async def get_finance_board(
time: FinanceTimeEnum = Query(default=FinanceTimeEnum.month),
area: AreaFilterEnum = Query(default=AreaFilterEnum.all),

View File

@@ -18,12 +18,12 @@ from __future__ import annotations
import json
import logging
import os
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from app.ai.bailian_client import BailianClient
from app.ai.config import AIConfig
from app.ai.dashscope_client import DashScopeClient
from app.auth.dependencies import CurrentUser
from app.database import get_connection
from app.middleware.permission import require_approved
@@ -39,6 +39,14 @@ from app.schemas.xcx_chat import (
SendMessageResponse,
)
from app.services.chat_service import ChatService
from app.trace.decorators import trace_service
from app.trace.sse_wrapper import (
record_ai_call,
record_ai_error,
record_sse_end,
record_sse_start,
record_sse_token,
)
logger = logging.getLogger(__name__)
@@ -49,6 +57,7 @@ router = APIRouter(prefix="/api/xcx/chat", tags=["小程序 CHAT"])
@router.get("/history", response_model=ChatHistoryResponse)
@trace_service("查询对话历史", "List chat history")
async def list_chat_history(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
@@ -75,6 +84,7 @@ async def list_chat_history(
@router.get("/messages", response_model=ChatMessagesResponse)
@trace_service("通过上下文查询消息", "Get messages by context")
async def get_chat_messages_by_context(
context_type: str = Query(..., alias="contextType"),
context_id: str = Query(..., alias="contextId"),
@@ -111,6 +121,7 @@ async def get_chat_messages_by_context(
@router.get("/{chat_id}/messages", response_model=ChatMessagesResponse)
@trace_service("查询对话消息", "Get chat messages")
async def get_chat_messages(
chat_id: int,
page: int = Query(1, ge=1),
@@ -140,6 +151,7 @@ async def get_chat_messages(
@router.post("/stream")
@trace_service("SSE 流式对话", "Chat stream SSE")
async def chat_stream(
body: ChatStreamRequest,
user: CurrentUser = Depends(require_approved()),
@@ -174,16 +186,74 @@ async def chat_stream(
- event: done\\ndata: {"messageId": ..., "createdAt": "..."}\\n\\n
- event: error\\ndata: {"message": "..."}\\n\\n
"""
import time as _time
full_reply_parts: list[str] = []
tokens_total = 0
_sse_start_ts = _time.perf_counter()
# SSE trace: 流开始
record_sse_start(
endpoint="/api/xcx/chat/stream",
user_id=user.user_id,
chat_id=str(body.chat_id),
)
try:
bailian = _get_bailian_client()
client = _get_dashscope_client()
config = AIConfig.from_env()
# 获取历史消息作为上下文
messages = _build_ai_messages(body.chat_id)
# 构建 prompt最近 20 条历史 + 当前消息已在历史中)
prompt = _build_prompt(body.chat_id)
# 流式调用百炼 API
async for chunk in bailian.chat_stream(messages):
# 构建 biz_params用户身份信息
biz_params = {
"User_ID": str(user.user_id),
"Role": getattr(user, "role", "coach"),
"Nickname": getattr(user, "nickname", ""),
}
# 看板入口:注入页面上下文到 prompt
if body.source_page:
try:
from app.ai.data_fetchers import build_page_text
filters = {}
if body.page_context:
filters = body.page_context
context_id = filters.pop("contextId", None)
page_text = await build_page_text(
source_page=body.source_page,
context_id=context_id,
site_id=user.site_id,
filters=filters if filters else None,
)
if page_text:
prompt = f"[页面上下文: {body.source_page}]\n{page_text}\n\n{prompt}"
except Exception:
logger.warning("页面上下文注入失败: source_page=%s", body.source_page, exc_info=True)
# 获取 session_id对话复用
session_id = svc.get_session_id(body.chat_id) if hasattr(svc, "get_session_id") else None
# SSE trace: AI 调用
record_ai_call(
app_id=config.app_id_1_chat,
prompt_length=len(prompt),
session_id=session_id or "",
)
# 流式调用 DashScope Application API
async for chunk in client.call_app_stream(
app_id=config.app_id_1_chat,
prompt=prompt,
session_id=session_id,
biz_params=biz_params,
):
full_reply_parts.append(chunk)
tokens_total += 1
# SSE trace: 每 10 个 token 记录一次
record_sse_token(token_count=1, total_tokens=tokens_total)
yield f"event: message\ndata: {json.dumps({'token': chunk}, ensure_ascii=False)}\n\n"
# 流结束:拼接完整回复并持久化
@@ -202,9 +272,23 @@ async def chat_stream(
)
yield f"event: done\ndata: {done_data}\n\n"
# SSE trace: 流正常结束
_sse_elapsed = (_time.perf_counter() - _sse_start_ts) * 1000
record_sse_end(
total_tokens=tokens_total,
total_duration_ms=_sse_elapsed,
completed=True,
)
except Exception as e:
logger.error("SSE 流式对话异常: %s", e, exc_info=True)
# SSE trace: AI 错误
record_ai_error(
error_type=type(e).__name__,
message=str(e),
)
# 如果已有部分回复,仍然持久化
if full_reply_parts:
partial = "".join(full_reply_parts)
@@ -220,6 +304,14 @@ async def chat_stream(
)
yield f"event: error\ndata: {error_data}\n\n"
# SSE trace: 流异常结束
_sse_elapsed = (_time.perf_counter() - _sse_start_ts) * 1000
record_sse_end(
total_tokens=tokens_total,
total_duration_ms=_sse_elapsed,
completed=False,
)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -235,6 +327,7 @@ async def chat_stream(
@router.post("/{chat_id}/messages", response_model=SendMessageResponse)
@trace_service("发送消息", "Send message")
async def send_message(
chat_id: int,
body: SendMessageRequest,
@@ -280,27 +373,25 @@ def _to_message_item(msg: dict) -> ChatMessageItem:
)
def _get_bailian_client() -> BailianClient:
"""从环境变量构建 BailianClient缺失时报错。"""
api_key = os.environ.get("BAILIAN_API_KEY")
base_url = os.environ.get("BAILIAN_BASE_URL")
model = os.environ.get("BAILIAN_MODEL")
if not api_key or not base_url or not model:
raise RuntimeError(
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
)
return BailianClient(api_key=api_key, base_url=base_url, model=model)
def _get_dashscope_client() -> DashScopeClient:
"""从环境变量构建 DashScopeClient缺失时报错。"""
config = AIConfig.from_env()
return DashScopeClient(api_key=config.api_key, workspace_id=config.workspace_id)
def _build_ai_messages(chat_id: int) -> list[dict]:
"""构建发送给 AI 的消息列表(含历史上下文)。"""
def _build_prompt(chat_id: int) -> str:
"""构建发送给 DashScope Application 的 prompt。
从 ai_messages 取最近 20 条历史,拼接为文本 prompt。
百炼 Application API 的 System Prompt 在控制台配置,此处只传用户对话内容。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT role, content FROM biz.ai_messages
WHERE conversation_id = %s
WHERE conversation_id = %s AND role != 'system'
ORDER BY created_at ASC
""",
(chat_id,),
@@ -309,21 +400,22 @@ def _build_ai_messages(chat_id: int) -> list[dict]:
finally:
conn.close()
messages: list[dict] = []
# 取最近 20 条
recent = history[-20:] if len(history) > 20 else history
for role, msg_content in recent:
messages.append({"role": role, "content": msg_content})
# 如果没有 system 消息,添加默认 system prompt
if not messages or messages[0]["role"] != "system":
system_prompt = {
"role": "system",
"content": json.dumps(
{"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"},
ensure_ascii=False,
),
}
messages.insert(0, system_prompt)
# 如果只有一条(刚发送的用户消息),直接返回内容
if len(recent) == 1:
return recent[0][1]
return messages
# 多条历史:拼接为对话格式,最后一条为当前用户消息
parts: list[str] = []
for role, msg_content in recent[:-1]:
label = "用户" if role == "user" else "AI"
parts.append(f"{label}: {msg_content}")
# 最后一条是当前用户消息,作为主 prompt
current_msg = recent[-1][1] if recent else ""
if parts:
context = "\n".join(parts)
return f"[历史对话]\n{context}\n\n[当前问题]\n{current_msg}"
return current_msg

View File

@@ -13,17 +13,20 @@ from __future__ import annotations
from fastapi import APIRouter, Depends
from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_approved
from app.middleware.permission import require_permission
from app.schemas.xcx_coaches import CoachDetailResponse
from app.services import coach_service
from app.trace.decorators import trace_service
router = APIRouter(prefix="/api/xcx/coaches", tags=["小程序助教"])
@router.get("/{coach_id}", response_model=CoachDetailResponse)
@trace_service("获取助教详情", "Get coach detail")
async def get_coach_detail(
coach_id: int,
user: CurrentUser = Depends(require_approved()),
# CHANGE 2026-03-27 | 权限改造 W4助教详情跟助教看板走
user: CurrentUser = Depends(require_permission("view_board_coach")),
):
"""助教详情COACH-1"""
return await coach_service.get_coach_detail(

View File

@@ -14,6 +14,7 @@ from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_approved
from app.schemas.xcx_config import SkillTypeItem
from app.services import fdw_queries
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -21,6 +22,7 @@ router = APIRouter(prefix="/api/xcx/config", tags=["xcx-config"])
@router.get("/skill-types", response_model=list[SkillTypeItem])
@trace_service("获取技能类型配置", "Get skill types config")
async def get_skill_types(
user: CurrentUser = Depends(require_approved()),
):

View File

@@ -14,20 +14,24 @@ from __future__ import annotations
from fastapi import APIRouter, Depends, Query
from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_approved
from app.middleware.permission import require_approved, require_permission
from app.schemas.xcx_customers import (
CustomerDetailResponse,
CustomerRecordsResponse,
CustomerConsumptionRecordsResponse,
)
from app.services import customer_service
from app.trace.decorators import trace_service
router = APIRouter(prefix="/api/xcx/customers", tags=["小程序客户"])
@router.get("/{customer_id}", response_model=CustomerDetailResponse)
@trace_service("获取客户详情", "Get customer detail")
async def get_customer_detail(
customer_id: int,
user: CurrentUser = Depends(require_approved()),
# CHANGE 2026-03-27 | 权限改造 W4客户详情跟客户看板走
user: CurrentUser = Depends(require_permission("view_board_customer")),
):
"""客户详情CUST-1"""
return await customer_service.get_customer_detail(
@@ -36,6 +40,7 @@ async def get_customer_detail(
@router.get("/{customer_id}/records", response_model=CustomerRecordsResponse)
@trace_service("获取客户服务记录", "Get customer records")
async def get_customer_records(
customer_id: int,
year: int = Query(...),
@@ -43,9 +48,24 @@ async def get_customer_records(
table: str | None = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_board_customer")),
):
"""客户服务记录CUST-2"""
return await customer_service.get_customer_records(
customer_id, user.site_id, year, month, table, page, page_size
customer_id, user.site_id, user.user_id,
year, month, table, page, page_size,
)
@router.get("/{customer_id}/consumption-records", response_model=CustomerConsumptionRecordsResponse)
@trace_service("获取客户消费记录", "Get customer consumption records")
async def get_customer_consumption_records(
customer_id: int,
year: int = Query(...),
month: int = Query(..., ge=1, le=12),
user: CurrentUser = Depends(require_permission("view_board_customer")),
):
"""客户消费记录CUST-3"""
return await customer_service.get_customer_consumption_records(
customer_id, user.site_id, year, month,
)

View File

@@ -18,11 +18,13 @@ from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_approved
from app.schemas.xcx_notes import NoteCreateRequest, NoteOut
from app.services import note_service
from app.trace.decorators import trace_service
router = APIRouter(prefix="/api/xcx/notes", tags=["小程序备注"])
@router.post("", response_model=NoteOut)
@trace_service("创建备注", "Create note")
async def create_note(
body: NoteCreateRequest,
user: CurrentUser = Depends(require_approved()),
@@ -37,10 +39,12 @@ async def create_note(
task_id=body.task_id,
rating_service_willingness=body.rating_service_willingness,
rating_revisit_likelihood=body.rating_revisit_likelihood,
score=body.score,
)
@router.get("")
@trace_service("查询备注列表", "Get notes")
async def get_notes(
target_type: str = Query("member", description="目标类型"),
target_id: int = Query(..., description="目标 ID"),
@@ -55,6 +59,7 @@ async def get_notes(
@router.delete("/{note_id}")
@trace_service("删除备注", "Delete note")
async def delete_note(
note_id: int,
user: CurrentUser = Depends(require_approved()),

View File

@@ -14,21 +14,24 @@ from __future__ import annotations
from fastapi import APIRouter, Depends, Query
from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_approved
from app.middleware.permission import require_approved, require_permission
from app.schemas.xcx_performance import (
PerformanceOverviewResponse,
PerformanceRecordsResponse,
)
from app.services import performance_service
from app.trace.decorators import trace_service
router = APIRouter(prefix="/api/xcx/performance", tags=["小程序绩效"])
@router.get("", response_model=PerformanceOverviewResponse)
@trace_service("获取绩效概览", "Get performance overview")
async def get_performance_overview(
year: int = Query(...),
month: int = Query(..., ge=1, le=12),
user: CurrentUser = Depends(require_approved()),
# CHANGE 2026-03-27 | 权限改造 W4绩效跟任务走
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""绩效概览PERF-1"""
return await performance_service.get_overview(
@@ -37,12 +40,13 @@ async def get_performance_overview(
@router.get("/records", response_model=PerformanceRecordsResponse)
@trace_service("获取绩效明细", "Get performance records")
async def get_performance_records(
year: int = Query(...),
month: int = Query(..., ge=1, le=12),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""绩效明细PERF-2"""
return await performance_service.get_records(

View File

@@ -3,12 +3,13 @@
小程序任务路由 —— 任务列表、任务详情、置顶、放弃、取消放弃。
端点清单:
- GET /api/xcx/tasks — 获取任务列表 + 绩效概览TASK-1
- GET /api/xcx/tasks/{task_id} — 获取任务详情完整版TASK-2
- POST /api/xcx/tasks/{id}/pin — 置顶任务
- POST /api/xcx/tasks/{id}/unpin — 取消置顶
- POST /api/xcx/tasks/{id}/abandon — 放弃任务
- POST /api/xcx/tasks/{id}/restore恢复任务
- GET /api/xcx/tasks — 获取任务列表 + 绩效概览TASK-1
- GET /api/xcx/tasks/by-member/{member_id} — 按会员查询最高优先级 active 任务详情
- GET /api/xcx/tasks/{task_id} — 获取任务详情完整版TASK-2
- POST /api/xcx/tasks/{id}/pin — 置顶任务
- POST /api/xcx/tasks/{id}/unpin — 取消置顶
- POST /api/xcx/tasks/{id}/abandon 放弃任务
- POST /api/xcx/tasks/{id}/restore — 恢复任务
所有端点均需 JWTapproved 状态)。
"""
@@ -18,23 +19,26 @@ from __future__ import annotations
from fastapi import APIRouter, Depends, Query
from app.auth.dependencies import CurrentUser
from app.middleware.permission import require_approved
from app.middleware.permission import require_approved, require_permission
from app.schemas.xcx_tasks import (
AbandonRequest,
TaskDetailResponse,
TaskListResponse,
)
from app.services import task_manager
from app.trace.decorators import trace_service
router = APIRouter(prefix="/api/xcx/tasks", tags=["小程序任务"])
@router.get("", response_model=TaskListResponse)
@trace_service("获取任务列表", "Get task list")
async def get_tasks(
status: str = Query("pending", pattern="^(pending|completed|abandoned)$"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
user: CurrentUser = Depends(require_approved()),
page_size: int = Query(20, ge=1, le=200),
# CHANGE 2026-03-27 | 权限改造 W4统一权限保护
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""获取任务列表 + 绩效概览。"""
return await task_manager.get_task_list_v2(
@@ -42,10 +46,23 @@ async def get_tasks(
)
@router.get("/by-member/{member_id}", response_model=TaskDetailResponse)
@trace_service("按会员查询任务详情", "Get task detail by member")
async def get_task_by_member(
member_id: int,
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""按 member_id 查询当前助教的最高优先级 active 任务详情。"""
return await task_manager.get_task_by_member(
member_id, user.user_id, user.site_id
)
@router.get("/{task_id}", response_model=TaskDetailResponse)
@trace_service("获取任务详情", "Get task detail")
async def get_task_detail(
task_id: int,
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""获取任务详情完整版。"""
return await task_manager.get_task_detail(
@@ -54,9 +71,10 @@ async def get_task_detail(
@router.post("/{task_id}/pin")
@trace_service("置顶任务", "Pin task")
async def pin_task(
task_id: int,
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""置顶任务。"""
result = await task_manager.pin_task(task_id, user.user_id, user.site_id)
@@ -64,9 +82,10 @@ async def pin_task(
@router.post("/{task_id}/unpin")
@trace_service("取消置顶", "Unpin task")
async def unpin_task(
task_id: int,
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""取消置顶。"""
result = await task_manager.unpin_task(task_id, user.user_id, user.site_id)
@@ -74,10 +93,11 @@ async def unpin_task(
@router.post("/{task_id}/abandon")
@trace_service("放弃任务", "Abandon task")
async def abandon_task(
task_id: int,
body: AbandonRequest,
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""放弃任务(需填写原因)。"""
return await task_manager.abandon_task(
@@ -86,9 +106,10 @@ async def abandon_task(
@router.post("/{task_id}/restore")
@trace_service("恢复任务", "Restore task")
async def restore_task(
task_id: int,
user: CurrentUser = Depends(require_approved()),
user: CurrentUser = Depends(require_permission("view_tasks")),
):
"""取消放弃,恢复为活跃状态。"""
return await task_manager.cancel_abandon(task_id, user.user_id, user.site_id)

View File

@@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
"""
管理端 — AI 监控后台 Pydantic Schema。
覆盖Dashboard 总览、调度任务、调用记录、缓存失效、Token 预算、批量执行、告警管理。
需求: A1.1, A2.1, A4.1, A5.1, A6.1, A7.1, A8.1
"""
from __future__ import annotations
from pydantic import BaseModel
# ── Dashboard ─────────────────────────────────────────────
class DailyTrend(BaseModel):
"""近 7 天按日聚合趋势项。"""
date: str # YYYY-MM-DD
calls: int
success_rate: float
class AppDistItem(BaseModel):
"""各 App 调用占比分布项。"""
app_type: str
count: int
percentage: float
class BudgetInfo(BaseModel):
"""日/月 Token 预算进度。"""
daily_used: int
daily_limit: int
daily_pct: float
monthly_used: int
monthly_limit: int
monthly_pct: float
class AlertItem(BaseModel):
"""告警事件项(失败/超时/熔断)。"""
id: int
app_type: str
status: str # failed / timeout / circuit_open
alert_status: str | None # pending / acknowledged / ignored
error_message: str | None
created_at: str
class AppHealthItem(BaseModel):
"""各 App 最近一次调用状态。"""
app_type: str
last_status: str | None
last_call_at: str | None
class DashboardResponse(BaseModel):
"""Dashboard 总览统计响应。"""
today_calls: int
today_success_rate: float # 0.0 ~ 1.0
today_tokens: int
today_avg_latency_ms: float
trend_7d: list[DailyTrend]
app_distribution: list[AppDistItem]
budget: BudgetInfo
recent_alerts: list[AlertItem]
app_health: list[AppHealthItem]
# ── 调度任务 ──────────────────────────────────────────────
class TriggerJobItem(BaseModel):
"""调度任务列表项。"""
id: int
event_type: str
member_id: int | None
status: str
app_chain: str | None
is_forced: bool
site_id: int
started_at: str | None
finished_at: str | None
created_at: str
class TriggerJobListResponse(BaseModel):
"""调度任务分页列表响应。"""
items: list[TriggerJobItem]
total: int
page: int
page_size: int
today_skipped_duplicates: int # 今日去重跳过数
class TriggerJobDetailResponse(TriggerJobItem):
"""调度任务详情响应(含 payload、error_message"""
payload: dict | None
error_message: str | None
connector_type: str
class RetryResponse(BaseModel):
"""手动重跑响应。"""
trigger_job_id: int
status: str # "pending"
# ── 调用记录 ──────────────────────────────────────────────
class RunLogItem(BaseModel):
"""调用记录列表项。"""
id: int
app_type: str
trigger_type: str
member_id: int | None
tokens_used: int
latency_ms: int | None
status: str
site_id: int
created_at: str
class RunLogListResponse(BaseModel):
"""调用记录分页列表响应。"""
items: list[RunLogItem]
total: int
page: int
page_size: int
class RunLogDetailResponse(RunLogItem):
"""调用记录详情响应(含完整 prompt/response不脱敏"""
request_prompt: str | None
response_text: str | None
error_message: str | None
session_id: str | None
finished_at: str | None
# ── 缓存失效 ─────────────────────────────────────────────
class CacheInvalidateRequest(BaseModel):
"""缓存失效请求site_id 必填)。"""
site_id: int
app_type: str | None = None
member_id: int | None = None
class CacheInvalidateResponse(BaseModel):
"""缓存失效响应。"""
affected_count: int
# ── Token 预算 ────────────────────────────────────────────
class BudgetResponse(BaseModel):
"""Token 预算使用情况响应。"""
daily_used: int
daily_limit: int
daily_pct: float
monthly_used: int
monthly_limit: int
monthly_pct: float
# ── 批量执行 ──────────────────────────────────────────────
class BatchRunRequest(BaseModel):
"""批量执行请求。"""
app_types: list[str]
member_ids: list[int]
site_id: int
class BatchRunEstimate(BaseModel):
"""批量执行预估响应(不立即执行)。"""
batch_id: str
estimated_calls: int
estimated_tokens: int
class BatchRunConfirm(BaseModel):
"""批量执行确认请求。"""
batch_id: str
class BatchRunConfirmResponse(BaseModel):
"""批量执行确认响应。"""
status: str # "started"
# ── 告警 ──────────────────────────────────────────────────
class AlertListResponse(BaseModel):
"""告警分页列表响应。"""
items: list[AlertItem]
total: int
page: int
page_size: int
class AlertActionResponse(BaseModel):
"""告警操作(确认/忽略)响应。"""
id: int
alert_status: str

View File

@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""管理端 — 数据库健康监控 Pydantic Schema。
需求: 6.1, 6.2
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
class DbHealthItem(BaseModel):
"""单个数据库健康状态"""
db_name: str
status: Literal['connected', 'disconnected']
active_connections: int | None = None
idle_connections: int | None = None
db_size_mb: float | None = None
slow_query_count: int | None = None

View File

@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
管理端 — 注册体系 Pydantic Schema。
覆盖租户列表、店铺列表、简写ID 管理、店铺同步。
需求: A2.1, A2.2, A2.4, A2.5
"""
from __future__ import annotations
from datetime import datetime
from app.schemas.base import CamelModel
# ── 租户 ──────────────────────────────────────────────────
class TenantItem(CamelModel):
"""租户列表项(含连接器名称)。"""
id: int
tenant_id: int
tenant_name: str | None = None
connector_name: str
is_active: bool
# ── 店铺 ──────────────────────────────────────────────────
class SiteItem(CamelModel):
"""店铺列表项。"""
id: int
site_id: int
site_name: str | None = None
site_code: str | None = None
site_label: str | None = None
is_active: bool
# ── 简写ID 管理 ──────────────────────────────────────────
class UpdateSiteCodeRequest(CamelModel):
"""设置/修改店铺简写ID 请求。"""
new_code: str # 6 位3+3 格式,统一大写
class SiteCodeResult(CamelModel):
"""简写ID 修改结果。"""
site_id: int
old_code: str | None = None
new_code: str
history_cleaned: bool # 旧 code 是否被清理
class SiteCodeHistoryItem(CamelModel):
"""简写ID 变更历史条目。"""
id: int
site_code: str
is_current: bool
created_at: datetime
retired_at: datetime | None = None
# ── 店铺同步 ─────────────────────────────────────────────
class SiteSyncResult(CamelModel):
"""店铺同步结果。"""
inserted: int # 新增店铺数
updated: int # 更新店铺数
# ── 测试用:手动创建/删除店铺 ─────────────────────────────
class CreateSiteRequest(CamelModel):
"""手动创建店铺请求(测试功能)。"""
tenant_id: int # 所属租户biz.tenants.tenant_id上游 BIGINT
site_id: int # 上游系统店铺 IDBIGINT
site_name: str # 店铺名称
site_code: str | None = None # 可选简写ID6 位 3+3 格式)

View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
"""P18 任务引擎运营看板 — Pydantic v2 Schema
包含:转移日志、待审核任务、候选助教、参数管理等数据模型。
"""
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, Field
# ---- 转移日志 ----
class TransferLogItem(BaseModel):
id: int
site_id: int
site_name: str = ""
member_id: int
member_name: str = ""
from_assistant_id: int
from_assistant_name: str = ""
to_assistant_id: int
to_assistant_name: str = ""
transfer_reason: str | None = None
transfer_score: float | None = None
guard_checks: dict | None = None
created_at: datetime
class TransferLogPage(BaseModel):
items: list[TransferLogItem]
total: int
# ---- 待审核任务 ----
class PendingReviewItem(BaseModel):
id: int
site_id: int
site_name: str = ""
member_id: int
member_name: str = ""
assistant_id: int
assistant_name: str = ""
task_type: str
task_type_label: str = ""
transfer_count: int = 0
priority_score: float | None = None
created_at: datetime
class PendingReviewPage(BaseModel):
items: list[PendingReviewItem]
total: int
class CandidateAssistant(BaseModel):
assistant_id: int
assistant_name: str = ""
rs_display: float = 0
ms_display: float = 0
ml_display: float = 0
transfer_score: float = 0
source: str = "pool"
class CandidateListResponse(BaseModel):
candidates: list[CandidateAssistant]
class ReassignRequest(BaseModel):
to_assistant_id: int
class ReassignResponse(BaseModel):
success: bool
new_task_id: int | None = None
class CloseRequest(BaseModel):
reason: str = Field(..., min_length=1, max_length=500)
class CloseResponse(BaseModel):
success: bool
# ---- 参数管理 ----
class ConfigParam(BaseModel):
id: int
site_id: int | None = None
site_name: str | None = None
param_key: str
param_value: float
description: str | None = None
updated_at: datetime
class ConfigParamList(BaseModel):
params: list[ConfigParam]
class ConfigParamUpdate(BaseModel):
param_value: float
class ConfigParamCreate(BaseModel):
site_id: int
param_key: str = Field(..., max_length=64)
param_value: float
class ConfigParamResponse(BaseModel):
success: bool
id: int | None = None

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""
管理端 — 租户管理员 CRUD Pydantic Schema。
覆盖:管理员列表、创建、编辑、重置密码。
需求: 14.1, 14.2, 14.4, 14.5
"""
from __future__ import annotations
from datetime import datetime
from pydantic import Field
from app.schemas.base import CamelModel
# ── 管理员列表 ────────────────────────────────────────────
class TenantAdminListItem(CamelModel):
"""租户管理员列表项。"""
id: int
username: str
display_name: str | None = None
tenant_id: int # 所属租户 ID上游 BIGINT
tenant_name: str | None = None # 所属租户名称JOIN biz.tenants
admin_type: str = "tenant_admin" # tenant_admin / site_admin
managed_site_ids: list[int] = Field(default_factory=list)
is_active: bool = True
created_at: str | None = None
last_login_at: str | None = None
# ── 创建管理员 ────────────────────────────────────────────
class TenantAdminCreateRequest(CamelModel):
"""创建租户管理员请求。"""
username: str = Field(..., min_length=1, max_length=100, description="用户名")
password: str = Field(..., min_length=1, description="初始密码")
display_name: str | None = Field(None, max_length=100, description="显示名称")
# tenant_id 从 biz.tenants 选择GET /api/admin/tenants 获取可选列表)
tenant_id: int = Field(..., description="所属租户 ID来源: biz.tenants")
managed_site_ids: list[int] = Field(..., min_length=1, description="管辖门店 ID 列表")
# ── 编辑管理员 ────────────────────────────────────────────
class TenantAdminEditRequest(CamelModel):
"""编辑租户管理员请求(所有字段可选)。"""
username: str | None = Field(None, min_length=1, max_length=100, description="用户名(需全局唯一)")
display_name: str | None = Field(None, max_length=100, description="显示名称")
managed_site_ids: list[int] | None = Field(None, description="管辖门店 ID 列表")
is_active: bool | None = Field(None, description="账号状态")
# ── 重置密码 ──────────────────────────────────────────────
class ResetPasswordRequest(CamelModel):
"""重置密码请求。"""
new_password: str = Field(..., min_length=1, description="新密码")

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""管理端 — 触发器统一视图 Pydantic Schema。
需求: 4.1, 4.2
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
class UnifiedTriggerItem(BaseModel):
"""统一触发器视图项"""
id: int
name: str
source: Literal['biz', 'ai', 'etl']
trigger_condition: str
status: str
last_run_at: str | None = None
next_run_at: str | None = None
last_error: str | None = None

View File

@@ -12,8 +12,8 @@ from pydantic import BaseModel
class CursorInfo(BaseModel):
"""ETL 游标信息(单条任务的最后抓取状态)。"""
task_code: str
last_fetch_time: str | None = None
record_count: int | None = None
last_start: str | None = None
last_end: str | None = None
class RecentRun(BaseModel):

View File

@@ -25,6 +25,13 @@ class ScheduleConfigSchema(BaseModel):
end_date: str | None = None
class MinRunIntervalItem(BaseModel):
"""单个任务的最小执行间隔"""
value: int = 0
unit: Literal["minutes", "hours", "days"] = "minutes"
class CreateScheduleRequest(BaseModel):
"""创建调度任务请求"""
@@ -33,6 +40,9 @@ class CreateScheduleRequest(BaseModel):
task_config: dict[str, Any]
schedule_config: ScheduleConfigSchema
run_immediately: bool = False
min_run_interval_value: int = 0 # 0 表示无限制schedule 级别默认)
min_run_interval_unit: Literal["minutes", "hours", "days"] = "minutes"
min_run_intervals: dict[str, MinRunIntervalItem] = {} # per-task-code 间隔
class UpdateScheduleRequest(BaseModel):
@@ -42,6 +52,9 @@ class UpdateScheduleRequest(BaseModel):
task_codes: list[str] | None = None
task_config: dict[str, Any] | None = None
schedule_config: ScheduleConfigSchema | None = None
min_run_interval_value: int | None = None
min_run_interval_unit: Literal["minutes", "hours", "days"] | None = None
min_run_intervals: dict[str, MinRunIntervalItem] | None = None
class ScheduleResponse(BaseModel):
@@ -58,5 +71,9 @@ class ScheduleResponse(BaseModel):
next_run_at: datetime | None = None
run_count: int
last_status: str | None = None
min_run_interval_value: int = 0
min_run_interval_unit: str = "minutes"
last_success_at: datetime | None = None
min_run_intervals: dict[str, Any] = {}
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
"""
租户管理后台 — 维客线索管理 Pydantic Schema。
覆盖:客户搜索结果、线索列表、线索编辑、线索隐藏/显示。
需求: 9.1, 10.1, 11.1
"""
from __future__ import annotations
from enum import Enum
from typing import Optional
from pydantic import Field, field_validator
from app.schemas.base import CamelModel
class ClueCategory(str, Enum):
"""线索大类枚举6 值)。"""
CUSTOMER_BASIC = "客户基础"
CONSUMPTION_HABIT = "消费习惯"
PLAY_PREFERENCE = "玩法偏好"
PROMO_PREFERENCE = "促销偏好"
SOCIAL_RELATION = "社交关系"
IMPORTANT_FEEDBACK = "重要反馈"
class CustomerSearchItem(CamelModel):
"""客户搜索结果项。"""
member_id: int
nickname: str | None = None
mobile_masked: str | None = None
site_name: str | None = None
site_id: int | None = None
class ClueListItem(CamelModel):
"""线索列表项。"""
id: int
category: str | None = None
summary: str | None = None
detail: str | None = None
recorded_by_name: str | None = None
source: str | None = None
recorded_at: str | None = None
is_hidden: bool = False
class ClueEditRequest(CamelModel):
"""线索编辑请求。"""
category: ClueCategory = Field(..., description="线索大类6 值枚举)")
summary: str = Field(..., min_length=1, max_length=200, description="摘要非空≤200 字符)")
detail: Optional[str] = Field(None, description="详情(可选)")
class ClueVisibilityRequest(CamelModel):
"""线索隐藏/显示请求。"""
is_hidden: bool = Field(..., description="是否隐藏")

View File

@@ -0,0 +1,134 @@
# -*- coding: utf-8 -*-
"""
租户管理后台 — Excel 上传 Pydantic Schema。
覆盖4 种模板行数据模型、校验结果、冲突 diff、确认请求、上传记录。
需求: 5.2, 7.2, 8.4
"""
from __future__ import annotations
from typing import Literal
from pydantic import Field
from app.schemas.base import CamelModel
# ── 4 种模板行数据模型 ────────────────────────────────────
class ExpenseRow(CamelModel):
"""财务支出行数据。"""
row_index: int = Field(..., description="行号(从 1 开始)")
expense_month: str = Field(..., description="月份 YYYY-MM")
category: str = Field(..., description="支出类别8 值枚举)")
amount: float = Field(..., description="金额(> 0精度 2 位小数)")
remark: str | None = Field(None, description="备注(可选,最长 500 字符)")
class PlatformIncomeRow(CamelModel):
"""团购收入行数据。"""
row_index: int = Field(..., description="行号")
income_month: str = Field(..., description="月份 YYYY-MM")
platform_name: str = Field(..., description="平台名称")
amount: float = Field(..., description="收入金额(> 0")
remark: str | None = Field(None, description="备注(可选,最长 500 字符)")
class SalaryAdjRow(CamelModel):
"""助教奖罚行数据。"""
row_index: int = Field(..., description="行号")
salary_month: str = Field(..., description="月份 YYYY-MM")
assistant_name: str = Field(..., description="助教姓名")
assistant_number: str = Field(..., description="助教编号")
adjustment_type: str = Field(..., description="类型(扣款/奖金)")
amount: float = Field(..., description="金额(> 0")
reason: str = Field(..., description="原因(非空,最长 200 字符)")
assistant_id: int | None = Field(None, description="匹配到的助教 ID")
class RechargeCommissionRow(CamelModel):
"""充值业绩归属行数据。"""
row_index: int = Field(..., description="行号")
recharge_date: str = Field(..., description="充值日期 YYYY-MM-DD")
member_name: str = Field(..., description="会员名称")
recharge_amount: float = Field(..., description="充值金额(> 0")
assigned_assistant: str = Field(..., description="归属助教")
reward_amount: float = Field(..., description="奖励金额(≥ 0")
assistant_id: int | None = Field(None, description="匹配到的助教 ID")
# ── 校验错误/警告 ─────────────────────────────────────────
class ValidationError(CamelModel):
"""单行校验错误。"""
row_index: int = Field(..., description="行号")
column: str = Field(..., description="列名")
message: str = Field(..., description="错误描述")
class ValidationWarning(CamelModel):
"""单行校验警告(如人员匹配失败)。"""
row_index: int = Field(..., description="行号")
column: str = Field(..., description="列名")
message: str = Field(..., description="警告描述")
class ValidationResult(CamelModel):
"""校验结果。"""
errors: list[ValidationError] = Field(default_factory=list, description="错误列表")
warnings: list[ValidationWarning] = Field(default_factory=list, description="警告列表")
passed_rows: list[dict] = Field(default_factory=list, description="通过校验的行数据")
upload_id: int | None = Field(None, description="上传批次 ID校验全部通过时创建")
# ── 冲突 diff ─────────────────────────────────────────────
class FieldDiff(CamelModel):
"""单字段差异。"""
field: str = Field(..., description="字段名")
old_value: str | None = Field(None, description="旧值")
new_value: str | None = Field(None, description="新值")
class ConflictDiff(CamelModel):
"""冲突行 diff。"""
row_index: int = Field(..., description="行号")
field_diffs: list[FieldDiff] = Field(default_factory=list, description="逐字段差异")
# ── 确认请求 ──────────────────────────────────────────────
class Resolution(CamelModel):
"""单行冲突解决方案。"""
row_index: int = Field(..., description="行号")
action: Literal["replace", "keep"] = Field(..., description="操作replace=替换/keep=保留")
class ConfirmRequest(CamelModel):
"""确认写入请求。"""
upload_id: int = Field(..., description="上传批次 ID")
resolutions: list[Resolution] = Field(default_factory=list, description="冲突解决方案列表")
# ── 上传记录 ──────────────────────────────────────────────
class UploadLogItem(CamelModel):
"""上传记录列表项。"""
id: int
site_id: int
upload_type: str = Field(..., description="模板类型")
file_name: str = Field(..., description="原始文件名")
uploaded_by: int = Field(..., description="上传人 ID")
row_count: int = Field(0, description="数据行数")
conflict_count: int = Field(0, description="冲突行数")
resolved_count: int = Field(0, description="已解决冲突数")
status: str = Field(..., description="状态pending/confirmed/failed")
created_at: str | None = Field(None, description="上传时间")
confirmed_at: str | None = Field(None, description="确认时间")

View File

@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
租户管理后台 — 用户审核 + 用户管理 Pydantic Schema。
覆盖:申请列表、关联建议、审核通过/拒绝、用户列表、编辑、绑定、角色列表、人员候选。
需求: 3.2, 4.1
AI_CHANGELOG
- 2026-03-23 17:00:00 | Prompt: P20260323-164500审核弹窗改造| Direct cause新增 roles + site-staff 端点需要响应 Schema | Summary新增 RoleItem角色列表项+ StaffCandidate人员候选项source 区分 assistant/staff| Verify后端 /roles 和 /site-staff 返回正确 JSON 结构
- 2026-03-24 | Prompt: 用户管理绑定功能改造 | Direct causeUserEditRequest 需要同时携带角色+绑定字段 | SummaryUserEditRequest 扩展 assistant_id/staff_id 字段,支持角色+绑定合并提交 | VerifyPATCH /users/{id} 接受 assistantId/staffId 参数
- 2026-03-24 | Prompt: 审核弹窗头像昵称+排版优化 | Direct causeApplicationListItem 缺少 avatar_url | Summary新增 avatar_url 字段 | VerifyGET /applications 返回 avatarUrl
"""
from __future__ import annotations
from pydantic import Field
from app.schemas.base import CamelModel
# ── 用户申请审核 ──────────────────────────────────────────
class ApplicationListItem(CamelModel):
"""申请列表项。"""
id: int
user_id: int
nickname: str | None = None
avatar_url: str | None = None
phone: str | None = None
site_code: str | None = None
applied_role_text: str | None = None
employee_number: str | None = None
created_at: str | None = None
status: str # pending / approved / rejected
class MatchSuggestion(CamelModel):
"""关联匹配建议。"""
assistant_id: int | None = None
staff_id: int | None = None
name: str
number: str | None = None
source_table: str # v_dim_assistant / v_dim_staff
class ApproveRequest(CamelModel):
"""审核通过请求。"""
role: str = Field(..., min_length=1, description="分配角色coach/staff/head_coach/manager")
assistant_id: int | None = Field(None, description="关联助教 ID")
staff_id: int | None = Field(None, description="关联员工 ID")
class RejectRequest(CamelModel):
"""审核拒绝请求。"""
reason: str = Field(..., min_length=1, description="拒绝原因")
# ── 用户管理 ──────────────────────────────────────────────
class UserListItem(CamelModel):
"""用户列表项。"""
id: int
nickname: str | None = None
role: str | None = None # 角色中文名(显示用)
role_code: str | None = None # 角色 code提交用
assistant_id: int | None = None # 当前绑定的助教 ID
staff_id: int | None = None # 当前绑定的员工 ID
assistant_name: str | None = None
site_name: str | None = None
site_id: int | None = None
status: str # approved / disabled
class UserEditRequest(CamelModel):
"""用户编辑请求(合并角色+绑定)。
角色与绑定互斥coach 只能绑 assistant_id其他角色只能绑 staff_id。
换角色时后端自动清除旧绑定。staffBinding="none" 表示解绑。
"""
role: str | None = Field(None, description="新角色 codecoach/staff/head_coach/manager")
site_id: int | None = Field(None, description="新门店 ID")
assistant_id: int | None = Field(None, description="关联助教 ID仅 coach 角色)")
staff_id: int | None = Field(None, description="关联员工 ID仅非 coach 角色)")
# CHANGE 2026-03-23 | 移除 status 字段:租户不能禁用用户,只能移除店铺关系
# CHANGE 2026-03-24 | 合并绑定字段:角色+绑定同一请求提交,换角色自动清除旧绑定
class UserBindingRequest(CamelModel):
"""用户绑定修改请求。"""
assistant_id: int | None = Field(None, description="关联助教 ID")
staff_id: int | None = Field(None, description="关联员工 ID")
# ── 角色 + 人员候选 ──────────────────────────────────────
# [CHANGE P20260323-164500] intent: 审核弹窗角色动态化 + 人员联动所需的响应 Schema
# assumptions: StaffCandidate.source 区分 assistant/staff前端据此构造 staffBinding 值
class RoleItem(CamelModel):
"""角色列表项(从 auth.roles 动态读取)。"""
id: int
code: str
name: str
description: str | None = None
class StaffCandidate(CamelModel):
"""人员候选项(审核弹窗关联下拉用)。"""
id: int = Field(..., description="assistant_id 或 staff_id")
identity_label: str | None = Field(None, description="身份角色level / staff_identity 的原始值)")
name: str = Field(..., description="姓名")
mobile: str | None = Field(None, description="手机号")
entry_time: str | None = Field(None, description="入职时间")
source: str = Field(..., description="assistant / staff")

View File

@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
"""定时任务管理 — Pydantic 响应模型"""
from __future__ import annotations
from typing import Self
from pydantic import BaseModel, model_validator
class TriggerJobItem(BaseModel):
"""单个定时任务信息"""
id: int
job_type: str
job_name: str
trigger_condition: str
trigger_config: dict | None = None
last_run_at: str | None = None
next_run_at: str | None = None
status: str
description: str | None = None
last_error: str | None = None
created_at: str | None = None
class RunJobResult(BaseModel):
"""手动执行结果"""
success: bool
message: str
class UpdateTriggerConfigRequest(BaseModel):
"""触发器配置编辑请求(部分更新)"""
cron_expression: str | None = None # 5 字段 cron 表达式
interval_seconds: int | None = None # 间隔秒数,>= 1
@model_validator(mode='after')
def at_least_one_field(self) -> Self:
if self.cron_expression is None and self.interval_seconds is None:
raise ValueError('至少提供 cron_expression 或 interval_seconds 之一')
return self

View File

@@ -1,3 +1,8 @@
# AI_CHANGELOG
# | 日期 | Prompt | 变更 |
# |------|--------|------|
# | 2026-03-23 | 角色路由+页面权限守卫 | WxLoginResponse 和 UserStatusResponse 增加 role 字段 |
"""
小程序认证相关 Pydantic 模型。
@@ -25,6 +30,8 @@ class WxLoginResponse(CamelModel):
token_type: str = "bearer"
user_status: str # pending / approved / rejected / disabled
user_id: int
# CHANGE 2026-03-23 | 角色路由:登录时返回角色 code
role: str | None = None
class DevLoginRequest(CamelModel):
@@ -37,7 +44,7 @@ class DevLoginRequest(CamelModel):
class DevSwitchRoleRequest(CamelModel):
"""切换角色请求。替换当前用户在当前门店下的所有角色为指定角色。"""
role_code: str = Field(..., description="目标角色 codecoach/staff/site_admin/tenant_admin")
role_code: str = Field(..., description="目标角色 codecoach/staff/head_coach/manager")
class DevSwitchStatusRequest(CamelModel):
@@ -71,7 +78,8 @@ class DevLoginRequest(CamelModel):
class ApplicationRequest(CamelModel):
"""用户申请提交请求。"""
site_code: str = Field(..., pattern=r"^[A-Za-z]{2}\d{3}$", description="球房ID")
# CHANGE 2026-03-23 | 球房ID 改为 6 位字母/数字,大小写不敏感
site_code: str = Field(..., pattern=r"^[A-Za-z0-9]{6}$", description="球房ID6位字母/数字)")
applied_role_text: str = Field(..., min_length=1, max_length=100, description="申请身份")
phone: str = Field(..., pattern=r"^\d{11}$", description="手机号")
employee_number: str | None = Field(None, max_length=50, description="员工编号")
@@ -89,6 +97,30 @@ class ApplicationResponse(CamelModel):
reviewed_at: str | None = None
class LatestApplicationDetail(CamelModel):
"""最新申请详情(含 phone/employee_number用于前端展示和预填"""
id: int
site_code: str
applied_role_text: str
phone: str
employee_number: str | None = None
status: str
review_note: str | None = None
created_at: str
reviewed_at: str | None = None
class CancelApplicationResponse(CamelModel):
"""取消申请响应(返回被取消申请的信息,用于预填重新申请表单)。"""
id: int
site_code: str
applied_role_text: str
phone: str
employee_number: str | None = None
status: str
created_at: str
# ── 用户状态 ──────────────────────────────────────────────
class UserStatusResponse(CamelModel):
@@ -96,7 +128,18 @@ class UserStatusResponse(CamelModel):
user_id: int
status: str
nickname: str | None = None
# CHANGE 2026-03-24 | 头像:从 auth.users.avatar_url 读取
avatar_url: str | None = None
# CHANGE 2026-03-23 | 角色路由:返回用户在当前门店下的角色 code
role: str | None = None
# CHANGE 2026-03-27 | 权限改造 W2返回权限码列表前端据此动态控制页面/tab 可见性
permissions: list[str] = []
# CHANGE 2026-03-23 | banner 数据修复:补充门店名和助教等级
store_name: str | None = None
coach_level: str | None = None
applications: list[ApplicationResponse] = []
# CHANGE 2026-03-23 | 审核流程增强:最新申请详情(含 phone/employee_number
latest_application: LatestApplicationDetail | None = None
# ── 店铺 ──────────────────────────────────────────────────

View File

@@ -2,6 +2,8 @@
# - 2026-03-20 | Prompt: R3 项目类型筛选接口重建 | SkillFilterEnum 和 ProjectFilterEnum
# 枚举值从 all/chinese/snooker/mahjong/karaoke 改为 ALL/BILLIARD/SNOOKER/MAHJONG/KTV
# 与 dws.cfg_area_category.category_code 一致,消除前后端映射层。
# - 2026-03-27 | Prompt: board-finance-integration T2.4 | AreaFilterEnum 从 7 项重建为 9 项
# (新增 vip/snooker/ktv移除 teamBuilding与区域筛选对照表一致。
"""三看板接口 Pydantic SchemaBOARD-1/2/3 请求参数枚举 + 响应模型)。"""
@@ -84,13 +86,17 @@ class FinanceTimeEnum(str, Enum):
class AreaFilterEnum(str, Enum):
"""BOARD-3 区域筛选。"""
# CHANGE 2026-03-27 | board-finance-integration T2.4 | 枚举从 7 项重建为 9 项,
# 与区域筛选对照表一致all/hall/hallA-C/vip/snooker/mahjong/ktv
all = "all"
hall = "hall"
hallA = "hallA"
hallB = "hallB"
hallC = "hallC"
vip = "vip"
snooker = "snooker"
mahjong = "mahjong"
teamBuilding = "teamBuilding"
ktv = "ktv"
# ---------------------------------------------------------------------------
@@ -137,6 +143,9 @@ class CoachBoardItem(CamelModel):
class CoachBoardResponse(CamelModel):
items: list[CoachBoardItem]
total: int
page: int
page_size: int
dim_type: str # perf/salary/sv/task
@@ -261,10 +270,10 @@ class OverviewPanel(CamelModel):
discount: float # 负值
discount_rate: float
confirmed_revenue: float
cash_in: float
cash_out: float
cash_balance: float
balance_rate: float
cash_in: float | None = None
cash_out: float | None = None
cash_balance: float | None = None
balance_rate: float | None = None
# occurrence 环比
occurrence_compare: str | None = None
occurrence_down: bool | None = None
@@ -340,6 +349,10 @@ class RechargePanel(CamelModel):
card_balance_compare: str | None = None
card_balance_down: bool | None = None
card_balance_flat: bool | None = None
# 全类别会员卡余额合计环比
all_card_balance_compare: str | None = None
all_card_balance_down: bool | None = None
all_card_balance_flat: bool | None = None
class RevenueStructureRow(CamelModel):
@@ -355,32 +368,53 @@ class RevenueStructureRow(CamelModel):
class RevenueItem(CamelModel):
label: str
desc: str | None = None
amount: float
compare: str | None = None
class ChannelItem(CamelModel):
label: str
desc: str | None = None
amount: float
compare: str | None = None
class RevenuePanel(CamelModel):
structure_rows: list[RevenueStructureRow]
price_items: list[RevenueItem] # 4 项
price_items: list[RevenueItem]
total_occurrence: float
discount_items: list[RevenueItem] # 4 项
total_occurrence_compare: str | None = None
total_occurrence_down: bool | None = None
total_occurrence_flat: bool | None = None
discount_items: list[RevenueItem]
# CHANGE 2026-03-28 | board-finance-phase2 bugfix | 优惠总计供前端展示
discount_total: float = 0.0
discount_total_compare: str | None = None
discount_total_down: bool | None = None
discount_total_flat: bool | None = None
confirmed_total: float
channel_items: list[ChannelItem] # 3 项
confirmed_total_compare: str | None = None
confirmed_total_down: bool | None = None
confirmed_total_flat: bool | None = None
channel_items: list[ChannelItem]
class CashflowItem(CamelModel):
label: str
desc: str | None = None
amount: float
compare: str | None = None
down: bool | None = None
class CashflowPanel(CamelModel):
consume_items: list[CashflowItem] # 3 项
recharge_items: list[CashflowItem] # 1 项
total: float
total_compare: str | None = None
total_down: bool | None = None
total_flat: bool | None = None
class ExpenseItem(CamelModel):
@@ -437,6 +471,6 @@ class FinanceBoardResponse(CamelModel):
overview: OverviewPanel
recharge: RechargePanel | None # area≠all 时为 null
revenue: RevenuePanel
cashflow: CashflowPanel
expense: ExpensePanel
cashflow: CashflowPanel | None # area≠all 时为 null
expense: ExpensePanel | None # area≠all 时为 null
coach_analysis: CoachAnalysisPanel

View File

@@ -104,3 +104,5 @@ class ChatStreamRequest(CamelModel):
chat_id: int
content: str
source_page: str | None = None
page_context: dict | None = None

View File

@@ -53,8 +53,9 @@ class TopCustomer(CamelModel):
score: str
score_color: str
service_count: int
balance: str
consume: str
# CHANGE 2026-03-29 | str → float后端返回原始数字前端 WXS 格式化(避免 NaN
balance: float
consume: float
class CoachServiceRecord(CamelModel):
@@ -66,7 +67,8 @@ class CoachServiceRecord(CamelModel):
type_class: str
table: str | None = None
duration: str
income: str
# CHANGE 2026-03-29 | str → float后端返回原始数字前端 WXS 格式化(避免 NaN
income: float
date: str
perf_hours: str | None = None
@@ -74,9 +76,10 @@ class CoachServiceRecord(CamelModel):
class HistoryMonth(CamelModel):
month: str
estimated: bool
customers: str
hours: str
salary: str
# CHANGE 2026-03-29 | str → int/float后端返回原始数字前端 WXS 格式化(避免 NaN
customers: int
hours: float
salary: float
callback_done: int
recall_done: int

View File

@@ -24,6 +24,7 @@ class CoachTask(CamelModel):
name: str
level: str # star / senior / middle / junior
level_color: str
heart_score: float = 0.0 # CHANGE 2026-03-29 | RSI 关系指数,用于爱心标识
task_type: str
task_color: str
bg_class: str
@@ -32,10 +33,10 @@ class CoachTask(CamelModel):
metrics: list[MetricItem] = []
class FavoriteCoach(CamelModel):
# CHANGE 2026-03-20 | M4 修复: emoji 注释与 P6 权威定义对齐4 级映射)
# intent: 注释应反映 compute_heart_icon() 的实际 4 级映射(💖🧡💛💙)
emoji: str # 💖 / 🧡 / 💛 / 💙
emoji: str
name: str
heart_score: float = 0.0
level: str = ""
relation_index: str
index_color: str
bg_class: str
@@ -46,7 +47,7 @@ class CoachServiceItem(CamelModel):
level: str
level_color: str
course_type: str # "基础课" / "激励课"
hours: float
hours: str # "2.5h" 格式
perf_hours: float | None = None
fee: float
@@ -57,15 +58,15 @@ class ConsumptionRecord(CamelModel):
table_name: str | None = None
start_time: str | None = None
end_time: str | None = None
duration: int | None = None
duration: str | None = None
table_fee: float | None = None
table_orig_price: float | None = None
coaches: list[CoachServiceItem] = []
food_amount: float | None = None
food_orig_price: float | None = None
total_amount: float
total_orig_price: float
pay_method: str
total_orig_price: float | None = None
pay_method: str | None = None
recharge_amount: float | None = None
class RetentionClue(CamelModel):
@@ -126,5 +127,25 @@ class CustomerRecordsResponse(CamelModel):
total_service_count: int
month_count: int
month_hours: float
month_income: float = 0.0
records: list[ServiceRecordItem] = []
has_more: bool = False
class CustomerConsumptionRecordsResponse(CamelModel):
"""CUST-3 响应:客户消费记录(按月)。"""
# Banner
id: int
name: str
phone: str
phone_full: str
balance: float | None = None
consumption_60d: float | None = None
ideal_interval: int | None = None
days_since_visit: int | None = None
# 月度汇总
visit_count: int = 0
consume_total: float = 0.0
recharge_total: float = 0.0
# 消费记录
records: list[ConsumptionRecord] = []

View File

@@ -12,7 +12,7 @@ from app.schemas.base import CamelModel
class NoteCreateRequest(CamelModel):
"""创建备注请求(含手动评分:再次服务意愿 + 再来店可能性,各 1-5"""
"""创建备注请求(含手动评分:再次服务意愿 + 再来店可能性,各 1-5备注星星评分 1-5"""
target_type: str = Field(default="member")
target_id: int
@@ -20,6 +20,7 @@ class NoteCreateRequest(CamelModel):
task_id: int | None = None
rating_service_willingness: int | None = Field(None, ge=1, le=5, description="再次服务意愿1-5")
rating_revisit_likelihood: int | None = Field(None, ge=1, le=5, description="再来店可能性1-5")
score: int | None = Field(None, ge=1, le=5, description="备注星星评分1-5")
class NoteOut(CamelModel):
@@ -30,6 +31,7 @@ class NoteOut(CamelModel):
content: str
rating_service_willingness: int | None
rating_revisit_likelihood: int | None
score: int | None
ai_score: int | None
ai_analysis: str | None
task_id: int | None

View File

@@ -18,12 +18,12 @@ class DateGroupRecord(CamelModel):
"""按日期分组的单条服务记录。"""
customer_name: str
member_id: int | None = None # 前端用于计算头像颜色
avatar_char: str | None = None # PERF-1 返回PERF-2 不返回
avatar_color: str | None = None # PERF-1 返回PERF-2 不返回
heart_score: float | None = None # RS 分数,前端用于 heart-icon 组件
time_range: str
hours: str
course_type: str
course_type_class: str # 'basic' | 'vip' | 'tip'
location: str
income: str
@@ -62,8 +62,9 @@ class CustomerSummary(CamelModel):
"""客户摘要(新客/常客基类)。"""
name: str
member_id: int | None = None # 前端用于计算头像颜色
avatar_char: str
avatar_color: str
heart_score: float | None = None # RS 分数
class NewCustomer(CustomerSummary):

View File

@@ -80,6 +80,11 @@ class TaskItem(CamelModel):
last_visit_days: int | None = None
balance: float | None = None
ai_suggestion: str | None = None
expected_days: int | None = None
ideal_interval_days: int | None = None
# CHANGE 2026-03-27 | 近60天服务汇总口径同 task-detail serviceSummary
recent60d_hours: float = 0.0
recent60d_income: float = 0.0
class TaskListResponse(CamelModel):
@@ -150,6 +155,7 @@ class TaskDetailResponse(CamelModel):
# 基础信息
id: int
customer_name: str
customer_phone: str | None = None
customer_avatar: str
task_type: str
task_type_label: str
@@ -160,6 +166,7 @@ class TaskDetailResponse(CamelModel):
has_note: bool
status: str
customer_id: int
balance: float | None = None
# 扩展模块
retention_clues: list[RetentionClue]
talking_points: list[str]

View File

@@ -0,0 +1 @@
# AI 监控后台服务层

View File

@@ -0,0 +1,721 @@
"""AI 监控后台聚合服务层。
提供 Dashboard 总览、调度任务管理、调用记录查询、缓存失效、
Token 预算、批量执行(含成本二次确认)、告警管理等功能。
所有数据库操作使用 psycopg2 同步连接,方法签名为 asyncFastAPI 兼容)。
查询强制 site_id 隔离(当 site_id 参数不为 None 时)。
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone, timedelta
from app.ai.budget_tracker import BudgetTracker
from app.database import get_connection
logger = logging.getLogger(__name__)
# 批量执行预估:每次调用平均 Token 消耗
AVG_TOKENS_PER_CALL = 2000
# 批量执行内存存储 TTL
_BATCH_TTL_SECONDS = 600 # 10 分钟
class AdminAIService:
"""AI 监控后台聚合服务。"""
def __init__(self, budget_tracker: BudgetTracker | None = None) -> None:
self._budget = budget_tracker
self._batch_store: dict[str, dict] = {} # batch_id → {params, expires_at}
# ── Dashboard ─────────────────────────────────────────
async def get_dashboard(self, site_id: int | None = None) -> dict:
"""聚合所有 Dashboard 数据。"""
today_stats = await self._get_today_stats(site_id)
trend_7d = await self._get_7d_trend(site_id)
app_dist = await self._get_app_distribution(site_id)
app_health = await self._get_app_health(site_id)
budget = await self.get_budget()
recent_alerts = await self._get_recent_alerts(site_id)
return {
**today_stats,
"trend_7d": trend_7d,
"app_distribution": app_dist,
"budget": budget,
"recent_alerts": recent_alerts,
"app_health": app_health,
}
async def _get_today_stats(self, site_id: int | None) -> dict:
"""今日调用次数、成功率、Token 消耗、平均延迟。"""
site_clause, params = _site_filter(site_id)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT
COUNT(*) AS total_calls,
COUNT(*) FILTER (WHERE status = 'success') AS success_count,
COALESCE(SUM(tokens_used), 0) AS total_tokens,
COALESCE(AVG(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0)
AS avg_latency
FROM biz.ai_run_logs
WHERE created_at >= CURRENT_DATE
AND created_at < CURRENT_DATE + INTERVAL '1 day'
{site_clause}
""",
params,
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
total, success, tokens, avg_lat = row if row else (0, 0, 0, 0)
rate = round(success / total, 4) if total > 0 else 0.0
return {
"today_calls": total,
"today_success_rate": rate,
"today_tokens": int(tokens),
"today_avg_latency_ms": round(float(avg_lat), 2),
}
async def _get_7d_trend(self, site_id: int | None) -> list[dict]:
"""近 7 天按日聚合。"""
site_clause, params = _site_filter(site_id)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT
created_at::date AS day,
COUNT(*) AS calls,
COUNT(*) FILTER (WHERE status = 'success') AS success_count
FROM biz.ai_run_logs
WHERE created_at >= CURRENT_DATE - INTERVAL '6 days'
{site_clause}
GROUP BY day
ORDER BY day
""",
params,
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
{
"date": row[0].isoformat(),
"calls": row[1],
"success_rate": round(row[2] / row[1], 4) if row[1] > 0 else 0.0,
}
for row in rows
]
async def _get_app_distribution(self, site_id: int | None) -> list[dict]:
"""各 App 调用占比。"""
site_clause, params = _site_filter(site_id)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT app_type, COUNT(*) AS cnt
FROM biz.ai_run_logs
WHERE created_at >= CURRENT_DATE - INTERVAL '6 days'
{site_clause}
GROUP BY app_type
ORDER BY cnt DESC
""",
params,
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
total = sum(r[1] for r in rows) or 1
return [
{
"app_type": row[0],
"count": row[1],
"percentage": round(row[1] / total, 4),
}
for row in rows
]
async def _get_app_health(self, site_id: int | None) -> list[dict]:
"""各 App 最近一次调用状态。"""
site_clause, params = _site_filter(site_id)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT DISTINCT ON (app_type)
app_type,
status AS last_status,
created_at AS last_call_at
FROM biz.ai_run_logs
WHERE TRUE {site_clause}
ORDER BY app_type, created_at DESC
""",
params,
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
{
"app_type": row[0],
"last_status": row[1],
"last_call_at": row[2].isoformat() if row[2] else None,
}
for row in rows
]
async def _get_recent_alerts(self, site_id: int | None, limit: int = 10) -> list[dict]:
"""最近告警事件Dashboard 用)。"""
site_clause, params = _site_filter(site_id)
params = (*params, limit)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT id, app_type, status, alert_status,
error_message, created_at
FROM biz.ai_run_logs
WHERE status IN ('failed', 'timeout', 'circuit_open')
{site_clause}
ORDER BY created_at DESC
LIMIT %s
""",
params,
)
cols = [d[0] for d in cur.description]
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [_row_to_dict(cols, r) for r in rows]
# ── 调度任务 ──────────────────────────────────────────
async def list_trigger_jobs(
self, filters: dict, page: int = 1, page_size: int = 20,
) -> dict:
"""分页查询 ai_trigger_jobs + 今日去重统计。"""
where_parts: list[str] = []
params: list = []
for key in ("event_type", "status", "site_id"):
if filters.get(key) is not None:
where_parts.append(f"{key} = %s")
params.append(filters[key])
if filters.get("date_from"):
where_parts.append("created_at >= %s")
params.append(filters["date_from"])
if filters.get("date_to"):
where_parts.append("created_at <= %s")
params.append(filters["date_to"])
where_sql = ("WHERE " + " AND ".join(where_parts)) if where_parts else ""
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
# 总数
cur.execute(
f"SELECT COUNT(*) FROM biz.ai_trigger_jobs {where_sql}",
params,
)
total = cur.fetchone()[0]
# 分页数据
cur.execute(
f"""
SELECT id, event_type, member_id, status, app_chain,
is_forced, site_id, started_at, finished_at, created_at
FROM biz.ai_trigger_jobs
{where_sql}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(*params, page_size, offset),
)
cols = [d[0] for d in cur.description]
rows = cur.fetchall()
# 今日去重跳过数
cur.execute(
"""
SELECT COUNT(*)
FROM biz.ai_trigger_jobs
WHERE status = 'skipped_duplicate'
AND created_at >= CURRENT_DATE
AND created_at < CURRENT_DATE + INTERVAL '1 day'
""",
)
today_skipped = cur.fetchone()[0]
conn.commit()
finally:
conn.close()
return {
"items": [_row_to_dict(cols, r) for r in rows],
"total": total,
"page": page,
"page_size": page_size,
"today_skipped_duplicates": today_skipped,
}
async def get_trigger_job(self, job_id: int) -> dict | None:
"""单条调度任务详情。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, event_type, member_id, status, app_chain,
is_forced, site_id, started_at, finished_at,
created_at, payload, error_message, connector_type
FROM biz.ai_trigger_jobs
WHERE id = %s
""",
(job_id,),
)
cols = [d[0] for d in cur.description]
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
return None
return _row_to_dict(cols, row)
async def retry_trigger_job(self, job_id: int) -> int:
"""创建新 trigger_jobis_forced=true返回新 job_id。"""
original = await self.get_trigger_job(job_id)
if original is None:
raise ValueError(f"trigger_job {job_id} 不存在")
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO biz.ai_trigger_jobs
(event_type, member_id, site_id, connector_type,
payload, app_chain, is_forced, status)
VALUES (%s, %s, %s, %s, %s, %s, true, 'pending')
RETURNING id
""",
(
original["event_type"],
original.get("member_id"),
original["site_id"],
original.get("connector_type", "feiqiu"),
original.get("payload"),
original.get("app_chain"),
),
)
new_id = cur.fetchone()[0]
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return new_id
# ── 调用记录 ──────────────────────────────────────────
async def list_run_logs(
self, filters: dict, page: int = 1, page_size: int = 20,
) -> dict:
"""分页查询 ai_run_logs。"""
where_parts: list[str] = []
params: list = []
for key in ("app_type", "status", "trigger_type", "site_id"):
if filters.get(key) is not None:
where_parts.append(f"{key} = %s")
params.append(filters[key])
if filters.get("date_from"):
where_parts.append("created_at >= %s")
params.append(filters["date_from"])
if filters.get("date_to"):
where_parts.append("created_at <= %s")
params.append(filters["date_to"])
where_sql = ("WHERE " + " AND ".join(where_parts)) if where_parts else ""
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"SELECT COUNT(*) FROM biz.ai_run_logs {where_sql}",
params,
)
total = cur.fetchone()[0]
cur.execute(
f"""
SELECT id, app_type, trigger_type, member_id,
tokens_used, latency_ms, status, site_id, created_at
FROM biz.ai_run_logs
{where_sql}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(*params, page_size, offset),
)
cols = [d[0] for d in cur.description]
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return {
"items": [_row_to_dict(cols, r) for r in rows],
"total": total,
"page": page,
"page_size": page_size,
}
async def get_run_log(self, log_id: int) -> dict | None:
"""单条调用记录详情(含完整 prompt/response不脱敏"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, app_type, trigger_type, member_id,
tokens_used, latency_ms, status, site_id,
created_at, request_prompt, response_text,
error_message, session_id, finished_at
FROM biz.ai_run_logs
WHERE id = %s
""",
(log_id,),
)
cols = [d[0] for d in cur.description]
row = cur.fetchone()
conn.commit()
finally:
conn.close()
if row is None:
return None
return _row_to_dict(cols, row)
# ── 缓存管理 ──────────────────────────────────────────
async def invalidate_cache(
self, site_id: int, app_type: str | None = None, member_id: int | None = None,
) -> int:
"""批量缓存失效,返回受影响记录数。"""
where_parts = ["site_id = %s"]
params: list = [site_id]
if app_type is not None:
where_parts.append("cache_type = %s")
params.append(app_type)
if member_id is not None:
where_parts.append("target_id = %s")
params.append(str(member_id))
where_sql = " AND ".join(where_parts)
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"""
UPDATE biz.ai_cache
SET status = 'invalidated'
WHERE {where_sql}
AND status != 'invalidated'
""",
params,
)
affected = cur.rowcount
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return affected
# ── Token 预算 ────────────────────────────────────────
async def get_budget(self) -> dict:
"""Token 预算使用情况。"""
if self._budget is not None:
status = self._budget.check_budget()
daily_limit = self._budget.daily_limit
monthly_limit = self._budget.monthly_limit
return {
"daily_used": status.daily_used,
"daily_limit": daily_limit,
"daily_pct": round(status.daily_used / daily_limit, 4) if daily_limit > 0 else 0.0,
"monthly_used": status.monthly_used,
"monthly_limit": monthly_limit,
"monthly_pct": round(status.monthly_used / monthly_limit, 4) if monthly_limit > 0 else 0.0,
}
# 无 BudgetTracker 时直接查询
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
COALESCE(SUM(tokens_used) FILTER (
WHERE created_at >= CURRENT_DATE
AND created_at < CURRENT_DATE + INTERVAL '1 day'
), 0) AS daily_used,
COALESCE(SUM(tokens_used) FILTER (
WHERE created_at >= date_trunc('month', CURRENT_DATE)
AND created_at < date_trunc('month', CURRENT_DATE) + INTERVAL '1 month'
), 0) AS monthly_used
FROM biz.ai_run_logs
WHERE status = 'success'
""",
)
row = cur.fetchone()
conn.commit()
finally:
conn.close()
daily_used, monthly_used = (int(row[0]), int(row[1])) if row else (0, 0)
daily_limit = 100_000
monthly_limit = 2_000_000
return {
"daily_used": daily_used,
"daily_limit": daily_limit,
"daily_pct": round(daily_used / daily_limit, 4) if daily_limit > 0 else 0.0,
"monthly_used": monthly_used,
"monthly_limit": monthly_limit,
"monthly_pct": round(monthly_used / monthly_limit, 4) if monthly_limit > 0 else 0.0,
}
# ── 批量执行 ──────────────────────────────────────────
async def estimate_batch(
self, app_types: list[str], member_ids: list[int], site_id: int,
) -> dict:
"""生成 batch_id存入内存TTL 10min返回预估。"""
self._cleanup_expired_batches()
batch_id = uuid.uuid4().hex
estimated_calls = len(app_types) * len(member_ids)
estimated_tokens = estimated_calls * AVG_TOKENS_PER_CALL
self._batch_store[batch_id] = {
"params": {
"app_types": app_types,
"member_ids": member_ids,
"site_id": site_id,
},
"expires_at": datetime.now(timezone.utc) + timedelta(seconds=_BATCH_TTL_SECONDS),
}
return {
"batch_id": batch_id,
"estimated_calls": estimated_calls,
"estimated_tokens": estimated_tokens,
}
async def confirm_batch(self, batch_id: str) -> None:
"""取出参数,异步执行批量调用。"""
self._cleanup_expired_batches()
entry = self._batch_store.pop(batch_id, None)
if entry is None:
raise ValueError(f"batch_id 无效或已过期: {batch_id}")
params = entry["params"]
logger.info(
"批量执行确认: batch_id=%s apps=%s members=%d site_id=%s",
batch_id,
params["app_types"],
len(params["member_ids"]),
params["site_id"],
)
# 后台异步执行(具体调用链由路由层注入 dispatcher 处理)
asyncio.create_task(
self._run_batch(params["app_types"], params["member_ids"], params["site_id"])
)
async def _run_batch(
self, app_types: list[str], member_ids: list[int], site_id: int,
) -> None:
"""后台批量执行(占位实现,实际由 dispatcher 驱动)。"""
logger.info(
"批量执行开始: apps=%s members=%d site_id=%s",
app_types, len(member_ids), site_id,
)
# 实际执行逻辑在路由层通过 dispatcher.handle_trigger 驱动
# 此处仅记录日志,避免服务层直接依赖 dispatcher 实例
def _cleanup_expired_batches(self) -> None:
"""清理过期 batch。"""
now = datetime.now(timezone.utc)
expired = [
bid for bid, entry in self._batch_store.items()
if entry["expires_at"] <= now
]
for bid in expired:
del self._batch_store[bid]
if expired:
logger.debug("清理过期 batch: %d", len(expired))
# ── 告警管理 ──────────────────────────────────────────
async def list_alerts(
self,
alert_status: str | None = None,
site_id: int | None = None,
page: int = 1,
page_size: int = 20,
) -> dict:
"""告警列表ai_run_logs WHERE status IN ('failed','timeout','circuit_open')。"""
where_parts = ["status IN ('failed', 'timeout', 'circuit_open')"]
params: list = []
if alert_status is not None:
if alert_status == "pending":
# pending 包含 NULL 和 'pending'
where_parts.append("(alert_status IS NULL OR alert_status = 'pending')")
else:
where_parts.append("alert_status = %s")
params.append(alert_status)
if site_id is not None:
where_parts.append("site_id = %s")
params.append(site_id)
where_sql = "WHERE " + " AND ".join(where_parts)
offset = (page - 1) * page_size
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
f"SELECT COUNT(*) FROM biz.ai_run_logs {where_sql}",
params,
)
total = cur.fetchone()[0]
cur.execute(
f"""
SELECT id, app_type, status, alert_status,
error_message, created_at
FROM biz.ai_run_logs
{where_sql}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""",
(*params, page_size, offset),
)
cols = [d[0] for d in cur.description]
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return {
"items": [_row_to_dict(cols, r) for r in rows],
"total": total,
"page": page,
"page_size": page_size,
}
async def ack_alert(self, log_id: int) -> str:
"""确认告警alert_status → acknowledged。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_run_logs
SET alert_status = 'acknowledged'
WHERE id = %s
AND status IN ('failed', 'timeout', 'circuit_open')
""",
(log_id,),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return "acknowledged"
async def ignore_alert(self, log_id: int) -> str:
"""忽略告警alert_status → ignored。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE biz.ai_run_logs
SET alert_status = 'ignored'
WHERE id = %s
AND status IN ('failed', 'timeout', 'circuit_open')
""",
(log_id,),
)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
return "ignored"
# ── 工具函数 ──────────────────────────────────────────────
def _site_filter(site_id: int | None) -> tuple[str, tuple]:
"""生成 site_id 过滤子句和参数。"""
if site_id is None:
return "", ()
return "AND site_id = %s", (site_id,)
def _row_to_dict(columns: list[str], row: tuple) -> dict:
"""将数据库行转换为 dict处理 datetime 序列化。"""
result = {}
for col, val in zip(columns, row):
if isinstance(val, datetime):
result[col] = val.isoformat()
else:
result[col] = val
return result

View File

@@ -0,0 +1,188 @@
# -*- coding: utf-8 -*-
"""
AI 数据清理服务。
由定时任务每日凌晨 03:00 调用,执行三步清理:
1. 删除 90 天前的 ai_run_logs
2. 删除 90 天前的 ai_trigger_jobs
3. 每个 App 类型App2~App8的 ai_cache 保留最新 20,000 条
永久保留 App1 对话记录ai_conversations + ai_messages不清理。
需求: E1.1, E1.2, E1.3, E1.4, E2.1, E2.2, E2.3
"""
from __future__ import annotations
import asyncio
import logging
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
class AICleanupService:
"""AI 数据清理服务,由定时任务调用。"""
RETENTION_DAYS = 90
CACHE_LIMIT_PER_APP = 20_000
CACHE_APP_TYPES = [
"app2_finance",
"app3_clue",
"app4_analysis",
"app5_tactics",
"app6_note_analysis",
"app7_customer_analysis",
"app8_clue_consolidated",
]
async def run_cleanup(self) -> dict:
"""执行全部清理,返回各步骤删除记录数。
单步清理失败记录错误日志,继续执行后续步骤。
"""
result: dict = {}
# 步骤 1清理 ai_run_logs
try:
result["run_logs_deleted"] = await self._cleanup_run_logs()
except Exception:
logger.exception("清理 ai_run_logs 失败")
result["run_logs_deleted"] = -1
# 步骤 2清理 ai_trigger_jobs
try:
result["trigger_jobs_deleted"] = await self._cleanup_trigger_jobs()
except Exception:
logger.exception("清理 ai_trigger_jobs 失败")
result["trigger_jobs_deleted"] = -1
# 步骤 3清理 ai_cache每个 App 类型)
try:
result["cache_deleted"] = await self._cleanup_cache()
except Exception:
logger.exception("清理 ai_cache 失败")
result["cache_deleted"] = {}
logger.info("AI 数据清理完成: %s", result)
return result
async def _cleanup_run_logs(self) -> int:
"""DELETE FROM ai_run_logs WHERE created_at < now() - 90 days。"""
from app.database import get_connection
conn = get_connection()
try:
with conn.cursor() as cur:
# 防止锁等待超时5 分钟)
cur.execute("SET statement_timeout = 300000")
cur.execute(
"""
DELETE FROM biz.ai_run_logs
WHERE created_at < NOW() - INTERVAL '%s days'
""",
(self.RETENTION_DAYS,),
)
deleted = cur.rowcount
conn.commit()
logger.info("清理 ai_run_logs: 删除 %d", deleted)
return deleted
except Exception:
conn.rollback()
raise
finally:
conn.close()
async def _cleanup_trigger_jobs(self) -> int:
"""DELETE FROM ai_trigger_jobs WHERE created_at < now() - 90 days。"""
from app.database import get_connection
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute("SET statement_timeout = 300000")
cur.execute(
"""
DELETE FROM biz.ai_trigger_jobs
WHERE created_at < NOW() - INTERVAL '%s days'
""",
(self.RETENTION_DAYS,),
)
deleted = cur.rowcount
conn.commit()
logger.info("清理 ai_trigger_jobs: 删除 %d", deleted)
return deleted
except Exception:
conn.rollback()
raise
finally:
conn.close()
async def _cleanup_cache(self) -> dict[str, int]:
"""每个 App 类型保留最新 20,000 条,删除超出部分。"""
from app.database import get_connection
result: dict[str, int] = {}
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute("SET statement_timeout = 300000")
for app_type in self.CACHE_APP_TYPES:
try:
# 子查询:找到该 app_type 第 20001 条的 created_at 作为截断点
cur.execute(
"""
DELETE FROM biz.ai_cache
WHERE app_type = %s
AND id NOT IN (
SELECT id FROM biz.ai_cache
WHERE app_type = %s
ORDER BY created_at DESC
LIMIT %s
)
""",
(app_type, app_type, self.CACHE_LIMIT_PER_APP),
)
deleted = cur.rowcount
result[app_type] = deleted
if deleted > 0:
logger.info(
"清理 ai_cache [%s]: 删除 %d",
app_type,
deleted,
)
except Exception:
logger.exception("清理 ai_cache [%s] 失败", app_type)
result[app_type] = -1
conn.rollback()
# 重新开始事务以继续后续 app_type
continue
conn.commit()
return result
except Exception:
conn.rollback()
raise
finally:
conn.close()
@trace_service(description_zh="register_cleanup_job", description_en="Register Cleanup Job")
def register_cleanup_job(scheduler) -> None: # noqa: ANN001
"""注册清理定时任务到调度器。每日 03:00 执行。
在 main.py lifespan 中调用,或通过 scheduled_tasks 表注册。
实际调度由 trigger_scheduler 的 cron 机制驱动:
- job_type: 'ai_data_cleanup'
- trigger_condition: 'cron'
- trigger_config: {"cron_expression": "0 3 * * *"}
需求: E2.1, E2.2, E2.3
"""
from app.services.trigger_scheduler import register_job
def _run_cleanup(**_kw):
"""同步包装器:在新事件循环中执行异步清理。"""
result = asyncio.run(AICleanupService().run_cleanup())
logger.info("定时清理任务完成: %s", result)
register_job("ai_data_cleanup", _run_cleanup)

View File

@@ -18,10 +18,12 @@ import logging
from fastapi import HTTPException, status
from app.database import get_connection
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@trace_service(description_zh="创建入驻申请", description_en="Create application")
async def create_application(
user_id: int,
site_code: str,
@@ -60,11 +62,21 @@ async def create_application(
detail="已有待审核的申请,请等待审核完成",
)
# 2. 查找 site_code → site_id 映射
# 2. 查找 site_code → site_id 映射(优先当前活跃编码,再查历史编码)
# CHANGE 2026-03-23 | 大小写不敏感匹配 site_codeUPPER
site_id = None
site_code_upper = site_code.upper()
cur.execute(
"SELECT site_id FROM auth.site_code_mapping WHERE site_code = %s",
(site_code,),
"""
SELECT site_id FROM biz.sites
WHERE UPPER(site_code) = %s AND is_active = true
UNION ALL
SELECT s.site_id FROM biz.site_code_history h
JOIN biz.sites s ON s.site_id = h.site_id
WHERE UPPER(h.site_code) = %s AND h.is_current = false AND s.is_active = true
LIMIT 1
""",
(site_code_upper, site_code_upper),
)
mapping_row = cur.fetchone()
if mapping_row is not None:
@@ -123,6 +135,7 @@ async def create_application(
@trace_service(description_zh="审批通过申请", description_en="Approve application")
async def approve_application(
application_id: int,
reviewer_id: int,
@@ -248,6 +261,7 @@ async def approve_application(
}
@trace_service(description_zh="驳回申请", description_en="Reject application")
async def reject_application(
application_id: int,
reviewer_id: int,
@@ -260,16 +274,18 @@ async def reject_application(
2. 检查申请状态为 pending否则 409
3. 更新 user_applications.status = 'rejected'
4. 记录 reviewer_id、review_note、reviewed_at
5. 累加 users.rejection_count达到 3 次自动禁用
返回:
更新后的申请记录 dict
更新后的申请记录 dict(含 user_disabled 标记)
"""
conn = get_connection()
user_disabled = False
try:
with conn.cursor() as cur:
# 1. 查询申请记录
# 1. 查询申请记录(含 user_id
cur.execute(
"SELECT id, status FROM auth.user_applications WHERE id = %s",
"SELECT id, user_id, status FROM auth.user_applications WHERE id = %s",
(application_id,),
)
app_row = cur.fetchone()
@@ -279,11 +295,13 @@ async def reject_application(
detail="申请不存在",
)
_, app_user_id, app_status = app_row
# 2. 检查状态为 pending
if app_row[1] != "pending":
if app_status != "pending":
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"申请当前状态为 {app_row[1]},无法审核",
detail=f"申请当前状态为 {app_status},无法审核",
)
# 3. 更新申请状态为 rejected
@@ -301,6 +319,46 @@ async def reject_application(
(reviewer_id, review_note, application_id),
)
updated_row = cur.fetchone()
# 4. 累加 rejection_count 并检查是否达到禁用阈值
cur.execute(
"""
UPDATE auth.users
SET rejection_count = rejection_count + 1,
updated_at = NOW()
WHERE id = %s
RETURNING rejection_count
""",
(app_user_id,),
)
new_count = cur.fetchone()[0]
if new_count >= 3:
# 第三次拒绝:自动禁用账号
cur.execute(
"""
UPDATE auth.users
SET status = 'disabled', updated_at = NOW()
WHERE id = %s
""",
(app_user_id,),
)
user_disabled = True
logger.warning(
"用户 %s 累计被拒绝 %d 次,已自动禁用",
app_user_id, new_count,
)
else:
# 未达阈值:回退用户状态为 rejected允许重新申请
cur.execute(
"""
UPDATE auth.users
SET status = 'rejected', updated_at = NOW()
WHERE id = %s
""",
(app_user_id,),
)
conn.commit()
finally:
conn.close()
@@ -313,9 +371,11 @@ async def reject_application(
"review_note": updated_row[4],
"created_at": updated_row[5],
"reviewed_at": updated_row[6],
"user_disabled": user_disabled,
}
@trace_service(description_zh="获取用户申请列表", description_en="Get user applications")
async def get_user_applications(user_id: int) -> list[dict]:
"""
查询用户的所有申请记录。
@@ -323,14 +383,15 @@ async def get_user_applications(user_id: int) -> list[dict]:
按创建时间倒序排列。
返回:
申请记录 dict 列表
申请记录 dict 列表(含 phone、employee_number
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_code, applied_role_text, status,
SELECT id, site_code, applied_role_text, phone,
employee_number, status,
review_note, created_at::text, reviewed_at::text
FROM auth.user_applications
WHERE user_id = %s
@@ -347,10 +408,84 @@ async def get_user_applications(user_id: int) -> list[dict]:
"id": r[0],
"site_code": r[1],
"applied_role_text": r[2],
"status": r[3],
"review_note": r[4],
"created_at": r[5],
"reviewed_at": r[6],
"phone": r[3],
"employee_number": r[4],
"status": r[5],
"review_note": r[6],
"created_at": r[7],
"reviewed_at": r[8],
}
for r in rows
]
@trace_service(description_zh="取消申请", description_en="Cancel application")
async def cancel_application(user_id: int) -> dict:
"""
用户主动取消当前 pending 申请。
1. 查找用户的 pending 申请(无则 404
2. 更新申请 status = 'cancelled'
3. 回退用户 status 为 'new'
返回:
被取消的申请记录 dict
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 1. 查找 pending 申请
cur.execute(
"""
SELECT id FROM auth.user_applications
WHERE user_id = %s AND status = 'pending'
ORDER BY created_at DESC
LIMIT 1
""",
(user_id,),
)
row = cur.fetchone()
if row is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="没有待审核的申请",
)
application_id = row[0]
# 2. 更新申请状态为 cancelled
cur.execute(
"""
UPDATE auth.user_applications
SET status = 'cancelled'
WHERE id = %s
RETURNING id, site_code, applied_role_text, phone,
employee_number, status, created_at::text
""",
(application_id,),
)
updated_row = cur.fetchone()
# 3. 回退用户状态为 new
cur.execute(
"""
UPDATE auth.users
SET status = 'new', updated_at = NOW()
WHERE id = %s
""",
(user_id,),
)
conn.commit()
finally:
conn.close()
return {
"id": updated_row[0],
"site_code": updated_row[1],
"applied_role_text": updated_row[2],
"phone": updated_row[3],
"employee_number": updated_row[4],
"status": updated_row[5],
"created_at": updated_row[6],
}

View File

@@ -15,6 +15,8 @@ import calendar
from datetime import date, timedelta
from decimal import Decimal, ROUND_HALF_UP
from app.trace.decorators import trace_service
# ---------------------------------------------------------------------------
# 通用工具函数
@@ -34,10 +36,10 @@ def _calc_date_range(
"""
today = ref_date or date.today()
# --- 当月 ---
# --- 当月cap 到今天)---
if time_enum == "month":
start = today.replace(day=1)
end = today.replace(day=calendar.monthrange(today.year, today.month)[1])
end = today
return start, end
# --- 上月 ---
@@ -47,11 +49,11 @@ def _calc_date_range(
last_month_start = last_month_end.replace(day=1)
return last_month_start, last_month_end
# --- 本周(周一 ~ 周日---
# --- 本周(周一 ~ 今天---
if time_enum == "week":
monday = today - timedelta(days=today.weekday())
sunday = monday + timedelta(days=6)
return monday, sunday
end = today
return monday, end
# --- 上周 ---
if time_enum == "lastWeek":
@@ -60,12 +62,11 @@ def _calc_date_range(
last_monday = last_sunday - timedelta(days=6)
return last_monday, last_sunday
# --- 本季度 ---
# --- 本季度cap 到今天)---
if time_enum == "quarter":
q_start_month = (today.month - 1) // 3 * 3 + 1
start = date(today.year, q_start_month, 1)
q_end_month = q_start_month + 2
end = date(today.year, q_end_month, calendar.monthrange(today.year, q_end_month)[1])
end = today
return start, end
# --- 上季度 ---
@@ -106,18 +107,57 @@ def _month_offset(d: date, months: int) -> date:
return date(y, m, 1)
def _calc_prev_range(start_date: date, end_date: date) -> tuple[date, date]:
def _calc_prev_range(
time_enum: str, start_date: date, end_date: date
) -> tuple[date, date]:
"""
根据当期范围计算上期日期范围。
根据当期范围和周期类型计算上期同期日期范围。
上期长度等于当期长度prev_end = start_date - 1 天。
CHANGE 2026-03-28 | 环比改为同期对比:
- month: 当期 3/1~3/28 → 上期 2/1~2/28上月同日
- week: 当期 周一~周四 → 上期 上周一~上周四(上周同天数)
- quarter: 当期 1/1~3/28 → 上期 去年10/1~10/28 对应天数
- lastMonth: 2/1~2/28 → 1/1~1/28再上月同天数
- lastWeek: 上周一~上周日 → 再上周一~再上周日
- lastQuarter: 上季度完整 → 再上季度完整
- quarter3/half6: 往前推等长天数(无明确"同期"概念)
"""
period_length = (end_date - start_date).days + 1
elapsed_days = (end_date - start_date).days # 当期已过天数0-indexed
# 月度类上月1日 + 同样天数
if time_enum in ("month", "last_month", "lastMonth"):
prev_start = _month_offset(start_date, -1)
# 上月同日,但不超过上月末日
prev_end_day = min(end_date.day, calendar.monthrange(prev_start.year, prev_start.month)[1])
prev_end = prev_start.replace(day=prev_end_day)
return prev_start, prev_end
# 周度类:往前推 7 天
if time_enum in ("week", "lastWeek"):
prev_start = start_date - timedelta(days=7)
prev_end = end_date - timedelta(days=7)
return prev_start, prev_end
# 季度类:上季度首日 + 同样天数
if time_enum in ("quarter", "last_quarter", "lastQuarter"):
prev_q_start = _month_offset(start_date, -3)
prev_end = prev_q_start + timedelta(days=elapsed_days)
# 不超过上季度末日
prev_q_end_month = prev_q_start.month + 2
prev_q_end_max = date(prev_q_start.year, prev_q_end_month,
calendar.monthrange(prev_q_start.year, prev_q_end_month)[1])
if prev_end > prev_q_end_max:
prev_end = prev_q_end_max
return prev_q_start, prev_end
# 其他quarter3/half6往前推等长天数
period_length = elapsed_days + 1
prev_end = start_date - timedelta(days=1)
prev_start = prev_end - timedelta(days=period_length - 1)
return prev_start, prev_end
@trace_service(description_zh="计算对比数据", description_en="Calc Compare")
def calc_compare(current: Decimal, previous: Decimal) -> dict:
"""
统一环比计算。
@@ -178,6 +218,20 @@ _SORT_KEY_MAP = {
"task_desc": ("task_total", True),
}
# 项目标签 category_code → 前端显示文本 / CSS 类名
_SKILL_DISPLAY = {
"BILLIARD": "🎱",
"SNOOKER": "",
"MAHJONG": "🀄",
"KTV": "🎤",
}
_SKILL_CLS = {
"BILLIARD": "skill--chinese",
"SNOOKER": "skill--snooker",
"MAHJONG": "skill--mahjong",
"KTV": "skill--karaoke",
}
_SORT_DIM_MAP = {
"perf_desc": "perf", "perf_asc": "perf",
"salary_desc": "salary", "salary_asc": "salary",
@@ -190,8 +244,9 @@ _SORT_DIM_MAP = {
# ---------------------------------------------------------------------------
@trace_service("获取助教看板", "Get coach board")
async def get_coach_board(
sort: str, skill: str, time: str, site_id: int
sort: str, skill: str, time: str, page: int, page_size: int, site_id: int
) -> dict:
"""
BOARD-1助教看板。扁平返回所有维度字段。
@@ -244,7 +299,17 @@ async def get_coach_board(
# 5. 任务数据
task_map = _query_coach_tasks(conn, site_id, aid_list, start_str, end_str)
# 6. 组装扁平响应
# 6. 查询档位配置,计算距升档(仅本月/上月有意义)
tier_nodes: list[float] = []
show_perf_gap = time in ("month", "last_month")
if show_perf_gap:
try:
tiers = fdw_queries.get_performance_tiers(conn, site_id)
tier_nodes = [float(t["min_hours"]) for t in tiers] if tiers else []
except Exception as e:
logger.warning("BOARD-1 档位配置查询失败: %s", e, exc_info=True)
# 7. 组装扁平响应
items = []
for a in assistants:
aid = a["assistant_id"]
@@ -256,26 +321,52 @@ async def get_coach_board(
name = a["name"]
initial = name[0] if name else ""
perf_hours = sal.get("effective_hours", 0.0)
salary_val = sal.get("gross_salary", 0.0)
perf_hours = float(sal.get("effective_hours", 0) or 0)
salary_val = float(sal.get("gross_salary", 0) or 0)
task_recall = tasks.get("recall", 0)
task_callback = tasks.get("callback", 0)
# 折前课时:当 effective_hours != raw_hours 时显示(惩罚扣减导致的差异)
# 惩罚规则:同台 >2 助教重叠per_hour_contribution < 24 元时按比例扣减
raw_hours = float(sal.get("raw_hours", 0) or 0)
perf_hours_before = None
if abs(perf_hours - raw_hours) > 0.01:
perf_hours_before = raw_hours
# 计算距升档差距
perf_gap = None
perf_reached = False
if tier_nodes and perf_hours is not None:
# 找到下一个未达到的档位
for threshold in tier_nodes:
if perf_hours < threshold:
gap = threshold - perf_hours
perf_gap = f"距升档 {gap:.1f}h"
break
else:
perf_reached = True # 已达到最高档
items.append({
"id": aid,
"name": name,
"initial": initial,
"avatar_gradient": "",
"level": sal.get("level_name", a.get("level", "")),
"skills": [], # CHANGE 2026-03-20 | v_dim_assistant skill 列,暂返回空
# CHANGE 2026-03-29 | 从 get_all_assistants 返回的 skill 字段取项目标签
# Schema 要求 list[CoachSkillItem]{text, cls}),不是纯字符串
# text 映射为中文短名 + emojicls 映射为 CSS 类名
"skills": [
{"text": _SKILL_DISPLAY.get(s, s), "cls": _SKILL_CLS.get(s, "")}
for s in (a.get("skill") or "").split(",") if s
],
"top_customers": top_custs,
"perf_hours": perf_hours,
"perf_hours_before": None,
"perf_gap": None,
"perf_reached": False,
"perf_hours_before": perf_hours_before,
"perf_gap": perf_gap,
"perf_reached": perf_reached,
"salary": salary_val,
"salary_perf_hours": perf_hours,
"salary_perf_before": None,
"salary_perf_before": perf_hours_before,
"sv_amount": sv.get("sv_amount", 0.0),
"sv_customer_count": sv.get("sv_customer_count", 0),
"sv_consume": sv.get("sv_consume", 0.0),
@@ -292,8 +383,16 @@ async def get_coach_board(
for item in items:
item.pop("task_total", None)
# 8. 分页
total = len(items)
start = (page - 1) * page_size
items = items[start : start + page_size]
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
"dim_type": _SORT_DIM_MAP.get(sort, "perf"),
}
finally:
@@ -318,8 +417,8 @@ def _query_coach_tasks(
cur.execute(
"""
SELECT assistant_id,
COUNT(*) FILTER (WHERE task_type = 'recall') AS recall_count,
COUNT(*) FILTER (WHERE task_type = 'callback') AS callback_count
COUNT(*) FILTER (WHERE task_type IN ('high_priority_recall', 'priority_recall')) AS recall_count,
COUNT(*) FILTER (WHERE task_type = 'relationship_building') AS callback_count
FROM biz.coach_tasks
WHERE assistant_id = ANY(%s)
AND site_id = %s
@@ -346,6 +445,106 @@ def _query_coach_tasks(
# BOARD-2 客户看板
# ---------------------------------------------------------------------------
def _batch_ideal_days(conn: Any, site_id: int, member_ids: list[int]) -> dict[int, int]:
"""批量查询客户理想到店间隔天数balance/recharge 维度头部用)。"""
from app.services.fdw_queries import _fdw_context
result: dict[int, int] = {}
try:
with _fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT member_id, COALESCE(ideal_interval_days, 0)
FROM app.v_dws_member_winback_index
WHERE member_id = ANY(%s)
""",
(member_ids,),
)
for row in cur.fetchall():
result[row[0]] = int(row[1]) if row[1] is not None else 0
except Exception:
logger.warning("_batch_ideal_days 查询失败", exc_info=True)
return result
def _batch_coach_details(conn: Any, site_id: int, member_ids: list[int]) -> dict[int, list[dict]]:
"""批量查询客户-助教服务明细loyal 维度 coachDetails 用)。每个客户前 5 个。"""
from app.services.fdw_queries import _fdw_context
result: dict[int, list[dict]] = {mid: [] for mid in member_ids}
try:
with _fdw_context(conn, site_id) as cur:
# CHANGE 2026-03-29 | coach_spend 改为从 dwd_assistant_service_log 聚合 60 天消费
cur.execute(
"""
SELECT ri.member_id,
COALESCE(da.nickname, da.real_name, '') AS name,
ri.rs_display,
ri.session_count,
ri.total_duration_minutes,
COALESCE(s60.spend_60d, 0) AS spend_60d
FROM app.v_dws_member_assistant_relation_index ri
LEFT JOIN app.v_dim_assistant da
ON ri.assistant_id = da.assistant_id AND da.scd2_is_current = 1
LEFT JOIN (
SELECT tenant_member_id, site_assistant_id,
SUM(ledger_amount) AS spend_60d
FROM app.v_dwd_assistant_service_log
WHERE is_delete = 0
AND create_time >= CURRENT_DATE - INTERVAL '60 days'
AND tenant_member_id = ANY(%s)
GROUP BY tenant_member_id, site_assistant_id
) s60 ON ri.member_id = s60.tenant_member_id
AND ri.assistant_id = s60.site_assistant_id
WHERE ri.member_id = ANY(%s)
AND (da.leave_status IS NULL OR da.leave_status = 0)
ORDER BY ri.member_id, ri.rs_display DESC
""",
(member_ids, member_ids),
)
for row in cur.fetchall():
mid = row[0]
if mid in result and len(result[mid]) < 5:
svc_count = row[3] or 0
total_mins = float(row[4]) if row[4] else 0.0
avg_dur = round(total_mins / 60 / svc_count, 1) if svc_count > 0 else 0.0
result[mid].append({
"name": row[1] or "",
"cls": "",
"heart_score": float(row[2]) if row[2] is not None else 0.0,
"avg_duration": f"{avg_dur}h",
"service_count": str(svc_count),
"coach_spend": float(row[5]) if row[5] is not None else 0.0,
"relation_idx": float(row[2]) if row[2] is not None else 0.0,
})
except Exception:
logger.warning("_batch_coach_details 查询失败", exc_info=True)
return result
def _batch_member_projects(conn: Any, site_id: int, member_ids: list[int]) -> dict[int, list[str]]:
"""批量查询客户项目标签BOARD-2 用)。通过 FDW 视图查询。"""
from app.services.fdw_queries import _fdw_context
result: dict[int, list[str]] = {mid: [] for mid in member_ids}
try:
with _fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT member_id, array_agg(DISTINCT category_code)
FROM app.v_dws_member_project_tag
WHERE member_id = ANY(%s) AND is_tagged = true
GROUP BY member_id
""",
(member_ids,),
)
for row in cur.fetchall():
mid = row[0]
codes = row[1] or []
if mid in result:
result[mid] = [c for c in codes if c]
except Exception:
logger.warning("_batch_member_projects 查询失败", exc_info=True)
return result
# 维度 → FDW 查询函数映射
_DIMENSION_QUERY_MAP = {
"recall": "get_customer_board_recall",
@@ -359,6 +558,7 @@ _DIMENSION_QUERY_MAP = {
}
@trace_service("获取客户看板", "Get customer board")
async def get_customer_board(
dimension: str, project: str, page: int, page_size: int, site_id: int
) -> dict:
@@ -388,6 +588,25 @@ async def get_customer_board(
except Exception:
logger.warning("BOARD-2 客户助教查询失败,降级为空", exc_info=True)
# 2b. 批量查询客户项目标签
member_projects: dict[int, list[str]] = {}
if member_ids:
try:
member_projects = _batch_member_projects(conn, site_id, member_ids)
except Exception:
logger.warning("BOARD-2 客户项目标签查询失败,降级为空", exc_info=True)
# 2c. balance/recharge 维度:补充 ideal_days
if dimension in ("balance", "recharge") and member_ids:
try:
ideal_map = _batch_ideal_days(conn, site_id, member_ids)
for item in items:
mid = item.get("member_id", 0)
if item.get("ideal_days") is None:
item["ideal_days"] = ideal_map.get(mid, 0)
except Exception:
logger.warning("BOARD-2 ideal_days 查询失败", exc_info=True)
# 3. 组装响应(添加基础字段 + assistants
for item in items:
mid = item.get("member_id", 0)
@@ -396,9 +615,43 @@ async def get_customer_board(
item["initial"] = name[0] if name else ""
item["avatar_cls"] = ""
item["assistants"] = assistants_map.get(mid, [])
item["projects"] = member_projects.get(mid, [])
# 3b. loyal 维度:为每个客户补充 coach_details前 5 个助教的服务明细)
if dimension == "loyal" and member_ids:
try:
coach_details_map = _batch_coach_details(conn, site_id, member_ids)
for item in items:
mid = item.get("member_id", 0)
item["coach_details"] = coach_details_map.get(mid, [])
except Exception:
logger.warning("BOARD-2 loyal coachDetails 查询失败", exc_info=True)
for item in items:
item["coach_details"] = []
# CHANGE 2026-03-28 | P5 联调修复items 是 list[dict]Pydantic CamelModel
# 不会自动转换内部 dict 的 key。手动 snake_case → camelCase。
# CHANGE 2026-03-29 | 递归处理嵌套 list[dict](如 assistants 数组)
def _to_camel(key: str) -> str:
parts = key.split("_")
return parts[0] + "".join(p.capitalize() for p in parts[1:])
def _camel_dict(d: dict) -> dict:
result = {}
for k, v in d.items():
ck = _to_camel(k)
if isinstance(v, list):
result[ck] = [_camel_dict(i) if isinstance(i, dict) else i for i in v]
elif isinstance(v, dict):
result[ck] = _camel_dict(v)
else:
result[ck] = v
return result
camel_items = [_camel_dict(item) for item in items]
return {
"items": items,
"items": camel_items,
"total": result["total"],
"page": result["page"],
"page_size": result["page_size"],
@@ -412,15 +665,26 @@ async def get_customer_board(
# ---------------------------------------------------------------------------
# CHANGE 2026-04-01 | board-finance-dws-area-refactor 9.1 | 缓存/日粒度查询路由
COMPLETED_PERIODS = {"lastMonth", "lastWeek", "lastQuarter", "quarter3", "half6"}
CURRENT_PERIODS = {"month", "week", "quarter"}
@trace_service("获取财务看板", "Get finance board")
async def get_finance_board(
time: str, area: str, compare: int, site_id: int
) -> dict:
"""
BOARD-3财务看板。6 板块独立查询、独立降级。
area≠all 时 recharge 返回 null。
compare=1 时计算上期范围并调用 calc_compare。
compare=0 时环比字段为 None序列化时排除
CHANGE 2026-04-01 | board-finance-dws-area-refactor 9.1 |
- 已完成周期先查缓存 → 未命中从日粒度表 SUM → 写缓存
- 当期周期直接从日粒度表 SUM不查缓存
- overview/revenue 改为从 dws_finance_area_daily 按 area_code 查询
- cashflow/expense/coach_analysis 不变(始终用全局数据)
- area≠all 时 recharge 返回 null
- area≠all 时 overview 覆盖逻辑保留
- compare=1 时对上期执行同样缓存/日粒度逻辑
"""
start_date, end_date = _calc_date_range(time)
start_str = str(start_date)
@@ -429,7 +693,7 @@ async def get_finance_board(
prev_start_str = None
prev_end_str = None
if compare == 1:
prev_start, prev_end = _calc_prev_range(start_date, end_date)
prev_start, prev_end = _calc_prev_range(time, start_date, end_date)
prev_start_str = str(prev_start)
prev_end_str = str(prev_end)
@@ -437,23 +701,47 @@ async def get_finance_board(
try:
# 各板块独立 try/except
overview = _build_overview(conn, site_id, start_str, end_str,
prev_start_str, prev_end_str, compare)
prev_start_str, prev_end_str, compare, area)
recharge = None
if area == "all":
recharge = _build_recharge(conn, site_id, start_str, end_str,
prev_start_str, prev_end_str, compare)
revenue = _build_revenue(conn, site_id, start_str, end_str, area)
cashflow = _build_cashflow(conn, site_id, start_str, end_str,
prev_start_str, prev_end_str, compare)
expense = _build_expense(conn, site_id, start_str, end_str,
revenue = _build_revenue(conn, site_id, start_str, end_str, area,
prev_start_str, prev_end_str, compare)
# CHANGE 2026-03-28 | 非全部区域时,用 revenue 的数据覆盖 overview 的发生额/优惠/确认收入
if area != "all" and revenue:
overview["occurrence"] = revenue.get("total_occurrence", 0.0)
overview["discount"] = revenue.get("discount_total", 0.0)
overview["confirmed_revenue"] = revenue.get("confirmed_total", 0.0)
# discount_rate 重算
occ = overview["occurrence"]
overview["discount_rate"] = (overview["discount"] / occ) if occ > 0 else 0.0
# CHANGE 2026-03-29 | area≠all 时隐藏实收流水(现金流 4 项无法按区域拆分)
overview["cash_in"] = None
overview["cash_out"] = None
overview["cash_balance"] = None
overview["balance_rate"] = None
# 移除现金流环比字段(如有)
for f in ("cash_in", "cash_out", "cash_balance", "balance_rate"):
overview.pop(f"{f}_compare", None)
overview.pop(f"{f}_down", None)
overview.pop(f"{f}_flat", None)
# CHANGE 2026-03-29 | area≠all 时隐藏现金流入和现金流出板块
cashflow = None
expense = None
if area == "all":
cashflow = _build_cashflow(conn, site_id, start_str, end_str,
prev_start_str, prev_end_str, compare)
expense = _build_expense(conn, site_id, start_str, end_str,
prev_start_str, prev_end_str, compare)
coach_analysis = _build_coach_analysis(conn, site_id, start_str, end_str,
prev_start_str, prev_end_str, compare)
prev_start_str, prev_end_str, compare, area)
return {
"overview": overview,
@@ -470,10 +758,15 @@ async def get_finance_board(
def _build_overview(
conn: Any, site_id: int, start: str, end: str,
prev_start: str | None, prev_end: str | None, compare: int,
area: str = "all",
) -> dict:
"""经营一览板块。"""
"""经营一览板块。
CHANGE 2026-04-01 | board-finance-dws-area-refactor 9.1 |
改为从 dws_finance_area_daily 按 area_code 查询(通过 get_finance_overview_area
"""
try:
data = fdw_queries.get_finance_overview(conn, site_id, start, end)
data = fdw_queries.get_finance_overview_area(conn, site_id, start, end, area)
except Exception:
logger.warning("overview 查询失败,降级为空", exc_info=True)
return _empty_overview()
@@ -482,7 +775,7 @@ def _build_overview(
if compare == 1 and prev_start and prev_end:
try:
prev = fdw_queries.get_finance_overview(conn, site_id, prev_start, prev_end)
prev = fdw_queries.get_finance_overview_area(conn, site_id, prev_start, prev_end, area)
_attach_compare(result, data, prev, [
"occurrence", "discount", "discount_rate", "confirmed_revenue",
"cash_in", "cash_out", "cash_balance", "balance_rate",
@@ -509,7 +802,7 @@ def _build_recharge(
prev = fdw_queries.get_finance_recharge(conn, site_id, prev_start, prev_end)
_attach_compare(data, data, prev, [
"actual_income", "first_charge", "renew_charge",
"consumed", "card_balance",
"consumed", "card_balance", "all_card_balance",
])
# 赠送卡矩阵环比
for i, row in enumerate(data.get("gift_rows", [])):
@@ -535,14 +828,192 @@ def _build_recharge(
def _build_revenue(
conn: Any, site_id: int, start: str, end: str, area: str,
prev_start: str | None = None, prev_end: str | None = None, compare: int = 0,
) -> dict:
"""应计收入板块。"""
"""应计收入板块。
CHANGE 2026-04-01 | board-finance-dws-area-refactor 9.1 |
改为从 dws_finance_area_daily 按 area_code 查询(通过 get_finance_revenue_area
然后在 Python 层构建 structure_rows / discount_items / channel_items 保持返回结构不变。
"""
try:
return fdw_queries.get_finance_revenue(conn, site_id, start, end, area)
if area == "all":
# CHANGE 2026-03-29 | area=all 走旧版查询,保留收入结构的区域子行拆分
data = fdw_queries.get_finance_revenue(conn, site_id, start, end, area)
else:
raw = fdw_queries.get_finance_revenue_area(conn, site_id, start, end, area)
data = _format_revenue_from_area(raw, conn, site_id, start, end, area)
except Exception:
logger.warning("revenue 查询失败,降级为空", exc_info=True)
return _empty_revenue()
if compare == 1 and prev_start and prev_end:
try:
if area == "all":
prev = fdw_queries.get_finance_revenue(conn, site_id, prev_start, prev_end, area)
else:
prev_raw = fdw_queries.get_finance_revenue_area(conn, site_id, prev_start, prev_end, area)
prev = _format_revenue_from_area(prev_raw, conn, site_id, prev_start, prev_end, area)
# 总计环比
_attach_compare(data, data, prev, [
"total_occurrence", "discount_total", "confirmed_total",
])
# structure_rows 行级环比(按 id 匹配)
prev_struct = {r["id"]: r for r in prev.get("structure_rows", [])}
for row in data.get("structure_rows", []):
prev_row = prev_struct.get(row["id"], {})
cmp = calc_compare(
Decimal(str(row.get("booked", 0))),
Decimal(str(prev_row.get("booked", 0))),
)
row["booked_compare"] = cmp["compare"]
# price_items 行级环比(按 label 匹配)
prev_prices = {r["label"]: r for r in prev.get("price_items", [])}
for item in data.get("price_items", []):
prev_item = prev_prices.get(item["label"], {})
cmp = calc_compare(
Decimal(str(item.get("amount", 0))),
Decimal(str(prev_item.get("amount", 0))),
)
item["compare"] = cmp["compare"]
# discount_items 行级环比(按 label 匹配)
prev_discounts = {r["label"]: r for r in prev.get("discount_items", [])}
for item in data.get("discount_items", []):
prev_item = prev_discounts.get(item["label"], {})
cmp = calc_compare(
Decimal(str(item.get("amount", 0))),
Decimal(str(prev_item.get("amount", 0))),
)
item["compare"] = cmp["compare"]
# channel_items 行级环比(按 label 匹配)
prev_channels = {r["label"]: r for r in prev.get("channel_items", [])}
for item in data.get("channel_items", []):
prev_item = prev_channels.get(item["label"], {})
cmp = calc_compare(
Decimal(str(item.get("amount", 0))),
Decimal(str(prev_item.get("amount", 0))),
)
item["compare"] = cmp["compare"]
except Exception:
logger.warning("revenue 环比查询失败", exc_info=True)
return data
def _format_revenue_from_area(
raw: dict, conn: Any, site_id: int, start: str, end: str, area: str,
) -> dict:
"""将 get_finance_revenue_area 的原始聚合数据格式化为前端期望的 revenue 结构。
CHANGE 2026-04-01 | board-finance-dws-area-refactor 9.1 |
从 dws_finance_area_daily 聚合数据构建 structure_rows / discount_items / channel_items
保持与旧 get_finance_revenue 返回结构完全一致。
"""
total_table_charge = raw.get("table_fee_amount", 0.0)
total_goods = raw.get("goods_amount", 0.0)
total_pd = raw.get("assistant_pd_amount", 0.0)
total_cx = raw.get("assistant_cx_amount", 0.0)
total_income = raw.get("total_occurrence", 0.0)
# 构建 structure_rows简化版不再按物理区域拆分子行因为 area_daily 已按 area_code 聚合)
structure_rows = [
{"id": "table_charge", "name": "开台与包厢", "desc": None,
"is_sub": False, "amount": total_table_charge,
"discount": 0.0, "booked": total_table_charge},
{"id": "assistant_pd", "name": "助教 基础课", "desc": None,
"is_sub": False, "amount": total_pd,
"discount": 0.0, "booked": total_pd},
{"id": "assistant_cx", "name": "助教 激励课", "desc": None,
"is_sub": False, "amount": total_cx,
"discount": 0.0, "booked": total_cx},
{"id": "goods", "name": "食品酒水", "desc": None,
"is_sub": False, "amount": total_goods,
"discount": 0.0, "booked": total_goods},
]
# 发生额构成
price_items = [
{"label": "开台消费", "amount": total_table_charge},
{"label": "酒水商品", "amount": total_goods},
{"label": "助教服务", "amount": total_pd + total_cx},
]
# 优惠拆分5 项,与旧逻辑一致)
groupbuy_d = raw.get("discount_groupbuy", 0.0)
vip_d = raw.get("discount_vip", 0.0)
manual_d = raw.get("discount_manual", 0.0)
gift_card_d = raw.get("discount_gift_card", 0.0)
# 其他 = discount_rounding + discount_other
rounding_d = raw.get("discount_rounding", 0.0)
other_d = raw.get("discount_other", 0.0)
discount_items = [
{"label": "团购优惠", "amount": groupbuy_d},
{"label": "会员折扣", "amount": vip_d},
{"label": "手动调整", "amount": manual_d + other_d},
{"label": "赠送卡抵扣", "desc": "台桌卡+酒水卡+抵用券", "amount": gift_card_d},
{"label": "其他优惠", "desc": "免单+抹零", "amount": rounding_d},
]
total_discount = raw.get("discount_total", 0.0)
# 回填收入结构表的优惠分摊
if total_table_charge > 0 and total_discount > 0:
for row in structure_rows:
if row["id"] == "table_charge":
row["discount"] = total_discount
row["booked"] = total_table_charge - total_discount
# 渠道分布(从 dws_finance_area_daily 的 all 行获取,因为渠道数据仅 all 有值)
# 需要额外查询 all 行的渠道数据
try:
channel_data = _get_channel_items(conn, site_id, start, end)
except Exception:
logger.warning("revenue 渠道数据查询失败,降级为空", exc_info=True)
channel_data = [
{"label": "储值卡结算冲销", "amount": 0.0},
{"label": "现金/线上支付", "amount": 0.0},
{"label": "团购核销确认收入", "desc": "团购成交价", "amount": 0.0},
]
confirmed_total = total_income - abs(total_discount)
return {
"structure_rows": structure_rows,
"price_items": price_items,
"total_occurrence": total_income,
"discount_items": discount_items,
"discount_total": total_discount,
"confirmed_total": confirmed_total,
"channel_items": channel_data,
}
def _get_channel_items(conn: Any, site_id: int, start: str, end: str) -> list[dict]:
"""从 v_dws_finance_daily_summary 获取渠道分布数据(全局数据,不按区域拆分)。"""
from app.services.fdw_queries import _fdw_context
with _fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT COALESCE(SUM(cash_pay_amount), 0) AS cash_pay,
COALESCE(SUM(groupbuy_pay_amount), 0) AS groupbuy_pay,
COALESCE(SUM(cash_card_consume), 0) AS cash_card,
COALESCE(SUM(gift_card_consume), 0) AS gift_card
FROM app.v_dws_finance_daily_summary
WHERE stat_date >= %s::date AND stat_date <= %s::date
""",
(start, end),
)
ch = cur.fetchone()
cash_pay = float(ch[0]) if ch and ch[0] is not None else 0.0
groupbuy_pay = float(ch[1]) if ch and ch[1] is not None else 0.0
cash_card = float(ch[2]) if ch and ch[2] is not None else 0.0
gift_card_consume = float(ch[3]) if ch and ch[3] is not None else 0.0
return [
{"label": "储值卡结算冲销", "amount": cash_card + gift_card_consume},
{"label": "现金/线上支付", "amount": cash_pay},
{"label": "团购核销确认收入", "desc": "团购成交价", "amount": groupbuy_pay},
]
def _build_cashflow(
conn: Any, site_id: int, start: str, end: str,
@@ -555,6 +1026,37 @@ def _build_cashflow(
logger.warning("cashflow 查询失败,降级为空", exc_info=True)
return {"consume_items": [], "recharge_items": [], "total": 0.0}
if compare == 1 and prev_start and prev_end:
try:
prev = fdw_queries.get_finance_cashflow(conn, site_id, prev_start, prev_end)
total_cmp = calc_compare(
Decimal(str(data["total"])), Decimal(str(prev["total"]))
)
data["total_compare"] = total_cmp["compare"]
data["total_down"] = total_cmp["is_down"]
data["total_flat"] = total_cmp["is_flat"]
# consume_items 行级环比(按 label 匹配)
prev_consumes = {r["label"]: r for r in prev.get("consume_items", [])}
for item in data.get("consume_items", []):
prev_item = prev_consumes.get(item["label"], {})
cmp = calc_compare(
Decimal(str(item.get("amount", 0))),
Decimal(str(prev_item.get("amount", 0))),
)
item["compare"] = cmp["compare"]
item["down"] = cmp["is_down"]
# recharge_items 行级环比(按 label 匹配)
prev_recharges = {r["label"]: r for r in prev.get("recharge_items", [])}
for item in data.get("recharge_items", []):
prev_item = prev_recharges.get(item["label"], {})
cmp = calc_compare(
Decimal(str(item.get("amount", 0))),
Decimal(str(prev_item.get("amount", 0))),
)
item["compare"] = cmp["compare"]
except Exception:
logger.warning("cashflow 环比查询失败", exc_info=True)
return data
@@ -590,10 +1092,18 @@ def _build_expense(
def _build_coach_analysis(
conn: Any, site_id: int, start: str, end: str,
prev_start: str | None, prev_end: str | None, compare: int,
area: str = "all",
) -> dict:
"""助教分析板块。"""
"""助教分析板块。
CHANGE 2026-03-29 | Prompt: 助教分析按区域细化 |
area=all 走现有 salary_calc 查询area≠all 走 coach_area_hours JOIN salary_calc。
"""
try:
data = fdw_queries.get_finance_coach_analysis(conn, site_id, start, end)
if area == "all":
data = fdw_queries.get_finance_coach_analysis(conn, site_id, start, end)
else:
data = fdw_queries.get_finance_coach_analysis_area(conn, site_id, start, end, area)
except Exception:
logger.warning("coachAnalysis 查询失败,降级为空", exc_info=True)
empty_table = {"total_pay": 0.0, "total_share": 0.0, "avg_hourly": 0.0, "rows": []}
@@ -601,15 +1111,33 @@ def _build_coach_analysis(
if compare == 1 and prev_start and prev_end:
try:
prev = fdw_queries.get_finance_coach_analysis(
conn, site_id, prev_start, prev_end
)
if area == "all":
prev = fdw_queries.get_finance_coach_analysis(
conn, site_id, prev_start, prev_end
)
else:
prev = fdw_queries.get_finance_coach_analysis_area(
conn, site_id, prev_start, prev_end, area
)
for key in ("basic", "incentive"):
cur_t = data[key]
prev_t = prev[key]
_attach_compare(cur_t, cur_t, prev_t, [
"total_pay", "total_share", "avg_hourly",
])
# 行级环比(按 level 匹配)
prev_rows = {r["level"]: r for r in prev_t.get("rows", [])}
for row in cur_t.get("rows", []):
prev_row = prev_rows.get(row["level"], {})
pay_cmp = calc_compare(Decimal(str(row.get("pay", 0))), Decimal(str(prev_row.get("pay", 0))))
row["pay_compare"] = pay_cmp["compare"]
row["pay_down"] = pay_cmp["is_down"]
share_cmp = calc_compare(Decimal(str(row.get("share", 0))), Decimal(str(prev_row.get("share", 0))))
row["share_compare"] = share_cmp["compare"]
row["share_down"] = share_cmp["is_down"]
hourly_cmp = calc_compare(Decimal(str(row.get("hourly", 0))), Decimal(str(prev_row.get("hourly", 0))))
row["hourly_compare"] = hourly_cmp["compare"]
row["hourly_flat"] = hourly_cmp["is_flat"]
except Exception:
logger.warning("coachAnalysis 环比查询失败", exc_info=True)
@@ -658,6 +1186,7 @@ def _empty_revenue() -> dict:
"price_items": [],
"total_occurrence": 0.0,
"discount_items": [],
"discount_total": 0.0,
"confirmed_total": 0.0,
"channel_items": [],
}

View File

@@ -27,9 +27,11 @@ from typing import Any
from fastapi import HTTPException, status
from app.ai.bailian_client import BailianClient
from app.ai.config import AIConfig
from app.ai.dashscope_client import DashScopeClient
from app.database import get_connection
from app.services import fdw_queries
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -46,6 +48,7 @@ class ChatService:
# CHAT-1: 对话历史列表
# ------------------------------------------------------------------
@trace_service("查询对话历史", "Get chat history")
def get_chat_history(
self,
user_id: int,
@@ -149,6 +152,7 @@ class ChatService:
# 对话复用 / 创建
# ------------------------------------------------------------------
@trace_service("查找或创建对话", "Get or create session")
def get_or_create_session(
self,
user_id: int,
@@ -213,7 +217,7 @@ class ChatService:
context_type: str,
context_id: str | None,
) -> int:
"""创建新对话记录,返回 conversation_id。"""
"""创建新对话记录,返回 conversation_id。同时生成 session_id。"""
conn = get_connection()
try:
with conn.cursor() as cur:
@@ -230,11 +234,23 @@ class ChatService:
INSERT INTO biz.ai_conversations
(user_id, nickname, app_id, site_id, context_type, context_id)
VALUES (%s, %s, %s, %s, %s, %s)
RETURNING id
RETURNING id, EXTRACT(EPOCH FROM created_at)::bigint
""",
(str(user_id), nickname, APP_ID, site_id, context_type, context_id),
)
new_id = cur.fetchone()[0]
result = cur.fetchone()
new_id = result[0]
created_ts = result[1]
# 生成 session_id 并回写格式conv_{id}_{timestamp}
session_id = f"conv_{new_id}_{created_ts}"
cur.execute(
"""
UPDATE biz.ai_conversations SET session_id = %s WHERE id = %s
""",
(session_id, new_id),
)
conn.commit()
return new_id
except Exception:
@@ -243,10 +259,26 @@ class ChatService:
finally:
conn.close()
@trace_service("获取对话 session_id", "Get session ID")
def get_session_id(self, chat_id: int) -> str | None:
"""获取对话的 session_id。无记录或字段为空时返回 None。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT session_id FROM biz.ai_conversations WHERE id = %s",
(chat_id,),
)
row = cur.fetchone()
return row[0] if row and row[0] else None
finally:
conn.close()
# ------------------------------------------------------------------
# CHAT-2: 消息列表
# ------------------------------------------------------------------
@trace_service("查询消息列表", "Get messages")
def get_messages(
self,
chat_id: int,
@@ -312,6 +344,7 @@ class ChatService:
# CHAT-3: 发送消息(同步回复)
# ------------------------------------------------------------------
@trace_service("发送消息并获取回复", "Send message sync")
async def send_message_sync(
self,
chat_id: int,
@@ -368,6 +401,7 @@ class ChatService:
# referenceCard 组装
# ------------------------------------------------------------------
@trace_service("构建引用卡片", "Build reference card")
def build_reference_card(
self,
customer_id: int,
@@ -438,6 +472,7 @@ class ChatService:
# 标题生成
# ------------------------------------------------------------------
@trace_service("生成对话标题", "Generate title")
def generate_title(
self,
title: str | None = None,
@@ -582,11 +617,13 @@ class ChatService:
user_id: int,
site_id: int,
) -> tuple[str, int | None]:
"""调用百炼 API 获取非流式回复,返回 (reply_text, tokens_used)。
"""调用 DashScope Application API 获取非流式回复,返回 (reply_text, tokens_used)。
构建历史消息上下文发送给 AI
通过 Application.call() 调用 App1通用对话prompt 为最近历史拼接
"""
bailian = _get_bailian_client()
# CHANGE 2026-03-22 | BailianClient → DashScopeClientP14 迁移收尾)
client = _get_dashscope_client()
ai_config = AIConfig.from_env()
# 获取历史消息作为上下文(最近 20 条)
conn = get_connection()
@@ -604,33 +641,21 @@ class ChatService:
finally:
conn.close()
# 构建消息列表
messages: list[dict] = []
# 取最近 20 条(含刚写入的 user 消息)
# 拼接历史消息为 prompt 文本
recent = history[-20:] if len(history) > 20 else history
prompt_parts: list[str] = []
for role, msg_content in recent:
messages.append({"role": role, "content": msg_content})
prompt_parts.append(f"[{role}]: {msg_content}")
prompt = "\n".join(prompt_parts)
# 如果没有 system 消息,添加默认 system prompt
if not messages or messages[0]["role"] != "system":
system_prompt = {
"role": "system",
"content": json.dumps(
{"task": "你是台球门店的 AI 助手,根据用户的问题和当前页面上下文提供帮助。"},
ensure_ascii=False,
),
}
messages.insert(0, system_prompt)
# 通过 Application API 调用 App1
result, tokens_used, _session_id = await client.call_app(
ai_config.app_id_1_chat, prompt,
)
# 非流式调用chat_stream 用于 SSE这里用 chat_stream 收集完整回复
full_parts: list[str] = []
async for chunk in bailian.chat_stream(messages):
full_parts.append(chunk)
reply = "".join(full_parts)
# 流式模式不返回 tokens_used按字符数估算
estimated_tokens = len(reply)
return reply, estimated_tokens
# 从返回结果提取文本回复
reply = result.get("text", "") if isinstance(result, dict) else str(result)
return reply, tokens_used
@staticmethod
def _get_consumption_30d(conn: Any, site_id: int, member_id: int) -> Decimal | None:
@@ -673,13 +698,8 @@ class ChatService:
# ── 模块级辅助函数 ──────────────────────────────────────────────
def _get_bailian_client() -> BailianClient:
"""从环境变量构建 BailianClient缺失时报错。"""
api_key = os.environ.get("BAILIAN_API_KEY")
base_url = os.environ.get("BAILIAN_BASE_URL")
model = os.environ.get("BAILIAN_MODEL")
if not api_key or not base_url or not model:
raise RuntimeError(
"百炼 API 环境变量缺失,需要 BAILIAN_API_KEY、BAILIAN_BASE_URL、BAILIAN_MODEL"
)
return BailianClient(api_key=api_key, base_url=base_url, model=model)
def _get_dashscope_client() -> DashScopeClient:
"""从环境变量构建 DashScopeClient缺失时报错。"""
# CHANGE 2026-03-22 | BailianClient → DashScopeClientP14 迁移收尾)
ai_config = AIConfig.from_env()
return DashScopeClient(api_key=ai_config.api_key, workspace_id=ai_config.workspace_id)

View File

@@ -25,6 +25,7 @@ from decimal import Decimal
from app.services import fdw_queries
from app.services.task_generator import compute_heart_icon
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -39,9 +40,10 @@ LEVEL_COLOR_MAP = {
}
TASK_TYPE_MAP = {
"follow_up_visit": {"label": "回访", "class": "tag-callback"},
"high_priority_recall": {"label": "紧急召回", "class": "tag-recall"},
"priority_recall": {"label": "优先召回", "class": "tag-recall"},
"follow_up_visit": {"label": "客户回访", "class": "callback"},
"high_priority_recall": {"label": "高优先召回", "class": "high-priority"},
"priority_recall": {"label": "优先召回", "class": "priority"},
"relationship_building": {"label": "关系构建", "class": "relationship"},
}
# 头像渐变色池(循环使用)
@@ -85,6 +87,7 @@ def _format_currency(amount: float) -> str:
# ── 6.1 核心函数 ──────────────────────────────────────────
@trace_service("获取助教详情", "Get coach detail")
async def get_coach_detail(coach_id: int, site_id: int) -> dict:
"""
助教详情COACH-1
@@ -150,7 +153,13 @@ async def get_coach_detail(coach_id: int, site_id: int) -> dict:
performance = {
"monthly_hours": salary_this.get("total_hours", 0.0),
"monthly_salary": salary_this.get("total_income", 0.0),
# CHANGE 2026-03-26 | 到手 = base_income + bonus_income + bonus_money + room_incomeDWS 层已扣抽成)
"monthly_salary": (
salary_this.get("assistant_pd_money_total", 0.0)
+ salary_this.get("assistant_cx_money_total", 0.0)
+ salary_this.get("bonus_money", 0.0)
+ salary_this.get("room_income", 0.0)
),
"customer_balance": customer_balance,
"tasks_completed": tasks_completed,
"perf_current": salary_this.get("total_hours", 0.0),
@@ -287,22 +296,22 @@ def _build_income(
{
"label": "基础课时费",
"amount": f"¥{salary.get('assistant_pd_money_total', 0.0):,.0f}",
"color": "#42A5F5",
"color": "primary",
},
{
"label": "激励课时费",
"amount": f"¥{salary.get('assistant_cx_money_total', 0.0):,.0f}",
"color": "#FFA726",
"color": "success",
},
{
"label": "充值提成",
"amount": f"¥{salary.get('bonus_money', 0.0):,.0f}",
"color": "#66BB6A",
"color": "warning",
},
{
"label": "酒水提成",
"amount": f"¥{salary.get('room_income', 0.0):,.0f}",
"color": "#AB47BC",
"color": "purple",
},
]
@@ -385,17 +394,18 @@ def _build_top_customers(
balance = cust.get("customer_balance", 0.0)
consume = cust.get("total_consume", 0.0)
# CHANGE 2026-03-29 | coach-detail-500 修复 | relation_score → score对齐 TopCustomer.score Schema
result.append({
"id": mid or 0,
"name": name,
"initial": _get_initial(name),
"avatar_gradient": _get_avatar_gradient(i),
"heart_emoji": heart_emoji,
"relation_score": f"{score:.2f}",
"score": f"{score:.2f}",
"score_color": score_color,
"service_count": cust.get("service_count", 0),
"balance": _format_currency(balance),
"consume": _format_currency(consume),
"balance": float(balance) if balance else 0.0,
"consume": float(consume) if consume else 0.0,
})
return result
@@ -440,9 +450,9 @@ def _build_service_records(
"avatar_gradient": _get_avatar_gradient(i),
"type": course_type or "课程",
"type_class": type_class,
"table": str(rec.get("table_id")) if rec.get("table_id") else None,
"table": rec.get("table_name") or None,
"duration": f"{hours:.1f}h",
"income": _format_currency(income),
"income": float(income),
"date": date_str,
"perf_hours": None,
})
@@ -594,11 +604,12 @@ def _build_notes(coach_id: int, site_id: int, conn) -> list[dict]:
result = []
for r in rows:
# CHANGE 2026-03-29 | coach-detail-500 修复 | ai_score → score对齐 CoachNoteItem.score Schema
result.append({
"id": r[0],
"content": r[1] or "",
"timestamp": r[2].isoformat() if r[2] else "",
"ai_score": r[3],
"score": r[3],
"customer_name": member_name_map.get(r[5], ""),
"tag_label": r[4] or "",
"created_at": r[2].isoformat() if r[2] else "",
@@ -698,9 +709,9 @@ def _build_history_months(
result.append({
"month": month_label,
"estimated": month_str == current_month_str,
"customers": f"{customers}",
"hours": f"{hours:.1f}h",
"salary": _format_currency(salary_amount),
"customers": customers,
"hours": float(hours),
"salary": float(salary_amount),
"callback_done": callback_done,
"recall_done": recall_done,
})

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -19,10 +19,12 @@ from __future__ import annotations
import logging
from app.services.fdw_queries import _fdw_context
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@trace_service(description_zh="查找匹配候选人", description_en="Find matching candidates")
async def find_candidates(
site_id: int | None,
phone: str,

View File

@@ -17,6 +17,8 @@
import json
import logging
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -57,6 +59,7 @@ def _insert_history(
)
@trace_service(description_zh="ai_analyze_note", description_en="Ai Analyze Note")
def ai_analyze_note(note_id: int) -> int | None:
"""
AI 应用 6 备注分析接口(占位)。
@@ -67,6 +70,7 @@ def ai_analyze_note(note_id: int) -> int | None:
return None
@trace_service(description_zh="执行笔记重分类", description_en="Run note reclassification")
def run(payload: dict | None = None, job_id: int | None = None) -> dict:
"""
备注回溯主流程。

View File

@@ -10,6 +10,7 @@ import json
import logging
from fastapi import HTTPException
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -52,16 +53,67 @@ def _record_history(
def ai_analyze_note(note_id: int) -> int | None:
@trace_service(description_zh="ai_analyze_note", description_en="Ai Analyze Note")
async def ai_analyze_note(note_id: int, site_id: int, member_id: int, content: str, user_name: str = "") -> int | None:
"""
AI 应用 6 备注分析接口(占位)
AI 应用 6 备注分析:调用百炼 Application API 获取评分
P5 AI 集成层实现后替换此占位函数。
当前返回 None 表示 AI 未就绪,跳过评分逻辑
CHANGE 2026-03-27 | 打通 AI 应用 6 调用链
仅执行 App6 评分,不触发 App8 线索整合(后续统一处理)
返回 score1-10失败返回 None。
"""
return None
try:
from app.ai.config import AIConfig
from app.ai.dashscope_client import DashScopeClient
import json
config = AIConfig.from_env()
client = DashScopeClient(api_key=config.api_key, workspace_id=config.workspace_id)
# 构建 prompt简化版直接传给百炼应用
prompt = json.dumps({
"site_id": site_id,
"member_id": member_id,
"note_content": content,
"noted_by_name": user_name,
}, ensure_ascii=False)
result, tokens_used, _ = await client.call_app(config.app_id_6_note, prompt)
score = result.get("score") if isinstance(result, dict) else None
if score is not None:
score = max(1, min(10, int(score)))
logger.info("App6 备注评分完成: note_id=%d score=%d tokens=%d", note_id, score, tokens_used)
return score
except Exception:
logger.warning("App6 备注评分失败: note_id=%d", note_id, exc_info=True)
return None
async def _async_ai_score(note_id: int, site_id: int, member_id: int, content: str) -> None:
"""后台异步执行 AI 评分,不阻塞 API 响应。"""
try:
ai_score_val = await ai_analyze_note(
note_id=note_id, site_id=site_id, member_id=member_id, content=content,
)
if ai_score_val is not None:
conn = _get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"UPDATE biz.notes SET ai_score = %s, updated_at = NOW() WHERE id = %s",
(ai_score_val, note_id),
)
conn.commit()
logger.info("AI 评分已写入: note_id=%d ai_score=%d", note_id, ai_score_val)
finally:
conn.close()
except Exception:
logger.warning("后台 AI 评分失败: note_id=%d", note_id, exc_info=True)
@trace_service("创建备注", "Create note")
async def create_note(
site_id: int,
user_id: int,
@@ -71,6 +123,7 @@ async def create_note(
task_id: int | None = None,
rating_service_willingness: int | None = None,
rating_revisit_likelihood: int | None = None,
score: int | None = None,
) -> dict:
"""
创建备注。
@@ -91,6 +144,7 @@ async def create_note(
for label, val in [
("再次服务意愿评分", rating_service_willingness),
("再来店可能性评分", rating_revisit_likelihood),
("备注星星评分", score),
]:
if val is not None and (val < 1 or val > 5):
raise HTTPException(
@@ -139,17 +193,17 @@ async def create_note(
INSERT INTO biz.notes
(site_id, user_id, target_type, target_id, type,
content, rating_service_willingness,
rating_revisit_likelihood, task_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
rating_revisit_likelihood, task_id, score)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id, site_id, user_id, target_type, target_id,
type, content, rating_service_willingness,
rating_revisit_likelihood, task_id,
ai_score, ai_analysis, created_at, updated_at
ai_score, ai_analysis, created_at, updated_at, score
""",
(
site_id, user_id, target_type, target_id, note_type,
content, rating_service_willingness,
rating_revisit_likelihood, task_id,
rating_revisit_likelihood, task_id, score,
),
)
row = cur.fetchone()
@@ -169,26 +223,11 @@ async def create_note(
"ai_analysis": row[11],
"created_at": row[12].isoformat() if row[12] else None,
"updated_at": row[13].isoformat() if row[13] else None,
"score": row[14],
}
# 若 type='follow_up'触发 AI 分析并标记回访任务完成
# 若 type='follow_up',标记回访任务完成(不依赖 AI 评分)
if note_type == "follow_up" and task_id is not None:
# 保留 AI 占位调用P5 接入时调用链不变)
ai_score = ai_analyze_note(note["id"])
if ai_score is not None:
# 更新备注的 ai_score
cur.execute(
"""
UPDATE biz.notes
SET ai_score = %s, updated_at = NOW()
WHERE id = %s
""",
(ai_score, note["id"]),
)
note["ai_score"] = ai_score
# 不论 ai_score 如何有备注即标记回访任务完成T4
if task_info and task_info["status"] == "active":
cur.execute(
"""
@@ -209,13 +248,17 @@ async def create_note(
new_status="completed",
old_task_type=task_info["task_type"],
new_task_type=task_info["task_type"],
detail={
"note_id": note["id"],
"ai_score": ai_score,
},
detail={"note_id": note["id"]},
)
conn.commit()
# CHANGE 2026-03-27 | AI 评分:后台异步执行,不阻塞 API 响应
# 备注先返回给前端aiScore=nullAI 评分完成后写入数据库
# 前端下次加载页面时自动获取最新 aiScore
import asyncio
asyncio.create_task(_async_ai_score(note["id"], site_id, target_id, content))
return note
except HTTPException:
@@ -228,6 +271,7 @@ async def create_note(
conn.close()
@trace_service("查询备注列表", "Get notes")
async def get_notes(
site_id: int, target_type: str, target_id: int
) -> list[dict]:
@@ -280,6 +324,7 @@ async def get_notes(
conn.close()
@trace_service("删除备注", "Delete note")
async def delete_note(note_id: int, user_id: int, site_id: int) -> dict:
"""
删除备注。

View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
"""ETL 输出目录清理服务
遍历 EXPORT_ROOT 下每个任务文件夹,按目录名中的时间戳排序,
只保留最近 N 个运行记录,其余永久删除。
CHANGE 2026-03-27 | 新增:执行前自动清理输出目录,每类任务只保留最近 10 个运行记录
"""
from __future__ import annotations
import logging
import os
import re
import shutil
from pathlib import Path
logger = logging.getLogger(__name__)
# 运行记录目录命名格式:{TASK_CODE}-{run_id}-{YYYYMMDD}-{HHMMSS}
# 按最后两段(日期-时间)排序
_RUN_DIR_PATTERN = re.compile(r"^.+-(\d{8})-(\d{6})$")
def _get_export_root() -> Path:
"""从环境变量读取 EXPORT_ROOT缺失时报错。"""
val = os.environ.get("EXPORT_ROOT")
if not val:
raise RuntimeError(
"环境变量 EXPORT_ROOT 未设置,无法执行输出目录清理。"
"请在 .env 中配置 EXPORT_ROOT。"
)
p = Path(val)
if not p.is_dir():
raise RuntimeError(f"EXPORT_ROOT 路径不存在或不是目录: {p}")
return p
def _sort_key(dirname: str) -> tuple[str, str]:
"""从目录名提取排序键(日期, 时间),越大越新。"""
m = _RUN_DIR_PATTERN.match(dirname)
if m:
return (m.group(1), m.group(2))
# 不匹配格式的目录排到最前面(最旧),优先被清理
return ("00000000", "000000")
def cleanup_output_dirs(keep: int = 10) -> dict:
"""清理 EXPORT_ROOT 下每个任务文件夹,只保留最近 keep 个运行记录。
Returns:
清理结果摘要 dict包含 task_folders_scanned / dirs_deleted / errors
"""
export_root = _get_export_root()
total_scanned = 0
total_deleted = 0
errors: list[str] = []
for task_dir in sorted(export_root.iterdir()):
if not task_dir.is_dir():
continue
total_scanned += 1
# 列出所有子目录(运行记录)
run_dirs = [d for d in task_dir.iterdir() if d.is_dir()]
if len(run_dirs) <= keep:
continue
# 按时间戳降序排列,保留前 keep 个
run_dirs.sort(key=lambda d: _sort_key(d.name), reverse=True)
to_delete = run_dirs[keep:]
for d in to_delete:
try:
shutil.rmtree(d)
total_deleted += 1
except Exception as exc:
msg = f"删除失败 {d}: {exc}"
logger.warning(msg)
errors.append(msg)
logger.info(
"输出目录清理完成: 扫描 %d 个任务文件夹, 删除 %d 个运行记录, %d 个错误",
total_scanned, total_deleted, len(errors),
)
return {
"task_folders_scanned": total_scanned,
"dirs_deleted": total_deleted,
"errors": errors,
}

View File

@@ -17,11 +17,8 @@ from decimal import Decimal
from fastapi import HTTPException
from app.services import fdw_queries
from app.services.task_manager import (
_get_assistant_id,
compute_income_trend,
map_course_type_class,
)
from app.services.task_manager import _get_assistant_id, compute_income_trend
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -37,22 +34,8 @@ def _get_connection():
# 纯函数:可被属性测试直接调用
# ---------------------------------------------------------------------------
# 头像颜色预定义集合
_AVATAR_COLORS = [
"#0052d9", "#e34d59", "#00a870", "#ed7b2f",
"#0594fa", "#a25eb5", "#f6c244", "#2ba471",
]
def avatar_char_color(name: str) -> tuple[str, str]:
"""从客户姓名计算 avatarChar 和 avatarColor。"""
if not name:
return ("?", _AVATAR_COLORS[0])
char = name[0]
color = _AVATAR_COLORS[ord(char) % len(_AVATAR_COLORS)]
return (char, color)
@trace_service(description_zh="format_income_desc", description_en="Format Income Desc")
def format_income_desc(rate: float, hours: float) -> str:
"""
格式化收入明细描述。
@@ -65,15 +48,17 @@ def format_income_desc(rate: float, hours: float) -> str:
return f"{rate_str}元/h × {hours_str}h"
@trace_service(description_zh="group_records_by_date", description_en="Group Records By Date")
def group_records_by_date(
records: list[dict], *, include_avatar: bool = False
records: list[dict], *, include_avatar: bool = False,
rs_map: dict[int, float] | None = None,
) -> list[dict]:
"""
将服务记录按日期分组为 DateGroup 结构。
参数:
records: 服务记录列表(已按 settle_time DESC 排序)
include_avatar: 是否包含 avatarChar/avatarColorPERF-1 需要PERF-2 不需要
include_avatar: 是否包含 member_idPERF-1 需要前端计算头像颜色
返回按日期倒序排列的 DateGroup 列表。
"""
@@ -95,24 +80,33 @@ def group_records_by_date(
end_time = rec.get("end_time")
time_range = _format_time_range(start_time, end_time)
raw_course_type = rec.get("course_type", "")
type_class = map_course_type_class(raw_course_type)
# CHANGE 2026-03-24 | 课程类型直接用数据库原始值skill_name不做二次映射
raw_course_type = rec.get("course_type", "") or "基础课"
customer_name = rec.get("customer_name") or "未知客户"
record_item: dict = {
"customer_name": customer_name,
"time_range": time_range,
"hours": f"{rec.get('service_hours', 0.0):g}",
"course_type": raw_course_type or "基础课",
"course_type_class": type_class,
"hours": f"{rec.get('service_hours', 0.0):.1f}",
"course_type": raw_course_type,
"location": rec.get("table_name") or "",
"income": f"{rec.get('income', 0.0):.2f}",
}
# CHANGE 2026-03-24 | 头像颜色改为前端根据 member_id 计算,后端只传 member_id 和首字
if include_avatar:
char, color = avatar_char_color(customer_name)
record_item["avatar_char"] = char
record_item["avatar_color"] = color
mid = rec.get("member_id")
record_item["member_id"] = mid
# 散客/未知客户member_id 为空、0、负数→ "?"
if not mid or mid <= 0:
record_item["avatar_char"] = "?"
else:
record_item["avatar_char"] = customer_name[0] if customer_name else "?"
# CHANGE 2026-03-27 | 关系爱心标识:注入 heart_scoreRS 分数)
if rs_map and mid:
record_item["heart_score"] = rs_map.get(mid, 0.0)
else:
record_item["heart_score"] = 0.0
groups[date_key].append(record_item)
@@ -125,7 +119,7 @@ def group_records_by_date(
total_income = sum(float(r["income"]) for r in recs)
result.append({
"date": date_key,
"total_hours": f"{total_hours:g}",
"total_hours": f"{total_hours:.1f}",
"total_income": f"{total_income:.2f}",
"records": recs,
})
@@ -133,6 +127,7 @@ def group_records_by_date(
return result
@trace_service(description_zh="paginate_records", description_en="Paginate Records")
def paginate_records(
records: list[dict], page: int, page_size: int
) -> tuple[list[dict], bool]:
@@ -149,6 +144,7 @@ def paginate_records(
return page_records, has_more
@trace_service(description_zh="compute_summary", description_en="Compute Summary")
def compute_summary(records: list[dict]) -> dict:
"""
计算月度汇总。
@@ -204,6 +200,7 @@ def _format_date_label(dt) -> str:
# ---------------------------------------------------------------------------
@trace_service("获取绩效概览", "Get performance overview")
async def get_overview(
user_id: int, site_id: int, year: int, month: int
) -> dict:
@@ -244,11 +241,30 @@ async def get_overview(
)
# 按日期分组(含 avatar
date_groups = group_records_by_date(all_records, include_avatar=True)
# CHANGE 2026-03-27 | 批量查 RS 分数,注入到服务记录和客户列表
member_ids = list({r.get("member_id") for r in all_records if r.get("member_id")})
rs_map: dict[int, float] = {}
if member_ids:
try:
with fdw_queries._fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT member_id, COALESCE(rs_display, 0)
FROM app.v_dws_member_assistant_relation_index
WHERE assistant_id = %s AND member_id = ANY(%s)
""",
(assistant_id, member_ids),
)
for row in cur.fetchall():
rs_map[row[0]] = float(row[1])
except Exception:
logger.warning("查询 RS 分数失败", exc_info=True)
date_groups = group_records_by_date(all_records, include_avatar=True, rs_map=rs_map)
# ── 4. 新客/常客列表 ──
new_customers, regular_customers = _build_customer_lists(
conn, site_id, assistant_id, year, month, all_records
conn, site_id, assistant_id, year, month, all_records, rs_map=rs_map
)
# ── 5. 构建响应 ──
@@ -266,6 +282,7 @@ async def get_overview(
FROM auth.user_assistant_binding uab
JOIN auth.users u ON uab.user_id = u.id
WHERE uab.assistant_id = %s AND uab.site_id = %s
AND uab.is_removed = false
LIMIT 1
""",
(assistant_id, site_id),
@@ -279,27 +296,76 @@ async def get_overview(
logger.warning("查询助教信息失败", exc_info=True)
current_income = salary["total_income"] if salary else 0.0
basic_rate = salary["basic_rate"] if salary else 0.0
incentive_rate = salary["incentive_rate"] if salary else 0.0
# CHANGE 2026-03-24 | basic_rate/incentive_rate 改为助教到手单价(客户价 - 球房提成),
# 不再使用 base_course_price/bonus_course_price客户收费标准
base_course_price = salary["basic_rate"] if salary else 0.0 # 客户收费标准
bonus_course_price = salary["incentive_rate"] if salary else 0.0 # 客户收费标准
base_deduction = salary["base_deduction"] if salary else 0.0
bonus_deduction_ratio = salary["bonus_deduction_ratio"] if salary else 0.0
# 助教到手单价 = 客户价 - 球房提成
basic_rate = base_course_price - base_deduction
incentive_rate = bonus_course_price * (1 - bonus_deduction_ratio)
basic_hours = salary["basic_hours"] if salary else 0.0
bonus_hours = salary["bonus_hours"] if salary else 0.0
pd_money = salary["assistant_pd_money_total"] if salary else 0.0
cx_money = salary["assistant_cx_money_total"] if salary else 0.0
top_rank_bonus = salary["top_rank_bonus"] if salary else 0.0
recharge_commission = salary["recharge_commission"] if salary else 0.0
# 收入明细项
income_items = _build_income_items(
basic_rate, incentive_rate, basic_hours, bonus_hours,
pd_money, cx_money,
top_rank_bonus=top_rank_bonus,
recharge_commission=recharge_commission,
)
# 档位信息
next_basic_rate = salary["next_tier_basic_rate"] if salary else 0.0
next_incentive_rate = salary["next_tier_incentive_rate"] if salary else 0.0
upgrade_hours = salary["next_tier_hours"] if salary else 0.0
# CHANGE 2026-03-24 | 档位信息从 cfg_performance_tier 配置表计算,
# 复用 task_manager._build_performance_summary 的逻辑
total_hours = salary["total_hours"] if salary else 0.0
upgrade_hours_needed = max(0.0, upgrade_hours - total_hours)
tier_completed = salary["tier_completed"] if salary else False
upgrade_bonus = 0.0 if tier_completed else (salary["bonus_money"] if salary else 0.0)
tiers: list[dict] = []
try:
tiers = fdw_queries.get_performance_tiers(conn, site_id)
except Exception:
logger.warning("查询 cfg_performance_tier 失败", exc_info=True)
# 找到当前档位和下一档
tier_completed = False
next_tier_hours = 0.0
current_tier_data = None
next_tier_data = None
if tiers:
for i, t in enumerate(tiers):
if t["min_hours"] > total_hours:
next_tier_data = t
current_tier_data = tiers[i - 1] if i > 0 else tiers[0]
next_tier_hours = t["min_hours"]
break
if next_tier_data is None:
# 已达到或超过最高档
tier_completed = True
current_tier_data = tiers[-1]
upgrade_hours_needed = max(0.0, next_tier_hours - total_hours) if not tier_completed else 0.0
# 下一档到手费率
if next_tier_data:
next_basic_rate = base_course_price - next_tier_data["base_deduction"]
next_incentive_rate = bonus_course_price * (1 - next_tier_data["bonus_deduction_ratio"])
else:
next_basic_rate = 0.0
next_incentive_rate = 0.0
# bonus_money: 升到下一档后因抽成降低能多拿的钱
# 公式同 task_manager._build_performance_summary
upgrade_bonus = 0.0
if not tier_completed and current_tier_data and next_tier_data:
base_ded_diff = current_tier_data["base_deduction"] - next_tier_data["base_deduction"]
base_saving = next_tier_data["min_hours"] * base_ded_diff if base_ded_diff > 0 else 0.0
bonus_ratio_diff = current_tier_data["bonus_deduction_ratio"] - next_tier_data["bonus_deduction_ratio"]
bonus_saving = bonus_hours * bonus_course_price * bonus_ratio_diff if bonus_ratio_diff > 0 else 0.0
upgrade_bonus = round(base_saving + bonus_saving, 2)
return {
"coach_name": coach_name,
@@ -335,27 +401,49 @@ def _build_income_items(
bonus_hours: float,
pd_money: float,
cx_money: float,
*,
top_rank_bonus: float = 0.0,
recharge_commission: float = 0.0,
) -> list[dict]:
"""构建收入明细项列表。"""
"""
构建收入明细项列表。
CHANGE 2026-03-24 | 始终显示所有项基础课、激励课、Top3销冠奖、充值提成即使为 0
Top3销冠奖为 0 时 desc 显示"继续努力"
"""
items = []
# 基础课收入
if basic_hours > 0 or pd_money > 0:
items.append({
"icon": "💰",
"label": "基础课收入",
"desc": format_income_desc(basic_rate, basic_hours),
"value": f"¥{pd_money:,.2f}",
})
# 基础课收入(始终显示)
items.append({
"icon": "💰",
"label": "基础课收入",
"desc": format_income_desc(basic_rate, basic_hours),
"value": f"¥{pd_money:,.2f}",
})
# 激励课收入
if bonus_hours > 0 or cx_money > 0:
items.append({
"icon": "🎯",
"label": "激励课收入",
"desc": format_income_desc(incentive_rate, bonus_hours),
"value": f"¥{cx_money:,.2f}",
})
# 激励课收入(始终显示)
items.append({
"icon": "🎯",
"label": "激励课收入",
"desc": format_income_desc(incentive_rate, bonus_hours),
"value": f"¥{cx_money:,.2f}",
})
# Top3销冠奖始终显示为 0 时 desc 显示"继续努力"
items.append({
"icon": "🏆",
"label": "Top3销冠奖",
"desc": "继续努力" if top_rank_bonus == 0 else "本月销冠奖励",
"value": f"¥{top_rank_bonus:,.2f}",
})
# CHANGE 2026-03-24 | 充值提成(始终显示)
items.append({
"icon": "💳",
"label": "充值提成",
"desc": "充值激励",
"value": f"¥{recharge_commission:,.2f}",
})
return items
@@ -367,12 +455,17 @@ def _build_customer_lists(
year: int,
month: int,
all_records: list[dict],
*,
rs_map: dict[int, float] | None = None,
) -> tuple[list[dict], list[dict]]:
"""
构建新客和常客列表。
新客: 本月有服务记录但本月之前无记录的客户
常客: 本月服务次数 ≥ 2 的客户
常客: 本月服务次数 ≥ 2 的客户统计数据拉近90天
CHANGE 2026-03-24 | 头像颜色改为前端根据 member_id 计算,后端只传 member_id 和首字。
CHANGE 2026-03-24 | 常客展示数据改为近90天聚合判定标准不变本月≥2次
"""
if not all_records:
return [], []
@@ -395,16 +488,12 @@ def _build_customer_lists(
stats["count"] += 1
stats["total_hours"] += rec.get("service_hours", 0.0)
stats["total_income"] += rec.get("income", 0.0)
# 更新最后服务时间(记录已按 settle_time DESC 排序,第一条即最新)
if stats["last_service"] is None:
stats["last_service"] = rec.get("settle_time")
member_ids = list(member_stats.keys())
# 查询历史记录(本月之前是否有服务记录)
# ⚠️ 直连 ETL 库查询 app.v_dwd_assistant_service_log RLS 视图
# 列名映射: assistant_id → site_assistant_id, member_id → tenant_member_id,
# is_trash → is_delete (int, 0=正常), settle_time → create_time
# 查询历史记录(本月之前是否有服务记录)— 用于新客判定
historical_members: set[int] = set()
try:
start_date = f"{year}-{month:02d}-01"
@@ -425,33 +514,64 @@ def _build_customer_lists(
except Exception:
logger.warning("查询历史客户记录失败", exc_info=True)
# 查询近90天聚合数据常客展示用
if month == 12:
next_month_start = f"{year + 1}-01-01"
else:
next_month_start = f"{year}-{month + 1:02d}-01"
stats_90d: dict[int, dict] = {}
try:
rows_90d = fdw_queries.get_service_records_90days(
conn, site_id, assistant_id, next_month_start,
)
for r in rows_90d:
stats_90d[r["member_id"]] = r
except Exception:
logger.warning("查询近90天服务记录失败", exc_info=True)
new_customers = []
regular_customers = []
for mid, stats in member_stats.items():
# CHANGE 2026-03-27 | 过滤散客/未知客户member_id ≤ 0不进入新客和常客列表
if mid <= 0:
continue
name = stats["customer_name"]
char, color = avatar_char_color(name)
char = name[0] if name else "?"
# 新客:历史无记录
if mid not in historical_members:
last_service_dt = stats["last_service"]
new_customers.append({
"name": name,
"member_id": mid,
"avatar_char": char,
"avatar_color": color,
"last_service": _format_date_label(last_service_dt),
"count": stats["count"],
"heart_score": rs_map.get(mid, 0.0) if rs_map else 0.0,
})
# 常客:本月 ≥ 2 次
# 常客:本月 ≥ 2 次展示数据用90天聚合
if stats["count"] >= 2:
s90 = stats_90d.get(mid)
if s90:
reg_count = s90["count"]
reg_hours = round(s90["total_hours"], 2)
reg_income = s90["total_income"]
else:
# 90天查询失败时回退到本月数据
reg_count = stats["count"]
reg_hours = round(stats["total_hours"], 2)
reg_income = stats["total_income"]
regular_customers.append({
"name": name,
"member_id": mid,
"avatar_char": char,
"avatar_color": color,
"hours": round(stats["total_hours"], 2),
"income": f"¥{stats['total_income']:,.2f}",
"count": stats["count"],
"hours": reg_hours,
"income": f"¥{reg_income:,.2f}",
"count": reg_count,
"heart_score": rs_map.get(mid, 0.0) if rs_map else 0.0,
})
# 新客按最后服务时间倒序
@@ -470,6 +590,7 @@ def _build_customer_lists(
# ---------------------------------------------------------------------------
@trace_service("获取绩效明细", "Get performance records")
async def get_records(
user_id: int, site_id: int,
year: int, month: int, page: int, page_size: int,
@@ -506,8 +627,27 @@ async def get_records(
# 判断 hasMore
has_more = len(all_records) > page * page_size
# 按日期分组(不含 avatar
date_groups = group_records_by_date(page_records, include_avatar=False)
# CHANGE 2026-03-27 | 批量查 RS 分数,注入到服务记录
page_member_ids = list({r.get("member_id") for r in page_records if r.get("member_id")})
rs_map: dict[int, float] = {}
if page_member_ids:
try:
with fdw_queries._fdw_context(conn, site_id) as cur:
cur.execute(
"""
SELECT member_id, COALESCE(rs_display, 0)
FROM app.v_dws_member_assistant_relation_index
WHERE assistant_id = %s AND member_id = ANY(%s)
""",
(assistant_id, page_member_ids),
)
for row in cur.fetchall():
rs_map[row[0]] = float(row[1])
except Exception:
logger.warning("PERF-2 查询 RS 分数失败", exc_info=True)
# 按日期分组(含 member_id / avatar_char前端计算头像颜色
date_groups = group_records_by_date(page_records, include_avatar=True, rs_map=rs_map)
return {
"summary": summary,

View File

@@ -16,6 +16,8 @@ ETL 数据更新后,直连 ETL 库读取助教服务记录,
import json
import logging
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@@ -56,6 +58,7 @@ def _insert_history(
)
@trace_service(description_zh="执行维客检测", description_en="Run recall detection")
def run(payload: dict | None = None, job_id: int | None = None) -> dict:
"""
召回完成检测主流程。
@@ -178,6 +181,9 @@ def _process_site(conn, site_id: int, last_run_at) -> int:
# ── 4-7. 逐条服务记录匹配并处理 ──
for assistant_id, member_id, service_time in service_records:
# 散客过滤member_id ≤ 0 不参与任务系统)
if member_id is None or member_id <= 0:
continue
try:
count = _process_service_record(
conn, site_id, assistant_id, member_id, service_time
@@ -203,7 +209,13 @@ def _process_service_record(
service_time,
) -> int:
"""
处理单条服务记录:匹配 active 任务并标记 completed。
处理单条服务记录:匹配 active 任务并标记 completed + 生成回访任务
CHANGE 2026-03-30 | 回访任务直接在此生成(不再依赖 note_reclassifier 事件链)。
规则:
- 有 active 召回任务 → 标记 completed然后生成回访任务
- 有 active 回访任务 → 关闭旧回访,生成新回访(重置 48h 倒计时)
- 无任何 active 召回/回访 → 直接生成回访任务
每条服务记录独立事务,失败不影响其他。
返回本次完成的任务数。
@@ -213,7 +225,7 @@ def _process_service_record(
with conn.cursor() as cur:
cur.execute("BEGIN")
# 查找匹配的 active 召回类任务(仅完成召回任务,回访/关系构建不在此处理)
# ── 1. 查找匹配的 active 召回类任务 ──
cur.execute(
"""
SELECT id, task_type
@@ -226,14 +238,12 @@ def _process_service_record(
""",
(site_id, assistant_id, member_id),
)
active_tasks = cur.fetchall()
active_recall_tasks = cur.fetchall()
if not active_tasks:
conn.commit()
return 0
has_active_recall = len(active_recall_tasks) > 0
# 将所有匹配的 active 任务标记为 completed
for task_id, task_type in active_tasks:
# 将所有匹配的 active 召回任务标记为 completed
for task_id, task_type in active_recall_tasks:
cur.execute(
"""
UPDATE biz.coach_tasks
@@ -260,28 +270,82 @@ def _process_service_record(
)
completed += 1
conn.commit()
# ── 2. 生成回访任务CHANGE 2026-03-30 ──
# 如果还有 active 召回任务(其他助教的),不生成回访
# 注意:上面已经把当前助教的召回任务标记为 completed 了
# 这里检查的是当前助教-客户对是否还有未完成的召回任务(不应该有了)
# ── 7. 触发 recall_completed 事件 ──
# 延迟导入 fire_event 避免循环依赖
try:
from app.services.trigger_scheduler import fire_event
# 关闭已有的 active 回访任务
cur.execute(
"""
SELECT id FROM biz.coach_tasks
WHERE site_id = %s AND assistant_id = %s AND member_id = %s
AND task_type = 'follow_up_visit' AND status = 'active'
""",
(site_id, assistant_id, member_id),
)
old_follow_ups = cur.fetchall()
for (old_id,) in old_follow_ups:
cur.execute(
"""
UPDATE biz.coach_tasks
SET status = 'inactive', updated_at = NOW()
WHERE id = %s
""",
(old_id,),
)
_insert_history(
cur, old_id,
action="superseded_by_new_visit",
old_status="active", new_status="inactive",
old_task_type="follow_up_visit", new_task_type="follow_up_visit",
detail={"reason": "new_service_record", "service_time": str(service_time)},
)
fire_event(
"recall_completed",
{
"site_id": site_id,
"assistant_id": assistant_id,
"member_id": member_id,
# 创建新的回访任务48h 过期)
from datetime import timedelta
expires_at = service_time + timedelta(hours=48) if hasattr(service_time, '__add__') else None
cur.execute(
"""
INSERT INTO biz.coach_tasks
(site_id, assistant_id, member_id, task_type, status, expires_at, created_at, updated_at)
VALUES (%s, %s, %s, 'follow_up_visit', 'active', %s, NOW(), NOW())
RETURNING id
""",
(site_id, assistant_id, member_id, expires_at),
)
new_follow_up_id = cur.fetchone()[0]
_insert_history(
cur, new_follow_up_id,
action="created",
old_status=None, new_status="active",
new_task_type="follow_up_visit",
detail={
"reason": "service_record_detected",
"service_time": str(service_time),
"had_recall": has_active_recall,
},
)
except Exception:
logger.exception(
"触发 recall_completed 事件失败: site_id=%s, assistant_id=%s, member_id=%s",
site_id,
assistant_id,
member_id,
)
conn.commit()
# ── 3. 触发 recall_completed 事件(仅当有召回任务被完成时) ──
if has_active_recall:
try:
from app.services.trigger_scheduler import fire_event
fire_event(
"recall_completed",
{
"site_id": site_id,
"assistant_id": assistant_id,
"member_id": member_id,
"service_time": str(service_time),
},
)
except Exception:
logger.exception(
"触发 recall_completed 事件失败: site_id=%s, assistant_id=%s, member_id=%s",
site_id, assistant_id, member_id,
)
return completed

View File

@@ -15,10 +15,12 @@ from __future__ import annotations
import logging
from app.database import get_connection
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
@trace_service(description_zh="获取用户权限列表", description_en="Get user permissions")
async def get_user_permissions(user_id: int, site_id: int) -> list[str]:
"""
获取用户在指定 site_id 下的权限 code 列表。
@@ -43,6 +45,7 @@ async def get_user_permissions(user_id: int, site_id: int) -> list[str]:
JOIN auth.role_permissions rp ON usr.role_id = rp.role_id
JOIN auth.permissions p ON rp.permission_id = p.id
WHERE usr.user_id = %s AND usr.site_id = %s
AND usr.is_removed = false
""",
(user_id, site_id),
)
@@ -53,11 +56,12 @@ async def get_user_permissions(user_id: int, site_id: int) -> list[str]:
return [row[0] for row in rows]
@trace_service(description_zh="获取用户门店列表", description_en="Get user sites")
async def get_user_sites(user_id: int) -> list[dict]:
"""
获取用户关联的所有店铺及对应角色。
查询 user_site_roles JOIN rolesLEFT JOIN site_code_mapping 获取店铺名称,
查询 user_site_roles JOIN rolesLEFT JOIN biz.sites 获取店铺名称,
按 site_id 分组聚合角色列表。
参数:
@@ -77,8 +81,9 @@ async def get_user_sites(user_id: int) -> list[dict]:
r.name
FROM auth.user_site_roles usr
JOIN auth.roles r ON usr.role_id = r.id
LEFT JOIN auth.site_code_mapping scm ON usr.site_id = scm.site_id
LEFT JOIN biz.sites scm ON scm.site_id = usr.site_id
WHERE usr.user_id = %s
AND usr.is_removed = false
ORDER BY usr.site_id, r.code
""",
(user_id,),
@@ -101,6 +106,7 @@ async def get_user_sites(user_id: int) -> list[dict]:
return list(sites_map.values())
@trace_service(description_zh="检查用户门店角色", description_en="Check user site role")
async def check_user_has_site_role(user_id: int, site_id: int) -> bool:
"""
检查用户在指定 site_id 下是否有任何角色绑定。
@@ -120,6 +126,7 @@ async def check_user_has_site_role(user_id: int, site_id: int) -> bool:
SELECT 1
FROM auth.user_site_roles
WHERE user_id = %s AND site_id = %s
AND is_removed = false
LIMIT 1
""",
(user_id, site_id),

View File

@@ -22,6 +22,8 @@ from ..schemas.schedules import ScheduleConfigSchema
from ..schemas.tasks import TaskConfigSchema
from .task_queue import task_queue
from app.trace.decorators import trace_service
logger = logging.getLogger(__name__)
# 调度器轮询间隔(秒)
@@ -34,6 +36,23 @@ def _parse_time(time_str: str) -> tuple[int, int]:
return int(parts[0]), int(parts[1])
def _convert_interval_to_seconds(value: int, unit: str) -> int:
"""将间隔值转换为秒数。
Args:
value: 间隔数值0 = 无限制)
unit: 间隔单位,支持 "minutes""hours""days"
Returns:
对应的秒数value <= 0 时返回 0
"""
if value <= 0:
return 0
multipliers = {"minutes": 60, "hours": 3600, "days": 86400}
return value * multipliers.get(unit, 60)
@trace_service(description_zh="calculate_next_run", description_en="Calculate Next Run")
def calculate_next_run(
schedule_config: ScheduleConfigSchema,
now: datetime | None = None,
@@ -188,7 +207,9 @@ class Scheduler:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, site_id, task_config, schedule_config
SELECT id, site_id, task_config, schedule_config,
min_run_interval_value, min_run_interval_unit,
last_run_at, last_status, min_run_intervals
FROM scheduled_tasks
WHERE enabled = TRUE
AND next_run_at IS NOT NULL
@@ -198,11 +219,32 @@ class Scheduler:
)
rows = cur.fetchall()
now = datetime.now(timezone.utc)
for row in rows:
task_id = str(row[0])
site_id = row[1]
task_config_raw = row[2] if isinstance(row[2], dict) else json.loads(row[2])
schedule_config_raw = row[3] if isinstance(row[3], dict) else json.loads(row[3])
min_interval_value = row[4] or 0
min_interval_unit = row[5] or "minutes"
last_run_at = row[6]
last_status = row[7]
# per-task 间隔:取所有任务中最大的间隔作为有效间隔
min_run_intervals_raw = row[8] if isinstance(row[8], dict) else json.loads(row[8]) if row[8] else {}
# 计算有效间隔per-task 最大值 vs schedule 级别,取较大者
effective_interval_seconds = _convert_interval_to_seconds(
min_interval_value, min_interval_unit
)
for _task_code, interval_cfg in min_run_intervals_raw.items():
if isinstance(interval_cfg, dict):
task_seconds = _convert_interval_to_seconds(
interval_cfg.get("value", 0),
interval_cfg.get("unit", "minutes"),
)
if task_seconds > effective_interval_seconds:
effective_interval_seconds = task_seconds
try:
config = TaskConfigSchema(**task_config_raw)
@@ -211,7 +253,44 @@ class Scheduler:
logger.exception("调度任务 [%s] 配置反序列化失败,跳过", task_id)
continue
# 入队
# 1. 并发检查:上次仍在运行中 → 跳过
if last_status == "running":
logger.warning(
"调度任务 [%s] skipped_concurrent上次执行仍在运行中",
task_id,
)
continue
# 2. 间隔检查:最小运行间隔未到 → 跳过并推进 next_run_at
if effective_interval_seconds > 0 and last_run_at is not None:
elapsed = (now - last_run_at).total_seconds()
if elapsed < effective_interval_seconds:
# 推进 next_run_at = last_run_at + interval
next_run_at_pushed = last_run_at + timedelta(
seconds=effective_interval_seconds
)
with conn.cursor() as cur:
cur.execute(
"""
UPDATE scheduled_tasks
SET next_run_at = %s,
updated_at = NOW()
WHERE id = %s
""",
(next_run_at_pushed, task_id),
)
conn.commit()
logger.info(
"调度任务 [%s] skipped_interval最小间隔未到"
"(已过 %.0fs / 需 %dsnext_run_at 推进至 %s",
task_id,
elapsed,
effective_interval_seconds,
next_run_at_pushed,
)
continue
# 3. 正常入队
try:
queue_id = task_queue.enqueue(config, site_id, schedule_id=task_id)
logger.info(
@@ -224,7 +303,6 @@ class Scheduler:
continue
# 更新调度任务状态
now = datetime.now(timezone.utc)
next_run = calculate_next_run(schedule_cfg, now)
with conn.cursor() as cur:
@@ -269,6 +347,9 @@ class Scheduler:
# 在线程池中执行同步数据库操作,避免阻塞事件循环
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.check_and_enqueue)
# CHANGE 2026-03-23 | 同时检查 trigger_jobs 中到期的 cron/interval 任务
from app.services.trigger_scheduler import check_scheduled_jobs
await loop.run_in_executor(None, check_scheduled_jobs)
except Exception:
logger.exception("Scheduler 循环迭代异常")

Some files were not shown because too many files have changed in this diff Show More