init: 项目初始提交 - NeoZQYY Monorepo 完整代码

This commit is contained in:
Neo
2026-02-15 14:58:14 +08:00
commit ded6dfb9d8
769 changed files with 182616 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
# tasks/ — ETL 任务
## 目录结构
```
tasks/
├── base_task.py # BaseTask 基类Extract → Transform → Load 模板方法)
├── ods/ # ODS 层:从 API 抓取或离线 JSON 回放,写入 ODS 表
├── dwd/ # DWD 层:从 ODS 清洗装载到 DWD维度 SCD2 + 事实增量)
├── dws/ # DWS 层:汇总统计(助教业绩、财务日报、会员分析等)
│ └── index/ # 指数计算(亲密度、新客转化、召回、关系、赢回)
├── utility/ # 工具类任务Schema 初始化、手动入库、完整性检查等)
└── verification/ # 校验任务ODS/DWD/DWS/指数层的数据一致性校验)
```
## 新增任务流程
1. 在对应子目录创建任务文件,继承 `BaseTask`
2. 实现 `get_task_code()` 返回大写蛇形任务代码(如 `DWS_MEMBER_VISIT`
3. 实现 `execute(context)` 方法,包含 Extract → Transform → Load 逻辑
4.`orchestration/task_registry.py` 中注册任务,指定元数据:
- `layer``ODS` / `DWD` / `DWS` / `UTILITY` / `VERIFICATION`
- `task_type``ETL` / `UTILITY` / `VERIFICATION`
- `requires_db_config`:是否需要数据库连接
```python
# 示例:注册一个新的 DWS 任务
registry.register(
task_code="DWS_NEW_REPORT",
task_class=NewReportTask,
layer="DWS",
task_type="ETL",
requires_db_config=True,
)
```
## 任务命名约定
- 任务代码:大写蛇形(`DWD_LOAD_FROM_ODS``DWS_ASSISTANT_DAILY`
- 文件名:小写蛇形 + `_task.py` 后缀(`assistant_daily_task.py`
- 类名:驼峰 + `Task` 后缀(`AssistantDailyTask`
## ODS 任务特殊说明
ODS 任务通过 `ods/ods_tasks.py` 中的 `ODS_TASK_SPECS` 声明式定义,无需为每个实体单独写 execute 逻辑。新增 ODS 实体只需在 `ODS_TASK_SPECS` 列表中添加一条 spec 记录。

View File

@@ -0,0 +1,253 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-14] 默认时区从 Asia/Taipei 修正为 Asia/Shanghai与运营地区一致
"""ETL任务基类引入 Extract/Transform/Load 模板方法)"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from utils.windowing import build_window_segments, calc_window_minutes, calc_window_days, format_window_days
@dataclass(frozen=True)
class TaskContext:
"""统一透传给 Extract/Transform/Load 的运行期信息。"""
store_id: int
window_start: datetime
window_end: datetime
window_minutes: int
cursor: dict | None = None
class BaseTask:
"""提供 E/T/L 模板的任务基类。"""
def __init__(self, config, db_connection, api_client, logger):
self.config = config
self.db = db_connection
self.api = api_client
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
# ------------------------------------------------------------------ 基本信息
def get_task_code(self) -> str:
"""获取任务代码"""
raise NotImplementedError("子类需实现 get_task_code 方法")
# ------------------------------------------------------------------ E/T/L 钩子
def extract(self, context: TaskContext):
"""提取数据"""
raise NotImplementedError("子类需实现 extract 方法")
def transform(self, extracted, context: TaskContext):
"""转换数据"""
return extracted
def load(self, transformed, context: TaskContext) -> dict:
"""加载数据并返回统计信息"""
raise NotImplementedError("子类需实现 load 方法")
# ------------------------------------------------------------------ 主流程
def execute(self, cursor_data: dict | None = None) -> dict:
"""统一 orchestrate Extract → Transform → Load"""
base_context = self._build_context(cursor_data)
task_code = self.get_task_code()
segments = build_window_segments(
self.config,
base_context.window_start,
base_context.window_end,
tz=self.tz,
override_only=True,
)
if not segments:
segments = [(base_context.window_start, base_context.window_end)]
total_segments = len(segments)
total_days = sum(calc_window_days(s, e) for s, e in segments) if segments else 0.0
processed_days = 0.0
if total_segments > 1:
self.logger.info(
"%s: 窗口拆分为 %s 段(共 %s 天)",
task_code,
total_segments,
format_window_days(total_days),
)
total_counts: dict = {}
segment_results: list[dict] = []
for idx, (window_start, window_end) in enumerate(segments, start=1):
context = self._build_context_for_window(window_start, window_end, cursor_data)
self.logger.info(
"%s: 开始执行(%s/%s),窗口[%s ~ %s]",
task_code,
idx,
total_segments,
context.window_start,
context.window_end,
)
try:
extracted = self.extract(context)
transformed = self.transform(extracted, context)
counts = self.load(transformed, context) or {}
self.db.commit()
except Exception:
self.db.rollback()
self.logger.error("%s: 执行失败", task_code, exc_info=True)
raise
self._accumulate_counts(total_counts, counts)
segment_days = calc_window_days(context.window_start, context.window_end)
processed_days += segment_days
if total_segments > 1:
self.logger.info(
"%s: 完成(%s/%s),已处理 %s/%s",
task_code,
idx,
total_segments,
format_window_days(processed_days),
format_window_days(total_days),
)
if total_segments > 1:
segment_results.append(
{
"window": {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
},
"counts": counts,
}
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
result = self._build_result("SUCCESS", total_counts)
result["window"] = {
"start": overall_start,
"end": overall_end,
"minutes": calc_window_minutes(overall_start, overall_end),
}
if segment_results:
result["segments"] = segment_results
self.logger.info("%s: 完成,统计=%s", task_code, result["counts"])
return result
def _build_context(self, cursor_data: dict | None) -> TaskContext:
window_start, window_end, window_minutes = self._get_time_window(cursor_data)
return TaskContext(
store_id=self.config.get("app.store_id"),
window_start=window_start,
window_end=window_end,
window_minutes=window_minutes,
cursor=cursor_data,
)
def _build_context_for_window(
self,
window_start: datetime,
window_end: datetime,
cursor_data: dict | None,
) -> TaskContext:
return TaskContext(
store_id=self.config.get("app.store_id"),
window_start=window_start,
window_end=window_end,
window_minutes=calc_window_minutes(window_start, window_end),
cursor=cursor_data,
)
@staticmethod
def _accumulate_counts(total: dict, current: dict) -> dict:
for key, value in (current or {}).items():
if isinstance(value, (int, float)):
total[key] = (total.get(key) or 0) + value
else:
total.setdefault(key, value)
return total
def _get_time_window(self, cursor_data: dict = None) -> tuple:
"""计算时间窗口"""
now = datetime.now(self.tz)
override_start = self.config.get("run.window_override.start")
override_end = self.config.get("run.window_override.end")
if override_start or override_end:
if not (override_start and override_end):
raise ValueError("run.window_override.start/end 需要同时提供")
window_start = override_start
if isinstance(window_start, str):
window_start = dtparser.parse(window_start)
if isinstance(window_start, datetime) and window_start.tzinfo is None:
window_start = window_start.replace(tzinfo=self.tz)
elif isinstance(window_start, datetime):
window_start = window_start.astimezone(self.tz)
window_end = override_end
if isinstance(window_end, str):
window_end = dtparser.parse(window_end)
if isinstance(window_end, datetime) and window_end.tzinfo is None:
window_end = window_end.replace(tzinfo=self.tz)
elif isinstance(window_end, datetime):
window_end = window_end.astimezone(self.tz)
if not isinstance(window_start, datetime) or not isinstance(window_end, datetime):
raise ValueError("run.window_override.start/end 解析失败")
if window_end <= window_start:
raise ValueError("run.window_override.end 必须大于 start")
window_minutes = max(1, int((window_end - window_start).total_seconds() // 60))
return window_start, window_end, window_minutes
idle_start = self.config.get("run.idle_window.start", "04:00")
idle_end = self.config.get("run.idle_window.end", "16:00")
is_idle = self._is_in_idle_window(now, idle_start, idle_end)
if is_idle:
window_minutes = self.config.get("run.window_minutes.default_idle", 180)
else:
window_minutes = self.config.get("run.window_minutes.default_busy", 30)
overlap_seconds = self.config.get("run.overlap_seconds", 600)
if cursor_data and cursor_data.get("last_end"):
window_start = cursor_data["last_end"] - timedelta(seconds=overlap_seconds)
else:
window_start = now - timedelta(minutes=window_minutes)
window_end = now
return window_start, window_end, window_minutes
def _is_in_idle_window(self, dt: datetime, start_time: str, end_time: str) -> bool:
"""判断是否在闲时窗口"""
current_time = dt.strftime("%H:%M")
return start_time <= current_time <= end_time
def _merge_common_params(self, base: dict) -> dict:
"""
合并全局/任务级参数池便于在配置中统一覆<E4B880>?/追加过滤条件。
支持:
- api.params 下的通用键<E794A8>?
- api.params.<task_code_lower> 下的任务级键<E7BAA7>?
"""
merged: dict = {}
common = self.config.get("api.params", {}) or {}
if isinstance(common, dict):
merged.update(common)
task_key = f"api.params.{self.get_task_code().lower()}"
scoped = self.config.get(task_key, {}) or {}
if isinstance(scoped, dict):
merged.update(scoped)
merged.update(base)
return merged
def _build_result(self, status: str, counts: dict) -> dict:
"""构建结果字典"""
return {"status": status, "counts": counts}

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""DWD 层装载任务"""

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
"""DWD任务基类"""
import json
from typing import Any, Dict, Iterator, List, Optional, Tuple
from datetime import datetime
from tasks.base_task import BaseTask
from models.parsers import TypeParser
class BaseDwdTask(BaseTask):
"""
DWD 层任务基类
负责从 ODS 表读取数据,供子类清洗和写入事实/维度表
"""
def _get_ods_cursor(self, task_code: str) -> datetime:
"""
获取上次处理的 ODS 数据的时间点 (fetched_at)
这里简化处理,实际应该从 etl_cursor 表读取
目前先依赖 BaseTask 的时间窗口逻辑,或者子类自己管理
"""
# TODO: 对接真正的 CursorManager
# 暂时返回一个较早的时间,或者由子类通过 _get_time_window 获取
return None
def iter_ods_rows(
self,
table_name: str,
columns: List[str],
start_time: datetime,
end_time: datetime,
time_col: str = "fetched_at",
batch_size: int = 1000
) -> Iterator[List[Dict[str, Any]]]:
"""
分批迭代读取 ODS 表数据
Args:
table_name: ODS 表名
columns: 需要查询的字段列表 (必须包含 payload)
start_time: 开始时间 (包含)
end_time: 结束时间 (包含)
time_col: 时间过滤字段,默认 fetched_at
batch_size: 批次大小
"""
offset = 0
cols_str = ", ".join(columns)
while True:
sql = f"""
SELECT {cols_str}
FROM {table_name}
WHERE {time_col} >= %s AND {time_col} <= %s
ORDER BY {time_col} ASC
LIMIT %s OFFSET %s
"""
rows = self.db.query(sql, (start_time, end_time, batch_size, offset))
if not rows:
break
yield rows
if len(rows) < batch_size:
break
offset += batch_size
def parse_payload(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""
解析 ODS 行中的 payload JSON
"""
payload = row.get("payload")
if isinstance(payload, str):
return json.loads(payload)
elif isinstance(payload, dict):
return payload
return {}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
"""DWD 质量核对任务:按 dwd_quality_check.md 输出行数/金额对照报表。"""
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Sequence, Tuple
from psycopg2.extras import RealDictCursor
from tasks.base_task import BaseTask, TaskContext
from tasks.dwd.dwd_load_task import DwdLoadTask
class DwdQualityTask(BaseTask):
"""对 ODS 与 DWD 进行行数、金额对照核查,生成 JSON 报表。"""
REPORT_PATH = Path("reports/dwd_quality_report.json")
AMOUNT_KEYWORDS = ("amount", "money", "fee", "balance")
def get_task_code(self) -> str:
"""返回任务编码。"""
return "DWD_QUALITY_CHECK"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""准备运行时上下文。"""
return {"now": datetime.now()}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict[str, Any]:
"""输出行数/金额差异报表到本地文件。"""
report: Dict[str, Any] = {
"generated_at": extracted["now"].isoformat(),
"tables": [],
"note": "行数/金额核对,金额字段基于列名包含 amount/money/fee/balance 的数值列自动扫描。",
}
with self.db.conn.cursor(cursor_factory=RealDictCursor) as cur:
for dwd_table, ods_table in DwdLoadTask.TABLE_MAP.items():
count_info = self._compare_counts(cur, dwd_table, ods_table)
amount_info = self._compare_amounts(cur, dwd_table, ods_table)
report["tables"].append(
{
"dwd_table": dwd_table,
"ods_table": ods_table,
"count": count_info,
"amounts": amount_info,
}
)
self.REPORT_PATH.parent.mkdir(parents=True, exist_ok=True)
self.REPORT_PATH.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
self.logger.info("DWD 质检报表已生成:%s", self.REPORT_PATH)
return {"report_path": str(self.REPORT_PATH)}
# ---------------------- 辅助方法 ----------------------
def _compare_counts(self, cur, dwd_table: str, ods_table: str) -> Dict[str, Any]:
"""统计两端行数并返回差异。"""
dwd_schema, dwd_name = self._split_table_name(dwd_table, default_schema="billiards_dwd")
ods_schema, ods_name = self._split_table_name(ods_table, default_schema="billiards_ods")
cur.execute(f'SELECT COUNT(1) AS cnt FROM "{dwd_schema}"."{dwd_name}"')
dwd_cnt = cur.fetchone()["cnt"]
cur.execute(f'SELECT COUNT(1) AS cnt FROM "{ods_schema}"."{ods_name}"')
ods_cnt = cur.fetchone()["cnt"]
return {"dwd": dwd_cnt, "ods": ods_cnt, "diff": dwd_cnt - ods_cnt}
def _compare_amounts(self, cur, dwd_table: str, ods_table: str) -> List[Dict[str, Any]]:
"""扫描金额相关列,生成 ODS 与 DWD 的汇总对照。"""
dwd_schema, dwd_name = self._split_table_name(dwd_table, default_schema="billiards_dwd")
ods_schema, ods_name = self._split_table_name(ods_table, default_schema="billiards_ods")
dwd_amount_cols = self._get_numeric_amount_columns(cur, dwd_schema, dwd_name)
ods_amount_cols = self._get_numeric_amount_columns(cur, ods_schema, ods_name)
common_amount_cols = sorted(set(dwd_amount_cols) & set(ods_amount_cols))
results: List[Dict[str, Any]] = []
for col in common_amount_cols:
cur.execute(f'SELECT COALESCE(SUM("{col}"),0) AS val FROM "{dwd_schema}"."{dwd_name}"')
dwd_sum = cur.fetchone()["val"]
cur.execute(f'SELECT COALESCE(SUM("{col}"),0) AS val FROM "{ods_schema}"."{ods_name}"')
ods_sum = cur.fetchone()["val"]
results.append({"column": col, "dwd_sum": float(dwd_sum or 0), "ods_sum": float(ods_sum or 0), "diff": float(dwd_sum or 0) - float(ods_sum or 0)})
return results
def _get_numeric_amount_columns(self, cur, schema: str, table: str) -> List[str]:
"""获取列名包含金额关键词的数值型字段。"""
cur.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s
AND table_name = %s
AND data_type IN ('numeric','double precision','integer','bigint','smallint','real','decimal')
""",
(schema, table),
)
cols = [r["column_name"].lower() for r in cur.fetchall()]
return [c for c in cols if any(key in c for key in self.AMOUNT_KEYWORDS)]
def _split_table_name(self, name: str, default_schema: str) -> Tuple[str, str]:
"""拆分 schema 与表名,缺省使用 default_schema。"""
parts = name.split(".")
if len(parts) == 2:
return parts[0], parts[1]
return default_schema, name

View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 移除 RecallIndexTask/IntimacyIndexTask 导入,更新 __all__
"""
DWS层ETL任务模块
包含:
- BaseDwsTask: DWS任务基类
- 助教维度任务
- 客户维度任务
- 财务维度任务
- 指数算法任务
"""
from .base_dws_task import BaseDwsTask, TimeLayer, TimeWindow, CourseType, DiscountType
from .assistant_daily_task import AssistantDailyTask
from .assistant_monthly_task import AssistantMonthlyTask
from .assistant_customer_task import AssistantCustomerTask
from .assistant_salary_task import AssistantSalaryTask
from .assistant_finance_task import AssistantFinanceTask
from .member_consumption_task import MemberConsumptionTask
from .member_visit_task import MemberVisitTask
from .finance_daily_task import FinanceDailyTask
from .finance_recharge_task import FinanceRechargeTask
from .finance_income_task import FinanceIncomeStructureTask
from .finance_discount_task import FinanceDiscountDetailTask
from .retention_cleanup_task import DwsRetentionCleanupTask
from .mv_refresh_task import DwsMvRefreshFinanceDailyTask, DwsMvRefreshAssistantDailyTask
# 指数算法任务
from .index import (
WinbackIndexTask,
NewconvIndexTask,
MlManualImportTask,
RelationIndexTask,
)
__all__ = [
# 基类
"BaseDwsTask",
"TimeLayer",
"TimeWindow",
"CourseType",
"DiscountType",
# 助教维度
"AssistantDailyTask",
"AssistantMonthlyTask",
"AssistantCustomerTask",
"AssistantSalaryTask",
"AssistantFinanceTask",
# 客户维度
"MemberConsumptionTask",
"MemberVisitTask",
# 财务维度
"FinanceDailyTask",
"FinanceRechargeTask",
"FinanceIncomeStructureTask",
"FinanceDiscountDetailTask",
"DwsRetentionCleanupTask",
"DwsMvRefreshFinanceDailyTask",
"DwsMvRefreshAssistantDailyTask",
# 指数算法
"WinbackIndexTask",
"NewconvIndexTask",
"MlManualImportTask",
"RelationIndexTask",
]

View File

@@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
"""
助教服务客户统计任务
功能说明:
"助教+客户"为粒度,统计服务关系和滚动窗口指标
数据来源:
- dwd_assistant_service_log: 助教服务流水
- dim_member: 会员维度
目标表:
billiards_dws.dws_assistant_customer_stats
更新策略:
- 更新频率:每日更新
- 幂等方式delete-before-insert按统计日期
业务规则:
- 散客处理member_id=0 不进入此表统计
- 滚动窗口7/10/15/30/60/90天
- 活跃度近7天/30天是否有服务
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class AssistantCustomerTask(BaseDwsTask):
"""
助教服务客户统计任务
统计每个助教与每个客户的服务关系:
- 首次/最近服务日期
- 累计服务统计
- 滚动窗口统计7/10/15/30/60/90天
- 活跃度指标
"""
def get_task_code(self) -> str:
return "DWS_ASSISTANT_CUSTOMER"
def get_target_table(self) -> str:
return "dws_assistant_customer_stats"
def get_primary_keys(self) -> List[str]:
return ["site_id", "assistant_id", "member_id", "stat_date"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据
"""
stat_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
self.logger.info(
"%s: 提取数据,统计日期 %s",
self.get_task_code(), stat_date
)
# 计算最大回溯日期90天窗口
lookback_start = stat_date - timedelta(days=90)
# 1. 获取助教-客户服务记录(包含历史全量用于累计统计)
service_pairs = self._extract_service_pairs(site_id, stat_date)
# 2. 获取会员信息
member_info = self._extract_member_info(site_id)
# 3. 获取助教信息
assistant_info = self._extract_assistant_info(site_id)
return {
'service_pairs': service_pairs,
'member_info': member_info,
'assistant_info': assistant_info,
'stat_date': stat_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据:计算各窗口统计
"""
service_pairs = extracted['service_pairs']
member_info = extracted['member_info']
assistant_info = extracted['assistant_info']
stat_date = extracted['stat_date']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,%d 条服务关系记录",
self.get_task_code(), len(service_pairs)
)
# 构建统计记录
results = []
for pair in service_pairs:
assistant_id = pair.get('assistant_id')
member_id = pair.get('member_id')
# 跳过散客
if self.is_guest(member_id):
continue
asst_info = assistant_info.get(assistant_id, {})
memb_info = member_info.get(member_id, {})
# 构建记录
record = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'assistant_id': assistant_id,
'assistant_nickname': asst_info.get('nickname', pair.get('assistant_nickname')),
'member_id': member_id,
'member_nickname': memb_info.get('nickname'),
'member_mobile': self._mask_mobile(memb_info.get('mobile')),
'stat_date': stat_date,
# 全量累计统计
'first_service_date': pair.get('first_service_date'),
'last_service_date': pair.get('last_service_date'),
'total_service_count': self.safe_int(pair.get('total_service_count', 0)),
'total_service_hours': self.safe_decimal(pair.get('total_service_hours', 0)),
'total_service_amount': self.safe_decimal(pair.get('total_service_amount', 0)),
# 滚动窗口统计
'service_count_7d': self.safe_int(pair.get('service_count_7d', 0)),
'service_count_10d': self.safe_int(pair.get('service_count_10d', 0)),
'service_count_15d': self.safe_int(pair.get('service_count_15d', 0)),
'service_count_30d': self.safe_int(pair.get('service_count_30d', 0)),
'service_count_60d': self.safe_int(pair.get('service_count_60d', 0)),
'service_count_90d': self.safe_int(pair.get('service_count_90d', 0)),
'service_hours_7d': self.safe_decimal(pair.get('service_hours_7d', 0)),
'service_hours_10d': self.safe_decimal(pair.get('service_hours_10d', 0)),
'service_hours_15d': self.safe_decimal(pair.get('service_hours_15d', 0)),
'service_hours_30d': self.safe_decimal(pair.get('service_hours_30d', 0)),
'service_hours_60d': self.safe_decimal(pair.get('service_hours_60d', 0)),
'service_hours_90d': self.safe_decimal(pair.get('service_hours_90d', 0)),
'service_amount_7d': self.safe_decimal(pair.get('service_amount_7d', 0)),
'service_amount_10d': self.safe_decimal(pair.get('service_amount_10d', 0)),
'service_amount_15d': self.safe_decimal(pair.get('service_amount_15d', 0)),
'service_amount_30d': self.safe_decimal(pair.get('service_amount_30d', 0)),
'service_amount_60d': self.safe_decimal(pair.get('service_amount_60d', 0)),
'service_amount_90d': self.safe_decimal(pair.get('service_amount_90d', 0)),
# 活跃度指标
'days_since_last': self._calc_days_since(stat_date, pair.get('last_service_date')),
'is_active_7d': self.safe_int(pair.get('service_count_7d', 0)) > 0,
'is_active_30d': self.safe_int(pair.get('service_count_30d', 0)) > 0,
}
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
# 删除已存在的数据
deleted = self.delete_existing_data(context, date_col="stat_date")
# 批量插入
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _extract_service_pairs(
self,
site_id: int,
stat_date: date
) -> List[Dict[str, Any]]:
"""
提取助教-客户服务统计(含滚动窗口)
"""
sql = """
WITH service_base AS (
SELECT
site_assistant_id AS assistant_id,
nickname AS assistant_nickname,
tenant_member_id AS member_id,
DATE(start_use_time) AS service_date,
income_seconds,
ledger_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE site_id = %s
AND tenant_member_id IS NOT NULL
AND tenant_member_id != 0
AND is_delete = 0
)
SELECT
assistant_id,
MAX(assistant_nickname) AS assistant_nickname,
member_id,
MIN(service_date) AS first_service_date,
MAX(service_date) AS last_service_date,
-- 全量累计
COUNT(*) AS total_service_count,
SUM(income_seconds) / 3600.0 AS total_service_hours,
SUM(ledger_amount) AS total_service_amount,
-- 7天窗口
COUNT(CASE WHEN service_date >= %s - INTERVAL '6 days' THEN 1 END) AS service_count_7d,
SUM(CASE WHEN service_date >= %s - INTERVAL '6 days' THEN income_seconds ELSE 0 END) / 3600.0 AS service_hours_7d,
SUM(CASE WHEN service_date >= %s - INTERVAL '6 days' THEN ledger_amount ELSE 0 END) AS service_amount_7d,
-- 10天窗口
COUNT(CASE WHEN service_date >= %s - INTERVAL '9 days' THEN 1 END) AS service_count_10d,
SUM(CASE WHEN service_date >= %s - INTERVAL '9 days' THEN income_seconds ELSE 0 END) / 3600.0 AS service_hours_10d,
SUM(CASE WHEN service_date >= %s - INTERVAL '9 days' THEN ledger_amount ELSE 0 END) AS service_amount_10d,
-- 15天窗口
COUNT(CASE WHEN service_date >= %s - INTERVAL '14 days' THEN 1 END) AS service_count_15d,
SUM(CASE WHEN service_date >= %s - INTERVAL '14 days' THEN income_seconds ELSE 0 END) / 3600.0 AS service_hours_15d,
SUM(CASE WHEN service_date >= %s - INTERVAL '14 days' THEN ledger_amount ELSE 0 END) AS service_amount_15d,
-- 30天窗口
COUNT(CASE WHEN service_date >= %s - INTERVAL '29 days' THEN 1 END) AS service_count_30d,
SUM(CASE WHEN service_date >= %s - INTERVAL '29 days' THEN income_seconds ELSE 0 END) / 3600.0 AS service_hours_30d,
SUM(CASE WHEN service_date >= %s - INTERVAL '29 days' THEN ledger_amount ELSE 0 END) AS service_amount_30d,
-- 60天窗口
COUNT(CASE WHEN service_date >= %s - INTERVAL '59 days' THEN 1 END) AS service_count_60d,
SUM(CASE WHEN service_date >= %s - INTERVAL '59 days' THEN income_seconds ELSE 0 END) / 3600.0 AS service_hours_60d,
SUM(CASE WHEN service_date >= %s - INTERVAL '59 days' THEN ledger_amount ELSE 0 END) AS service_amount_60d,
-- 90天窗口
COUNT(CASE WHEN service_date >= %s - INTERVAL '89 days' THEN 1 END) AS service_count_90d,
SUM(CASE WHEN service_date >= %s - INTERVAL '89 days' THEN income_seconds ELSE 0 END) / 3600.0 AS service_hours_90d,
SUM(CASE WHEN service_date >= %s - INTERVAL '89 days' THEN ledger_amount ELSE 0 END) AS service_amount_90d
FROM service_base
GROUP BY assistant_id, member_id
HAVING MAX(service_date) >= %s - INTERVAL '90 days'
"""
# 构建参数每个窗口需要3个日期参数
params = [site_id]
for _ in range(6): # 6个窗口每个3个参数
params.extend([stat_date, stat_date, stat_date])
params.append(stat_date) # HAVING条件
rows = self.db.query(sql, tuple(params))
return [dict(row) for row in rows] if rows else []
def _extract_member_info(self, site_id: int) -> Dict[int, Dict[str, Any]]:
"""
提取会员信息
"""
sql = """
SELECT
member_id,
nickname,
mobile
FROM billiards_dwd.dim_member
WHERE site_id = %s
"""
rows = self.db.query(sql, (site_id,))
result = {}
for row in (rows or []):
row_dict = dict(row)
result[row_dict['member_id']] = row_dict
return result
def _extract_assistant_info(self, site_id: int) -> Dict[int, Dict[str, Any]]:
"""
提取助教信息
"""
sql = """
SELECT
assistant_id,
nickname
FROM billiards_dwd.dim_assistant
WHERE site_id = %s
AND scd2_is_current = 1
"""
rows = self.db.query(sql, (site_id,))
result = {}
for row in (rows or []):
row_dict = dict(row)
result[row_dict['assistant_id']] = row_dict
return result
# ==========================================================================
# 工具方法
# ==========================================================================
def _mask_mobile(self, mobile: Optional[str]) -> Optional[str]:
"""
手机号脱敏
"""
if not mobile or len(mobile) < 7:
return mobile
return mobile[:3] + "****" + mobile[-4:]
def _calc_days_since(self, stat_date: date, last_date: Optional[date]) -> Optional[int]:
"""
计算距离最近服务的天数
"""
if not last_date:
return None
if isinstance(last_date, datetime):
last_date = last_date.date()
return (stat_date - last_date).days
# 便于外部导入
__all__ = ['AssistantCustomerTask']

View File

@@ -0,0 +1,356 @@
# -*- coding: utf-8 -*-
"""
助教日度业绩明细任务
功能说明:
"助教+日期"为粒度,汇总每日业绩明细
数据来源:
- dwd_assistant_service_log: 助教服务流水
- dwd_assistant_trash_event: 废除记录(排除)
- dim_assistant: 助教维度SCD2获取当日等级
- cfg_skill_type: 技能→课程类型映射
目标表:
billiards_dws.dws_assistant_daily_detail
更新策略:
- 更新频率:每小时增量更新
- 幂等方式delete-before-insert按日期窗口
业务规则:
- 有效业绩需排除dwd_assistant_trash_event中的废除记录
- 助教等级使用SCD2 as-of取值获取统计日当日生效的等级
- 课程类型通过skill_id映射分为基础课和附加课
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_dws_task import BaseDwsTask, CourseType, TaskContext
class AssistantDailyTask(BaseDwsTask):
"""
助教日度业绩明细任务
汇总每个助教每天的:
- 服务次数(总/基础课/附加课)
- 计费时长(秒/小时)
- 计费金额
- 服务客户数(去重)
- 服务台桌数(去重)
- 被废除的记录统计
"""
def get_task_code(self) -> str:
return "DWS_ASSISTANT_DAILY"
def get_target_table(self) -> str:
return "dws_assistant_daily_detail"
def get_primary_keys(self) -> List[str]:
return ["site_id", "assistant_id", "stat_date"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据从DWD层读取助教服务记录
"""
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
self.logger.info(
"%s: 提取数据,日期范围 %s ~ %s",
self.get_task_code(), start_date, end_date
)
# 1. 获取助教服务记录
service_records = self._extract_service_records(site_id, start_date, end_date)
# 2. 获取废除记录
trash_records = self._extract_trash_records(site_id, start_date, end_date)
# 3. 加载配置缓存
self.load_config_cache()
return {
'service_records': service_records,
'trash_records': trash_records,
'start_date': start_date,
'end_date': end_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据:按助教+日期聚合
"""
service_records = extracted['service_records']
trash_records = extracted['trash_records']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,服务记录 %d 条,废除记录 %d",
self.get_task_code(), len(service_records), len(trash_records)
)
# 构建废除记录索引assistant_service_id -> trash_info
trash_index = self._build_trash_index(trash_records)
# 按助教+日期聚合
aggregated = self._aggregate_by_assistant_date(
service_records,
trash_index,
site_id
)
return aggregated
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据写入DWS表
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
# 删除已存在的数据(幂等)
deleted = self.delete_existing_data(context, date_col="stat_date")
# 批量插入
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _extract_service_records(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取助教服务记录
"""
sql = """
SELECT
asl.assistant_service_id,
asl.order_settle_id,
asl.site_assistant_id AS assistant_id,
asl.nickname AS assistant_nickname,
asl.assistant_level,
asl.skill_id,
asl.skill_name,
asl.tenant_member_id AS member_id,
asl.site_table_id AS table_id,
asl.income_seconds,
asl.real_use_seconds,
asl.ledger_amount,
asl.ledger_unit_price,
DATE(asl.start_use_time) AS service_date
FROM billiards_dwd.dwd_assistant_service_log asl
WHERE asl.site_id = %s
AND DATE(asl.start_use_time) >= %s
AND DATE(asl.start_use_time) <= %s
AND asl.is_delete = 0
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_trash_records(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取废除记录
有效业绩的排除规则:仅对"助教废除表"的记录进行处理排除
"""
sql = """
SELECT
assistant_service_id,
trash_seconds,
trash_reason,
trash_time
FROM billiards_dwd.dwd_assistant_trash_event
WHERE site_id = %s
AND DATE(trash_time) >= %s
AND DATE(trash_time) <= %s
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
# ==========================================================================
# 数据转换方法
# ==========================================================================
def _build_trash_index(
self,
trash_records: List[Dict[str, Any]]
) -> Dict[int, Dict[str, Any]]:
"""
构建废除记录索引
"""
index = {}
for record in trash_records:
service_id = record.get('assistant_service_id')
if service_id:
index[service_id] = record
return index
def _aggregate_by_assistant_date(
self,
service_records: List[Dict[str, Any]],
trash_index: Dict[int, Dict[str, Any]],
site_id: int
) -> List[Dict[str, Any]]:
"""
按助教+日期聚合服务记录
"""
# 聚合字典:(assistant_id, service_date) -> aggregated_data
agg_dict: Dict[Tuple[int, date], Dict[str, Any]] = {}
for record in service_records:
assistant_id = record.get('assistant_id')
service_date = record.get('service_date')
if not assistant_id or not service_date:
continue
key = (assistant_id, service_date)
# 初始化聚合数据
if key not in agg_dict:
# 获取助教当日等级SCD2 as-of
level_info = self.get_assistant_level_asof(assistant_id, service_date)
agg_dict[key] = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'assistant_id': assistant_id,
'assistant_nickname': record.get('assistant_nickname'),
'stat_date': service_date,
'assistant_level_code': level_info.get('level_code') if level_info else record.get('assistant_level'),
'assistant_level_name': level_info.get('level_name') if level_info else None,
'total_service_count': 0,
'base_service_count': 0,
'bonus_service_count': 0,
'room_service_count': 0,
'total_seconds': 0,
'base_seconds': 0,
'bonus_seconds': 0,
'room_seconds': 0,
'total_hours': Decimal('0'),
'base_hours': Decimal('0'),
'bonus_hours': Decimal('0'),
'room_hours': Decimal('0'),
'total_ledger_amount': Decimal('0'),
'base_ledger_amount': Decimal('0'),
'bonus_ledger_amount': Decimal('0'),
'room_ledger_amount': Decimal('0'),
'unique_customers': set(),
'unique_tables': set(),
'trashed_seconds': 0,
'trashed_count': 0,
}
agg = agg_dict[key]
# 获取服务信息
service_id = record.get('assistant_service_id')
income_seconds = self.safe_int(record.get('income_seconds', 0))
ledger_amount = self.safe_decimal(record.get('ledger_amount', 0))
skill_id = record.get('skill_id')
member_id = record.get('member_id')
table_id = record.get('table_id')
# 判断课程类型
course_type = self.get_course_type(skill_id) if skill_id else CourseType.BASE
is_base = course_type == CourseType.BASE
is_bonus = course_type == CourseType.BONUS
is_room = course_type == CourseType.ROOM
# 检查是否被废除
is_trashed = service_id in trash_index
if is_trashed:
# 废除记录单独统计
trash_info = trash_index[service_id]
trash_seconds = self.safe_int(trash_info.get('trash_seconds', income_seconds))
agg['trashed_seconds'] += trash_seconds
agg['trashed_count'] += 1
else:
# 正常记录累加
agg['total_service_count'] += 1
agg['total_seconds'] += income_seconds
agg['total_ledger_amount'] += ledger_amount
if is_base:
agg['base_service_count'] += 1
agg['base_seconds'] += income_seconds
agg['base_ledger_amount'] += ledger_amount
elif is_bonus:
agg['bonus_service_count'] += 1
agg['bonus_seconds'] += income_seconds
agg['bonus_ledger_amount'] += ledger_amount
elif is_room:
agg['room_service_count'] += 1
agg['room_seconds'] += income_seconds
agg['room_ledger_amount'] += ledger_amount
# 客户和台桌去重统计(不论是否废除)
if member_id and not self.is_guest(member_id):
agg['unique_customers'].add(member_id)
if table_id:
agg['unique_tables'].add(table_id)
# 转换为列表并计算派生字段
result = []
for key, agg in agg_dict.items():
# 计算小时数
agg['total_hours'] = self.seconds_to_hours(agg['total_seconds'])
agg['base_hours'] = self.seconds_to_hours(agg['base_seconds'])
agg['bonus_hours'] = self.seconds_to_hours(agg['bonus_seconds'])
agg['room_hours'] = self.seconds_to_hours(agg['room_seconds'])
# 转换set为count
agg['unique_customers'] = len(agg['unique_customers'])
agg['unique_tables'] = len(agg['unique_tables'])
result.append(agg)
return result
# 便于外部导入
__all__ = ['AssistantDailyTask']

View File

@@ -0,0 +1,205 @@
# -*- coding: utf-8 -*-
"""
助教收支分析任务
功能说明:
"日期+助教"为粒度,分析助教产出的收入和成本
数据来源:
- dwd_assistant_service_log: 助教服务流水(收入)
- dws_assistant_salary_calc: 工资计算(成本)
目标表:
billiards_dws.dws_assistant_finance_analysis
更新策略:
- 更新频率:每日更新
- 幂等方式delete-before-insert按日期
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_dws_task import BaseDwsTask, CourseType, TaskContext
class AssistantFinanceTask(BaseDwsTask):
"""
助教收支分析任务
"""
def get_task_code(self) -> str:
return "DWS_ASSISTANT_FINANCE"
def get_target_table(self) -> str:
return "dws_assistant_finance_analysis"
def get_primary_keys(self) -> List[str]:
return ["site_id", "stat_date", "assistant_id"]
def extract(self, context: TaskContext) -> Dict[str, Any]:
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
# 获取助教日度收入
daily_revenue = self._extract_daily_revenue(site_id, start_date, end_date)
# 获取月度工资(用于计算日均成本)
monthly_salary = self._extract_monthly_salary(site_id, start_date, end_date)
# 加载配置
self.load_config_cache()
return {
'daily_revenue': daily_revenue,
'monthly_salary': monthly_salary,
'start_date': start_date,
'end_date': end_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
daily_revenue = extracted['daily_revenue']
monthly_salary = extracted['monthly_salary']
site_id = extracted['site_id']
# 构建月度工资索引
salary_index = {}
for sal in monthly_salary:
asst_id = sal.get('assistant_id')
month = sal.get('salary_month')
if asst_id and month:
salary_index[(asst_id, month)] = sal
results = []
for rev in daily_revenue:
assistant_id = rev.get('assistant_id')
stat_date = rev.get('stat_date')
# 获取对应月份的工资
month_start = stat_date.replace(day=1) if isinstance(stat_date, date) else None
salary = salary_index.get((assistant_id, month_start), {})
# 计算日均成本
gross_salary = self.safe_decimal(salary.get('gross_salary', 0))
work_days = self.safe_int(salary.get('work_days', 1)) or 1
cost_daily = gross_salary / Decimal(str(work_days))
revenue_total = self.safe_decimal(rev.get('revenue_total', 0))
gross_profit = revenue_total - cost_daily
gross_margin = gross_profit / revenue_total if revenue_total > 0 else Decimal('0')
record = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'stat_date': stat_date,
'assistant_id': assistant_id,
'assistant_nickname': rev.get('assistant_nickname'),
'revenue_total': revenue_total,
'revenue_base': self.safe_decimal(rev.get('revenue_base', 0)),
'revenue_bonus': self.safe_decimal(rev.get('revenue_bonus', 0)),
'revenue_room': self.safe_decimal(rev.get('revenue_room', 0)),
'cost_daily': cost_daily,
'gross_profit': gross_profit,
'gross_margin': gross_margin,
'service_count': self.safe_int(rev.get('service_count', 0)),
'service_hours': self.safe_decimal(rev.get('service_hours', 0)),
'room_service_count': self.safe_int(rev.get('room_service_count', 0)),
'room_service_hours': self.safe_decimal(rev.get('room_service_hours', 0)),
'unique_customers': self.safe_int(rev.get('unique_customers', 0)),
}
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
if not transformed:
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
deleted = self.delete_existing_data(context, date_col="stat_date")
inserted = self.bulk_insert(transformed)
return {
"counts": {"fetched": len(transformed), "inserted": inserted, "updated": 0, "skipped": 0, "errors": 0},
"extra": {"deleted": deleted}
}
def _extract_daily_revenue(self, site_id: int, start_date: date, end_date: date) -> List[Dict[str, Any]]:
sql = """
SELECT
DATE(s.start_use_time) AS stat_date,
s.site_assistant_id AS assistant_id,
MAX(s.nickname) AS assistant_nickname,
COUNT(*) AS service_count,
SUM(s.income_seconds) / 3600.0 AS service_hours,
SUM(s.ledger_amount) AS revenue_total,
SUM(CASE WHEN COALESCE(st.course_type_code, 'BASE') = 'BASE' THEN s.ledger_amount ELSE 0 END) AS revenue_base,
SUM(CASE WHEN COALESCE(st.course_type_code, 'BASE') = 'BONUS' THEN s.ledger_amount ELSE 0 END) AS revenue_bonus,
SUM(CASE WHEN COALESCE(st.course_type_code, 'BASE') = 'ROOM' THEN s.ledger_amount ELSE 0 END) AS revenue_room,
COUNT(CASE WHEN COALESCE(st.course_type_code, 'BASE') = 'ROOM' THEN 1 END) AS room_service_count,
SUM(CASE WHEN COALESCE(st.course_type_code, 'BASE') = 'ROOM' THEN s.income_seconds ELSE 0 END) / 3600.0 AS room_service_hours,
COUNT(DISTINCT CASE WHEN s.tenant_member_id > 0 THEN s.tenant_member_id END) AS unique_customers
FROM billiards_dwd.dwd_assistant_service_log s
LEFT JOIN billiards_dws.cfg_skill_type st
ON st.skill_id = s.skill_id AND st.is_active = TRUE
WHERE s.site_id = %s
AND DATE(s.start_use_time) >= %s
AND DATE(s.start_use_time) <= %s
AND s.is_delete = 0
GROUP BY DATE(s.start_use_time), s.site_assistant_id
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_monthly_salary(self, site_id: int, start_date: date, end_date: date) -> List[Dict[str, Any]]:
# 获取涉及的月份
month_start = start_date.replace(day=1)
month_end = end_date.replace(day=1)
sql = """
SELECT
assistant_id,
salary_month,
gross_salary,
effective_hours
FROM billiards_dws.dws_assistant_salary_calc
WHERE site_id = %s
AND salary_month >= %s
AND salary_month <= %s
"""
rows = self.db.query(sql, (site_id, month_start, month_end))
# 获取每月工作天数
work_days_sql = """
SELECT
assistant_id,
DATE_TRUNC('month', stat_date)::DATE AS month,
COUNT(DISTINCT stat_date) AS work_days
FROM billiards_dws.dws_assistant_daily_detail
WHERE site_id = %s
AND stat_date >= %s
AND stat_date <= %s
GROUP BY assistant_id, DATE_TRUNC('month', stat_date)
"""
work_days_rows = self.db.query(work_days_sql, (site_id, start_date, end_date))
work_days_index = {(r['assistant_id'], r['month']): r['work_days'] for r in (work_days_rows or [])}
results = []
for row in (rows or []):
row_dict = dict(row)
asst_id = row_dict.get('assistant_id')
month = row_dict.get('salary_month')
row_dict['work_days'] = work_days_index.get((asst_id, month), 20)
results.append(row_dict)
return results
__all__ = ['AssistantFinanceTask']

View File

@@ -0,0 +1,600 @@
# -*- coding: utf-8 -*-
"""
助教月度业绩汇总任务
功能说明:
"助教+月份"为粒度,汇总月度业绩及档位计算
数据来源:
- dws_assistant_daily_detail: 日度明细(聚合)
- dim_assistant: 助教维度(入职日期、等级)
- cfg_performance_tier: 绩效档位配置
目标表:
billiards_dws.dws_assistant_monthly_summary
更新策略:
- 更新频率:每日更新当月数据
- 幂等方式delete-before-insert按月份
业务规则:
- 新入职判断入职日期在月1日0点之后则为新入职
- 有效业绩total_hours - trashed_hours
- 档位匹配根据有效业绩小时数匹配cfg_performance_tier
- 排名计算按有效业绩小时数降序考虑并列如2个第一则都是1下一个是3
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class AssistantMonthlyTask(BaseDwsTask):
"""
助教月度业绩汇总任务
汇总每个助教每月的:
- 工作天数、服务次数、时长
- 有效业绩(扣除废除记录后)
- 档位匹配
- 月度排名用于Top3奖金
"""
def get_task_code(self) -> str:
return "DWS_ASSISTANT_MONTHLY"
def get_target_table(self) -> str:
return "dws_assistant_monthly_summary"
def get_primary_keys(self) -> List[str]:
return ["site_id", "assistant_id", "stat_month"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据:从日度明细表聚合
"""
# 确定月份范围
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
# 获取涉及的月份列表
months = self._get_months_in_range(start_date, end_date)
months = self._filter_months_for_schedule(months, end_date)
self.logger.info(
"%s: 提取数据,月份范围 %s",
self.get_task_code(), [str(m) for m in months]
)
if not months:
self.logger.info("%s: 无需处理月份,跳过", self.get_task_code())
return {
'daily_aggregates': [],
'monthly_uniques': [],
'assistant_info': {},
'months': [],
'site_id': site_id
}
# 1. 获取日度明细聚合数据
daily_aggregates = self._extract_daily_aggregates(site_id, months)
# 1.1 获取月度去重客户/台桌统计从DWD直接去重
monthly_uniques = self._extract_monthly_uniques(site_id, months)
# 2. 获取助教基本信息
assistant_info = self._extract_assistant_info(site_id)
# 3. 加载配置缓存
self.load_config_cache()
return {
'daily_aggregates': daily_aggregates,
'monthly_uniques': monthly_uniques,
'assistant_info': assistant_info,
'months': months,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据:计算月度汇总、档位匹配、排名
"""
daily_aggregates = extracted['daily_aggregates']
monthly_uniques = extracted['monthly_uniques']
assistant_info = extracted['assistant_info']
months = extracted['months']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,%d 个月份,%d 条聚合记录",
self.get_task_code(), len(months), len(daily_aggregates)
)
# 月度去重索引
monthly_unique_index = {
(row.get('assistant_id'), row.get('stat_month')): row
for row in (monthly_uniques or [])
if row.get('assistant_id') and row.get('stat_month')
}
# 按月份处理
all_results = []
for month in months:
month_results = self._process_month(
daily_aggregates,
assistant_info,
monthly_unique_index,
month,
site_id
)
all_results.extend(month_results)
return all_results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据写入DWS表
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
# 删除已存在的数据(按月份)
deleted = self._delete_by_months(context, transformed)
# 批量插入
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _get_months_in_range(self, start_date: date, end_date: date) -> List[date]:
"""
获取日期范围内的所有月份(月第一天)
"""
months = []
current = start_date.replace(day=1)
end_month = end_date.replace(day=1)
while current <= end_month:
months.append(current)
# 下个月
if current.month == 12:
current = current.replace(year=current.year + 1, month=1)
else:
current = current.replace(month=current.month + 1)
return months
def _filter_months_for_schedule(self, months: List[date], end_date: date) -> List[date]:
"""
按调度口径过滤历史月份(默认仅当月,月初可包含上月)
"""
if not months:
return []
history_months = self.safe_int(self.config.get("dws.monthly.history_months", 0))
if history_months > 0:
current_month = self.get_month_first_day(end_date)
allowed = {current_month}
for offset in range(1, history_months + 1):
allowed.add(self.get_month_first_day(self._shift_months(current_month, -offset)))
filtered = [m for m in months if m in allowed]
skipped = [m for m in months if m not in allowed]
if skipped:
self.logger.info(
"%s: 跳过历史月份 %s",
self.get_task_code(),
[str(m) for m in skipped]
)
return filtered
allow_history = bool(self.config.get("dws.monthly.allow_history", False))
if allow_history:
return months
current_month = self.get_month_first_day(end_date)
allowed = {current_month}
grace_days = self.safe_int(self.config.get("dws.monthly.prev_month_grace_days", 5))
if grace_days > 0 and end_date.day <= grace_days:
prev_month = self.get_month_first_day(self._shift_months(current_month, -1))
allowed.add(prev_month)
filtered = [m for m in months if m in allowed]
skipped = [m for m in months if m not in allowed]
if skipped:
self.logger.info(
"%s: 跳过历史月份 %s",
self.get_task_code(),
[str(m) for m in skipped]
)
return filtered
def _extract_daily_aggregates(
self,
site_id: int,
months: List[date]
) -> List[Dict[str, Any]]:
"""
从日度明细表提取并按月聚合
"""
if not months:
return []
# 构建月份条件
month_conditions = []
for month in months:
next_month = (month.replace(day=28) + timedelta(days=4)).replace(day=1)
month_conditions.append(f"(stat_date >= '{month}' AND stat_date < '{next_month}')")
month_where = " OR ".join(month_conditions)
sql = f"""
SELECT
assistant_id,
assistant_nickname,
assistant_level_code,
assistant_level_name,
DATE_TRUNC('month', stat_date)::DATE AS stat_month,
COUNT(DISTINCT stat_date) AS work_days,
SUM(total_service_count) AS total_service_count,
SUM(base_service_count) AS base_service_count,
SUM(bonus_service_count) AS bonus_service_count,
SUM(room_service_count) AS room_service_count,
SUM(total_hours) AS total_hours,
SUM(base_hours) AS base_hours,
SUM(bonus_hours) AS bonus_hours,
SUM(room_hours) AS room_hours,
SUM(total_ledger_amount) AS total_ledger_amount,
SUM(base_ledger_amount) AS base_ledger_amount,
SUM(bonus_ledger_amount) AS bonus_ledger_amount,
SUM(room_ledger_amount) AS room_ledger_amount,
SUM(unique_customers) AS total_unique_customers,
SUM(unique_tables) AS total_unique_tables,
SUM(trashed_seconds) AS trashed_seconds,
SUM(trashed_count) AS trashed_count
FROM billiards_dws.dws_assistant_daily_detail
WHERE site_id = %s AND ({month_where})
GROUP BY assistant_id, assistant_nickname, assistant_level_code, assistant_level_name,
DATE_TRUNC('month', stat_date)
"""
rows = self.db.query(sql, (site_id,))
return [dict(row) for row in rows] if rows else []
def _extract_monthly_uniques(
self,
site_id: int,
months: List[date]
) -> List[Dict[str, Any]]:
"""
从DWD按月直接去重客户与台桌
"""
if not months:
return []
start_month = min(months)
end_month = max(months)
next_month = (end_month.replace(day=28) + timedelta(days=4)).replace(day=1)
sql = """
SELECT
site_assistant_id AS assistant_id,
DATE_TRUNC('month', start_use_time)::DATE AS stat_month,
COUNT(DISTINCT CASE WHEN tenant_member_id > 0 THEN tenant_member_id END) AS unique_customers,
COUNT(DISTINCT site_table_id) AS unique_tables
FROM billiards_dwd.dwd_assistant_service_log
WHERE site_id = %s
AND start_use_time >= %s
AND start_use_time < %s
AND is_delete = 0
GROUP BY site_assistant_id, DATE_TRUNC('month', start_use_time)
"""
rows = self.db.query(sql, (site_id, start_month, next_month))
return [dict(row) for row in rows] if rows else []
def _extract_assistant_info(self, site_id: int) -> Dict[int, Dict[str, Any]]:
"""
提取助教基本信息
"""
sql = """
SELECT
assistant_id,
nickname,
level AS assistant_level,
entry_time AS hire_date
FROM billiards_dwd.dim_assistant
WHERE site_id = %s
AND scd2_is_current = 1 -- 当前有效记录
"""
rows = self.db.query(sql, (site_id,))
result = {}
for row in (rows or []):
row_dict = dict(row)
result[row_dict['assistant_id']] = row_dict
return result
# ==========================================================================
# 数据转换方法
# ==========================================================================
def _process_month(
self,
daily_aggregates: List[Dict[str, Any]],
assistant_info: Dict[int, Dict[str, Any]],
monthly_unique_index: Dict[Tuple[int, date], Dict[str, Any]],
month: date,
site_id: int
) -> List[Dict[str, Any]]:
"""
处理单个月份的数据
"""
# 筛选该月份的数据
month_data = [
agg for agg in daily_aggregates
if agg.get('stat_month') == month
]
if not month_data:
return []
# 构建月度汇总记录
month_records = []
for agg in month_data:
assistant_id = agg.get('assistant_id')
asst_info = assistant_info.get(assistant_id, {})
# 计算有效业绩
total_hours = self.safe_decimal(agg.get('total_hours', 0))
trashed_hours = self.seconds_to_hours(self.safe_int(agg.get('trashed_seconds', 0)))
effective_hours = total_hours - trashed_hours
# 判断是否新入职
hire_date = asst_info.get('hire_date')
is_new_hire = False
if hire_date:
if isinstance(hire_date, datetime):
hire_date = hire_date.date()
is_new_hire = self.is_new_hire_in_month(hire_date, month)
# 匹配档位
tier_hours = effective_hours
max_tier_level = None
if is_new_hire:
tier_hours = self._calc_new_hire_tier_hours(effective_hours, self.safe_int(agg.get('work_days', 0)))
if self._should_apply_new_hire_tier_cap(month, hire_date):
max_tier_level = self._get_new_hire_max_tier_level()
tier = self.get_performance_tier(
tier_hours,
is_new_hire,
effective_date=month,
max_tier_level=max_tier_level
)
# 获取月末的等级信息(用于记录)
month_end = self._get_month_end(month)
level_info = self.get_assistant_level_asof(assistant_id, month_end)
# 月度去重客户/台桌从DWD直接去重
unique_info = monthly_unique_index.get((assistant_id, month), {})
unique_customers = self.safe_int(
unique_info.get('unique_customers', agg.get('total_unique_customers', 0))
)
unique_tables = self.safe_int(
unique_info.get('unique_tables', agg.get('total_unique_tables', 0))
)
record = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'assistant_id': assistant_id,
'assistant_nickname': agg.get('assistant_nickname'),
'stat_month': month,
'assistant_level_code': level_info.get('level_code') if level_info else agg.get('assistant_level_code'),
'assistant_level_name': level_info.get('level_name') if level_info else agg.get('assistant_level_name'),
'hire_date': hire_date,
'is_new_hire': is_new_hire,
'work_days': self.safe_int(agg.get('work_days', 0)),
'total_service_count': self.safe_int(agg.get('total_service_count', 0)),
'base_service_count': self.safe_int(agg.get('base_service_count', 0)),
'bonus_service_count': self.safe_int(agg.get('bonus_service_count', 0)),
'room_service_count': self.safe_int(agg.get('room_service_count', 0)),
'total_hours': total_hours,
'base_hours': self.safe_decimal(agg.get('base_hours', 0)),
'bonus_hours': self.safe_decimal(agg.get('bonus_hours', 0)),
'room_hours': self.safe_decimal(agg.get('room_hours', 0)),
'effective_hours': effective_hours,
'trashed_hours': trashed_hours,
'total_ledger_amount': self.safe_decimal(agg.get('total_ledger_amount', 0)),
'base_ledger_amount': self.safe_decimal(agg.get('base_ledger_amount', 0)),
'bonus_ledger_amount': self.safe_decimal(agg.get('bonus_ledger_amount', 0)),
'room_ledger_amount': self.safe_decimal(agg.get('room_ledger_amount', 0)),
'unique_customers': unique_customers,
'unique_tables': unique_tables,
'avg_service_seconds': self._calc_avg_service_seconds(agg),
'tier_id': tier.get('tier_id') if tier else None,
'tier_code': tier.get('tier_code') if tier else None,
'tier_name': tier.get('tier_name') if tier else None,
'rank_by_hours': None, # 后面计算
'rank_with_ties': None, # 后面计算
}
month_records.append(record)
# 计算排名
self._calculate_ranks(month_records)
return month_records
def _get_month_end(self, month: date) -> date:
"""
获取月末日期
"""
if month.month == 12:
next_month = month.replace(year=month.year + 1, month=1, day=1)
else:
next_month = month.replace(month=month.month + 1, day=1)
return next_month - timedelta(days=1)
def _calc_avg_service_seconds(self, agg: Dict[str, Any]) -> Decimal:
"""
计算平均单次服务时长
"""
total_count = self.safe_int(agg.get('total_service_count', 0))
if total_count == 0:
return Decimal('0')
total_hours = self.safe_decimal(agg.get('total_hours', 0))
total_seconds = total_hours * Decimal('3600')
return total_seconds / Decimal(str(total_count))
def _calc_new_hire_tier_hours(self, effective_hours: Decimal, work_days: int) -> Decimal:
"""
新入职定档:日均 * 30仅用于定档不影响奖金与排名
"""
if work_days <= 0:
return Decimal('0')
return (effective_hours / Decimal(str(work_days))) * Decimal('30')
def _should_apply_new_hire_tier_cap(self, stat_month: date, hire_date: Optional[date]) -> bool:
"""
新入职封顶规则是否生效:
- 仅在规则生效月及之后(默认 2026-03-01 起)
- 仅当入职日期晚于封顶日(默认当月 25 日)
"""
if not hire_date:
return False
effective_from = self._get_new_hire_cap_effective_from()
cap_day = self._get_new_hire_cap_day()
return stat_month >= effective_from and hire_date.day > cap_day
def _get_new_hire_cap_effective_from(self) -> date:
"""
获取新入职封顶规则生效月份(默认 2026-03-01
"""
raw_value = self.config.get("dws.monthly.new_hire_cap_effective_from", "2026-03-01")
if isinstance(raw_value, datetime):
return raw_value.date()
if isinstance(raw_value, date):
return raw_value
if isinstance(raw_value, str):
try:
return datetime.strptime(raw_value.strip(), "%Y-%m-%d").date()
except ValueError:
pass
return date(2026, 3, 1)
def _get_new_hire_cap_day(self) -> int:
"""
获取新入职封顶日(默认 25
"""
value = self.safe_int(self.config.get("dws.monthly.new_hire_cap_day", 25))
return min(max(value, 1), 31)
def _get_new_hire_max_tier_level(self) -> int:
"""
获取新入职封顶档位等级(默认 2 档)
"""
value = self.safe_int(self.config.get("dws.monthly.new_hire_max_tier_level", 2))
return max(value, 0)
def _calculate_ranks(self, records: List[Dict[str, Any]]) -> None:
"""
计算排名(考虑并列)
Top3排名口径按有效业绩总小时数排名
如遇并列则都算比如2个第一则记为2个第一一个第三
"""
if not records:
return
# 按有效业绩降序排序
sorted_records = sorted(
records,
key=lambda x: x.get('effective_hours', Decimal('0')),
reverse=True
)
# 计算考虑并列的排名
values = [
(r.get('assistant_id'), r.get('effective_hours', Decimal('0')))
for r in sorted_records
]
ranked = self.calculate_rank_with_ties(values)
# 创建排名映射
rank_map = {
assistant_id: (rank, dense_rank)
for assistant_id, rank, dense_rank in ranked
}
# 更新记录
for record in records:
assistant_id = record.get('assistant_id')
if assistant_id in rank_map:
rank, _ = rank_map[assistant_id]
record['rank_by_hours'] = rank
record['rank_with_ties'] = rank # 使用考虑并列的排名
def _delete_by_months(
self,
context: TaskContext,
records: List[Dict[str, Any]]
) -> int:
"""
按月份删除已存在的数据
"""
# 获取涉及的月份
months = set(r.get('stat_month') for r in records if r.get('stat_month'))
if not months:
return 0
target_table = self.get_target_table()
full_table = f"{self.DWS_SCHEMA}.{target_table}"
total_deleted = 0
with self.db.conn.cursor() as cur:
for month in months:
sql = f"""
DELETE FROM {full_table}
WHERE site_id = %s AND stat_month = %s
"""
cur.execute(sql, (context.store_id, month))
total_deleted += cur.rowcount
return total_deleted
# 便于外部导入
__all__ = ['AssistantMonthlyTask']

View File

@@ -0,0 +1,437 @@
# -*- coding: utf-8 -*-
"""
助教工资计算任务
功能说明:
"助教+月份"为粒度,计算月度工资明细
数据来源:
- dws_assistant_monthly_summary: 月度业绩汇总
- dws_assistant_recharge_commission: 充值提成Excel导入
- cfg_performance_tier: 绩效档位配置
- cfg_assistant_level_price: 等级定价配置
- cfg_bonus_rules: 奖金规则配置
目标表:
billiards_dws.dws_assistant_salary_calc
更新策略:
- 更新频率:月初计算上月工资
- 幂等方式delete-before-insert按月份
业务规则来自DWS数据库处理需求.md
- 基础课收入 = 基础课小时数 × (客户支付价格 - 专业课抽成)
中级助教基础课170小时3档 = 170 × (108 - 13) = 16,150元
- 附加课收入 = 附加课小时数 × 附加课价格 × (1 - 打赏课抽成比例)
附加课15小时3档 = 15 × 190 × (1 - 0.35) = 1,852.5元
- 包厢课收入 = 包厢课小时数 × (包厢课客户支付价格 - 专业课抽成)
- 冲刺奖金:按规则表配置(历史口径,不累计取最高档)
- Top3奖金1st:1000, 2nd:600, 3rd:400并列都算
- 充值提成来自dws_assistant_recharge_commission
- SCD2口径等级定价使用月份对应的历史值
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class AssistantSalaryTask(BaseDwsTask):
"""
助教工资计算任务
计算每个助教每月的工资明细:
- 课时收入(基础课+附加课)
- 扣款(档位扣款+其他)
- 奖金(档位奖金+冲刺+Top3+充值提成+其他)
- 应发工资
"""
def get_task_code(self) -> str:
return "DWS_ASSISTANT_SALARY"
def get_target_table(self) -> str:
return "dws_assistant_salary_calc"
def get_primary_keys(self) -> List[str]:
return ["site_id", "assistant_id", "salary_month"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据
"""
# 确定工资月份(通常是上月)
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
if self._should_skip_run(end_date):
self.logger.info("%s: 非工资结算期,跳过", self.get_task_code())
return {
'monthly_summary': [],
'recharge_commission': [],
'salary_month': None,
'site_id': context.store_id,
}
salary_month = self._get_salary_month(end_date)
site_id = context.store_id
self.logger.info(
"%s: 提取数据,工资月份 %s",
self.get_task_code(), salary_month
)
# 1. 获取月度业绩汇总
monthly_summary = self._extract_monthly_summary(site_id, salary_month)
# 2. 获取充值提成
recharge_commission = self._extract_recharge_commission(site_id, salary_month)
# 3. 加载配置缓存
self.load_config_cache()
return {
'monthly_summary': monthly_summary,
'recharge_commission': recharge_commission,
'salary_month': salary_month,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据:计算工资
"""
if not extracted.get('salary_month'):
return []
monthly_summary = extracted['monthly_summary']
recharge_commission = extracted['recharge_commission']
salary_month = extracted['salary_month']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,%d 条月度汇总记录",
self.get_task_code(), len(monthly_summary)
)
# 构建充值提成索引
commission_index = {}
for comm in recharge_commission:
asst_id = comm.get('assistant_id')
if asst_id:
commission_index[asst_id] = commission_index.get(asst_id, Decimal('0')) + \
self.safe_decimal(comm.get('commission_amount', 0))
# 计算工资
results = []
for summary in monthly_summary:
record = self._calculate_salary(summary, commission_index, salary_month, site_id)
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
# 删除已存在的数据
deleted = self._delete_by_month(context, transformed)
# 批量插入
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _get_salary_month(self, end_date: date) -> date:
"""
获取工资月份(默认为上月)
"""
# 如果是月初,计算上月工资
if end_date.day <= 5:
if end_date.month == 1:
return date(end_date.year - 1, 12, 1)
else:
return date(end_date.year, end_date.month - 1, 1)
else:
# 否则计算当月(可能是调整)
return end_date.replace(day=1)
def _should_skip_run(self, end_date: date) -> bool:
"""
工资计算仅在月初运行(默认前 N 天)
"""
allow_out_of_cycle = bool(self.config.get("dws.salary.allow_out_of_cycle", False))
if allow_out_of_cycle:
return False
run_days = self.safe_int(self.config.get("dws.salary.run_days", 5))
if run_days <= 0:
return False
return end_date.day > run_days
def _extract_monthly_summary(
self,
site_id: int,
salary_month: date
) -> List[Dict[str, Any]]:
"""
提取月度业绩汇总
"""
sql = """
SELECT
assistant_id,
assistant_nickname,
stat_month,
assistant_level_code,
assistant_level_name,
hire_date,
is_new_hire,
effective_hours,
base_hours,
bonus_hours,
room_hours,
tier_id,
tier_code,
tier_name,
rank_with_ties
FROM billiards_dws.dws_assistant_monthly_summary
WHERE site_id = %s AND stat_month = %s
"""
rows = self.db.query(sql, (site_id, salary_month))
return [dict(row) for row in rows] if rows else []
def _extract_recharge_commission(
self,
site_id: int,
salary_month: date
) -> List[Dict[str, Any]]:
"""
提取充值提成
"""
sql = """
SELECT
assistant_id,
commission_amount
FROM billiards_dws.dws_assistant_recharge_commission
WHERE site_id = %s AND commission_month = %s
"""
rows = self.db.query(sql, (site_id, salary_month))
return [dict(row) for row in rows] if rows else []
# ==========================================================================
# 工资计算方法
# ==========================================================================
def _calculate_salary(
self,
summary: Dict[str, Any],
commission_index: Dict[int, Decimal],
salary_month: date,
site_id: int
) -> Dict[str, Any]:
"""
计算单个助教的月度工资
"""
assistant_id = summary.get('assistant_id')
level_code = summary.get('assistant_level_code')
effective_hours = self.safe_decimal(summary.get('effective_hours', 0))
base_hours = self.safe_decimal(summary.get('base_hours', 0))
bonus_hours = self.safe_decimal(summary.get('bonus_hours', 0))
room_hours = self.safe_decimal(summary.get('room_hours', 0))
is_new_hire = summary.get('is_new_hire', False)
rank = summary.get('rank_with_ties')
# 获取等级定价SCD2口径按月份取值
# base_course_price: 客户支付价格初级98/中级108/高级118/星级138
# bonus_course_price: 附加课客户支付价格固定190元
# room_course_price: 包厢课客户支付价格固定138元
level_price = self.get_level_price(level_code, salary_month)
base_course_price = self.safe_decimal(
level_price.get('base_course_price', 98) if level_price else 98
)
bonus_course_price = self.safe_decimal(
level_price.get('bonus_course_price', 190) if level_price else 190
)
room_course_price = self.safe_decimal(
self.config.get("dws.salary.room_course_price", 138)
)
# 获取档位配置
# base_deduction: 专业课抽成(元/小时),球房从每小时扣除
# bonus_deduction_ratio: 打赏课抽成比例,球房从附加课收入扣除的比例
tier = self.get_performance_tier_by_id(summary.get('tier_id'), salary_month)
if not tier:
tier = self.get_performance_tier(
effective_hours,
is_new_hire,
effective_date=salary_month
)
base_deduction = self.safe_decimal(tier.get('base_deduction', 18)) if tier else Decimal('18')
bonus_deduction_ratio = self.safe_decimal(tier.get('bonus_deduction_ratio', 0.40)) if tier else Decimal('0.40')
vacation_days = tier.get('vacation_days', 0) if tier else 0
vacation_unlimited = tier.get('vacation_unlimited', False) if tier else False
# ============================================================
# 工资计算公式来自DWS数据库处理需求.md
# ============================================================
# 基础课收入 = 基础课小时数 × (客户支付价格 - 专业课抽成)
# 例中级助教170小时3档 = 170 × (108 - 13) = 16,150元
base_income = base_hours * (base_course_price - base_deduction)
# 附加课收入 = 附加课小时数 × 附加课价格 × (1 - 打赏课抽成比例)
# 例15小时3档 = 15 × 190 × (1 - 0.35) = 1,852.5元
bonus_income = bonus_hours * bonus_course_price * (Decimal('1') - bonus_deduction_ratio)
# 包厢课收入(按包厢课统一价格口径)
room_income = room_hours * (room_course_price - base_deduction)
# 课时收入合计
total_course_income = base_income + bonus_income + room_income
# 计算冲刺奖金(按规则表配置,不累计取最高)
sprint_bonus = self.calculate_sprint_bonus(effective_hours, salary_month)
# 计算Top3排名奖金1st:1000, 2nd:600, 3rd:400并列都算
top_rank_bonus = Decimal('0')
if rank and rank <= 3:
top_rank_bonus = self.calculate_top_rank_bonus(rank, salary_month)
# 获取充值提成
recharge_commission = commission_index.get(assistant_id, Decimal('0'))
# 汇总奖金
other_bonus = Decimal('0') # 预留其他奖金
total_bonus = sprint_bonus + top_rank_bonus + recharge_commission + other_bonus
# 计算应发工资 = 课时收入 + 奖金
gross_salary = total_course_income + total_bonus
# 构建记录
return {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'assistant_id': assistant_id,
'assistant_nickname': summary.get('assistant_nickname'),
'salary_month': salary_month,
'assistant_level_code': level_code,
'assistant_level_name': summary.get('assistant_level_name'),
'hire_date': summary.get('hire_date'),
'is_new_hire': is_new_hire,
'effective_hours': effective_hours,
'base_hours': base_hours,
'bonus_hours': bonus_hours,
'room_hours': room_hours,
'tier_id': summary.get('tier_id'),
'tier_code': tier.get('tier_code') if tier else None,
'tier_name': tier.get('tier_name') if tier else None,
'rank_with_ties': rank,
# 定价信息
'base_course_price': base_course_price,
'bonus_course_price': bonus_course_price,
'base_deduction': base_deduction,
'bonus_deduction_ratio': bonus_deduction_ratio,
# 收入明细
'base_income': base_income,
'bonus_income': bonus_income,
'room_income': room_income,
'total_course_income': total_course_income,
# 奖金明细
'sprint_bonus': sprint_bonus,
'top_rank_bonus': top_rank_bonus,
'recharge_commission': recharge_commission,
'other_bonus': other_bonus,
'total_bonus': total_bonus,
# 应发工资
'gross_salary': gross_salary,
# 假期
'vacation_days': vacation_days,
'vacation_unlimited': vacation_unlimited,
'calc_notes': self._build_calc_notes(summary, tier, sprint_bonus, top_rank_bonus),
}
def _build_calc_notes(
self,
summary: Dict[str, Any],
tier: Optional[Dict[str, Any]],
sprint_bonus: Decimal,
top_rank_bonus: Decimal
) -> Optional[str]:
"""
构建计算备注
"""
notes = []
if summary.get('is_new_hire'):
notes.append("新入职首月")
if tier:
notes.append(f"档位: {tier.get('tier_name', 'N/A')}")
if sprint_bonus > 0:
notes.append(f"冲刺奖金: {sprint_bonus}")
if top_rank_bonus > 0:
rank = summary.get('rank_with_ties')
notes.append(f"Top{rank}奖金: {top_rank_bonus}")
return "; ".join(notes) if notes else None
def _delete_by_month(
self,
context: TaskContext,
records: List[Dict[str, Any]]
) -> int:
"""
按月份删除已存在的数据
"""
months = set(r.get('salary_month') for r in records if r.get('salary_month'))
if not months:
return 0
target_table = self.get_target_table()
full_table = f"{self.DWS_SCHEMA}.{target_table}"
total_deleted = 0
with self.db.conn.cursor() as cur:
for month in months:
sql = f"""
DELETE FROM {full_table}
WHERE site_id = %s AND salary_month = %s
"""
cur.execute(sql, (context.store_id, month))
total_deleted += cur.rowcount
return total_deleted
# 便于外部导入
__all__ = ['AssistantSalaryTask']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,627 @@
# -*- coding: utf-8 -*-
"""
财务日度汇总任务
功能说明:
"日期"为粒度,汇总当日财务数据
数据来源:
- dwd_settlement_head: 结账单头表
- dwd_groupbuy_redemption: 团购核销
- dwd_recharge_order: 充值订单
- dws_finance_expense_summary: 支出汇总Excel导入
- dws_platform_settlement: 平台回款/服务费Excel导入
目标表:
billiards_dws.dws_finance_daily_summary
更新策略:
- 更新频率:每小时更新当日数据
- 幂等方式delete-before-insert按日期
业务规则:
- 发生额table_charge_money + goods_money + assistant_pd_money + assistant_cx_money
- 团购优惠coupon_amount - 团购支付金额
- 团购支付pl_coupon_sale_amount 或关联 groupbuy_redemption.ledger_unit_price
- 首充/续充:通过 is_first 字段区分
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
import calendar
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class FinanceDailyTask(BaseDwsTask):
"""
财务日度汇总任务
汇总每日的:
- 发生额(正价)
- 优惠拆分
- 确认收入
- 现金流(流入/流出)
- 充值统计(首充/续充)
- 订单统计
"""
def get_task_code(self) -> str:
return "DWS_FINANCE_DAILY"
def get_target_table(self) -> str:
return "dws_finance_daily_summary"
def get_primary_keys(self) -> List[str]:
return ["site_id", "stat_date"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据
"""
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
self.logger.info(
"%s: 提取数据,日期范围 %s ~ %s",
self.get_task_code(), start_date, end_date
)
# 1. 获取结账单汇总
settlement_summary = self._extract_settlement_summary(site_id, start_date, end_date)
# 2. 获取团购核销汇总
groupbuy_summary = self._extract_groupbuy_summary(site_id, start_date, end_date)
# 3. 获取充值汇总
recharge_summary = self._extract_recharge_summary(site_id, start_date, end_date)
# 3.1 获取赠送卡消费汇总(余额变动)
gift_card_summary = self._extract_gift_card_consume_summary(site_id, start_date, end_date)
# 4. 获取支出汇总(来自导入表)
expense_summary = self._extract_expense_summary(site_id, start_date, end_date)
# 5. 获取平台回款汇总(来自导入表)
platform_summary = self._extract_platform_summary(site_id, start_date, end_date)
# 6. 获取大客户优惠明细(用于拆分手动优惠)
big_customer_summary = self._extract_big_customer_discounts(site_id, start_date, end_date)
return {
'settlement_summary': settlement_summary,
'groupbuy_summary': groupbuy_summary,
'recharge_summary': recharge_summary,
'gift_card_summary': gift_card_summary,
'expense_summary': expense_summary,
'platform_summary': platform_summary,
'big_customer_summary': big_customer_summary,
'start_date': start_date,
'end_date': end_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据:按日期聚合
"""
settlement_summary = extracted['settlement_summary']
groupbuy_summary = extracted['groupbuy_summary']
recharge_summary = extracted['recharge_summary']
gift_card_summary = extracted['gift_card_summary']
expense_summary = extracted['expense_summary']
platform_summary = extracted['platform_summary']
big_customer_summary = extracted['big_customer_summary']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,%d 天结账数据,%d 天充值数据",
self.get_task_code(), len(settlement_summary), len(recharge_summary)
)
# 按日期合并数据
dates = set()
for item in settlement_summary + recharge_summary + gift_card_summary + expense_summary + platform_summary:
stat_date = item.get('stat_date')
if stat_date:
dates.add(stat_date)
# 构建索引
settle_index = {s['stat_date']: s for s in settlement_summary}
groupbuy_index = {g['stat_date']: g for g in groupbuy_summary}
recharge_index = {r['stat_date']: r for r in recharge_summary}
gift_card_index = {g['stat_date']: g for g in gift_card_summary}
expense_index = {e['stat_date']: e for e in expense_summary}
platform_index = {p['stat_date']: p for p in platform_summary}
big_customer_index = {b['stat_date']: b for b in big_customer_summary}
results = []
for stat_date in sorted(dates):
settle = settle_index.get(stat_date, {})
groupbuy = groupbuy_index.get(stat_date, {})
recharge = recharge_index.get(stat_date, {})
gift_card = gift_card_index.get(stat_date, {})
expense = expense_index.get(stat_date, {})
platform = platform_index.get(stat_date, {})
big_customer = big_customer_index.get(stat_date, {})
record = self._build_daily_record(
stat_date, settle, groupbuy, recharge, gift_card, expense, platform, big_customer, site_id
)
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
deleted = self.delete_existing_data(context, date_col="stat_date")
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _extract_settlement_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取结账单日汇总
"""
sql = """
SELECT
DATE(pay_time) AS stat_date,
COUNT(*) AS order_count,
COUNT(CASE WHEN member_id != 0 AND member_id IS NOT NULL THEN 1 END) AS member_order_count,
COUNT(CASE WHEN member_id = 0 OR member_id IS NULL THEN 1 END) AS guest_order_count,
-- 发生额(正价)
SUM(table_charge_money) AS table_fee_amount,
SUM(goods_money) AS goods_amount,
SUM(assistant_pd_money) AS assistant_pd_amount,
SUM(assistant_cx_money) AS assistant_cx_amount,
SUM(table_charge_money + goods_money + assistant_pd_money + assistant_cx_money) AS gross_amount,
-- 支付
SUM(pay_amount) AS cash_pay_amount,
SUM(recharge_card_amount) AS card_pay_amount,
SUM(balance_amount) AS balance_pay_amount,
-- 优惠
SUM(coupon_amount) AS coupon_amount,
SUM(adjust_amount) AS adjust_amount,
SUM(member_discount_amount) AS member_discount_amount,
SUM(rounding_amount) AS rounding_amount,
SUM(pl_coupon_sale_amount) AS pl_coupon_sale_amount,
-- 消费金额
SUM(consume_money) AS total_consume
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %s
AND DATE(pay_time) >= %s
AND DATE(pay_time) <= %s
GROUP BY DATE(pay_time)
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_groupbuy_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取团购核销日汇总
"""
sql = """
SELECT
sh.pay_time::DATE AS stat_date,
COUNT(CASE WHEN sh.coupon_amount > 0 THEN 1 END) AS groupbuy_count,
SUM(
CASE
WHEN sh.coupon_amount > 0 THEN
CASE
WHEN sh.pl_coupon_sale_amount > 0 THEN sh.pl_coupon_sale_amount
ELSE COALESCE(gr.ledger_unit_price, 0)
END
ELSE 0
END
) AS groupbuy_pay_total
FROM billiards_dwd.dwd_settlement_head sh
LEFT JOIN billiards_dwd.dwd_groupbuy_redemption gr
ON gr.order_settle_id = sh.order_settle_id
AND COALESCE(gr.is_delete, 0) = 0
WHERE sh.site_id = %s
AND sh.pay_time >= %s
AND sh.pay_time < %s + INTERVAL '1 day'
GROUP BY sh.pay_time::DATE
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_recharge_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取充值日汇总
"""
sql = """
SELECT
DATE(pay_time) AS stat_date,
COUNT(*) AS recharge_count,
SUM(pay_money + gift_money) AS recharge_total,
SUM(pay_money) AS recharge_cash,
SUM(gift_money) AS recharge_gift,
COUNT(CASE WHEN is_first = 1 THEN 1 END) AS first_recharge_count,
SUM(CASE WHEN is_first = 1 THEN pay_money + gift_money ELSE 0 END) AS first_recharge_total,
SUM(CASE WHEN is_first = 1 THEN pay_money ELSE 0 END) AS first_recharge_cash,
SUM(CASE WHEN is_first = 1 THEN gift_money ELSE 0 END) AS first_recharge_gift,
COUNT(CASE WHEN is_first = 0 OR is_first IS NULL THEN 1 END) AS renewal_count,
SUM(CASE WHEN is_first = 0 OR is_first IS NULL THEN pay_money + gift_money ELSE 0 END) AS renewal_total,
SUM(CASE WHEN is_first = 0 OR is_first IS NULL THEN pay_money ELSE 0 END) AS renewal_cash,
SUM(CASE WHEN is_first = 0 OR is_first IS NULL THEN gift_money ELSE 0 END) AS renewal_gift,
COUNT(DISTINCT member_id) AS recharge_member_count
FROM billiards_dwd.dwd_recharge_order
WHERE site_id = %s
AND DATE(pay_time) >= %s
AND DATE(pay_time) <= %s
GROUP BY DATE(pay_time)
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_gift_card_consume_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取赠送卡消费汇总(来自余额变动)
"""
gift_card_type_ids = (
2791990152417157, # 台费卡
2794699703437125, # 酒水卡
2793266846533445, # 活动抵用券
)
id_list = ", ".join(str(card_id) for card_id in gift_card_type_ids)
sql = f"""
SELECT
change_time::DATE AS stat_date,
SUM(ABS(change_amount)) AS gift_card_consume
FROM billiards_dwd.dwd_member_balance_change
WHERE site_id = %s
AND change_time >= %s
AND change_time < %s + INTERVAL '1 day'
AND from_type = 1
AND change_amount < 0
AND COALESCE(is_delete, 0) = 0
AND card_type_id IN ({id_list})
GROUP BY change_time::DATE
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_expense_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取支出汇总(来自导入表,按月分摊到日)
"""
if start_date > end_date:
return []
start_month = start_date.replace(day=1)
end_month = end_date.replace(day=1)
sql = """
SELECT
expense_month,
SUM(expense_amount) AS expense_amount
FROM billiards_dws.dws_finance_expense_summary
WHERE site_id = %s
AND expense_month >= %s
AND expense_month <= %s
GROUP BY expense_month
"""
rows = self.db.query(sql, (site_id, start_month, end_month))
if not rows:
return []
daily_totals: Dict[date, Decimal] = {}
for row in rows:
row_dict = dict(row)
month_date = row_dict.get('expense_month')
if not month_date:
continue
amount = self.safe_decimal(row_dict.get('expense_amount', 0))
days_in_month = calendar.monthrange(month_date.year, month_date.month)[1]
daily_amount = amount / Decimal(str(days_in_month)) if days_in_month > 0 else Decimal('0')
for day in range(1, days_in_month + 1):
stat_date = date(month_date.year, month_date.month, day)
if stat_date < start_date or stat_date > end_date:
continue
daily_totals[stat_date] = daily_totals.get(stat_date, Decimal('0')) + daily_amount
return [
{'stat_date': stat_date, 'expense_amount': amount}
for stat_date, amount in sorted(daily_totals.items())
]
def _extract_platform_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取平台回款/服务费汇总(来自导入表)
"""
sql = """
SELECT
settlement_date AS stat_date,
SUM(settlement_amount) AS settlement_amount,
SUM(commission_amount) AS commission_amount,
SUM(service_fee) AS service_fee
FROM billiards_dws.dws_platform_settlement
WHERE site_id = %s
AND settlement_date >= %s
AND settlement_date <= %s
GROUP BY settlement_date
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_big_customer_discounts(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取大客户优惠(用于拆分手动调整)
"""
member_ids = self._parse_id_list(self.config.get("dws.discount.big_customer_member_ids"))
order_ids = self._parse_id_list(self.config.get("dws.discount.big_customer_order_ids"))
if not member_ids and not order_ids:
return []
sql = """
SELECT
pay_time::DATE AS stat_date,
order_settle_id,
member_id,
adjust_amount
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %s
AND pay_time >= %s
AND pay_time < %s + INTERVAL '1 day'
AND adjust_amount != 0
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
if not rows:
return []
result: Dict[date, Dict[str, Any]] = {}
for row in rows:
row_dict = dict(row)
stat_date = row_dict.get('stat_date')
if not stat_date:
continue
order_id = row_dict.get('order_settle_id')
member_id = row_dict.get('member_id')
if order_id not in order_ids and member_id not in member_ids:
continue
amount = abs(self.safe_decimal(row_dict.get('adjust_amount', 0)))
entry = result.setdefault(stat_date, {'stat_date': stat_date, 'big_customer_amount': Decimal('0'), 'big_customer_count': 0})
entry['big_customer_amount'] += amount
entry['big_customer_count'] += 1
return list(result.values())
def _parse_id_list(self, value: Any) -> set:
if not value:
return set()
if isinstance(value, str):
items = [v.strip() for v in value.split(",") if v.strip()]
return {int(v) for v in items if v.isdigit()}
if isinstance(value, (list, tuple, set)):
result = set()
for item in value:
if item is None:
continue
try:
result.add(int(item))
except (ValueError, TypeError):
continue
return result
return set()
# ==========================================================================
# 数据转换方法
# ==========================================================================
def _build_daily_record(
self,
stat_date: date,
settle: Dict[str, Any],
groupbuy: Dict[str, Any],
recharge: Dict[str, Any],
gift_card: Dict[str, Any],
expense: Dict[str, Any],
platform: Dict[str, Any],
big_customer: Dict[str, Any],
site_id: int
) -> Dict[str, Any]:
"""
构建日度财务记录
"""
# 发生额
gross_amount = self.safe_decimal(settle.get('gross_amount', 0))
table_fee_amount = self.safe_decimal(settle.get('table_fee_amount', 0))
goods_amount = self.safe_decimal(settle.get('goods_amount', 0))
assistant_pd_amount = self.safe_decimal(settle.get('assistant_pd_amount', 0))
assistant_cx_amount = self.safe_decimal(settle.get('assistant_cx_amount', 0))
# 支付
cash_pay_amount = self.safe_decimal(settle.get('cash_pay_amount', 0))
card_pay_amount = self.safe_decimal(settle.get('card_pay_amount', 0))
balance_pay_amount = self.safe_decimal(settle.get('balance_pay_amount', 0))
# 优惠
coupon_amount = self.safe_decimal(settle.get('coupon_amount', 0))
pl_coupon_sale = self.safe_decimal(settle.get('pl_coupon_sale_amount', 0))
groupbuy_pay = self.safe_decimal(groupbuy.get('groupbuy_pay_total', 0))
# 团购支付金额优先使用pl_coupon_sale_amount否则使用groupbuy核销金额
if pl_coupon_sale > 0:
groupbuy_pay_amount = pl_coupon_sale
else:
groupbuy_pay_amount = groupbuy_pay
# 团购优惠 = 团购抵消台费 - 团购支付金额
discount_groupbuy = coupon_amount - groupbuy_pay_amount
if discount_groupbuy < 0:
discount_groupbuy = Decimal('0')
adjust_amount = self.safe_decimal(settle.get('adjust_amount', 0))
member_discount = self.safe_decimal(settle.get('member_discount_amount', 0))
rounding_amount = self.safe_decimal(settle.get('rounding_amount', 0))
big_customer_amount = self.safe_decimal(big_customer.get('big_customer_amount', 0))
other_discount = adjust_amount - big_customer_amount
if other_discount < 0:
other_discount = Decimal('0')
# 赠送卡消费(来自余额变动)
gift_card_consume_amount = self.safe_decimal(gift_card.get('gift_card_consume', 0))
# 优惠合计
discount_total = discount_groupbuy + member_discount + gift_card_consume_amount + adjust_amount + rounding_amount
# 确认收入
confirmed_income = gross_amount - discount_total
# 现金流
platform_settlement_amount = self.safe_decimal(platform.get('settlement_amount', 0))
platform_fee_amount = (
self.safe_decimal(platform.get('commission_amount', 0))
+ self.safe_decimal(platform.get('service_fee', 0))
)
recharge_cash_inflow = self.safe_decimal(recharge.get('recharge_cash', 0))
platform_inflow = platform_settlement_amount if platform_settlement_amount > 0 else groupbuy_pay_amount
cash_inflow_total = cash_pay_amount + platform_inflow + recharge_cash_inflow
cash_outflow_total = self.safe_decimal(expense.get('expense_amount', 0)) + platform_fee_amount
cash_balance_change = cash_inflow_total - cash_outflow_total
# 卡消费
cash_card_consume = card_pay_amount + balance_pay_amount
gift_card_consume = gift_card_consume_amount
card_consume_total = cash_card_consume + gift_card_consume
# 充值统计
recharge_count = self.safe_int(recharge.get('recharge_count', 0))
recharge_total = self.safe_decimal(recharge.get('recharge_total', 0))
recharge_cash = self.safe_decimal(recharge.get('recharge_cash', 0))
recharge_gift = self.safe_decimal(recharge.get('recharge_gift', 0))
first_recharge_count = self.safe_int(recharge.get('first_recharge_count', 0))
first_recharge_amount = self.safe_decimal(recharge.get('first_recharge_total', 0))
renewal_count = self.safe_int(recharge.get('renewal_count', 0))
renewal_amount = self.safe_decimal(recharge.get('renewal_total', 0))
# 订单统计
order_count = self.safe_int(settle.get('order_count', 0))
member_order_count = self.safe_int(settle.get('member_order_count', 0))
guest_order_count = self.safe_int(settle.get('guest_order_count', 0))
avg_order_amount = gross_amount / order_count if order_count > 0 else Decimal('0')
return {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'stat_date': stat_date,
# 发生额
'gross_amount': gross_amount,
'table_fee_amount': table_fee_amount,
'goods_amount': goods_amount,
'assistant_pd_amount': assistant_pd_amount,
'assistant_cx_amount': assistant_cx_amount,
# 优惠
'discount_total': discount_total,
'discount_groupbuy': discount_groupbuy,
'discount_vip': member_discount,
'discount_gift_card': gift_card_consume_amount,
'discount_manual': adjust_amount,
'discount_rounding': rounding_amount,
'discount_other': other_discount,
# 确认收入
'confirmed_income': confirmed_income,
# 现金流
'cash_inflow_total': cash_inflow_total,
'cash_pay_amount': cash_pay_amount,
'groupbuy_pay_amount': groupbuy_pay_amount,
'platform_settlement_amount': platform_settlement_amount,
'platform_fee_amount': platform_fee_amount,
'recharge_cash_inflow': recharge_cash_inflow,
'card_consume_total': card_consume_total,
'cash_card_consume': cash_card_consume,
'gift_card_consume': gift_card_consume,
'cash_outflow_total': cash_outflow_total,
'cash_balance_change': cash_balance_change,
# 充值统计
'recharge_count': recharge_count,
'recharge_total': recharge_total,
'recharge_cash': recharge_cash,
'recharge_gift': recharge_gift,
'first_recharge_count': first_recharge_count,
'first_recharge_amount': first_recharge_amount,
'renewal_count': renewal_count,
'renewal_amount': renewal_amount,
# 订单统计
'order_count': order_count,
'member_order_count': member_order_count,
'guest_order_count': guest_order_count,
'avg_order_amount': avg_order_amount,
}
# 便于外部导入
__all__ = ['FinanceDailyTask']

View File

@@ -0,0 +1,486 @@
# -*- coding: utf-8 -*-
"""
优惠明细分析任务
功能说明:
"日期+优惠类型"为粒度,分析优惠构成
数据来源:
- dwd_settlement_head: 结账单头表(优惠字段)
- dwd_groupbuy_redemption: 团购核销(团购实付金额)
- dwd_member_balance_change: 余额变动(赠送卡消费)
目标表:
billiards_dws.dws_finance_discount_detail
更新策略:
- 更新频率:每日更新
- 幂等方式delete-before-insert按日期
业务规则:
- 团购优惠 (GROUPBUY): coupon_amount - 团购实付金额
- 会员折扣 (VIP): member_discount_amount
- 赠送卡抵扣 (GIFT_CARD_*): dwd_member_balance_change台费卡/酒水卡/活动抵用券)
- 抹零 (ROUNDING): rounding_amount
- 大客户优惠 (BIG_CUSTOMER): 手动调整中标记的大客户订单
- 其他优惠 (OTHER): 手动调整中除大客户外的部分
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class FinanceDiscountDetailTask(BaseDwsTask):
"""
优惠明细分析任务
分析各类优惠的使用情况:
- 团购优惠
- 会员折扣
- 赠送卡抵扣
- 手动调整
- 抹零
- 其他优惠
"""
def get_task_code(self) -> str:
return "DWS_FINANCE_DISCOUNT_DETAIL"
def get_target_table(self) -> str:
return "dws_finance_discount_detail"
def get_primary_keys(self) -> List[str]:
return ["site_id", "stat_date", "discount_type_code"]
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
抽取优惠相关数据
数据来源:
1. settlement_head: 各类优惠字段
2. groupbuy_redemption: 团购实付金额
"""
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
# 从settlement_head抽取优惠数据
discount_summary = self._extract_discount_summary(site_id, start_date, end_date)
# 从groupbuy_redemption获取团购实付金额
groupbuy_payments = self._extract_groupbuy_payments(site_id, start_date, end_date)
# 提取大客户优惠(拆分手动调整)
big_customer_summary = self._extract_big_customer_discounts(site_id, start_date, end_date)
# 提取赠送卡消费(按卡类型拆分)
gift_card_consumes = self._extract_gift_card_consumes(site_id, start_date, end_date)
return {
'discount_summary': discount_summary,
'groupbuy_payments': groupbuy_payments,
'big_customer_summary': big_customer_summary,
'gift_card_consumes': gift_card_consumes,
}
def _extract_discount_summary(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
从结账单头表抽取优惠汇总
字段说明:
- coupon_amount: 团购抵消台费金额
- adjust_amount: 手动调整金额(台费打折)
- member_discount_amount: 会员折扣
- rounding_amount: 抹零金额
- pl_coupon_sale_amount: 平台券销售金额团购实付路径1
"""
sql = """
SELECT
pay_time::DATE AS stat_date,
-- 团购相关
COALESCE(SUM(coupon_amount), 0) AS coupon_amount_total,
COALESCE(SUM(pl_coupon_sale_amount), 0) AS pl_coupon_sale_total,
COUNT(CASE WHEN coupon_amount > 0 THEN 1 END) AS coupon_order_count,
-- 手动调整
COALESCE(SUM(adjust_amount), 0) AS adjust_amount_total,
COUNT(CASE WHEN adjust_amount != 0 THEN 1 END) AS adjust_order_count,
-- 会员折扣
COALESCE(SUM(member_discount_amount), 0) AS member_discount_total,
COUNT(CASE WHEN member_discount_amount > 0 THEN 1 END) AS member_discount_order_count,
-- 抹零
COALESCE(SUM(rounding_amount), 0) AS rounding_amount_total,
COUNT(CASE WHEN rounding_amount != 0 THEN 1 END) AS rounding_order_count,
-- 总订单数
COUNT(*) AS total_orders
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %(site_id)s
AND pay_time >= %(start_date)s
AND pay_time < %(end_date)s + INTERVAL '1 day'
AND settle_status = 1 -- 已结账
GROUP BY pay_time::DATE
ORDER BY stat_date
"""
rows = self.db.query(sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
return [dict(row) for row in rows] if rows else []
def _extract_groupbuy_payments(
self,
site_id: int,
start_date: date,
end_date: date
) -> Dict[date, Decimal]:
"""
从团购核销表获取团购实付金额
团购实付金额计算:
- 若 pl_coupon_sale_amount > 0使用该值
- 否则使用 groupbuy_redemption.ledger_unit_price
返回:{日期: 团购实付总额}
"""
sql = """
SELECT
sh.pay_time::DATE AS stat_date,
SUM(
CASE
WHEN sh.pl_coupon_sale_amount > 0 THEN sh.pl_coupon_sale_amount
ELSE COALESCE(gr.ledger_unit_price, 0)
END
) AS groupbuy_payment
FROM billiards_dwd.dwd_settlement_head sh
LEFT JOIN billiards_dwd.dwd_groupbuy_redemption gr
ON gr.order_settle_id = sh.order_settle_id
AND COALESCE(gr.is_delete, 0) = 0
WHERE sh.site_id = %(site_id)s
AND sh.pay_time >= %(start_date)s
AND sh.pay_time < %(end_date)s + INTERVAL '1 day'
AND sh.settle_status = 1
AND sh.coupon_amount > 0 -- 只统计有团购的订单
GROUP BY sh.pay_time::DATE
"""
rows = self.db.query(sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
result = {}
if rows:
for row in rows:
result[row['stat_date']] = self.safe_decimal(row.get('groupbuy_payment', 0))
return result
def _extract_gift_card_consumes(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取赠送卡消费(按卡类型)
"""
gift_card_type_ids = (
2791990152417157, # 台费卡
2794699703437125, # 酒水卡
2793266846533445, # 活动抵用券
)
id_list = ", ".join(str(card_id) for card_id in gift_card_type_ids)
sql = f"""
SELECT
change_time::DATE AS stat_date,
card_type_id,
COUNT(*) AS consume_count,
SUM(ABS(change_amount)) AS consume_amount
FROM billiards_dwd.dwd_member_balance_change
WHERE site_id = %(site_id)s
AND change_time >= %(start_date)s
AND change_time < %(end_date)s + INTERVAL '1 day'
AND from_type = 1
AND change_amount < 0
AND COALESCE(is_delete, 0) = 0
AND card_type_id IN ({id_list})
GROUP BY change_time::DATE, card_type_id
"""
rows = self.db.query(sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
return [dict(row) for row in rows] if rows else []
def transform(self, data: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据
将抽取的数据转换为目标表格式:
- 每种优惠类型一条记录
- 计算团购优惠coupon_amount - 团购实付)
- 计算优惠占比
"""
site_id = context.store_id
tenant_id = self.config.get("app.tenant_id", site_id)
discount_summary = data.get('discount_summary', [])
groupbuy_payments = data.get('groupbuy_payments', {})
big_customer_summary = {r['stat_date']: r for r in data.get('big_customer_summary', [])}
gift_card_consumes = data.get('gift_card_consumes', [])
records = []
# 优惠类型定义
# (type_code, type_name, amount_field, count_field, special_calc)
discount_types = [
('GROUPBUY', '团购优惠', 'coupon_amount_total', 'coupon_order_count', True),
('VIP', '会员折扣', 'member_discount_total', 'member_discount_order_count', False),
('ROUNDING', '抹零', 'rounding_amount_total', 'rounding_order_count', False),
]
gift_card_type_map = {
2791990152417157: ('GIFT_CARD_TABLE', '台费卡抵扣'),
2794699703437125: ('GIFT_CARD_DRINK', '酒水卡抵扣'),
2793266846533445: ('GIFT_CARD_COUPON', '活动抵用券抵扣'),
}
# 赠送卡消费按日期+类型聚合
gift_card_by_date: Dict[date, Dict[str, Dict[str, Any]]] = {}
for row in gift_card_consumes:
stat_date = row.get('stat_date')
card_type_id = row.get('card_type_id')
type_info = gift_card_type_map.get(card_type_id)
if not stat_date or not type_info:
continue
type_code, type_name = type_info
daily = gift_card_by_date.setdefault(stat_date, {})
entry = daily.setdefault(type_code, {'type_name': type_name, 'amount': Decimal('0'), 'count': 0})
entry['amount'] += self.safe_decimal(row.get('consume_amount', 0))
entry['count'] += self.safe_int(row.get('consume_count', 0))
discount_summary_map = {row.get('stat_date'): row for row in discount_summary if row.get('stat_date')}
stat_dates = set(discount_summary_map.keys())
stat_dates.update(groupbuy_payments.keys())
stat_dates.update(big_customer_summary.keys())
stat_dates.update(gift_card_by_date.keys())
for stat_date in sorted(stat_dates):
daily_data = discount_summary_map.get(stat_date, {})
# 计算各类优惠金额
daily_discounts = {}
total_discount = Decimal('0')
for type_code, type_name, amount_field, count_field, special_calc in discount_types:
if special_calc and type_code == 'GROUPBUY':
# 团购优惠 = 团购抵消台费 - 团购实付
coupon_amount = self.safe_decimal(daily_data.get(amount_field, 0))
groupbuy_paid = groupbuy_payments.get(stat_date, Decimal('0'))
discount_amount = coupon_amount - groupbuy_paid
# 确保优惠金额为正数
discount_amount = max(discount_amount, Decimal('0'))
else:
discount_amount = abs(self.safe_decimal(daily_data.get(amount_field, 0)))
usage_count = daily_data.get(count_field, 0) or 0
daily_discounts[type_code] = {
'type_name': type_name,
'amount': discount_amount,
'count': usage_count,
}
total_discount += discount_amount
# 赠送卡拆分(台费卡/酒水卡/活动券)
gift_daily = gift_card_by_date.get(stat_date, {})
for type_code, type_name in gift_card_type_map.values():
info = gift_daily.get(type_code, {'amount': Decimal('0'), 'count': 0})
daily_discounts[type_code] = {
'type_name': type_name,
'amount': self.safe_decimal(info.get('amount', 0)),
'count': self.safe_int(info.get('count', 0)),
}
total_discount += self.safe_decimal(info.get('amount', 0))
# 拆分手动调整为大客户/其他
adjust_amount = abs(self.safe_decimal(daily_data.get('adjust_amount_total', 0)))
adjust_count = daily_data.get('adjust_order_count', 0) or 0
big_customer_info = big_customer_summary.get(stat_date, {})
big_customer_amount = self.safe_decimal(big_customer_info.get('big_customer_amount', 0))
big_customer_count = big_customer_info.get('big_customer_count', 0) or 0
other_amount = adjust_amount - big_customer_amount
if other_amount < 0:
other_amount = Decimal('0')
other_count = adjust_count - big_customer_count
if other_count < 0:
other_count = 0
daily_discounts['BIG_CUSTOMER'] = {
'type_name': '大客户优惠',
'amount': big_customer_amount,
'count': big_customer_count,
}
daily_discounts['OTHER'] = {
'type_name': '其他优惠',
'amount': other_amount,
'count': other_count,
}
total_discount += big_customer_amount + other_amount
# 为每种优惠类型生成记录
for type_code, discount_info in daily_discounts.items():
discount_amount = discount_info['amount']
usage_count = discount_info['count']
# 计算占比(避免除零)
discount_ratio = (discount_amount / total_discount) if total_discount > 0 else Decimal('0')
records.append({
'site_id': site_id,
'tenant_id': tenant_id,
'stat_date': stat_date,
'discount_type_code': type_code,
'discount_type_name': discount_info['type_name'],
'discount_amount': discount_amount,
'discount_ratio': round(discount_ratio, 4),
'usage_count': usage_count,
'affected_orders': usage_count, # 简化:使用次数=影响订单数
})
return records
def _extract_big_customer_discounts(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取大客户优惠(基于手动调整)
"""
member_ids = self._parse_id_list(self.config.get("dws.discount.big_customer_member_ids"))
order_ids = self._parse_id_list(self.config.get("dws.discount.big_customer_order_ids"))
if not member_ids and not order_ids:
return []
sql = """
SELECT
pay_time::DATE AS stat_date,
order_settle_id,
member_id,
adjust_amount
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %(site_id)s
AND pay_time >= %(start_date)s
AND pay_time < %(end_date)s + INTERVAL '1 day'
AND adjust_amount != 0
"""
rows = self.db.query(sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
if not rows:
return []
result: Dict[date, Dict[str, Any]] = {}
for row in rows:
row_dict = dict(row)
stat_date = row_dict.get('stat_date')
if not stat_date:
continue
order_id = row_dict.get('order_settle_id')
member_id = row_dict.get('member_id')
if order_id not in order_ids and member_id not in member_ids:
continue
amount = abs(self.safe_decimal(row_dict.get('adjust_amount', 0)))
entry = result.setdefault(stat_date, {'stat_date': stat_date, 'big_customer_amount': Decimal('0'), 'big_customer_count': 0})
entry['big_customer_amount'] += amount
entry['big_customer_count'] += 1
return list(result.values())
def _parse_id_list(self, value: Any) -> set:
if not value:
return set()
if isinstance(value, str):
items = [v.strip() for v in value.split(",") if v.strip()]
return {int(v) for v in items if v.isdigit()}
if isinstance(value, (list, tuple, set)):
result = set()
for item in value:
if item is None:
continue
try:
result.add(int(item))
except (ValueError, TypeError):
continue
return result
return set()
def load(self, records: List[Dict[str, Any]], context: TaskContext) -> Dict[str, Any]:
"""
加载数据到目标表
使用幂等方式delete-before-insert按日期范围
"""
if not records:
return {'inserted': 0, 'deleted': 0}
site_id = context.store_id
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
# 删除窗口内的旧数据
delete_sql = """
DELETE FROM billiards_dws.dws_finance_discount_detail
WHERE site_id = %(site_id)s
AND stat_date >= %(start_date)s
AND stat_date <= %(end_date)s
"""
deleted = self.db.execute(delete_sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
# 批量插入新数据
insert_sql = """
INSERT INTO billiards_dws.dws_finance_discount_detail (
site_id, tenant_id, stat_date,
discount_type_code, discount_type_name,
discount_amount, discount_ratio,
usage_count, affected_orders,
created_at, updated_at
) VALUES (
%(site_id)s, %(tenant_id)s, %(stat_date)s,
%(discount_type_code)s, %(discount_type_name)s,
%(discount_amount)s, %(discount_ratio)s,
%(usage_count)s, %(affected_orders)s,
NOW(), NOW()
)
"""
inserted = 0
for record in records:
self.db.execute(insert_sql, record)
inserted += 1
return {
'deleted': deleted or 0,
'inserted': inserted,
}

View File

@@ -0,0 +1,412 @@
# -*- coding: utf-8 -*-
"""
收入结构分析任务
功能说明:
"日期+区域/类型"为粒度,分析收入结构
数据来源:
- dwd_settlement_head: 结账单头表(台费、商品、助教正价)
- dwd_table_fee_log: 台费流水(区域关联)
- dwd_assistant_service_log: 助教服务流水(区域关联)
- cfg_area_category: 区域分类映射
目标表:
billiards_dws.dws_finance_income_structure
更新策略:
- 更新频率:每日更新
- 幂等方式delete-before-insert按日期+类型)
业务规则:
- 结构类型1INCOME_TYPE按收入类型分析台费/商品/助教基础课/助教附加课)
- 结构类型2AREA按区域分析普通台球区/VIP包厢/斯诺克/麻将/KTV等
- 区域映射使用cfg_area_category配置
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class FinanceIncomeStructureTask(BaseDwsTask):
"""
收入结构分析任务
分析收入的两种维度:
1. INCOME_TYPE: 按收入类型(台费/商品/助教基础课/助教附加课)
2. AREA: 按区域使用cfg_area_category映射
"""
def get_task_code(self) -> str:
return "DWS_FINANCE_INCOME_STRUCTURE"
def get_target_table(self) -> str:
return "dws_finance_income_structure"
def get_primary_keys(self) -> List[str]:
return ["site_id", "stat_date", "structure_type", "category_code"]
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
抽取数据
分两条路径抽取:
1. 按收入类型汇总来自settlement_head
2. 按区域汇总来自table_fee_log和assistant_service_log
"""
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
# 按收入类型汇总
income_by_type = self._extract_income_by_type(site_id, start_date, end_date)
# 按区域汇总
income_by_area = self._extract_income_by_area(site_id, start_date, end_date)
return {
'income_by_type': income_by_type,
'income_by_area': income_by_area,
}
def _extract_income_by_type(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
按收入类型汇总
收入类型分类:
- TABLE_FEE: 台费收入 (table_charge_money)
- GOODS: 商品收入 (goods_money)
- ASSISTANT_BASE: 助教基础课 (assistant_pd_money)
- ASSISTANT_BONUS: 助教附加课 (assistant_cx_money)
"""
sql = """
SELECT
pay_time::DATE AS stat_date,
-- 台费收入
COALESCE(SUM(table_charge_money), 0) AS table_fee_income,
COUNT(CASE WHEN table_charge_money > 0 THEN 1 END) AS table_fee_orders,
-- 商品收入
COALESCE(SUM(goods_money), 0) AS goods_income,
COUNT(CASE WHEN goods_money > 0 THEN 1 END) AS goods_orders,
-- 助教基础课收入PD=陪打)
COALESCE(SUM(assistant_pd_money), 0) AS assistant_base_income,
COUNT(CASE WHEN assistant_pd_money > 0 THEN 1 END) AS assistant_base_orders,
-- 助教附加课收入CX=超休/促销)
COALESCE(SUM(assistant_cx_money), 0) AS assistant_bonus_income,
COUNT(CASE WHEN assistant_cx_money > 0 THEN 1 END) AS assistant_bonus_orders,
-- 总订单数
COUNT(*) AS total_orders
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %(site_id)s
AND pay_time >= %(start_date)s
AND pay_time < %(end_date)s + INTERVAL '1 day'
AND settle_status = 1 -- 已结账
GROUP BY pay_time::DATE
ORDER BY stat_date
"""
rows = self.db.query(sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
return [dict(row) for row in rows] if rows else []
def _extract_income_by_area(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
按区域汇总收入
关联dim_table获取区域名称再映射到cfg_area_category
"""
sql = """
WITH area_orders AS (
SELECT
tfl.pay_time::DATE AS stat_date,
dt.site_table_area_name AS area_name,
tfl.order_settle_id,
COALESCE(tfl.ledger_amount, 0) AS income_amount,
COALESCE(tfl.ledger_time_seconds, 0) AS duration_seconds
FROM billiards_dwd.dwd_table_fee_log tfl
LEFT JOIN billiards_dwd.dim_table dt
ON dt.site_table_id = tfl.site_table_id
WHERE tfl.site_id = %(site_id)s
AND tfl.pay_time >= %(start_date)s
AND tfl.pay_time < %(end_date)s + INTERVAL '1 day'
AND COALESCE(tfl.is_delete, 0) = 0
UNION ALL
SELECT
asl.start_use_time::DATE AS stat_date,
dt.site_table_area_name AS area_name,
asl.order_settle_id,
COALESCE(asl.ledger_amount, 0) AS income_amount,
COALESCE(asl.income_seconds, 0) AS duration_seconds
FROM billiards_dwd.dwd_assistant_service_log asl
LEFT JOIN billiards_dwd.dim_table dt
ON dt.site_table_id = asl.site_table_id
WHERE asl.site_id = %(site_id)s
AND asl.start_use_time >= %(start_date)s
AND asl.start_use_time < %(end_date)s + INTERVAL '1 day'
AND asl.is_delete = 0
)
SELECT
stat_date,
area_name,
COALESCE(SUM(income_amount), 0) AS income_amount,
COALESCE(SUM(duration_seconds), 0) AS duration_seconds,
COUNT(DISTINCT order_settle_id) AS order_count
FROM area_orders
GROUP BY stat_date, area_name
ORDER BY stat_date, area_name
"""
rows = self.db.query(sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
return [dict(row) for row in rows] if rows else []
def transform(self, data: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据
将抽取的数据转换为目标表格式:
1. 按收入类型展开(每种类型一条记录)
2. 按区域展开(每个区域一条记录)
3. 计算占比
"""
site_id = context.store_id
tenant_id = self.config.get("app.tenant_id", site_id)
records = []
# 处理按收入类型的数据
income_type_records = self._transform_income_by_type(
data.get('income_by_type', []),
site_id,
tenant_id
)
records.extend(income_type_records)
# 处理按区域的数据
area_records = self._transform_income_by_area(
data.get('income_by_area', []),
site_id,
tenant_id
)
records.extend(area_records)
return records
def _transform_income_by_type(
self,
income_data: List[Dict[str, Any]],
site_id: int,
tenant_id: int
) -> List[Dict[str, Any]]:
"""
转换按收入类型的数据
将每日汇总数据展开为4条记录台费/商品/基础课/附加课)
"""
# 收入类型定义
income_types = [
('TABLE_FEE', '台费收入', 'table_fee_income', 'table_fee_orders'),
('GOODS', '商品收入', 'goods_income', 'goods_orders'),
('ASSISTANT_BASE', '助教基础课', 'assistant_base_income', 'assistant_base_orders'),
('ASSISTANT_BONUS', '助教附加课', 'assistant_bonus_income', 'assistant_bonus_orders'),
]
records = []
for daily_data in income_data:
stat_date = daily_data.get('stat_date')
# 计算当日总收入(用于计算占比)
total_income = sum(
self.safe_decimal(daily_data.get(field, 0))
for _, _, field, _ in income_types
)
# 为每种收入类型生成一条记录
for type_code, type_name, income_field, order_field in income_types:
income_amount = self.safe_decimal(daily_data.get(income_field, 0))
order_count = daily_data.get(order_field, 0) or 0
# 计算占比(避免除零)
income_ratio = (income_amount / total_income) if total_income > 0 else Decimal('0')
records.append({
'site_id': site_id,
'tenant_id': tenant_id,
'stat_date': stat_date,
'structure_type': 'INCOME_TYPE',
'category_code': type_code,
'category_name': type_name,
'income_amount': income_amount,
'income_ratio': round(income_ratio, 4),
'order_count': order_count,
'duration_minutes': 0, # 收入类型维度不统计时长
})
return records
def _transform_income_by_area(
self,
area_data: List[Dict[str, Any]],
site_id: int,
tenant_id: int
) -> List[Dict[str, Any]]:
"""
转换按区域的数据
将区域名称映射到cfg_area_category的category_code
"""
records = []
# 加载区域分类配置
self.load_config_cache()
# 按日期分组计算总收入(用于计算占比)
daily_totals = {}
for row in area_data:
stat_date = row.get('stat_date')
income = self.safe_decimal(row.get('income_amount', 0))
daily_totals[stat_date] = daily_totals.get(stat_date, Decimal('0')) + income
# 按日期+区域聚合相同category_code需要合并
aggregated = {}
for row in area_data:
stat_date = row.get('stat_date')
area_name = row.get('area_name') or '未知区域'
income_amount = self.safe_decimal(row.get('income_amount', 0))
duration_seconds = row.get('duration_seconds', 0) or 0
order_count = row.get('order_count', 0) or 0
# 映射区域名称到分类代码
category = self.get_area_category(area_name)
category_code = category.get('category_code', 'OTHER')
category_name = category.get('category_name', '其他区域')
# 聚合键
key = (stat_date, category_code)
if key not in aggregated:
aggregated[key] = {
'stat_date': stat_date,
'category_code': category_code,
'category_name': category_name,
'income_amount': Decimal('0'),
'duration_seconds': 0,
'order_count': 0,
}
aggregated[key]['income_amount'] += income_amount
aggregated[key]['duration_seconds'] += duration_seconds
aggregated[key]['order_count'] += order_count
# 生成记录
for key, agg_data in aggregated.items():
stat_date = agg_data['stat_date']
total_income = daily_totals.get(stat_date, Decimal('1'))
income_amount = agg_data['income_amount']
# 计算占比
income_ratio = (income_amount / total_income) if total_income > 0 else Decimal('0')
records.append({
'site_id': site_id,
'tenant_id': tenant_id,
'stat_date': stat_date,
'structure_type': 'AREA',
'category_code': agg_data['category_code'],
'category_name': agg_data['category_name'],
'income_amount': income_amount,
'income_ratio': round(income_ratio, 4),
'order_count': agg_data['order_count'],
'duration_minutes': agg_data['duration_seconds'] // 60,
})
return records
def _map_area_to_category(
self,
area_name: str,
area_categories: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
"""
兼容旧逻辑的映射方法(当前使用 get_area_category
"""
return self.get_area_category(area_name)
def load(self, records: List[Dict[str, Any]], context: TaskContext) -> Dict[str, Any]:
"""
加载数据到目标表
使用幂等方式delete-before-insert按日期范围
"""
if not records:
return {'inserted': 0, 'deleted': 0}
site_id = context.store_id
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
# 删除窗口内的旧数据
delete_sql = """
DELETE FROM billiards_dws.dws_finance_income_structure
WHERE site_id = %(site_id)s
AND stat_date >= %(start_date)s
AND stat_date <= %(end_date)s
"""
deleted = self.db.execute(delete_sql, {
'site_id': site_id,
'start_date': start_date,
'end_date': end_date,
})
# 批量插入新数据
insert_sql = """
INSERT INTO billiards_dws.dws_finance_income_structure (
site_id, tenant_id, stat_date,
structure_type, category_code, category_name,
income_amount, income_ratio,
order_count, duration_minutes,
created_at, updated_at
) VALUES (
%(site_id)s, %(tenant_id)s, %(stat_date)s,
%(structure_type)s, %(category_code)s, %(category_name)s,
%(income_amount)s, %(income_ratio)s,
%(order_count)s, %(duration_minutes)s,
NOW(), NOW()
)
"""
inserted = 0
for record in records:
self.db.execute(insert_sql, record)
inserted += 1
return {
'deleted': deleted or 0,
'inserted': inserted,
}

View File

@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
"""
充值统计任务
功能说明:
"日期"为粒度,统计充值数据
数据来源:
- dwd_recharge_order: 充值订单
- dim_member_card_account: 会员卡账户(余额快照)
目标表:
billiards_dws.dws_finance_recharge_summary
更新策略:
- 更新频率:每日更新
- 幂等方式delete-before-insert按日期
业务规则:
- 首充/续充:通过 is_first 字段区分
- 现金/赠送:通过 pay_money/gift_money 区分
- 卡余额:区分储值卡和赠送卡
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class FinanceRechargeTask(BaseDwsTask):
"""
充值统计任务
"""
def get_task_code(self) -> str:
return "DWS_FINANCE_RECHARGE"
def get_target_table(self) -> str:
return "dws_finance_recharge_summary"
def get_primary_keys(self) -> List[str]:
return ["site_id", "stat_date"]
def extract(self, context: TaskContext) -> Dict[str, Any]:
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
recharge_summary = self._extract_recharge_summary(site_id, start_date, end_date)
card_balances = self._extract_card_balances(site_id, end_date)
return {
'recharge_summary': recharge_summary,
'card_balances': card_balances,
'start_date': start_date,
'end_date': end_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
recharge_summary = extracted['recharge_summary']
card_balances = extracted['card_balances']
site_id = extracted['site_id']
results = []
for recharge in recharge_summary:
stat_date = recharge.get('stat_date')
# 仅有当前快照时统一写入避免窗口内其他日期为0
balance = card_balances
record = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'stat_date': stat_date,
'recharge_count': self.safe_int(recharge.get('recharge_count', 0)),
'recharge_total': self.safe_decimal(recharge.get('recharge_total', 0)),
'recharge_cash': self.safe_decimal(recharge.get('recharge_cash', 0)),
'recharge_gift': self.safe_decimal(recharge.get('recharge_gift', 0)),
'first_recharge_count': self.safe_int(recharge.get('first_recharge_count', 0)),
'first_recharge_cash': self.safe_decimal(recharge.get('first_recharge_cash', 0)),
'first_recharge_gift': self.safe_decimal(recharge.get('first_recharge_gift', 0)),
'first_recharge_total': self.safe_decimal(recharge.get('first_recharge_total', 0)),
'renewal_count': self.safe_int(recharge.get('renewal_count', 0)),
'renewal_cash': self.safe_decimal(recharge.get('renewal_cash', 0)),
'renewal_gift': self.safe_decimal(recharge.get('renewal_gift', 0)),
'renewal_total': self.safe_decimal(recharge.get('renewal_total', 0)),
'recharge_member_count': self.safe_int(recharge.get('recharge_member_count', 0)),
'new_member_count': self.safe_int(recharge.get('new_member_count', 0)),
'total_card_balance': self.safe_decimal(balance.get('total_balance', 0)),
'cash_card_balance': self.safe_decimal(balance.get('cash_balance', 0)),
'gift_card_balance': self.safe_decimal(balance.get('gift_balance', 0)),
}
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
if not transformed:
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
deleted = self.delete_existing_data(context, date_col="stat_date")
inserted = self.bulk_insert(transformed)
return {
"counts": {"fetched": len(transformed), "inserted": inserted, "updated": 0, "skipped": 0, "errors": 0},
"extra": {"deleted": deleted}
}
def _extract_recharge_summary(self, site_id: int, start_date: date, end_date: date) -> List[Dict[str, Any]]:
sql = """
SELECT
DATE(pay_time) AS stat_date,
COUNT(*) AS recharge_count,
SUM(pay_money + gift_money) AS recharge_total,
SUM(pay_money) AS recharge_cash,
SUM(gift_money) AS recharge_gift,
COUNT(CASE WHEN is_first = 1 THEN 1 END) AS first_recharge_count,
SUM(CASE WHEN is_first = 1 THEN pay_money ELSE 0 END) AS first_recharge_cash,
SUM(CASE WHEN is_first = 1 THEN gift_money ELSE 0 END) AS first_recharge_gift,
SUM(CASE WHEN is_first = 1 THEN pay_money + gift_money ELSE 0 END) AS first_recharge_total,
COUNT(CASE WHEN is_first != 1 OR is_first IS NULL THEN 1 END) AS renewal_count,
SUM(CASE WHEN is_first != 1 OR is_first IS NULL THEN pay_money ELSE 0 END) AS renewal_cash,
SUM(CASE WHEN is_first != 1 OR is_first IS NULL THEN gift_money ELSE 0 END) AS renewal_gift,
SUM(CASE WHEN is_first != 1 OR is_first IS NULL THEN pay_money + gift_money ELSE 0 END) AS renewal_total,
COUNT(DISTINCT member_id) AS recharge_member_count,
COUNT(DISTINCT CASE WHEN is_first = 1 THEN member_id END) AS new_member_count
FROM billiards_dwd.dwd_recharge_order
WHERE site_id = %s AND DATE(pay_time) >= %s AND DATE(pay_time) <= %s
GROUP BY DATE(pay_time)
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_card_balances(self, site_id: int, stat_date: date) -> Dict[str, Decimal]:
CASH_CARD_TYPE_ID = 2793249295533893
GIFT_CARD_TYPE_IDS = [2791990152417157, 2793266846533445, 2794699703437125]
sql = """
SELECT card_type_id, SUM(balance) AS total_balance
FROM billiards_dwd.dim_member_card_account
WHERE site_id = %s AND scd2_is_current = 1
AND COALESCE(is_delete, 0) = 0
GROUP BY card_type_id
"""
rows = self.db.query(sql, (site_id,))
cash_balance = Decimal('0')
gift_balance = Decimal('0')
for row in (rows or []):
card_type_id = row['card_type_id']
balance = self.safe_decimal(row['total_balance'])
if card_type_id == CASH_CARD_TYPE_ID:
cash_balance += balance
elif card_type_id in GIFT_CARD_TYPE_IDS:
gift_balance += balance
return {
'cash_balance': cash_balance,
'gift_balance': gift_balance,
'total_balance': cash_balance + gift_balance
}
__all__ = ['FinanceRechargeTask']

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 移除 RecallIndexTask/IntimacyIndexTask 导出,仅保留 WBI/NCI/ML/Relation
"""
指数算法任务模块
包含:
- WinbackIndexTask: 老客挽回指数 (WBI)
- NewconvIndexTask: 新客转化指数 (NCI)
- MlManualImportTask: ML 人工台账导入任务
- RelationIndexTask: 关系指数计算任务RS/OS/MS/ML
"""
from .winback_index_task import WinbackIndexTask
from .newconv_index_task import NewconvIndexTask
from .ml_manual_import_task import MlManualImportTask
from .relation_index_task import RelationIndexTask
__all__ = [
'WinbackIndexTask',
'NewconvIndexTask',
'MlManualImportTask',
'RelationIndexTask',
]

View File

@@ -0,0 +1,572 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 更新 docstring移除 RECALL/INTIMACY 引用反映当前指数体系WBI/NCI/RS/OS/MS/ML
"""
指数算法任务基类
功能说明:
- 提供半衰期时间衰减函数
- 提供分位数计算和分位截断
- 提供0-10映射方法
- 提供算法参数加载
- 提供分位点历史记录用于EWMA平滑
算法原理:
1. 时间衰减函数半衰期模型decay(d; h) = exp(-ln(2) * d / h)
当 d=h 时权重衰减到 0.5,越近权重越大
2. 0-10映射流程
Raw Score → Winsorize(P5, P95) → [可选Log/asinh压缩] → MinMax(0, 10)
作者ETL团队
创建日期2026-02-03
"""
from __future__ import annotations
import math
from abc import abstractmethod
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from ..base_dws_task import BaseDwsTask, TaskContext
# =============================================================================
# 数据类定义
# =============================================================================
@dataclass
class IndexParameters:
"""指数算法参数数据类"""
params: Dict[str, float]
loaded_at: datetime
@dataclass
class PercentileHistory:
"""分位点历史记录"""
percentile_5: float
percentile_95: float
percentile_5_smoothed: float
percentile_95_smoothed: float
record_count: int
calc_time: datetime
# =============================================================================
# 指数任务基类
# =============================================================================
class BaseIndexTask(BaseDwsTask):
"""
指数算法任务基类
提供指数计算通用功能:
1. 半衰期时间衰减函数
2. 分位数计算与截断
3. 0-10归一化映射
4. 算法参数加载
5. 分位点历史管理EWMA平滑
"""
# 子类需要定义的指数类型
INDEX_TYPE: str = ""
# 参数缓存TTL
_index_params_ttl: int = 300
def __init__(self, config, db_connection, api_client, logger):
super().__init__(config, db_connection, api_client, logger)
# 参数缓存:按 index_type 隔离,避免单任务多指数串参
self._index_params_cache_by_type: Dict[str, IndexParameters] = {}
# 默认参数
DEFAULT_LOOKBACK_DAYS = 60
DEFAULT_PERCENTILE_LOWER = 5
DEFAULT_PERCENTILE_UPPER = 95
DEFAULT_EWMA_ALPHA = 0.2
# ==========================================================================
# 抽象方法(子类需实现)
# ==========================================================================
@abstractmethod
def get_index_type(self) -> str:
"""获取指数类型(如 WBI/NCI/RS/OS/MS/ML"""
raise NotImplementedError
# ==========================================================================
# 时间衰减函数
# ==========================================================================
def decay(self, days: float, halflife: float) -> float:
"""
半衰期衰减函数
公式: decay(d; h) = exp(-ln(2) * d / h)
解释:当 d=h 时权重衰减到 0.5;越近权重越大,符合"近期更重要"的直觉
Args:
days: 事件距今天数 (d >= 0)
halflife: 半衰期 (h > 0),单位:天
Returns:
衰减后的权重,范围 (0, 1]
Examples:
>>> decay(0, 7) # 今天,权重=1.0
1.0
>>> decay(7, 7) # 7天前半衰期=7权重=0.5
0.5
>>> decay(14, 7) # 14天前权重=0.25
0.25
"""
if halflife <= 0:
raise ValueError("半衰期必须大于0")
if days < 0:
days = 0
return math.exp(-math.log(2) * days / halflife)
# ==========================================================================
# 分位数计算
# ==========================================================================
def calculate_percentiles(
self,
scores: List[float],
lower: int = 5,
upper: int = 95
) -> Tuple[float, float]:
"""
计算分位点
Args:
scores: 分数列表
lower: 下分位点百分比默认5
upper: 上分位点百分比默认95
Returns:
(下分位值, 上分位值) 元组
"""
if not scores:
return 0.0, 0.0
sorted_scores = sorted(scores)
n = len(sorted_scores)
# 计算分位点索引
lower_idx = max(0, int(n * lower / 100) - 1)
upper_idx = min(n - 1, int(n * upper / 100))
return sorted_scores[lower_idx], sorted_scores[upper_idx]
def winsorize(self, value: float, lower: float, upper: float) -> float:
"""
分位截断Winsorize
将值限制在 [lower, upper] 范围内
Args:
value: 原始值
lower: 下限P5分位
upper: 上限P95分位
Returns:
截断后的值
"""
return min(max(value, lower), upper)
# ==========================================================================
# 0-10映射
# ==========================================================================
def normalize_to_display(
self,
value: float,
min_val: float,
max_val: float,
use_log: bool = False,
compression: Optional[str] = None,
epsilon: float = 1e-6
) -> float:
"""
归一化到0-10分
映射流程:
1. [可选] 压缩y = ln(1 + x) / asinh(x)
2. MinMax映射score = 10 * (y - min) / (max - min)
Args:
value: 原始值已Winsorize
min_val: 最小值通常为P5
max_val: 最大值通常为P95
use_log: 是否使用log1p压缩兼容历史参数
compression: 压缩方式none/log1p/asinh优先级高于use_log
epsilon: 防除零小量
Returns:
0-10范围的分数
"""
compression_mode = self._resolve_compression(compression, use_log)
if compression_mode == "log1p":
value = math.log1p(value)
min_val = math.log1p(min_val)
max_val = math.log1p(max_val)
elif compression_mode == "asinh":
value = math.asinh(value)
min_val = math.asinh(min_val)
max_val = math.asinh(max_val)
# 防止分母为0
range_val = max_val - min_val
if range_val < epsilon:
return 5.0 # 几乎全员相同时返回中间值
score = 10.0 * (value - min_val) / range_val
# 确保在0-10范围内
return max(0.0, min(10.0, score))
def batch_normalize_to_display(
self,
raw_scores: List[Tuple[Any, float]], # [(entity_id, raw_score), ...]
use_log: bool = False,
compression: Optional[str] = None,
percentile_lower: int = 5,
percentile_upper: int = 95,
use_smoothing: bool = False,
site_id: Optional[int] = None,
index_type: Optional[str] = None,
) -> List[Tuple[Any, float, float]]:
"""
批量归一化Raw Score到Display Score
流程:
1. 提取所有raw_score
2. 计算分位点可选EWMA平滑
3. Winsorize截断
4. MinMax映射到0-10
Args:
raw_scores: (entity_id, raw_score) 元组列表
use_log: 是否使用log1p压缩兼容历史参数
compression: 压缩方式none/log1p/asinh优先级高于use_log
percentile_lower: 下分位百分比
percentile_upper: 上分位百分比
use_smoothing: 是否使用EWMA平滑分位点
site_id: 门店ID平滑时需要
index_type: 指数类型(平滑时用于分位历史隔离)
Returns:
(entity_id, raw_score, display_score) 元组列表
"""
if not raw_scores:
return []
# 提取raw_score
scores = [s for _, s in raw_scores]
# 计算分位点
q_l, q_u = self.calculate_percentiles(scores, percentile_lower, percentile_upper)
# EWMA平滑
if use_smoothing and site_id is not None:
q_l, q_u = self._apply_ewma_smoothing(
site_id=site_id,
current_p5=q_l,
current_p95=q_u,
index_type=index_type,
)
# 映射
results = []
compression_mode = self._resolve_compression(compression, use_log)
for entity_id, raw_score in raw_scores:
clipped = self.winsorize(raw_score, q_l, q_u)
display = self.normalize_to_display(
clipped,
q_l,
q_u,
compression=compression_mode,
)
results.append((entity_id, raw_score, round(display, 2)))
return results
# ==========================================================================
# 算法参数加载
# ==========================================================================
def load_index_parameters(
self,
index_type: Optional[str] = None,
force_reload: bool = False
) -> Dict[str, float]:
"""
加载指数算法参数
Args:
index_type: 指数类型默认使用子类定义的INDEX_TYPE
force_reload: 是否强制重新加载
Returns:
参数名到参数值的字典
"""
if index_type is None:
index_type = self.get_index_type()
now = datetime.now(self.tz)
cache_key = str(index_type).upper()
cache_item = self._index_params_cache_by_type.get(cache_key)
# 检查缓存
if (
not force_reload
and cache_item is not None
and (now - cache_item.loaded_at).total_seconds() < self._index_params_ttl
):
return cache_item.params
self.logger.debug("加载指数算法参数: %s", index_type)
sql = """
SELECT param_name, param_value
FROM billiards_dws.cfg_index_parameters
WHERE index_type = %s
AND effective_from <= CURRENT_DATE
AND (effective_to IS NULL OR effective_to >= CURRENT_DATE)
ORDER BY effective_from DESC
"""
rows = self.db.query(sql, (index_type,))
params = {}
seen = set()
for row in (rows or []):
row_dict = dict(row)
name = row_dict['param_name']
if name not in seen:
params[name] = float(row_dict['param_value'])
seen.add(name)
self._index_params_cache_by_type[cache_key] = IndexParameters(
params=params,
loaded_at=now
)
return params
def get_param(
self,
name: str,
default: float = 0.0,
index_type: Optional[str] = None,
) -> float:
"""
获取单个参数值
Args:
name: 参数名
default: 默认值
Returns:
参数值
"""
params = self.load_index_parameters(index_type=index_type)
return params.get(name, default)
# ==========================================================================
# 分位点历史管理EWMA平滑
# ==========================================================================
def get_last_percentile_history(
self,
site_id: int,
index_type: Optional[str] = None
) -> Optional[PercentileHistory]:
"""
获取最近一次分位点历史
Args:
site_id: 门店ID
index_type: 指数类型
Returns:
PercentileHistory 或 None
"""
if index_type is None:
index_type = self.get_index_type()
sql = """
SELECT
percentile_5, percentile_95,
percentile_5_smoothed, percentile_95_smoothed,
record_count, calc_time
FROM billiards_dws.dws_index_percentile_history
WHERE site_id = %s AND index_type = %s
ORDER BY calc_time DESC
LIMIT 1
"""
rows = self.db.query(sql, (site_id, index_type))
if not rows:
return None
row = dict(rows[0])
return PercentileHistory(
percentile_5=float(row['percentile_5'] or 0),
percentile_95=float(row['percentile_95'] or 0),
percentile_5_smoothed=float(row['percentile_5_smoothed'] or 0),
percentile_95_smoothed=float(row['percentile_95_smoothed'] or 0),
record_count=int(row['record_count'] or 0),
calc_time=row['calc_time']
)
def save_percentile_history(
self,
site_id: int,
percentile_5: float,
percentile_95: float,
percentile_5_smoothed: float,
percentile_95_smoothed: float,
record_count: int,
min_raw: float,
max_raw: float,
avg_raw: float,
index_type: Optional[str] = None
) -> None:
"""
保存分位点历史
Args:
site_id: 门店ID
percentile_5: 原始5分位
percentile_95: 原始95分位
percentile_5_smoothed: 平滑后5分位
percentile_95_smoothed: 平滑后95分位
record_count: 记录数
min_raw: 最小Raw Score
max_raw: 最大Raw Score
avg_raw: 平均Raw Score
index_type: 指数类型
"""
if index_type is None:
index_type = self.get_index_type()
sql = """
INSERT INTO billiards_dws.dws_index_percentile_history (
site_id, index_type, calc_time,
percentile_5, percentile_95,
percentile_5_smoothed, percentile_95_smoothed,
record_count, min_raw_score, max_raw_score, avg_raw_score
) VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, %s, %s)
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (
site_id, index_type,
percentile_5, percentile_95,
percentile_5_smoothed, percentile_95_smoothed,
record_count, min_raw, max_raw, avg_raw
))
self.db.conn.commit()
def _apply_ewma_smoothing(
self,
site_id: int,
current_p5: float,
current_p95: float,
alpha: Optional[float] = None,
index_type: Optional[str] = None,
) -> Tuple[float, float]:
"""
应用EWMA平滑到分位点
公式: Q_t = (1 - α) * Q_{t-1} + α * Q_now
Args:
site_id: 门店ID
current_p5: 当前5分位
current_p95: 当前95分位
alpha: 平滑系数默认0.2
index_type: 指数类型(用于参数和历史隔离)
Returns:
(平滑后的P5, 平滑后的P95)
"""
if index_type is None:
index_type = self.get_index_type()
if alpha is None:
alpha = self.get_param(
'ewma_alpha',
self.DEFAULT_EWMA_ALPHA,
index_type=index_type,
)
history = self.get_last_percentile_history(site_id, index_type=index_type)
if history is None:
# 首次计算,不平滑
return current_p5, current_p95
smoothed_p5 = (1 - alpha) * history.percentile_5_smoothed + alpha * current_p5
smoothed_p95 = (1 - alpha) * history.percentile_95_smoothed + alpha * current_p95
return smoothed_p5, smoothed_p95
# ==========================================================================
# 统计工具方法
# ==========================================================================
def calculate_median(self, values: List[float]) -> float:
"""计算中位数"""
if not values:
return 0.0
sorted_vals = sorted(values)
n = len(sorted_vals)
mid = n // 2
if n % 2 == 0:
return (sorted_vals[mid - 1] + sorted_vals[mid]) / 2
return sorted_vals[mid]
def calculate_mad(self, values: List[float]) -> float:
"""
计算MAD中位绝对偏差
MAD = median(|x - median(x)|)
MAD是比标准差更稳健的离散度度量不受极端值影响
"""
if not values:
return 0.0
median_val = self.calculate_median(values)
deviations = [abs(v - median_val) for v in values]
return self.calculate_median(deviations)
def safe_log(self, value: float, default: float = 0.0) -> float:
"""安全的对数运算"""
if value <= 0:
return default
return math.log(value)
def safe_ln1p(self, value: float) -> float:
"""安全的ln(1+x)运算"""
if value < -1:
return 0.0
return math.log1p(value)
def _resolve_compression(self, compression: Optional[str], use_log: bool) -> str:
"""规范化压缩方式"""
if compression is None:
return "log1p" if use_log else "none"
compression_key = str(compression).strip().lower()
if compression_key in ("none", "log1p", "asinh"):
return compression_key
if hasattr(self, "logger"):
self.logger.warning("未知压缩方式: %s,已降级为 none", compression)
return "none"

View File

@@ -0,0 +1,461 @@
# -*- coding: utf-8 -*-
"""
会员层召回/转化指数共享逻辑
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import TaskContext
@dataclass
class MemberActivityData:
"""Shared member activity features for WBI/NCI."""
member_id: int
site_id: int
tenant_id: int
member_create_time: Optional[datetime] = None
first_visit_time: Optional[datetime] = None
last_visit_time: Optional[datetime] = None
last_recharge_time: Optional[datetime] = None
t_v: float = 60.0
t_r: float = 60.0
t_a: float = 60.0
days_since_first_visit: Optional[int] = None
days_since_last_visit: Optional[int] = None
days_since_last_recharge: Optional[int] = None
visits_14d: int = 0
visits_60d: int = 0
visits_total: int = 0
spend_30d: float = 0.0
spend_180d: float = 0.0
sv_balance: float = 0.0
recharge_60d_amt: float = 0.0
interval_count: int = 0
intervals: List[float] = field(default_factory=list)
interval_ages_days: List[int] = field(default_factory=list)
recharge_unconsumed: int = 0
class MemberIndexBaseTask(BaseIndexTask):
"""Shared extraction and feature building for WBI/NCI."""
DEFAULT_VISIT_LOOKBACK_DAYS = 180
DEFAULT_RECENCY_LOOKBACK_DAYS = 60
CASH_CARD_TYPE_ID = 2793249295533893
def _get_site_id(self, context: Optional[TaskContext]) -> int:
"""获取门店ID"""
if context and hasattr(context, 'store_id') and context.store_id:
return context.store_id
site_id = self.config.get('app.default_site_id') or self.config.get('app.store_id')
if site_id is not None:
return int(site_id)
sql = "SELECT DISTINCT site_id FROM billiards_dwd.dwd_settlement_head WHERE site_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
value = dict(rows[0]).get('site_id')
if value is not None:
return int(value)
self.logger.warning("无法确定门店ID使用 0 继续执行")
return 0
def _get_tenant_id(self) -> int:
"""获取租户ID"""
tenant_id = self.config.get('app.tenant_id')
if tenant_id is not None:
return int(tenant_id)
sql = "SELECT DISTINCT tenant_id FROM billiards_dwd.dwd_settlement_head WHERE tenant_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
value = dict(rows[0]).get('tenant_id')
if value is not None:
return int(value)
self.logger.warning("无法确定租户ID使用 0 继续执行")
return 0
def _load_params(self) -> Dict[str, float]:
"""Load index parameters with defaults and runtime overrides."""
params = self.load_index_parameters()
result = dict(self.DEFAULT_PARAMS)
result.update(params)
# GUI/环境变量可通过 run.index_lookback_days 覆盖 recency 窗口
override_days = self.config.get('run.index_lookback_days')
if override_days is not None:
try:
override_days_int = int(override_days)
if override_days_int < 7 or override_days_int > 180:
self.logger.warning(
"%s: run.index_lookback_days=%s 超出建议范围[7,180],已自动截断",
self.get_task_code(),
override_days,
)
override_days_int = max(7, min(180, override_days_int))
result['lookback_days_recency'] = float(override_days_int)
self.logger.info(
"%s: 使用回溯天数覆盖 lookback_days_recency=%d",
self.get_task_code(),
override_days_int,
)
except (TypeError, ValueError):
self.logger.warning(
"%s: run.index_lookback_days=%s is invalid; ignore override and use parameter table value",
self.get_task_code(),
override_days,
)
return result
def _build_visit_condition_sql(self) -> str:
"""Build visit-scope condition SQL."""
return """
(
s.settle_type = 1
OR (
s.settle_type = 3
AND EXISTS (
SELECT 1
FROM billiards_dwd.dwd_assistant_service_log asl
JOIN billiards_dws.cfg_skill_type st
ON asl.skill_id = st.skill_id
AND st.course_type_code = 'BONUS'
AND st.is_active = TRUE
WHERE asl.order_settle_id = s.order_settle_id
AND asl.site_id = s.site_id
AND asl.tenant_member_id = s.member_id
AND asl.is_delete = 0
)
)
)
"""
def _extract_visit_day_rows(
self,
site_id: int,
start_date: date,
end_date: date,
) -> List[Dict[str, Any]]:
"""提取到店记录(按天去重)"""
condition_sql = self._build_visit_condition_sql()
sql = f"""
WITH visit_source AS (
SELECT
COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) AS canonical_member_id,
s.pay_time,
s.pay_amount
FROM billiards_dwd.dwd_settlement_head s
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON s.member_card_account_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = s.site_id
AND COALESCE(mca.is_delete, 0) = 0
WHERE s.site_id = %s
AND s.pay_time >= %s
AND s.pay_time < %s + INTERVAL '1 day'
AND {condition_sql}
)
SELECT
canonical_member_id AS member_id,
DATE(pay_time) AS visit_date,
MAX(pay_time) AS last_visit_time,
SUM(COALESCE(pay_amount, 0)) AS day_pay_amount
FROM visit_source
WHERE canonical_member_id > 0
GROUP BY canonical_member_id, DATE(pay_time)
ORDER BY canonical_member_id, visit_date
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in (rows or [])]
def _extract_recharge_rows(
self,
site_id: int,
start_date: date,
end_date: date,
) -> Dict[int, Dict[str, Any]]:
"""提取充值记录近60天"""
sql = """
WITH recharge_source AS (
SELECT
COALESCE(NULLIF(r.member_id, 0), mca.tenant_member_id) AS canonical_member_id,
r.pay_time,
r.pay_amount
FROM billiards_dwd.dwd_recharge_order r
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON r.tenant_member_card_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = r.site_id
AND COALESCE(mca.is_delete, 0) = 0
WHERE r.site_id = %s
AND r.settle_type = 5
AND r.pay_time >= %s
AND r.pay_time < %s + INTERVAL '1 day'
)
SELECT
canonical_member_id AS member_id,
MAX(pay_time) AS last_recharge_time,
SUM(COALESCE(pay_amount, 0)) AS recharge_60d_amt
FROM recharge_source
WHERE canonical_member_id > 0
GROUP BY canonical_member_id
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
result: Dict[int, Dict[str, Any]] = {}
for row in (rows or []):
row_dict = dict(row)
result[int(row_dict['member_id'])] = row_dict
return result
def _extract_member_create_times(self, member_ids: List[int]) -> Dict[int, datetime]:
"""提取会员建档时间"""
if not member_ids:
return {}
member_ids_str = ','.join(str(m) for m in member_ids)
sql = f"""
SELECT
member_id,
create_time
FROM billiards_dwd.dim_member
WHERE member_id IN ({member_ids_str})
AND scd2_is_current = 1
"""
rows = self.db.query(sql)
result = {}
for row in (rows or []):
row_dict = dict(row)
member_id = int(row_dict['member_id'])
create_time = row_dict.get('create_time')
if create_time:
result[member_id] = create_time
return result
def _extract_first_visit_times(self, site_id: int, member_ids: List[int]) -> Dict[int, datetime]:
"""提取首次到店时间(全量)"""
if not member_ids:
return {}
member_ids_str = ','.join(str(m) for m in member_ids)
condition_sql = self._build_visit_condition_sql()
sql = f"""
WITH visit_source AS (
SELECT
COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) AS canonical_member_id,
s.pay_time
FROM billiards_dwd.dwd_settlement_head s
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON s.member_card_account_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = s.site_id
AND COALESCE(mca.is_delete, 0) = 0
WHERE s.site_id = %s
AND {condition_sql}
)
SELECT
canonical_member_id AS member_id,
MIN(pay_time) AS first_visit_time
FROM visit_source
WHERE canonical_member_id IN ({member_ids_str})
GROUP BY canonical_member_id
"""
rows = self.db.query(sql, (site_id,))
result = {}
for row in (rows or []):
row_dict = dict(row)
member_id = int(row_dict['member_id'])
first_visit_time = row_dict.get('first_visit_time')
if first_visit_time:
result[member_id] = first_visit_time
return result
def _extract_sv_balances(self, site_id: int, tenant_id: int, member_ids: List[int]) -> Dict[int, Decimal]:
"""Fetch member stored-value card balances."""
if not member_ids:
return {}
member_ids_str = ','.join(str(m) for m in member_ids)
sql = f"""
SELECT
tenant_member_id AS member_id,
SUM(CASE WHEN card_type_id = %s THEN balance ELSE 0 END) AS sv_balance
FROM billiards_dwd.dim_member_card_account
WHERE tenant_id = %s
AND register_site_id = %s
AND scd2_is_current = 1
AND COALESCE(is_delete, 0) = 0
AND tenant_member_id IN ({member_ids_str})
GROUP BY tenant_member_id
"""
rows = self.db.query(sql, (self.CASH_CARD_TYPE_ID, tenant_id, site_id))
result: Dict[int, Decimal] = {}
for row in (rows or []):
row_dict = dict(row)
member_id = int(row_dict['member_id'])
result[member_id] = row_dict.get('sv_balance') or Decimal('0')
return result
def _build_member_activity(
self,
site_id: int,
tenant_id: int,
params: Dict[str, float],
) -> Dict[int, MemberActivityData]:
"""构建会员活动特征"""
now = datetime.now(self.tz)
base_date = now.date()
visit_lookback_days = int(params.get('visit_lookback_days', self.DEFAULT_VISIT_LOOKBACK_DAYS))
recency_days = int(params.get('lookback_days_recency', self.DEFAULT_RECENCY_LOOKBACK_DAYS))
visit_start_date = base_date - timedelta(days=visit_lookback_days)
visit_rows = self._extract_visit_day_rows(site_id, visit_start_date, base_date)
member_day_rows: Dict[int, List[Dict[str, Any]]] = {}
for row in (visit_rows or []):
member_id = int(row['member_id'])
member_day_rows.setdefault(member_id, []).append(row)
recharge_start_date = base_date - timedelta(days=recency_days)
recharge_rows = self._extract_recharge_rows(site_id, recharge_start_date, base_date)
member_ids = set(member_day_rows.keys()) | set(recharge_rows.keys())
if not member_ids:
return {}
member_id_list = list(member_ids)
member_create_times = self._extract_member_create_times(member_id_list)
first_visit_times = self._extract_first_visit_times(site_id, member_id_list)
sv_balances = self._extract_sv_balances(site_id, tenant_id, member_id_list)
results: Dict[int, MemberActivityData] = {}
for member_id in member_ids:
data = MemberActivityData(
member_id=member_id,
site_id=site_id,
tenant_id=tenant_id,
)
day_rows = member_day_rows.get(member_id, [])
if day_rows:
day_rows_sorted = sorted(day_rows, key=lambda x: x['visit_date'])
data.visits_total = len(day_rows_sorted)
last_visit_time = max(r.get('last_visit_time') for r in day_rows_sorted)
data.last_visit_time = last_visit_time
# 近14/60天到店次数
days_14_ago = base_date - timedelta(days=14)
days_60_ago = base_date - timedelta(days=60)
for r in day_rows_sorted:
visit_date = r.get('visit_date')
if visit_date is None:
continue
if visit_date >= days_14_ago:
data.visits_14d += 1
if visit_date >= days_60_ago:
data.visits_60d += 1
# 消费金额
days_30_ago = base_date - timedelta(days=30)
for r in day_rows_sorted:
visit_date = r.get('visit_date')
day_pay = float(r.get('day_pay_amount') or 0)
data.spend_180d += day_pay
if visit_date and visit_date >= days_30_ago:
data.spend_30d += day_pay
# 计算到店间隔(按天)
visit_dates = [r.get('visit_date') for r in day_rows_sorted if r.get('visit_date')]
intervals: List[float] = []
interval_ages_days: List[int] = []
for i in range(1, len(visit_dates)):
interval = (visit_dates[i] - visit_dates[i - 1]).days
intervals.append(float(min(recency_days, interval)))
interval_ages_days.append(max(0, (base_date - visit_dates[i]).days))
data.intervals = intervals
data.interval_ages_days = interval_ages_days
data.interval_count = len(intervals)
recharge_info = recharge_rows.get(member_id)
if recharge_info:
data.last_recharge_time = recharge_info.get('last_recharge_time')
data.recharge_60d_amt = float(recharge_info.get('recharge_60d_amt') or 0)
data.member_create_time = member_create_times.get(member_id)
data.first_visit_time = first_visit_times.get(member_id)
sv_balance = sv_balances.get(member_id)
if sv_balance is not None:
data.sv_balance = float(sv_balance)
# 时间差计算
if data.first_visit_time:
data.days_since_first_visit = (base_date - data.first_visit_time.date()).days
if data.last_visit_time:
data.days_since_last_visit = (base_date - data.last_visit_time.date()).days
if data.last_recharge_time:
data.days_since_last_recharge = (base_date - data.last_recharge_time.date()).days
# tV/tR/tA
data.t_v = float(min(recency_days, data.days_since_last_visit)) if data.days_since_last_visit is not None else float(recency_days)
data.t_r = float(min(recency_days, data.days_since_last_recharge)) if data.days_since_last_recharge is not None else float(recency_days)
data.t_a = float(min(data.t_v, data.t_r))
# 充值是否未回访
if data.last_recharge_time and (data.last_visit_time is None or data.last_recharge_time > data.last_visit_time):
data.recharge_unconsumed = 1
results[member_id] = data
return results
def classify_segment(
self,
data: MemberActivityData,
params: Dict[str, float],
) -> Tuple[str, str, bool]:
"""Classify member into NEW/OLD/STOP buckets."""
recency_days = int(params.get('lookback_days_recency', self.DEFAULT_RECENCY_LOOKBACK_DAYS))
enable_stop_exception = int(params.get('enable_stop_high_balance_exception', 0)) == 1
high_balance_threshold = float(params.get('high_balance_threshold', 1000))
if data.t_a >= recency_days:
if enable_stop_exception and data.sv_balance >= high_balance_threshold:
return "STOP", "STOP_HIGH_BALANCE", True
return "STOP", "STOP", False
new_visit_threshold = int(params.get('new_visit_threshold', 2))
new_days_threshold = int(params.get('new_days_threshold', 30))
recharge_recent_days = int(params.get('recharge_recent_days', 14))
new_recharge_max_visits = int(params.get('new_recharge_max_visits', 10))
is_new_by_visits = data.visits_total <= new_visit_threshold
is_new_by_first_visit = data.days_since_first_visit is not None and data.days_since_first_visit <= new_days_threshold
is_new_by_recharge = (
data.recharge_unconsumed == 1
and data.days_since_last_recharge is not None
and data.days_since_last_recharge <= recharge_recent_days
and data.visits_total <= new_recharge_max_visits
)
if is_new_by_visits or is_new_by_first_visit or is_new_by_recharge:
return "NEW", "NEW", True
return "OLD", "OLD", True

View File

@@ -0,0 +1,623 @@
# -*- coding: utf-8 -*-
"""
ML 人工台账导入任务。
设计目标:
1. 人工台账作为 ML 唯一真源;
2. 同一订单支持多助教归因,默认均分;
3. 覆盖策略:
- 近 30 天:按 site_id + biz_date 日覆盖;
- 超过 30 天按固定纪元2026-01-01切 30 天批次覆盖。
"""
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import TaskContext
@dataclass(frozen=True)
class ImportScope:
"""导入覆盖范围定义。"""
site_id: int
scope_type: str # DAY / P30
start_date: date
end_date: date
@property
def scope_key(self) -> str:
if self.scope_type == "DAY":
return f"DAY:{self.site_id}:{self.start_date.isoformat()}"
return (
f"P30:{self.site_id}:{self.start_date.isoformat()}:{self.end_date.isoformat()}"
)
class MlManualImportTask(BaseIndexTask):
"""导入并拆分 ML 人工台账(订单宽表 + 助教分摊窄表)。"""
INDEX_TYPE = "ML"
EPOCH_ANCHOR = date(2026, 1, 1)
HISTORICAL_BUCKET_DAYS = 30
ASSISTANT_SLOT_COUNT = 5
# Excel 模板字段(按列顺序)
TEMPLATE_COLUMNS = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"currency",
"assistant_id_1",
"assistant_name_1",
"assistant_id_2",
"assistant_name_2",
"assistant_id_3",
"assistant_name_3",
"assistant_id_4",
"assistant_name_4",
"assistant_id_5",
"assistant_name_5",
"remark",
]
def get_task_code(self) -> str:
return "DWS_ML_MANUAL_IMPORT"
def get_target_table(self) -> str:
return "dws_ml_manual_order_source"
def get_primary_keys(self) -> List[str]:
return ["site_id", "external_id", "import_scope_key", "row_no"]
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""
执行导入。
说明:该任务按“文件”运行,不依赖时间窗口。调度器会以工具任务方式直接触发。
"""
file_path = self._resolve_file_path()
if not file_path:
raise ValueError(
"未找到 ML 台账文件,请通过环境变量 ML_MANUAL_LEDGER_FILE 或配置 run.ml_manual_ledger_file 指定"
)
rows = self._read_excel_rows(file_path)
if not rows:
self.logger.warning("台账文件为空:%s", file_path)
return {
"status": "SUCCESS",
"counts": {
"source_rows": 0,
"alloc_rows": 0,
"deleted_source_rows": 0,
"deleted_alloc_rows": 0,
"scopes": 0,
},
}
now = datetime.now(self.tz)
today = now.date()
import_batch_no = self._build_import_batch_no(now)
import_file_name = Path(file_path).name
import_user = self._resolve_import_user()
source_rows: List[Dict[str, Any]] = []
alloc_rows: List[Dict[str, Any]] = []
scope_set: Dict[Tuple[int, str, date, date], ImportScope] = {}
for idx, raw in enumerate(rows, start=2):
normalized = self._normalize_row(raw, row_no=idx, file_path=file_path)
row_scope = self.resolve_scope(
site_id=normalized["site_id"],
biz_date=normalized["biz_date"],
today=today,
)
scope_set[(row_scope.site_id, row_scope.scope_type, row_scope.start_date, row_scope.end_date)] = row_scope
source_row = self._build_source_row(
normalized=normalized,
scope=row_scope,
import_batch_no=import_batch_no,
import_file_name=import_file_name,
import_user=import_user,
import_time=now,
)
source_rows.append(source_row)
alloc_rows.extend(
self._build_alloc_rows(
normalized=normalized,
scope=row_scope,
import_batch_no=import_batch_no,
import_file_name=import_file_name,
import_user=import_user,
import_time=now,
)
)
scopes = list(scope_set.values())
deleted_source_rows, deleted_alloc_rows = self._delete_by_scopes(scopes)
inserted_source = self._insert_source_rows(source_rows)
upserted_alloc = self._upsert_alloc_rows(alloc_rows)
self.db.conn.commit()
self.logger.info(
"ML 人工台账导入完成: file=%s source=%d alloc=%d scopes=%d",
file_path,
inserted_source,
upserted_alloc,
len(scopes),
)
return {
"status": "SUCCESS",
"counts": {
"source_rows": inserted_source,
"alloc_rows": upserted_alloc,
"deleted_source_rows": deleted_source_rows,
"deleted_alloc_rows": deleted_alloc_rows,
"scopes": len(scopes),
},
}
def _resolve_file_path(self) -> Optional[str]:
"""解析台账文件路径。"""
raw_path = (
self.config.get("run.ml_manual_ledger_file")
or self.config.get("run.ml_manual_file")
or os.getenv("ML_MANUAL_LEDGER_FILE")
)
if not raw_path:
return None
candidate = Path(str(raw_path)).expanduser()
if not candidate.is_absolute():
candidate = Path.cwd() / candidate
if not candidate.exists():
raise FileNotFoundError(f"台账文件不存在: {candidate}")
return str(candidate)
def _read_excel_rows(self, file_path: str) -> List[Dict[str, Any]]:
"""读取 Excel 为行字典列表。"""
try:
from openpyxl import load_workbook
except Exception as exc: # noqa: BLE001
raise RuntimeError(
"缺少 openpyxl 依赖,无法读取 Excel请先安装 openpyxl"
) from exc
wb = load_workbook(file_path, data_only=True)
ws = wb.active
header_row = next(ws.iter_rows(min_row=1, max_row=1, values_only=True), None)
if not header_row:
return []
headers = [str(col).strip() if col is not None else "" for col in header_row]
if not headers:
return []
rows: List[Dict[str, Any]] = []
for values in ws.iter_rows(min_row=2, values_only=True):
if values is None:
continue
row_dict = {headers[i]: values[i] for i in range(min(len(headers), len(values)))}
if self._is_empty_row(row_dict):
continue
rows.append(row_dict)
return rows
@staticmethod
def _is_empty_row(row: Dict[str, Any]) -> bool:
for value in row.values():
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
return False
return True
def _normalize_row(
self,
raw: Dict[str, Any],
row_no: int,
file_path: str,
) -> Dict[str, Any]:
"""规范化单行字段。"""
site_id = self._to_int(raw.get("site_id"), fallback=self.config.get("app.store_id"))
biz_date = self._to_date(raw.get("biz_date"))
pay_time = self._to_datetime(raw.get("pay_time"), fallback_date=biz_date)
external_id = str(raw.get("external_id") or "").strip()
if not external_id:
raise ValueError(f"台账行 {row_no} 缺少 external_id订单ID: {file_path}")
member_id = self._to_int(raw.get("member_id"), fallback=0)
order_amount = self._to_decimal(raw.get("order_amount"))
currency = str(raw.get("currency") or "CNY").strip().upper() or "CNY"
remark = str(raw.get("remark") or "").strip()
assistants: List[Tuple[int, str]] = []
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
aid = self._to_int(raw.get(f"assistant_id_{idx}"), fallback=None)
name = str(raw.get(f"assistant_name_{idx}") or "").strip()
if aid is None:
continue
assistants.append((aid, name))
return {
"site_id": site_id,
"biz_date": biz_date,
"external_id": external_id,
"member_id": member_id,
"pay_time": pay_time,
"order_amount": order_amount,
"currency": currency,
"assistants": assistants,
"remark": remark,
"row_no": row_no,
}
def _build_source_row(
self,
*,
normalized: Dict[str, Any],
scope: ImportScope,
import_batch_no: str,
import_file_name: str,
import_user: str,
import_time: datetime,
) -> Dict[str, Any]:
"""构造宽表入库行。"""
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
row = {
"site_id": normalized["site_id"],
"biz_date": normalized["biz_date"],
"external_id": normalized["external_id"],
"member_id": normalized["member_id"],
"pay_time": normalized["pay_time"],
"order_amount": normalized["order_amount"],
"currency": normalized["currency"],
"import_batch_no": import_batch_no,
"import_file_name": import_file_name,
"import_scope_key": scope.scope_key,
"import_time": import_time,
"import_user": import_user,
"row_no": normalized["row_no"],
"remark": normalized["remark"],
}
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
aid, aname = (assistants[idx - 1] if idx - 1 < len(assistants) else (None, None))
row[f"assistant_id_{idx}"] = aid
row[f"assistant_name_{idx}"] = aname
return row
def _build_alloc_rows(
self,
*,
normalized: Dict[str, Any],
scope: ImportScope,
import_batch_no: str,
import_file_name: str,
import_user: str,
import_time: datetime,
) -> List[Dict[str, Any]]:
"""构造窄表分摊行。"""
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
if not assistants:
return []
n = Decimal(str(len(assistants)))
share_ratio = Decimal("1") / n
rows: List[Dict[str, Any]] = []
for assistant_id, assistant_name in assistants:
allocated_amount = normalized["order_amount"] * share_ratio
rows.append(
{
"site_id": normalized["site_id"],
"biz_date": normalized["biz_date"],
"external_id": normalized["external_id"],
"member_id": normalized["member_id"],
"pay_time": normalized["pay_time"],
"order_amount": normalized["order_amount"],
"assistant_id": assistant_id,
"assistant_name": assistant_name,
"share_ratio": share_ratio,
"allocated_amount": allocated_amount,
"currency": normalized["currency"],
"import_scope_key": scope.scope_key,
"import_batch_no": import_batch_no,
"import_file_name": import_file_name,
"import_time": import_time,
"import_user": import_user,
}
)
return rows
@classmethod
def resolve_scope(cls, site_id: int, biz_date: date, today: date) -> ImportScope:
"""按规则解析覆盖范围。"""
day_diff = (today - biz_date).days
if day_diff <= cls.HISTORICAL_BUCKET_DAYS:
return ImportScope(
site_id=site_id,
scope_type="DAY",
start_date=biz_date,
end_date=biz_date,
)
bucket_start, bucket_end = cls.resolve_p30_bucket(biz_date)
return ImportScope(
site_id=site_id,
scope_type="P30",
start_date=bucket_start,
end_date=bucket_end,
)
@classmethod
def resolve_p30_bucket(cls, biz_date: date) -> Tuple[date, date]:
"""固定纪元 30 天分桶。"""
delta_days = (biz_date - cls.EPOCH_ANCHOR).days
bucket_index = delta_days // cls.HISTORICAL_BUCKET_DAYS
bucket_start = cls.EPOCH_ANCHOR + timedelta(days=bucket_index * cls.HISTORICAL_BUCKET_DAYS)
bucket_end = bucket_start + timedelta(days=cls.HISTORICAL_BUCKET_DAYS - 1)
return bucket_start, bucket_end
def _delete_by_scopes(self, scopes: Iterable[ImportScope]) -> Tuple[int, int]:
"""按 scope 先删后写,保证整批覆盖。"""
deleted_source = 0
deleted_alloc = 0
with self.db.conn.cursor() as cur:
for scope in scopes:
if scope.scope_type == "DAY":
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_source
WHERE site_id = %s AND biz_date = %s
""",
(scope.site_id, scope.start_date),
)
deleted_source += max(cur.rowcount, 0)
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s AND biz_date = %s
""",
(scope.site_id, scope.start_date),
)
deleted_alloc += max(cur.rowcount, 0)
else:
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_source
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
""",
(scope.site_id, scope.start_date, scope.end_date),
)
deleted_source += max(cur.rowcount, 0)
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
""",
(scope.site_id, scope.start_date, scope.end_date),
)
deleted_alloc += max(cur.rowcount, 0)
return deleted_source, deleted_alloc
def _insert_source_rows(self, rows: List[Dict[str, Any]]) -> int:
if not rows:
return 0
columns = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"currency",
"assistant_id_1",
"assistant_name_1",
"assistant_id_2",
"assistant_name_2",
"assistant_id_3",
"assistant_name_3",
"assistant_id_4",
"assistant_name_4",
"assistant_id_5",
"assistant_name_5",
"import_batch_no",
"import_file_name",
"import_scope_key",
"import_time",
"import_user",
"row_no",
"remark",
"created_at",
"updated_at",
]
sql = f"""
INSERT INTO billiards_dws.dws_ml_manual_order_source ({", ".join(columns)})
VALUES ({", ".join(["%s"] * len(columns))})
"""
inserted = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [
row.get("site_id"),
row.get("biz_date"),
row.get("external_id"),
row.get("member_id"),
row.get("pay_time"),
row.get("order_amount"),
row.get("currency"),
row.get("assistant_id_1"),
row.get("assistant_name_1"),
row.get("assistant_id_2"),
row.get("assistant_name_2"),
row.get("assistant_id_3"),
row.get("assistant_name_3"),
row.get("assistant_id_4"),
row.get("assistant_name_4"),
row.get("assistant_id_5"),
row.get("assistant_name_5"),
row.get("import_batch_no"),
row.get("import_file_name"),
row.get("import_scope_key"),
row.get("import_time"),
row.get("import_user"),
row.get("row_no"),
row.get("remark"),
row.get("import_time"),
row.get("import_time"),
]
cur.execute(sql, values)
inserted += max(cur.rowcount, 0)
return inserted
def _upsert_alloc_rows(self, rows: List[Dict[str, Any]]) -> int:
if not rows:
return 0
columns = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"assistant_id",
"assistant_name",
"share_ratio",
"allocated_amount",
"currency",
"import_scope_key",
"import_batch_no",
"import_file_name",
"import_time",
"import_user",
"created_at",
"updated_at",
]
sql = f"""
INSERT INTO billiards_dws.dws_ml_manual_order_alloc ({", ".join(columns)})
VALUES ({", ".join(["%s"] * len(columns))})
ON CONFLICT (site_id, external_id, assistant_id)
DO UPDATE SET
biz_date = EXCLUDED.biz_date,
member_id = EXCLUDED.member_id,
pay_time = EXCLUDED.pay_time,
order_amount = EXCLUDED.order_amount,
assistant_name = EXCLUDED.assistant_name,
share_ratio = EXCLUDED.share_ratio,
allocated_amount = EXCLUDED.allocated_amount,
currency = EXCLUDED.currency,
import_scope_key = EXCLUDED.import_scope_key,
import_batch_no = EXCLUDED.import_batch_no,
import_file_name = EXCLUDED.import_file_name,
import_time = EXCLUDED.import_time,
import_user = EXCLUDED.import_user,
updated_at = NOW()
"""
affected = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [
row.get("site_id"),
row.get("biz_date"),
row.get("external_id"),
row.get("member_id"),
row.get("pay_time"),
row.get("order_amount"),
row.get("assistant_id"),
row.get("assistant_name"),
row.get("share_ratio"),
row.get("allocated_amount"),
row.get("currency"),
row.get("import_scope_key"),
row.get("import_batch_no"),
row.get("import_file_name"),
row.get("import_time"),
row.get("import_user"),
row.get("import_time"),
row.get("import_time"),
]
cur.execute(sql, values)
affected += max(cur.rowcount, 0)
return affected
@staticmethod
def _to_int(value: Any, fallback: Optional[int] = None) -> Optional[int]:
if value is None:
return fallback
if isinstance(value, str) and not value.strip():
return fallback
try:
return int(value)
except Exception: # noqa: BLE001
return fallback
@staticmethod
def _to_decimal(value: Any) -> Decimal:
if value is None or value == "":
return Decimal("0")
return Decimal(str(value))
@staticmethod
def _to_date(value: Any) -> date:
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
if isinstance(value, str):
text = value.strip()
if not text:
raise ValueError("biz_date 不能为空")
if len(text) >= 10:
return datetime.fromisoformat(text[:10]).date()
return datetime.fromisoformat(text).date()
raise ValueError(f"无法解析 biz_date: {value}")
@staticmethod
def _to_datetime(value: Any, fallback_date: date) -> datetime:
if isinstance(value, datetime):
return value
if isinstance(value, date):
return datetime.combine(value, datetime.min.time())
if isinstance(value, str):
text = value.strip()
if text:
text = text.replace("/", "-")
try:
return datetime.fromisoformat(text)
except Exception: # noqa: BLE001
if len(text) >= 19:
return datetime.strptime(text[:19], "%Y-%m-%d %H:%M:%S")
return datetime.fromisoformat(text[:10])
return datetime.combine(fallback_date, datetime.min.time())
@staticmethod
def _build_import_batch_no(now: datetime) -> str:
return f"MLM_{now.strftime('%Y%m%d%H%M%S')}_{str(uuid.uuid4())[:8]}"
@staticmethod
def _resolve_import_user() -> str:
return (
os.getenv("ETL_OPERATOR")
or os.getenv("USERNAME")
or os.getenv("USER")
or "system"
)
__all__ = ["MlManualImportTask", "ImportScope"]

View File

@@ -0,0 +1,381 @@
# -*- coding: utf-8 -*-
"""
新客转化指数NCI计算任务。"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from .member_index_base import MemberActivityData, MemberIndexBaseTask
from ..base_dws_task import TaskContext
@dataclass
class MemberNewconvData:
activity: MemberActivityData
status: str
segment: str
need_new: float = 0.0
salvage_new: float = 0.0
recharge_new: float = 0.0
value_new: float = 0.0
welcome_new: float = 0.0
raw_score_welcome: Optional[float] = None
raw_score_convert: Optional[float] = None
raw_score: Optional[float] = None
display_score_welcome: Optional[float] = None
display_score_convert: Optional[float] = None
display_score: Optional[float] = None
class NewconvIndexTask(MemberIndexBaseTask):
"""新客转化指数NCI计算任务。"""
INDEX_TYPE = "NCI"
DEFAULT_PARAMS = {
# 通用参数
'lookback_days_recency': 60,
'visit_lookback_days': 180,
'percentile_lower': 5,
'percentile_upper': 95,
'compression_mode': 0,
'use_smoothing': 1,
'ewma_alpha': 0.2,
# 分流参数
'new_visit_threshold': 2,
'new_days_threshold': 30,
'recharge_recent_days': 14,
'new_recharge_max_visits': 10,
# NCI参数
'no_touch_days_new': 3,
't2_target_days': 7,
'salvage_start': 30,
'salvage_end': 60,
'welcome_window_days': 3,
'active_new_visit_threshold_14d': 2,
'active_new_recency_days': 7,
'active_new_penalty': 0.2,
'h_recharge': 7,
'amount_base_M0': 300,
'balance_base_B0': 500,
'value_w_spend': 1.0,
'value_w_bal': 0.8,
'w_welcome': 1.0,
'w_need': 1.6,
'w_re': 0.8,
'w_value': 1.0,
# STOP高余额例外默认关闭
'enable_stop_high_balance_exception': 0,
'high_balance_threshold': 1000,
}
def get_task_code(self) -> str:
return "DWS_NEWCONV_INDEX"
def get_target_table(self) -> str:
return "dws_member_newconv_index"
def get_primary_keys(self) -> List[str]:
return ['site_id', 'member_id']
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""执行 NCI 计算"""
self.logger.info("开始计算新客转化指数(NCI)")
site_id = self._get_site_id(context)
tenant_id = self._get_tenant_id()
params = self._load_params()
activity_map = self._build_member_activity(site_id, tenant_id, params)
if not activity_map:
self.logger.warning("No member activity data available; skip calculation")
return {'status': 'skipped', 'reason': 'no_data'}
newconv_list: List[MemberNewconvData] = []
for activity in activity_map.values():
segment, status, in_scope = self.classify_segment(activity, params)
if not in_scope:
continue
if segment != "NEW":
continue
data = MemberNewconvData(activity=activity, status=status, segment=segment)
self._calculate_nci_scores(data, params)
newconv_list.append(data)
if not newconv_list:
self.logger.warning("No new-member rows to calculate")
return {'status': 'skipped', 'reason': 'no_new_members'}
# 归一化 Display Score
raw_scores = [
(d.activity.member_id, d.raw_score)
for d in newconv_list
if d.raw_score is not None
]
if raw_scores:
use_smoothing = int(params.get('use_smoothing', 1)) == 1
total_score_map = self._normalize_score_pairs(
raw_scores,
params=params,
site_id=site_id,
use_smoothing=use_smoothing,
)
for data in newconv_list:
if data.activity.member_id in total_score_map:
data.display_score = total_score_map[data.activity.member_id]
raw_scores_welcome = [
(d.activity.member_id, d.raw_score_welcome)
for d in newconv_list
if d.raw_score_welcome is not None
]
welcome_score_map = self._normalize_score_pairs(
raw_scores_welcome,
params=params,
site_id=site_id,
use_smoothing=False,
)
for data in newconv_list:
if data.activity.member_id in welcome_score_map:
data.display_score_welcome = welcome_score_map[data.activity.member_id]
raw_scores_convert = [
(d.activity.member_id, d.raw_score_convert)
for d in newconv_list
if d.raw_score_convert is not None
]
convert_score_map = self._normalize_score_pairs(
raw_scores_convert,
params=params,
site_id=site_id,
use_smoothing=False,
)
for data in newconv_list:
if data.activity.member_id in convert_score_map:
data.display_score_convert = convert_score_map[data.activity.member_id]
# 保存分位点历史
all_raw = [float(score) for _, score in raw_scores]
q_l, q_u = self.calculate_percentiles(
all_raw,
int(params['percentile_lower']),
int(params['percentile_upper'])
)
if use_smoothing:
smoothed_l, smoothed_u = self._apply_ewma_smoothing(site_id, q_l, q_u)
else:
smoothed_l, smoothed_u = q_l, q_u
self.save_percentile_history(
site_id=site_id,
percentile_5=q_l,
percentile_95=q_u,
percentile_5_smoothed=smoothed_l,
percentile_95_smoothed=smoothed_u,
record_count=len(all_raw),
min_raw=min(all_raw),
max_raw=max(all_raw),
avg_raw=sum(all_raw) / len(all_raw)
)
inserted = self._save_newconv_data(newconv_list)
self.logger.info("NCI calculation finished, inserted %d rows", inserted)
return {
'status': 'success',
'member_count': len(newconv_list),
'records_inserted': inserted
}
def _calculate_nci_scores(self, data: MemberNewconvData, params: Dict[str, float]) -> None:
"""计算 NCI 分项与 Raw Score"""
activity = data.activity
# 1) 紧迫度
no_touch_days = float(params['no_touch_days_new'])
t2_target_days = float(params['t2_target_days'])
t2_max_days = t2_target_days * 2.0
if t2_max_days <= no_touch_days:
data.need_new = 0.0
else:
data.need_new = self._clip(
(activity.t_v - no_touch_days) / (t2_max_days - no_touch_days),
0.0, 1.0
)
# 2) Salvage30-60天线性衰减
salvage_start = float(params['salvage_start'])
salvage_end = float(params['salvage_end'])
if salvage_end <= salvage_start:
data.salvage_new = 0.0
elif activity.t_a <= salvage_start:
data.salvage_new = 1.0
elif activity.t_a >= salvage_end:
data.salvage_new = 0.0
else:
data.salvage_new = (salvage_end - activity.t_a) / (salvage_end - salvage_start)
# 3) 充值未回访压力
if activity.recharge_unconsumed == 1:
data.recharge_new = self.decay(activity.t_r, params['h_recharge'])
else:
data.recharge_new = 0.0
# 4) 价值分
m0 = float(params['amount_base_M0'])
b0 = float(params['balance_base_B0'])
spend_score = math.log1p(activity.spend_180d / m0) if m0 > 0 else 0.0
bal_score = math.log1p(activity.sv_balance / b0) if b0 > 0 else 0.0
data.value_new = float(params['value_w_spend']) * spend_score + float(params['value_w_bal']) * bal_score
# 5) 欢迎建联分:优先首访后立即触达
welcome_window_days = float(params.get('welcome_window_days', 3))
data.welcome_new = 0.0
if welcome_window_days > 0 and activity.visits_total <= 1 and activity.t_v <= welcome_window_days:
data.welcome_new = self._clip(1.0 - (activity.t_v / welcome_window_days), 0.0, 1.0)
# 6) 抑制高活跃新客在转化召回排名中的权重
active_visit_threshold = int(params.get('active_new_visit_threshold_14d', 2))
active_recency_days = float(params.get('active_new_recency_days', 7))
active_penalty = float(params.get('active_new_penalty', 0.2))
if activity.visits_14d >= active_visit_threshold and activity.t_v <= active_recency_days:
active_multiplier = self._clip(active_penalty, 0.0, 1.0)
else:
active_multiplier = 1.0
# 7) 价值/充值分主要在进入免打扰窗口后生效
if no_touch_days > 0:
touch_multiplier = self._clip(activity.t_v / no_touch_days, 0.0, 1.0)
else:
touch_multiplier = 1.0
data.raw_score_welcome = float(params.get('w_welcome', 1.0)) * data.welcome_new
data.raw_score_convert = active_multiplier * (
float(params['w_need']) * (data.need_new * data.salvage_new)
+ float(params['w_re']) * data.recharge_new * touch_multiplier
+ float(params['w_value']) * data.value_new * touch_multiplier
)
data.raw_score_welcome = max(0.0, data.raw_score_welcome)
data.raw_score_convert = max(0.0, data.raw_score_convert)
data.raw_score = data.raw_score_welcome + data.raw_score_convert
if data.raw_score < 0:
data.raw_score = 0.0
def _save_newconv_data(self, data_list: List[MemberNewconvData]) -> int:
"""保存 NCI 数据"""
if not data_list:
return 0
site_id = data_list[0].activity.site_id
# 按门店全量刷新,避免因分群变化导致过期数据残留。
delete_sql = """
DELETE FROM billiards_dws.dws_member_newconv_index
WHERE site_id = %s
"""
with self.db.conn.cursor() as cur:
cur.execute(delete_sql, (site_id,))
insert_sql = """
INSERT INTO billiards_dws.dws_member_newconv_index (
site_id, tenant_id, member_id,
status, segment,
member_create_time, first_visit_time, last_visit_time, last_recharge_time,
t_v, t_r, t_a,
visits_14d, visits_60d, visits_total,
spend_30d, spend_180d, sv_balance, recharge_60d_amt,
interval_count,
need_new, salvage_new, recharge_new, value_new,
welcome_new,
raw_score_welcome, raw_score_convert, raw_score,
display_score_welcome, display_score_convert, display_score,
last_wechat_touch_time,
calc_time, created_at, updated_at
) VALUES (
%s, %s, %s,
%s, %s,
%s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s,
%s, %s, %s, %s,
%s,
%s, %s, %s,
%s, %s, %s,
%s,
NOW(), NOW(), NOW()
)
"""
inserted = 0
with self.db.conn.cursor() as cur:
for data in data_list:
activity = data.activity
cur.execute(insert_sql, (
activity.site_id, activity.tenant_id, activity.member_id,
data.status, data.segment,
activity.member_create_time, activity.first_visit_time, activity.last_visit_time, activity.last_recharge_time,
activity.t_v, activity.t_r, activity.t_a,
activity.visits_14d, activity.visits_60d, activity.visits_total,
activity.spend_30d, activity.spend_180d, activity.sv_balance, activity.recharge_60d_amt,
activity.interval_count,
data.need_new, data.salvage_new, data.recharge_new, data.value_new,
data.welcome_new,
data.raw_score_welcome, data.raw_score_convert, data.raw_score,
data.display_score_welcome, data.display_score_convert, data.display_score,
None,
))
inserted += cur.rowcount
self.db.conn.commit()
return inserted
def _clip(self, value: float, low: float, high: float) -> float:
return max(low, min(high, value))
def _map_compression(self, params: Dict[str, float]) -> str:
mode = int(params.get('compression_mode', 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
def _normalize_score_pairs(
self,
raw_scores: List[tuple[int, Optional[float]]],
params: Dict[str, float],
site_id: int,
use_smoothing: bool,
) -> Dict[int, float]:
valid_scores = [(member_id, float(score)) for member_id, score in raw_scores if score is not None]
if not valid_scores:
return {}
# 全为0时直接返回避免 MinMax 归一化退化
if all(abs(score) <= 1e-9 for _, score in valid_scores):
return {member_id: 0.0 for member_id, _ in valid_scores}
compression = self._map_compression(params)
normalized = self.batch_normalize_to_display(
valid_scores,
compression=compression,
percentile_lower=int(params['percentile_lower']),
percentile_upper=int(params['percentile_upper']),
use_smoothing=use_smoothing,
site_id=site_id
)
return {member_id: display for member_id, _, display in normalized}
__all__ = ['NewconvIndexTask']

View File

@@ -0,0 +1,695 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 删除 _apply_last_touch_ml 方法及 source_mode/recharge_attribute_hours 参数;
# 更新 docstring 移除 last-touch 备用路径描述;
# Prompt: "ML 只用人工台账,删除所有 last-touch 备用路径"
"""
关系指数任务RS/OS/MS/ML
设计说明:
1. 单任务一次产出 RS / OS / MS / ML写入统一关系表
2. RS/MS 复用服务日志 + 会话合并口径;
3. ML 以人工台账窄表为唯一真源;
4. RS/MS/ML 的 display 映射按 index_type 隔离分位历史。
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import CourseType, TaskContext
@dataclass
class ServiceSession:
"""合并后的服务会话。"""
session_start: datetime
session_end: datetime
total_duration_minutes: int
course_weight: float
is_incentive: bool
@dataclass
class RelationPairMetrics:
"""单个 member-assistant 关系对的计算指标。"""
site_id: int
tenant_id: int
member_id: int
assistant_id: int
sessions: List[ServiceSession] = field(default_factory=list)
days_since_last_session: Optional[int] = None
session_count: int = 0
total_duration_minutes: int = 0
basic_session_count: int = 0
incentive_session_count: int = 0
rs_f: float = 0.0
rs_d: float = 0.0
rs_r: float = 0.0
rs_raw: float = 0.0
rs_display: float = 0.0
ms_f_short: float = 0.0
ms_f_long: float = 0.0
ms_raw: float = 0.0
ms_display: float = 0.0
ml_raw: float = 0.0
ml_display: float = 0.0
ml_order_count: int = 0
ml_allocated_amount: float = 0.0
os_share: float = 0.0
os_label: str = "POOL"
os_rank: Optional[int] = None
class RelationIndexTask(BaseIndexTask):
"""关系指数任务:单任务产出 RS / OS / MS / ML。"""
INDEX_TYPE = "RS"
DEFAULT_PARAMS_RS: Dict[str, float] = {
"lookback_days": 60,
"session_merge_hours": 4,
"incentive_weight": 1.5,
"halflife_session": 14.0,
"halflife_last": 10.0,
"weight_f": 1.0,
"weight_d": 0.7,
"gate_alpha": 0.6,
"percentile_lower": 5.0,
"percentile_upper": 95.0,
"compression_mode": 1.0,
"use_smoothing": 1.0,
"ewma_alpha": 0.2,
}
DEFAULT_PARAMS_OS: Dict[str, float] = {
"min_rs_raw_for_ownership": 0.05,
"min_total_rs_raw": 0.10,
"ownership_main_threshold": 0.60,
"ownership_comanage_threshold": 0.35,
"ownership_gap_threshold": 0.15,
"eps": 1e-6,
}
DEFAULT_PARAMS_MS: Dict[str, float] = {
"lookback_days": 60,
"session_merge_hours": 4,
"incentive_weight": 1.5,
"halflife_short": 7.0,
"halflife_long": 30.0,
"eps": 1e-6,
"percentile_lower": 5.0,
"percentile_upper": 95.0,
"compression_mode": 1.0,
"use_smoothing": 1.0,
"ewma_alpha": 0.2,
}
# CHANGE 2026-02-13 | intent: ML 仅使用人工台账,移除 source_mode / recharge_attribute_hours
DEFAULT_PARAMS_ML: Dict[str, float] = {
"lookback_days": 60,
"amount_base": 500.0,
"halflife_recharge": 21.0,
"percentile_lower": 5.0,
"percentile_upper": 95.0,
"compression_mode": 1.0,
"use_smoothing": 1.0,
"ewma_alpha": 0.2,
}
def get_task_code(self) -> str:
return "DWS_RELATION_INDEX"
def get_target_table(self) -> str:
return "dws_member_assistant_relation_index"
def get_primary_keys(self) -> List[str]:
return ["site_id", "member_id", "assistant_id"]
def get_index_type(self) -> str:
# 多指数任务保留一个默认 index_type调用处应显式传 RS/MS/ML
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
self.logger.info("开始计算关系指数RS/OS/MS/ML")
site_id = self._get_site_id(context)
tenant_id = self._get_tenant_id()
now = datetime.now(self.tz)
params_rs = self._load_params("RS", self.DEFAULT_PARAMS_RS)
params_os = self._load_params("OS", self.DEFAULT_PARAMS_OS)
params_ms = self._load_params("MS", self.DEFAULT_PARAMS_MS)
params_ml = self._load_params("ML", self.DEFAULT_PARAMS_ML)
service_lookback_days = max(
int(params_rs.get("lookback_days", 60)),
int(params_ms.get("lookback_days", 60)),
)
service_start = now - timedelta(days=service_lookback_days)
merge_hours = max(
int(params_rs.get("session_merge_hours", 4)),
int(params_ms.get("session_merge_hours", 4)),
)
raw_services = self._extract_service_records(site_id, service_start, now)
pair_map = self._group_and_merge_sessions(
raw_services=raw_services,
merge_hours=merge_hours,
incentive_weight=max(
float(params_rs.get("incentive_weight", 1.5)),
float(params_ms.get("incentive_weight", 1.5)),
),
now=now,
site_id=site_id,
tenant_id=tenant_id,
)
self.logger.info("服务关系对数量: %d", len(pair_map))
self._calculate_rs(pair_map, params_rs, now)
self._calculate_ms(pair_map, params_ms, now)
self._calculate_ml(pair_map, params_ml, site_id, now)
self._calculate_os(pair_map, params_os)
self._apply_display_scores(pair_map, params_rs, params_ms, params_ml, site_id)
inserted = self._save_relation_rows(site_id, list(pair_map.values()))
self.logger.info("关系指数计算完成,写入 %d 条记录", inserted)
return {
"status": "SUCCESS",
"records_inserted": inserted,
"pair_count": len(pair_map),
}
def _load_params(self, index_type: str, defaults: Dict[str, float]) -> Dict[str, float]:
params = dict(defaults)
params.update(self.load_index_parameters(index_type=index_type))
return params
def _extract_service_records(
self,
site_id: int,
start_datetime: datetime,
end_datetime: datetime,
) -> List[Dict[str, Any]]:
"""提取服务记录。"""
sql = """
SELECT
s.tenant_member_id AS member_id,
d.assistant_id AS assistant_id,
s.start_use_time AS start_time,
s.last_use_time AS end_time,
COALESCE(s.income_seconds, 0) / 60 AS duration_minutes,
s.skill_id
FROM billiards_dwd.dwd_assistant_service_log s
JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id
AND d.scd2_is_current = 1
AND COALESCE(d.is_delete, 0) = 0
WHERE s.site_id = %s
AND s.tenant_member_id > 0
AND s.user_id > 0
AND s.is_delete = 0
AND s.last_use_time >= %s
AND s.last_use_time < %s
ORDER BY s.tenant_member_id, d.assistant_id, s.start_use_time
"""
rows = self.db.query(sql, (site_id, start_datetime, end_datetime))
return [dict(row) for row in (rows or [])]
def _group_and_merge_sessions(
self,
*,
raw_services: List[Dict[str, Any]],
merge_hours: int,
incentive_weight: float,
now: datetime,
site_id: int,
tenant_id: int,
) -> Dict[Tuple[int, int], RelationPairMetrics]:
"""按 (member_id, assistant_id) 分组并合并会话。"""
result: Dict[Tuple[int, int], RelationPairMetrics] = {}
if not raw_services:
return result
merge_threshold = timedelta(hours=max(0, merge_hours))
grouped: Dict[Tuple[int, int], List[Dict[str, Any]]] = {}
for row in raw_services:
member_id = int(row["member_id"])
assistant_id = int(row["assistant_id"])
grouped.setdefault((member_id, assistant_id), []).append(row)
for (member_id, assistant_id), records in grouped.items():
metrics = RelationPairMetrics(
site_id=site_id,
tenant_id=tenant_id,
member_id=member_id,
assistant_id=assistant_id,
)
sorted_records = sorted(records, key=lambda r: r["start_time"])
current: Optional[ServiceSession] = None
for svc in sorted_records:
start_time = svc["start_time"]
end_time = svc["end_time"]
duration = int(svc.get("duration_minutes") or 0)
skill_id = int(svc.get("skill_id") or 0)
course_type = self.get_course_type(skill_id)
is_incentive = course_type == CourseType.BONUS
weight = incentive_weight if is_incentive else 1.0
if current is None:
current = ServiceSession(
session_start=start_time,
session_end=end_time,
total_duration_minutes=duration,
course_weight=weight,
is_incentive=is_incentive,
)
continue
if start_time - current.session_end <= merge_threshold:
current.session_end = max(current.session_end, end_time)
current.total_duration_minutes += duration
current.course_weight = max(current.course_weight, weight)
current.is_incentive = current.is_incentive or is_incentive
else:
metrics.sessions.append(current)
current = ServiceSession(
session_start=start_time,
session_end=end_time,
total_duration_minutes=duration,
course_weight=weight,
is_incentive=is_incentive,
)
if current is not None:
metrics.sessions.append(current)
metrics.session_count = len(metrics.sessions)
metrics.total_duration_minutes = sum(s.total_duration_minutes for s in metrics.sessions)
metrics.basic_session_count = sum(1 for s in metrics.sessions if not s.is_incentive)
metrics.incentive_session_count = sum(1 for s in metrics.sessions if s.is_incentive)
if metrics.sessions:
last_session = max(metrics.sessions, key=lambda s: s.session_end)
metrics.days_since_last_session = (now - last_session.session_end).days
result[(member_id, assistant_id)] = metrics
return result
def _calculate_rs(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
now: datetime,
) -> None:
lookback_days = int(params.get("lookback_days", 60))
halflife_session = float(params.get("halflife_session", 14.0))
halflife_last = float(params.get("halflife_last", 10.0))
weight_f = float(params.get("weight_f", 1.0))
weight_d = float(params.get("weight_d", 0.7))
gate_alpha = max(0.0, float(params.get("gate_alpha", 0.6)))
for metrics in pair_map.values():
f_score = 0.0
d_score = 0.0
for session in metrics.sessions:
days_ago = min(
lookback_days,
max(0.0, (now - session.session_end).total_seconds() / 86400.0),
)
decay_factor = self.decay(days_ago, halflife_session)
f_score += session.course_weight * decay_factor
d_score += (
math.sqrt(max(session.total_duration_minutes, 0) / 60.0)
* session.course_weight
* decay_factor
)
if metrics.days_since_last_session is None:
r_score = 0.0
else:
r_score = self.decay(min(lookback_days, metrics.days_since_last_session), halflife_last)
base = weight_f * f_score + weight_d * d_score
gate = math.pow(r_score, gate_alpha) if r_score > 0 else 0.0
metrics.rs_f = f_score
metrics.rs_d = d_score
metrics.rs_r = r_score
metrics.rs_raw = max(0.0, base * gate)
def _calculate_ms(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
now: datetime,
) -> None:
lookback_days = int(params.get("lookback_days", 60))
halflife_short = float(params.get("halflife_short", 7.0))
halflife_long = float(params.get("halflife_long", 30.0))
eps = float(params.get("eps", 1e-6))
for metrics in pair_map.values():
f_short = 0.0
f_long = 0.0
for session in metrics.sessions:
days_ago = min(
lookback_days,
max(0.0, (now - session.session_end).total_seconds() / 86400.0),
)
f_short += session.course_weight * self.decay(days_ago, halflife_short)
f_long += session.course_weight * self.decay(days_ago, halflife_long)
ratio = (f_short + eps) / (f_long + eps)
metrics.ms_f_short = f_short
metrics.ms_f_long = f_long
metrics.ms_raw = max(0.0, self.safe_log(ratio, 0.0))
def _calculate_ml(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
site_id: int,
now: datetime,
) -> None:
lookback_days = int(params.get("lookback_days", 60))
amount_base = float(params.get("amount_base", 500.0))
halflife_recharge = float(params.get("halflife_recharge", 21.0))
start_time = now - timedelta(days=lookback_days)
# CHANGE 2026-02-13 | intent: ML 仅使用人工台账,移除 last-touch 备用路径
manual_rows = self._extract_manual_alloc(site_id, start_time, now)
for row in manual_rows:
member_id = int(row["member_id"])
assistant_id = int(row["assistant_id"])
key = (member_id, assistant_id)
if key not in pair_map:
pair_map[key] = RelationPairMetrics(
site_id=site_id,
tenant_id=pair_map[next(iter(pair_map))].tenant_id if pair_map else self._get_tenant_id(),
member_id=member_id,
assistant_id=assistant_id,
)
metrics = pair_map[key]
amount = float(row.get("allocated_amount") or 0.0)
pay_time = row.get("pay_time")
if amount <= 0 or pay_time is None:
continue
days_ago = min(lookback_days, max(0.0, (now - pay_time).total_seconds() / 86400.0))
metrics.ml_raw += math.log1p(amount / max(amount_base, 1e-6)) * self.decay(
days_ago,
halflife_recharge,
)
metrics.ml_order_count += 1
metrics.ml_allocated_amount += amount
def _extract_manual_alloc(
self,
site_id: int,
start_time: datetime,
end_time: datetime,
) -> List[Dict[str, Any]]:
sql = """
SELECT
member_id,
assistant_id,
pay_time,
allocated_amount
FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s
AND pay_time >= %s
AND pay_time < %s
"""
rows = self.db.query(sql, (site_id, start_time, end_time))
return [dict(row) for row in (rows or [])]
def _calculate_os(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
) -> None:
min_rs = float(params.get("min_rs_raw_for_ownership", 0.05))
min_total = float(params.get("min_total_rs_raw", 0.10))
main_threshold = float(params.get("ownership_main_threshold", 0.60))
comanage_threshold = float(params.get("ownership_comanage_threshold", 0.35))
gap_threshold = float(params.get("ownership_gap_threshold", 0.15))
member_groups: Dict[int, List[RelationPairMetrics]] = {}
for metrics in pair_map.values():
member_groups.setdefault(metrics.member_id, []).append(metrics)
for _, rows in member_groups.items():
eligible = [row for row in rows if row.rs_raw >= min_rs]
sum_rs = sum(row.rs_raw for row in eligible)
if sum_rs < min_total:
for row in rows:
row.os_share = 0.0
row.os_label = "UNASSIGNED"
row.os_rank = None
continue
for row in rows:
if row.rs_raw >= min_rs:
row.os_share = row.rs_raw / sum_rs
else:
row.os_share = 0.0
sorted_eligible = sorted(
eligible,
key=lambda item: (
-item.os_share,
-item.rs_raw,
item.days_since_last_session if item.days_since_last_session is not None else 10**9,
item.assistant_id,
),
)
for idx, row in enumerate(sorted_eligible, start=1):
row.os_rank = idx
top1 = sorted_eligible[0]
top2_share = sorted_eligible[1].os_share if len(sorted_eligible) > 1 else 0.0
gap = top1.os_share - top2_share
has_main = top1.os_share >= main_threshold and gap >= gap_threshold
if has_main:
for row in rows:
if row is top1:
row.os_label = "MAIN"
elif row.os_share >= comanage_threshold:
row.os_label = "COMANAGE"
else:
row.os_label = "POOL"
else:
for row in rows:
if row.os_share >= comanage_threshold and row.rs_raw >= min_rs:
row.os_label = "COMANAGE"
else:
row.os_label = "POOL"
# 非 eligible 不赋 rank
for row in rows:
if row.rs_raw < min_rs:
row.os_rank = None
def _apply_display_scores(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params_rs: Dict[str, float],
params_ms: Dict[str, float],
params_ml: Dict[str, float],
site_id: int,
) -> None:
pair_items = list(pair_map.items())
rs_map = self._normalize_and_record(
raw_pairs=[(key, item.rs_raw) for key, item in pair_items],
params=params_rs,
index_type="RS",
site_id=site_id,
)
ms_map = self._normalize_and_record(
raw_pairs=[(key, item.ms_raw) for key, item in pair_items],
params=params_ms,
index_type="MS",
site_id=site_id,
)
ml_map = self._normalize_and_record(
raw_pairs=[(key, item.ml_raw) for key, item in pair_items],
params=params_ml,
index_type="ML",
site_id=site_id,
)
for key, item in pair_items:
item.rs_display = rs_map.get(key, 0.0)
item.ms_display = ms_map.get(key, 0.0)
item.ml_display = ml_map.get(key, 0.0)
def _normalize_and_record(
self,
*,
raw_pairs: List[Tuple[Any, float]],
params: Dict[str, float],
index_type: str,
site_id: int,
) -> Dict[Any, float]:
if not raw_pairs:
return {}
if all(abs(score) <= 1e-9 for _, score in raw_pairs):
return {entity: 0.0 for entity, _ in raw_pairs}
percentile_lower = int(params.get("percentile_lower", 5))
percentile_upper = int(params.get("percentile_upper", 95))
use_smoothing = int(params.get("use_smoothing", 1)) == 1
compression = self._map_compression(params)
normalized = self.batch_normalize_to_display(
raw_scores=raw_pairs,
compression=compression,
percentile_lower=percentile_lower,
percentile_upper=percentile_upper,
use_smoothing=use_smoothing,
site_id=site_id,
index_type=index_type,
)
display_map = {entity: display for entity, _, display in normalized}
raw_values = [float(score) for _, score in raw_pairs]
q_l, q_u = self.calculate_percentiles(raw_values, percentile_lower, percentile_upper)
if use_smoothing:
smoothed_l, smoothed_u = self._apply_ewma_smoothing(
site_id=site_id,
current_p5=q_l,
current_p95=q_u,
index_type=index_type,
)
else:
smoothed_l, smoothed_u = q_l, q_u
self.save_percentile_history(
site_id=site_id,
percentile_5=q_l,
percentile_95=q_u,
percentile_5_smoothed=smoothed_l,
percentile_95_smoothed=smoothed_u,
record_count=len(raw_values),
min_raw=min(raw_values),
max_raw=max(raw_values),
avg_raw=sum(raw_values) / len(raw_values),
index_type=index_type,
)
return display_map
@staticmethod
def _map_compression(params: Dict[str, float]) -> str:
mode = int(params.get("compression_mode", 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
def _save_relation_rows(self, site_id: int, rows: List[RelationPairMetrics]) -> int:
with self.db.conn.cursor() as cur:
cur.execute(
"DELETE FROM billiards_dws.dws_member_assistant_relation_index WHERE site_id = %s",
(site_id,),
)
if not rows:
self.db.conn.commit()
return 0
insert_sql = """
INSERT INTO billiards_dws.dws_member_assistant_relation_index (
site_id, tenant_id, member_id, assistant_id,
session_count, total_duration_minutes, basic_session_count, incentive_session_count,
days_since_last_session,
rs_f, rs_d, rs_r, rs_raw, rs_display,
os_share, os_label, os_rank,
ms_f_short, ms_f_long, ms_raw, ms_display,
ml_order_count, ml_allocated_amount, ml_raw, ml_display,
calc_time, created_at, updated_at
) VALUES (
%s, %s, %s, %s,
%s, %s, %s, %s,
%s,
%s, %s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s, %s, %s, %s,
NOW(), NOW(), NOW()
)
"""
inserted = 0
for row in rows:
cur.execute(
insert_sql,
(
row.site_id,
row.tenant_id,
row.member_id,
row.assistant_id,
row.session_count,
row.total_duration_minutes,
row.basic_session_count,
row.incentive_session_count,
row.days_since_last_session,
row.rs_f,
row.rs_d,
row.rs_r,
row.rs_raw,
row.rs_display,
row.os_share,
row.os_label,
row.os_rank,
row.ms_f_short,
row.ms_f_long,
row.ms_raw,
row.ms_display,
row.ml_order_count,
row.ml_allocated_amount,
row.ml_raw,
row.ml_display,
),
)
inserted += max(cur.rowcount, 0)
self.db.conn.commit()
return inserted
def _get_site_id(self, context: Optional[TaskContext]) -> int:
if context and getattr(context, "store_id", None):
return int(context.store_id)
site_id = self.config.get("app.default_site_id") or self.config.get("app.store_id")
if site_id is not None:
return int(site_id)
sql = "SELECT DISTINCT site_id FROM billiards_dwd.dwd_assistant_service_log WHERE site_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
return int(dict(rows[0]).get("site_id") or 0)
self.logger.warning("无法确定门店ID使用 0 继续执行")
return 0
def _get_tenant_id(self) -> int:
tenant_id = self.config.get("app.tenant_id")
if tenant_id is not None:
return int(tenant_id)
sql = "SELECT DISTINCT tenant_id FROM billiards_dwd.dwd_assistant_service_log WHERE tenant_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
return int(dict(rows[0]).get("tenant_id") or 0)
self.logger.warning("无法确定租户ID使用 0 继续执行")
return 0
__all__ = ["RelationIndexTask", "RelationPairMetrics", "ServiceSession"]

View File

@@ -0,0 +1,408 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 修复 STOP_HIGH_BALANCE 会员不参与评分的逻辑缺陷;
# Prompt: "STOP_HIGH_BALANCE 应该参与 WBI 评分"
"""
老客挽回指数WBI计算任务。"""
from __future__ import annotations
import math
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Any, Dict, List, Optional, Tuple
from .member_index_base import MemberActivityData, MemberIndexBaseTask
from ..base_dws_task import TaskContext
@dataclass
class MemberWinbackData:
activity: MemberActivityData
status: str
segment: str
overdue_old: float = 0.0
overdue_cdf_p: float = 0.0
drop_old: float = 0.0
recharge_old: float = 0.0
value_old: float = 0.0
ideal_interval_days: Optional[float] = None
ideal_next_visit_date: Optional[date] = None
raw_score: Optional[float] = None
display_score: Optional[float] = None
class WinbackIndexTask(MemberIndexBaseTask):
"""老客挽回指数WBI计算任务。"""
INDEX_TYPE = "WBI"
DEFAULT_PARAMS = {
# 通用参数
'lookback_days_recency': 60,
'visit_lookback_days': 180,
'percentile_lower': 5,
'percentile_upper': 95,
'compression_mode': 0,
'use_smoothing': 1,
'ewma_alpha': 0.2,
# 分流参数
'new_visit_threshold': 2,
'new_days_threshold': 30,
'recharge_recent_days': 14,
'new_recharge_max_visits': 10,
'recency_hard_floor_days': 14,
'recency_gate_days': 14,
'recency_gate_slope_days': 3,
# WBI参数
'overdue_alpha': 2.0,
'overdue_weight_halflife_days': 30,
'overdue_weight_blend_min_samples': 8,
'h_recharge': 7,
'amount_base_M0': 300,
'balance_base_B0': 500,
'value_w_spend': 1.0,
'value_w_bal': 1.0,
'w_over': 2.0,
'w_drop': 1.0,
'w_re': 0.4,
'w_value': 1.2,
# STOP高余额例外默认关闭
'enable_stop_high_balance_exception': 0,
'high_balance_threshold': 1000,
}
def get_task_code(self) -> str:
return "DWS_WINBACK_INDEX"
def get_target_table(self) -> str:
return "dws_member_winback_index"
def get_primary_keys(self) -> List[str]:
return ['site_id', 'member_id']
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""执行 WBI 计算"""
self.logger.info("开始计算老客挽回指数 (WBI)")
site_id = self._get_site_id(context)
tenant_id = self._get_tenant_id()
params = self._load_params()
activity_map = self._build_member_activity(site_id, tenant_id, params)
if not activity_map:
self.logger.warning("No member activity data available; skip calculation")
return {'status': 'skipped', 'reason': 'no_data'}
winback_list: List[MemberWinbackData] = []
for activity in activity_map.values():
segment, status, in_scope = self.classify_segment(activity, params)
if not in_scope:
continue
if segment != "OLD" and status != "STOP_HIGH_BALANCE":
continue
data = MemberWinbackData(activity=activity, status=status, segment=segment)
# CHANGE 2026-02-13 | intent: STOP_HIGH_BALANCE 也参与评分
# 原先只对 segment=="OLD" 评分STOP_HIGH_BALANCE (segment=="STOP")
# 进入范围却不评分raw_score=NULL 无运营价值。
# 这些会员具备完整特征数据,应同等评分。
if segment == "OLD" or status == "STOP_HIGH_BALANCE":
self._calculate_wbi_scores(data, params)
winback_list.append(data)
if not winback_list:
self.logger.warning("No old-member rows to calculate")
return {'status': 'skipped', 'reason': 'no_old_members'}
# 归一化 Display Score
raw_scores = [
(d.activity.member_id, d.raw_score)
for d in winback_list
if d.raw_score is not None
]
if raw_scores:
compression = self._map_compression(params)
use_smoothing = int(params.get('use_smoothing', 1)) == 1
normalized = self.batch_normalize_to_display(
raw_scores,
compression=compression,
percentile_lower=int(params['percentile_lower']),
percentile_upper=int(params['percentile_upper']),
use_smoothing=use_smoothing,
site_id=site_id
)
score_map = {member_id: display for member_id, _, display in normalized}
for data in winback_list:
if data.activity.member_id in score_map:
data.display_score = score_map[data.activity.member_id]
# 保存分位点历史
all_raw = [float(score) for _, score in raw_scores]
q_l, q_u = self.calculate_percentiles(
all_raw,
int(params['percentile_lower']),
int(params['percentile_upper'])
)
if use_smoothing:
smoothed_l, smoothed_u = self._apply_ewma_smoothing(site_id, q_l, q_u)
else:
smoothed_l, smoothed_u = q_l, q_u
self.save_percentile_history(
site_id=site_id,
percentile_5=q_l,
percentile_95=q_u,
percentile_5_smoothed=smoothed_l,
percentile_95_smoothed=smoothed_u,
record_count=len(all_raw),
min_raw=min(all_raw),
max_raw=max(all_raw),
avg_raw=sum(all_raw) / len(all_raw)
)
inserted = self._save_winback_data(winback_list)
self.logger.info("WBI calculation finished, inserted %d rows", inserted)
return {
'status': 'success',
'member_count': len(winback_list),
'records_inserted': inserted
}
def _weighted_cdf(
self,
samples: List[Tuple[float, int]],
t_v: float,
halflife_days: float,
blend_min_samples: int,
) -> float:
if not samples:
return 0.5
if halflife_days <= 0:
p_equal = sum(1.0 for interval, _ in samples if interval <= t_v) / len(samples)
return self._clip(p_equal, 0.0, 1.0)
ln2 = math.log(2.0)
weighted_hit = 0.0
weight_sum = 0.0
equal_hit = 0.0
for interval, age_days in samples:
weight = math.exp(-ln2 * float(age_days) / halflife_days)
indicator = 1.0 if interval <= t_v else 0.0
weighted_hit += weight * indicator
weight_sum += weight
equal_hit += indicator
p_weighted = 0.5 if weight_sum <= 0 else (weighted_hit / weight_sum)
p_equal = equal_hit / len(samples)
lam = min(1.0, float(len(samples)) / float(max(1, blend_min_samples)))
p_final = lam * p_weighted + (1.0 - lam) * p_equal
return self._clip(p_final, 0.0, 1.0)
def _weighted_quantile(
self,
samples: List[Tuple[float, int]],
quantile: float,
halflife_days: float,
blend_min_samples: int,
) -> Optional[float]:
if not samples:
return None
q = self._clip(quantile, 0.0, 1.0)
equal_weight = 1.0 / float(len(samples))
if halflife_days <= 0:
weighted = [(interval, equal_weight) for interval, _ in samples]
else:
ln2 = math.log(2.0)
raw_weighted: List[Tuple[float, float]] = []
total = 0.0
for interval, age_days in samples:
w = math.exp(-ln2 * float(age_days) / halflife_days)
raw_weighted.append((interval, w))
total += w
if total <= 0:
weighted = [(interval, equal_weight) for interval, _ in samples]
else:
weighted = [(interval, w / total) for interval, w in raw_weighted]
# 对小样本混合加权分布与等权分布。
lam = min(1.0, float(len(samples)) / float(max(1, blend_min_samples)))
blended: List[Tuple[float, float]] = []
for (interval_w, w), (interval_e, _) in zip(weighted, samples):
_ = interval_e # keep tuple alignment explicit
blended_weight = lam * w + (1.0 - lam) * equal_weight
blended.append((interval_w, blended_weight))
blended.sort(key=lambda item: item[0])
cumulative = 0.0
for interval, weight in blended:
cumulative += weight
if cumulative >= q:
return float(interval)
return float(blended[-1][0])
def _calculate_wbi_scores(self, data: MemberWinbackData, params: Dict[str, float]) -> None:
"""计算 WBI 分项与 Raw Score"""
activity = data.activity
# 1) 超期紧急性基于近期加权经验CDF
overdue_alpha = float(params['overdue_alpha'])
half_life_days = float(params.get('overdue_weight_halflife_days', 30))
blend_min_samples = int(params.get('overdue_weight_blend_min_samples', 8))
if activity.interval_count <= 0:
p = 0.5
ideal_interval = None
else:
if len(activity.interval_ages_days) == activity.interval_count:
samples = list(zip(activity.intervals, activity.interval_ages_days))
else:
samples = [(interval, 0) for interval in activity.intervals]
p = self._weighted_cdf(
samples=samples,
t_v=activity.t_v,
halflife_days=half_life_days,
blend_min_samples=blend_min_samples,
)
ideal_interval = self._weighted_quantile(
samples=samples,
quantile=0.5,
halflife_days=half_life_days,
blend_min_samples=blend_min_samples,
)
data.overdue_cdf_p = p
data.overdue_old = math.pow(p, overdue_alpha)
data.ideal_interval_days = ideal_interval
if ideal_interval is not None and activity.last_visit_time is not None:
ideal_days = max(0, int(round(ideal_interval)))
data.ideal_next_visit_date = activity.last_visit_time.date() + timedelta(days=ideal_days)
else:
data.ideal_next_visit_date = None
# 2) 降频分
expected14 = activity.visits_60d * 14.0 / 60.0
data.drop_old = self._clip((expected14 - activity.visits_14d) / (expected14 + 1), 0.0, 1.0)
# 3) 充值未回访压力
if activity.recharge_unconsumed == 1:
data.recharge_old = self.decay(activity.t_r, params['h_recharge'])
else:
data.recharge_old = 0.0
# 4) 价值分
m0 = float(params['amount_base_M0'])
b0 = float(params['balance_base_B0'])
spend_score = math.log1p(activity.spend_180d / m0) if m0 > 0 else 0.0
bal_score = math.log1p(activity.sv_balance / b0) if b0 > 0 else 0.0
data.value_old = float(params['value_w_spend']) * spend_score + float(params['value_w_bal']) * bal_score
data.raw_score = (
float(params['w_over']) * data.overdue_old
+ float(params['w_drop']) * data.drop_old
+ float(params['w_re']) * data.recharge_old
+ float(params['w_value']) * data.value_old
)
hard_floor_days = float(params.get('recency_hard_floor_days', 0))
gate_days = float(params.get('recency_gate_days', 14))
slope_days = float(params.get('recency_gate_slope_days', 3))
if hard_floor_days > 0 and activity.t_v < hard_floor_days:
suppression = 0.0
elif slope_days <= 0:
suppression = 1.0 if activity.t_v >= gate_days else 0.0
else:
x = (activity.t_v - gate_days) / slope_days
x = self._clip(x, -60.0, 60.0)
suppression = 1.0 / (1.0 + math.exp(-x))
data.raw_score *= suppression
# 限制在 0 以上
if data.raw_score < 0:
data.raw_score = 0.0
def _save_winback_data(self, data_list: List[MemberWinbackData]) -> int:
"""保存 WBI 数据"""
if not data_list:
return 0
site_id = data_list[0].activity.site_id
# 按门店全量刷新,避免因分群变化导致过期数据残留。
delete_sql = """
DELETE FROM billiards_dws.dws_member_winback_index
WHERE site_id = %s
"""
with self.db.conn.cursor() as cur:
cur.execute(delete_sql, (site_id,))
insert_sql = """
INSERT INTO billiards_dws.dws_member_winback_index (
site_id, tenant_id, member_id,
status, segment,
member_create_time, first_visit_time, last_visit_time, last_recharge_time,
t_v, t_r, t_a,
visits_14d, visits_60d, visits_total,
spend_30d, spend_180d, sv_balance, recharge_60d_amt,
interval_count,
overdue_old, overdue_cdf_p, drop_old, recharge_old, value_old,
ideal_interval_days, ideal_next_visit_date,
raw_score, display_score,
last_wechat_touch_time,
calc_time, created_at, updated_at
) VALUES (
%s, %s, %s,
%s, %s,
%s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s,
%s, %s, %s, %s, %s,
%s, %s,
%s, %s,
%s,
NOW(), NOW(), NOW()
)
"""
inserted = 0
with self.db.conn.cursor() as cur:
for data in data_list:
activity = data.activity
cur.execute(insert_sql, (
activity.site_id, activity.tenant_id, activity.member_id,
data.status, data.segment,
activity.member_create_time, activity.first_visit_time, activity.last_visit_time, activity.last_recharge_time,
activity.t_v, activity.t_r, activity.t_a,
activity.visits_14d, activity.visits_60d, activity.visits_total,
activity.spend_30d, activity.spend_180d, activity.sv_balance, activity.recharge_60d_amt,
activity.interval_count,
data.overdue_old, data.overdue_cdf_p, data.drop_old, data.recharge_old, data.value_old,
data.ideal_interval_days, data.ideal_next_visit_date,
data.raw_score, data.display_score,
None,
))
inserted += cur.rowcount
self.db.conn.commit()
return inserted
def _clip(self, value: float, low: float, high: float) -> float:
return max(low, min(high, value))
def _map_compression(self, params: Dict[str, float]) -> str:
mode = int(params.get('compression_mode', 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
__all__ = ['WinbackIndexTask']

View File

@@ -0,0 +1,370 @@
# -*- coding: utf-8 -*-
"""
会员消费汇总任务
功能说明:
"会员"为粒度,统计消费行为和滚动窗口指标
数据来源:
- dwd_settlement_head: 结账单头表
- dim_member: 会员维度
- dim_member_card_account: 会员卡账户
目标表:
billiards_dws.dws_member_consumption_summary
更新策略:
- 更新频率:每日更新
- 幂等方式delete-before-insert按统计日期
业务规则:
- 散客处理member_id=0 不进入此表
- 滚动窗口7/10/15/30/60/90天
- 卡余额:区分储值卡(现金卡)和赠送卡
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class MemberConsumptionTask(BaseDwsTask):
"""
会员消费汇总任务
统计每个会员的:
- 首次/最近消费日期
- 累计消费统计
- 滚动窗口统计7/10/15/30/60/90天
- 卡余额快照
- 活跃度指标和客户分层
"""
def get_task_code(self) -> str:
return "DWS_MEMBER_CONSUMPTION"
def get_target_table(self) -> str:
return "dws_member_consumption_summary"
def get_primary_keys(self) -> List[str]:
return ["site_id", "member_id", "stat_date"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据
"""
stat_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
self.logger.info(
"%s: 提取数据,统计日期 %s",
self.get_task_code(), stat_date
)
# 1. 获取会员消费统计(含滚动窗口)
consumption_stats = self._extract_consumption_stats(site_id, stat_date)
# 2. 获取会员信息
member_info = self._extract_member_info(site_id)
# 3. 获取会员卡余额
card_balances = self._extract_card_balances(site_id)
return {
'consumption_stats': consumption_stats,
'member_info': member_info,
'card_balances': card_balances,
'stat_date': stat_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据
"""
consumption_stats = extracted['consumption_stats']
member_info = extracted['member_info']
card_balances = extracted['card_balances']
stat_date = extracted['stat_date']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,%d 条会员消费记录",
self.get_task_code(), len(consumption_stats)
)
results = []
for stats in consumption_stats:
member_id = stats.get('member_id')
# 跳过散客
if self.is_guest(member_id):
continue
memb_info = member_info.get(member_id, {})
balance = card_balances.get(member_id, {})
# 计算活跃度和客户分层
days_since_last = self._calc_days_since(stat_date, stats.get('last_consume_date'))
customer_tier = self._calculate_customer_tier(stats, days_since_last)
record = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'member_id': member_id,
'stat_date': stat_date,
# 会员基本信息
'member_nickname': memb_info.get('nickname'),
'member_mobile': self._mask_mobile(memb_info.get('mobile')),
'card_grade_name': memb_info.get('member_card_grade_name'),
'register_date': memb_info.get('register_date'),
# 全量累计统计
'first_consume_date': stats.get('first_consume_date'),
'last_consume_date': stats.get('last_consume_date'),
'total_visit_count': self.safe_int(stats.get('total_visit_count', 0)),
'total_consume_amount': self.safe_decimal(stats.get('total_consume_amount', 0)),
'total_recharge_amount': self.safe_decimal(memb_info.get('recharge_money_sum', 0)),
'total_table_fee': self.safe_decimal(stats.get('total_table_fee', 0)),
'total_goods_amount': self.safe_decimal(stats.get('total_goods_amount', 0)),
'total_assistant_amount': self.safe_decimal(stats.get('total_assistant_amount', 0)),
# 滚动窗口统计
'visit_count_7d': self.safe_int(stats.get('visit_count_7d', 0)),
'visit_count_10d': self.safe_int(stats.get('visit_count_10d', 0)),
'visit_count_15d': self.safe_int(stats.get('visit_count_15d', 0)),
'visit_count_30d': self.safe_int(stats.get('visit_count_30d', 0)),
'visit_count_60d': self.safe_int(stats.get('visit_count_60d', 0)),
'visit_count_90d': self.safe_int(stats.get('visit_count_90d', 0)),
'consume_amount_7d': self.safe_decimal(stats.get('consume_amount_7d', 0)),
'consume_amount_10d': self.safe_decimal(stats.get('consume_amount_10d', 0)),
'consume_amount_15d': self.safe_decimal(stats.get('consume_amount_15d', 0)),
'consume_amount_30d': self.safe_decimal(stats.get('consume_amount_30d', 0)),
'consume_amount_60d': self.safe_decimal(stats.get('consume_amount_60d', 0)),
'consume_amount_90d': self.safe_decimal(stats.get('consume_amount_90d', 0)),
# 卡余额
'cash_card_balance': self.safe_decimal(balance.get('cash_balance', 0)),
'gift_card_balance': self.safe_decimal(balance.get('gift_balance', 0)),
'total_card_balance': self.safe_decimal(balance.get('total_balance', 0)),
# 活跃度指标
'days_since_last': days_since_last,
'is_active_7d': self.safe_int(stats.get('visit_count_7d', 0)) > 0,
'is_active_30d': self.safe_int(stats.get('visit_count_30d', 0)) > 0,
'is_active_90d': self.safe_int(stats.get('visit_count_90d', 0)) > 0,
# 客户分层
'customer_tier': customer_tier,
}
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
deleted = self.delete_existing_data(context, date_col="stat_date")
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _extract_consumption_stats(
self,
site_id: int,
stat_date: date
) -> List[Dict[str, Any]]:
"""
提取会员消费统计(含滚动窗口)
"""
sql = """
WITH consume_base AS (
SELECT
member_id,
DATE(pay_time) AS consume_date,
consume_money,
table_charge_money,
goods_money,
assistant_pd_money + assistant_cx_money AS assistant_amount
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %s
AND member_id IS NOT NULL
AND member_id != 0
)
SELECT
member_id,
MIN(consume_date) AS first_consume_date,
MAX(consume_date) AS last_consume_date,
-- 全量累计
COUNT(*) AS total_visit_count,
SUM(consume_money) AS total_consume_amount,
SUM(table_charge_money) AS total_table_fee,
SUM(goods_money) AS total_goods_amount,
SUM(assistant_amount) AS total_assistant_amount,
-- 滚动窗口
COUNT(CASE WHEN consume_date >= %s - INTERVAL '6 days' THEN 1 END) AS visit_count_7d,
COUNT(CASE WHEN consume_date >= %s - INTERVAL '9 days' THEN 1 END) AS visit_count_10d,
COUNT(CASE WHEN consume_date >= %s - INTERVAL '14 days' THEN 1 END) AS visit_count_15d,
COUNT(CASE WHEN consume_date >= %s - INTERVAL '29 days' THEN 1 END) AS visit_count_30d,
COUNT(CASE WHEN consume_date >= %s - INTERVAL '59 days' THEN 1 END) AS visit_count_60d,
COUNT(CASE WHEN consume_date >= %s - INTERVAL '89 days' THEN 1 END) AS visit_count_90d,
SUM(CASE WHEN consume_date >= %s - INTERVAL '6 days' THEN consume_money ELSE 0 END) AS consume_amount_7d,
SUM(CASE WHEN consume_date >= %s - INTERVAL '9 days' THEN consume_money ELSE 0 END) AS consume_amount_10d,
SUM(CASE WHEN consume_date >= %s - INTERVAL '14 days' THEN consume_money ELSE 0 END) AS consume_amount_15d,
SUM(CASE WHEN consume_date >= %s - INTERVAL '29 days' THEN consume_money ELSE 0 END) AS consume_amount_30d,
SUM(CASE WHEN consume_date >= %s - INTERVAL '59 days' THEN consume_money ELSE 0 END) AS consume_amount_60d,
SUM(CASE WHEN consume_date >= %s - INTERVAL '89 days' THEN consume_money ELSE 0 END) AS consume_amount_90d
FROM consume_base
GROUP BY member_id
"""
params = [site_id] + [stat_date] * 12
rows = self.db.query(sql, tuple(params))
return [dict(row) for row in rows] if rows else []
def _extract_member_info(self, site_id: int) -> Dict[int, Dict[str, Any]]:
"""
提取会员信息
"""
sql = """
SELECT
member_id,
nickname,
mobile,
member_card_grade_name,
DATE(create_time) AS register_date,
recharge_money_sum
FROM billiards_dwd.dim_member
WHERE site_id = %s
AND scd2_is_current = 1
"""
rows = self.db.query(sql, (site_id,))
result = {}
for row in (rows or []):
row_dict = dict(row)
result[row_dict['member_id']] = row_dict
return result
def _extract_card_balances(self, site_id: int) -> Dict[int, Dict[str, Decimal]]:
"""
提取会员卡余额
"""
# 卡类型ID
CASH_CARD_TYPE_ID = 2793249295533893
GIFT_CARD_TYPE_IDS = [2791990152417157, 2793266846533445, 2794699703437125]
sql = """
SELECT
tenant_member_id AS member_id,
card_type_id,
balance
FROM billiards_dwd.dim_member_card_account
WHERE site_id = %s
AND scd2_is_current = 1
AND COALESCE(is_delete, 0) = 0
"""
rows = self.db.query(sql, (site_id,))
result: Dict[int, Dict[str, Decimal]] = {}
for row in (rows or []):
row_dict = dict(row)
member_id = row_dict.get('member_id')
card_type_id = row_dict.get('card_type_id')
balance = self.safe_decimal(row_dict.get('balance', 0))
if member_id not in result:
result[member_id] = {
'cash_balance': Decimal('0'),
'gift_balance': Decimal('0'),
'total_balance': Decimal('0')
}
if card_type_id == CASH_CARD_TYPE_ID:
result[member_id]['cash_balance'] += balance
elif card_type_id in GIFT_CARD_TYPE_IDS:
result[member_id]['gift_balance'] += balance
result[member_id]['total_balance'] = (
result[member_id]['cash_balance'] + result[member_id]['gift_balance']
)
return result
# ==========================================================================
# 工具方法
# ==========================================================================
def _mask_mobile(self, mobile: Optional[str]) -> Optional[str]:
"""手机号脱敏"""
if not mobile or len(mobile) < 7:
return mobile
return mobile[:3] + "****" + mobile[-4:]
def _calc_days_since(self, stat_date: date, last_date: Optional[date]) -> Optional[int]:
"""计算距离最近消费的天数"""
if not last_date:
return None
if isinstance(last_date, datetime):
last_date = last_date.date()
return (stat_date - last_date).days
def _calculate_customer_tier(
self,
stats: Dict[str, Any],
days_since_last: Optional[int]
) -> str:
"""
计算客户分层
分层规则:
- 高价值90天内消费>=3次 且 消费金额>=1000
- 中等30天内有消费
- 低活跃90天内有消费但30天内无消费
- 流失90天内无消费
"""
visit_90d = self.safe_int(stats.get('visit_count_90d', 0))
visit_30d = self.safe_int(stats.get('visit_count_30d', 0))
amount_90d = self.safe_decimal(stats.get('consume_amount_90d', 0))
if visit_90d >= 3 and amount_90d >= 1000:
return "高价值"
elif visit_30d > 0:
return "中等"
elif visit_90d > 0:
return "低活跃"
else:
return "流失"
# 便于外部导入
__all__ = ['MemberConsumptionTask']

View File

@@ -0,0 +1,423 @@
# -*- coding: utf-8 -*-
"""
会员来店明细任务
功能说明:
"会员+订单"为粒度,记录每次来店消费明细
数据来源:
- dwd_settlement_head: 结账单头表
- dwd_assistant_service_log: 助教服务流水
- dim_member: 会员维度
- dim_table: 台桌维度
- cfg_area_category: 区域分类映射
目标表:
billiards_dws.dws_member_visit_detail
更新策略:
- 更新频率:每日增量更新
- 幂等方式delete-before-insert按日期窗口
业务规则:
- 散客处理member_id=0 不进入此表
- 区域分类使用cfg_area_category映射
- 助教服务以JSON格式存储多个助教的服务明细
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
import json
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_dws_task import BaseDwsTask, TaskContext
class MemberVisitTask(BaseDwsTask):
"""
会员来店明细任务
记录每个会员每次来店的:
- 台桌信息和区域分类
- 消费金额明细
- 支付方式明细
- 助教服务明细JSON格式
"""
def get_task_code(self) -> str:
return "DWS_MEMBER_VISIT"
def get_target_table(self) -> str:
return "dws_member_visit_detail"
def get_primary_keys(self) -> List[str]:
return ["site_id", "member_id", "order_settle_id"]
# ==========================================================================
# ETL主流程
# ==========================================================================
def extract(self, context: TaskContext) -> Dict[str, Any]:
"""
提取数据
"""
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
site_id = context.store_id
self.logger.info(
"%s: 提取数据,日期范围 %s ~ %s",
self.get_task_code(), start_date, end_date
)
# 1. 获取结账单
settlements = self._extract_settlements(site_id, start_date, end_date)
# 2. 获取助教服务明细
assistant_services = self._extract_assistant_services(site_id, start_date, end_date)
# 2.1 获取台费时长(真实秒数)
table_fee_durations = self._extract_table_fee_durations(site_id, start_date, end_date)
# 3. 获取会员信息
member_info = self._extract_member_info(site_id)
# 4. 获取台桌信息
table_info = self._extract_table_info(site_id)
# 5. 加载配置
self.load_config_cache()
return {
'settlements': settlements,
'assistant_services': assistant_services,
'member_info': member_info,
'table_info': table_info,
'table_fee_durations': table_fee_durations,
'start_date': start_date,
'end_date': end_date,
'site_id': site_id
}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> List[Dict[str, Any]]:
"""
转换数据
"""
settlements = extracted['settlements']
assistant_services = extracted['assistant_services']
member_info = extracted['member_info']
table_info = extracted['table_info']
table_fee_durations = extracted['table_fee_durations']
site_id = extracted['site_id']
self.logger.info(
"%s: 转换数据,%d 条结账单",
self.get_task_code(), len(settlements)
)
# 构建助教服务索引order_settle_id -> [services]
service_index = self._build_service_index(assistant_services)
# 构建台费时长索引order_settle_id -> total_seconds
table_duration_index = {
row.get('order_settle_id'): self.safe_int(row.get('table_use_seconds', 0))
for row in (table_fee_durations or [])
if row.get('order_settle_id')
}
results = []
for settle in settlements:
member_id = settle.get('member_id')
# 跳过散客
if self.is_guest(member_id):
continue
order_settle_id = settle.get('order_settle_id')
table_id = settle.get('table_id')
memb_info = member_info.get(member_id, {})
tbl_info = table_info.get(table_id, {})
services = service_index.get(order_settle_id, [])
# 获取区域分类
area_name = tbl_info.get('area_name')
area_cat = self.get_area_category(area_name)
# 构建助教服务JSON
assistant_services_json = self._build_assistant_services_json(services)
# 计算时长
table_seconds = table_duration_index.get(order_settle_id, 0)
table_duration = self._calc_table_duration(table_seconds)
assistant_duration = sum(
self.safe_int(s.get('income_seconds', 0))
for s in services
) // 60 # 转为分钟
record = {
'site_id': site_id,
'tenant_id': self.config.get("app.tenant_id", site_id),
'member_id': member_id,
'order_settle_id': order_settle_id,
'visit_date': settle.get('visit_date'),
'visit_time': settle.get('create_time'),
# 会员信息
'member_nickname': memb_info.get('nickname'),
'member_mobile': self._mask_mobile(memb_info.get('mobile')),
'member_birthday': memb_info.get('birthday'),
# 台桌信息
'table_id': table_id,
'table_name': tbl_info.get('table_name'),
'area_name': area_name,
'area_category': area_cat.get('category_name'),
# 消费金额
'table_fee': self.safe_decimal(settle.get('table_charge_money', 0)),
'goods_amount': self.safe_decimal(settle.get('goods_money', 0)),
'assistant_amount': self.safe_decimal(settle.get('assistant_pd_money', 0)) + \
self.safe_decimal(settle.get('assistant_cx_money', 0)),
'total_consume': self.safe_decimal(settle.get('consume_money', 0)),
'total_discount': self._calc_total_discount(settle),
'actual_pay': self.safe_decimal(settle.get('pay_amount', 0)),
# 支付方式
'cash_pay': self.safe_decimal(settle.get('pay_amount', 0)),
'cash_card_pay': self.safe_decimal(settle.get('balance_amount', 0)),
'gift_card_pay': self.safe_decimal(settle.get('gift_card_amount', 0)),
'groupbuy_pay': self.safe_decimal(settle.get('coupon_amount', 0)),
# 时长
'table_duration_min': table_duration,
'assistant_duration_min': assistant_duration,
# 助教服务明细
'assistant_services': assistant_services_json,
}
results.append(record)
return results
def load(self, transformed: List[Dict[str, Any]], context: TaskContext) -> Dict:
"""
加载数据
"""
if not transformed:
self.logger.info("%s: 无数据需要写入", self.get_task_code())
return {"counts": {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}}
deleted = self.delete_existing_data(context, date_col="visit_date")
inserted = self.bulk_insert(transformed)
self.logger.info(
"%s: 加载完成,删除 %d 行,插入 %d",
self.get_task_code(), deleted, inserted
)
return {
"counts": {
"fetched": len(transformed),
"inserted": inserted,
"updated": 0,
"skipped": 0,
"errors": 0
},
"extra": {"deleted": deleted}
}
# ==========================================================================
# 数据提取方法
# ==========================================================================
def _extract_settlements(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取结账单
"""
sql = """
SELECT
order_settle_id,
order_trade_no,
table_id,
member_id,
create_time,
pay_time,
DATE(pay_time) AS visit_date,
consume_money,
pay_amount,
table_charge_money,
goods_money,
assistant_pd_money,
assistant_cx_money,
coupon_amount,
adjust_amount,
member_discount_amount,
rounding_amount,
gift_card_amount,
balance_amount,
recharge_card_amount
FROM billiards_dwd.dwd_settlement_head
WHERE site_id = %s
AND DATE(pay_time) >= %s
AND DATE(pay_time) <= %s
AND member_id IS NOT NULL
AND member_id != 0
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_assistant_services(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取助教服务明细
"""
sql = """
SELECT
order_settle_id,
site_assistant_id AS assistant_id,
nickname AS assistant_nickname,
income_seconds,
ledger_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE site_id = %s
AND DATE(start_use_time) >= %s
AND DATE(start_use_time) <= %s
AND is_delete = 0
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_table_fee_durations(
self,
site_id: int,
start_date: date,
end_date: date
) -> List[Dict[str, Any]]:
"""
提取台费时长(真实秒数)
"""
sql = """
SELECT
order_settle_id,
SUM(COALESCE(real_table_use_seconds, 0)) AS table_use_seconds
FROM billiards_dwd.dwd_table_fee_log
WHERE site_id = %s
AND DATE(ledger_end_time) >= %s
AND DATE(ledger_end_time) <= %s
AND COALESCE(is_delete, 0) = 0
GROUP BY order_settle_id
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in rows] if rows else []
def _extract_member_info(self, site_id: int) -> Dict[int, Dict[str, Any]]:
"""
提取会员信息
"""
sql = """
SELECT
member_id,
nickname,
mobile,
birthday
FROM billiards_dwd.dim_member
WHERE site_id = %s
AND scd2_is_current = 1
"""
rows = self.db.query(sql, (site_id,))
return {r['member_id']: dict(r) for r in (rows or [])}
def _extract_table_info(self, site_id: int) -> Dict[int, Dict[str, Any]]:
"""
提取台桌信息
"""
sql = """
SELECT
site_table_id AS table_id,
site_table_name AS table_name,
site_table_area_name AS area_name
FROM billiards_dwd.dim_table
WHERE site_id = %s
AND scd2_is_current = 1
"""
rows = self.db.query(sql, (site_id,))
return {r['table_id']: dict(r) for r in (rows or [])}
# ==========================================================================
# 工具方法
# ==========================================================================
def _build_service_index(
self,
services: List[Dict[str, Any]]
) -> Dict[int, List[Dict[str, Any]]]:
"""
构建助教服务索引
"""
index: Dict[int, List[Dict[str, Any]]] = {}
for service in services:
order_id = service.get('order_settle_id')
if order_id:
if order_id not in index:
index[order_id] = []
index[order_id].append(service)
return index
def _build_assistant_services_json(
self,
services: List[Dict[str, Any]]
) -> Optional[str]:
"""
构建助教服务JSON
"""
if not services:
return None
json_data = []
for s in services:
json_data.append({
'assistant_id': s.get('assistant_id'),
'nickname': s.get('assistant_nickname'),
'duration_min': self.safe_int(s.get('income_seconds', 0)) // 60,
'amount': float(self.safe_decimal(s.get('ledger_amount', 0)))
})
return json.dumps(json_data, ensure_ascii=False)
def _calc_table_duration(self, table_use_seconds: int) -> int:
"""
计算台桌使用时长(分钟)
使用真实台费流水秒数
"""
if not table_use_seconds or table_use_seconds <= 0:
return 0
return int(table_use_seconds // 60)
def _calc_total_discount(self, settle: Dict[str, Any]) -> Decimal:
"""
计算总优惠
"""
adjust = self.safe_decimal(settle.get('adjust_amount', 0))
member_discount = self.safe_decimal(settle.get('member_discount_amount', 0))
rounding = self.safe_decimal(settle.get('rounding_amount', 0))
return adjust + member_discount + rounding
def _mask_mobile(self, mobile: Optional[str]) -> Optional[str]:
"""手机号脱敏"""
if not mobile or len(mobile) < 7:
return mobile
return mobile[:3] + "****" + mobile[-4:]
# 便于外部导入
__all__ = ['MemberVisitTask']

View File

@@ -0,0 +1,196 @@
# -*- coding: utf-8 -*-
"""
DWS 物化视图刷新任务
说明:
- 按 L1/L2/L3/L4 时间分层刷新物化视图
- 默认受 dws.mv.enabled 与 dws.retention.* 配置联动控制
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from .base_dws_task import BaseDwsTask, TaskContext, TimeLayer
class BaseMvRefreshTask(BaseDwsTask):
"""物化视图刷新基类"""
BASE_TABLE: str = ""
DATE_COL: str = ""
VIEW_PREFIX = "mv_"
LAYER_ORDER = [
TimeLayer.LAST_2_DAYS,
TimeLayer.LAST_1_MONTH,
TimeLayer.LAST_3_MONTHS,
TimeLayer.LAST_6_MONTHS,
]
LAYER_SUFFIX = {
TimeLayer.LAST_2_DAYS: "l1",
TimeLayer.LAST_1_MONTH: "l2",
TimeLayer.LAST_3_MONTHS: "l3",
TimeLayer.LAST_6_MONTHS: "l4",
}
def get_target_table(self) -> str:
return self.BASE_TABLE
def get_primary_keys(self) -> List[str]:
return []
def extract(self, context: TaskContext) -> Dict[str, Any]:
return {"site_id": context.store_id}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> Dict[str, Any]:
return extracted
def load(self, transformed: Dict[str, Any], context: TaskContext) -> Dict[str, Any]:
if not self._is_enabled():
self.logger.info("%s: 未启用物化刷新,跳过", self.get_task_code())
return {"counts": {"refreshed": 0}}
layers = self._resolve_layers()
refreshed = 0
details = []
for layer in layers:
view_name = self._get_view_name(layer)
if not view_name:
continue
if not self._view_exists(view_name):
self.logger.warning("%s: 物化视图不存在,跳过 %s", self.get_task_code(), view_name)
continue
self._refresh_view(view_name)
refreshed += 1
details.append({"view": view_name, "layer": layer.value})
self.logger.info("%s: 刷新完成,物化视图数=%d", self.get_task_code(), refreshed)
return {"counts": {"refreshed": refreshed}, "extra": {"details": details}}
def _is_enabled(self) -> bool:
enabled = bool(self.config.get("dws.mv.enabled", False))
if not enabled:
return False
tables = self._parse_list(self.config.get("dws.mv.tables"))
if not tables:
tables = self._parse_list(self.config.get("dws.retention.tables"))
if tables and self.BASE_TABLE not in tables:
return False
return True
def _resolve_layers(self) -> List[TimeLayer]:
# 显式配置优先
configured = self._parse_layers(self.config.get("dws.mv.layers"))
if configured:
return configured
# 表级覆盖:优先 mv.table_layers其次 retention.table_layers
table_layers = self._resolve_layer_map(
self.config.get("dws.mv.table_layers") or self.config.get("dws.retention.table_layers")
)
layer_name = table_layers.get(self.BASE_TABLE)
if layer_name:
layer = self._get_layer(layer_name)
if layer and layer != TimeLayer.ALL:
return self._layers_up_to(layer)
# 默认使用 retention.layer
retention_layer = self._get_layer(self.config.get("dws.retention.layer"))
if retention_layer and retention_layer != TimeLayer.ALL:
return self._layers_up_to(retention_layer)
return list(self.LAYER_ORDER)
def _layers_up_to(self, target: TimeLayer) -> List[TimeLayer]:
layers = []
for layer in self.LAYER_ORDER:
layers.append(layer)
if layer == target:
break
return layers
def _get_view_name(self, layer: TimeLayer) -> Optional[str]:
suffix = self.LAYER_SUFFIX.get(layer)
if not suffix or not self.BASE_TABLE:
return None
return f"{self.VIEW_PREFIX}{self.BASE_TABLE}_{suffix}"
def _view_exists(self, view_name: str) -> bool:
sql = "SELECT to_regclass(%s) AS reg"
rows = self.db.query(sql, (f"{self.DWS_SCHEMA}.{view_name}",))
return bool(rows and rows[0].get("reg"))
def _refresh_view(self, view_name: str) -> None:
concurrently = bool(self.config.get("dws.mv.refresh_concurrently", False))
keyword = "CONCURRENTLY " if concurrently else ""
sql = f"REFRESH MATERIALIZED VIEW {keyword}{self.DWS_SCHEMA}.{view_name}"
self.db.execute(sql)
def _get_layer(self, layer_name: Optional[str]) -> Optional[TimeLayer]:
if not layer_name:
return None
name = str(layer_name).upper()
try:
return TimeLayer[name]
except KeyError:
return None
def _resolve_layer_map(self, raw: Any) -> Dict[str, str]:
if not raw:
return {}
if isinstance(raw, dict):
return {str(k): str(v) for k, v in raw.items()}
if isinstance(raw, str):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
return {str(k): str(v) for k, v in parsed.items()}
except json.JSONDecodeError:
return {}
return {}
def _parse_layers(self, raw: Any) -> List[TimeLayer]:
if not raw:
return []
if isinstance(raw, str):
items = [v.strip() for v in raw.split(",") if v.strip()]
elif isinstance(raw, (list, tuple, set)):
items = [str(v).strip() for v in raw if str(v).strip()]
else:
return []
layers = []
for item in items:
layer = self._get_layer(item)
if layer and layer not in layers:
layers.append(layer)
return layers
def _parse_list(self, raw: Any) -> List[str]:
if not raw:
return []
if isinstance(raw, str):
return [v.strip() for v in raw.split(",") if v.strip()]
if isinstance(raw, (list, tuple, set)):
return [str(v).strip() for v in raw if str(v).strip()]
return []
class DwsMvRefreshFinanceDailyTask(BaseMvRefreshTask):
BASE_TABLE = "dws_finance_daily_summary"
DATE_COL = "stat_date"
def get_task_code(self) -> str:
return "DWS_MV_REFRESH_FINANCE_DAILY"
class DwsMvRefreshAssistantDailyTask(BaseMvRefreshTask):
BASE_TABLE = "dws_assistant_daily_detail"
DATE_COL = "stat_date"
def get_task_code(self) -> str:
return "DWS_MV_REFRESH_ASSISTANT_DAILY"
__all__ = ["DwsMvRefreshFinanceDailyTask", "DwsMvRefreshAssistantDailyTask"]

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""
DWS 时间分层清理任务
功能说明:
按配置的时间分层范围,对 DWS 表执行历史数据清理。
该任务默认不启用,需通过配置显式开启。
配置示例(.env / settings
DWS_RETENTION_ENABLED=true
DWS_RETENTION_LAYER=LAST_3_MONTHS
DWS_RETENTION_TABLES=dws_finance_daily_summary,dws_assistant_daily_detail
DWS_RETENTION_TABLE_LAYERS={"dws_finance_expense_summary":"ALL"}
作者ETL团队
创建日期2026-02-03
"""
from __future__ import annotations
import json
from datetime import date
from typing import Any, Dict, List, Optional
from .base_dws_task import BaseDwsTask, TaskContext, TimeLayer
class DwsRetentionCleanupTask(BaseDwsTask):
"""
DWS 时间分层清理任务
"""
DEFAULT_TABLES = [
{"table": "dws_assistant_daily_detail", "date_col": "stat_date"},
{"table": "dws_assistant_monthly_summary", "date_col": "stat_month"},
{"table": "dws_assistant_customer_stats", "date_col": "stat_date"},
{"table": "dws_assistant_salary_calc", "date_col": "salary_month"},
{"table": "dws_assistant_recharge_commission", "date_col": "commission_month"},
{"table": "dws_assistant_finance_analysis", "date_col": "stat_date"},
{"table": "dws_member_consumption_summary", "date_col": "stat_date"},
{"table": "dws_member_visit_detail", "date_col": "visit_date"},
{"table": "dws_finance_daily_summary", "date_col": "stat_date"},
{"table": "dws_finance_income_structure", "date_col": "stat_date"},
{"table": "dws_finance_discount_detail", "date_col": "stat_date"},
{"table": "dws_finance_recharge_summary", "date_col": "stat_date"},
{"table": "dws_finance_expense_summary", "date_col": "expense_month"},
{"table": "dws_platform_settlement", "date_col": "settlement_date"},
]
def get_task_code(self) -> str:
return "DWS_RETENTION_CLEANUP"
def get_target_table(self) -> str:
return "dws_finance_daily_summary"
def get_primary_keys(self) -> List[str]:
return []
def extract(self, context: TaskContext) -> Dict[str, Any]:
return {"site_id": context.store_id}
def transform(self, extracted: Dict[str, Any], context: TaskContext) -> Dict[str, Any]:
return extracted
def load(self, transformed: Dict[str, Any], context: TaskContext) -> Dict:
"""
执行清理逻辑
"""
if not self._is_retention_enabled():
self.logger.info("%s: 未启用清理配置,跳过", self.get_task_code())
return {"counts": {"cleaned": 0}}
base_date = context.window_end.date() if hasattr(context.window_end, "date") else context.window_end
default_layer = self._get_retention_layer(self.config.get("dws.retention.layer", "ALL"))
if default_layer is None:
self.logger.warning("%s: 未识别的清理层级,跳过", self.get_task_code())
return {"counts": {"cleaned": 0}}
target_tables = self._resolve_target_tables()
if not target_tables:
self.logger.info("%s: 未配置需要清理的表,跳过", self.get_task_code())
return {"counts": {"cleaned": 0}}
table_layers = self._resolve_table_layers()
total_deleted = 0
details = []
for item in target_tables:
table = item["table"]
date_col = item["date_col"]
layer_name = table_layers.get(table, default_layer.value)
layer = self._get_retention_layer(layer_name)
if layer is None or layer == TimeLayer.ALL:
continue
time_range = self.get_time_layer_range(layer, base_date)
cutoff = self._normalize_cutoff(date_col, time_range.start)
deleted = self._cleanup_table(table, date_col, cutoff, context.store_id)
total_deleted += deleted
details.append({"table": table, "deleted": deleted, "cutoff": str(cutoff)})
self.logger.info("%s: 清理完成,总删除 %d", self.get_task_code(), total_deleted)
return {"counts": {"cleaned": total_deleted}, "extra": {"details": details}}
def _is_retention_enabled(self) -> bool:
return bool(self.config.get("dws.retention.enabled", False))
def _get_retention_layer(self, layer_name: Optional[str]) -> Optional[TimeLayer]:
if not layer_name:
return None
name = str(layer_name).upper()
try:
return TimeLayer[name]
except KeyError:
return None
def _resolve_target_tables(self) -> List[Dict[str, str]]:
table_list = self.config.get("dws.retention.tables")
if not table_list:
return self.DEFAULT_TABLES
if isinstance(table_list, str):
names = [t.strip() for t in table_list.split(",") if t.strip()]
else:
names = list(table_list)
selected = []
for item in self.DEFAULT_TABLES:
if item["table"] in names:
selected.append(item)
return selected
def _resolve_table_layers(self) -> Dict[str, str]:
raw = self.config.get("dws.retention.table_layers")
if not raw:
return {}
if isinstance(raw, dict):
return {str(k): str(v) for k, v in raw.items()}
if isinstance(raw, str):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
return {str(k): str(v) for k, v in parsed.items()}
except json.JSONDecodeError:
return {}
return {}
def _normalize_cutoff(self, date_col: str, cutoff: date) -> date:
monthly_cols = {"stat_month", "salary_month", "commission_month", "expense_month"}
if date_col in monthly_cols:
return cutoff.replace(day=1)
return cutoff
def _cleanup_table(self, table: str, date_col: str, cutoff: date, site_id: int) -> int:
full_table = f"{self.DWS_SCHEMA}.{table}"
sql = f"DELETE FROM {full_table} WHERE site_id = %s AND {date_col} < %s"
with self.db.conn.cursor() as cur:
cur.execute(sql, (site_id, cutoff))
return cur.rowcount
__all__ = ["DwsRetentionCleanupTask"]

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""ODS 层抓取任务"""

View File

@@ -0,0 +1,260 @@
# -*- coding: utf-8 -*-
"""在线抓取 ODS 相关接口并落盘为 JSON用于后续离线回放/入库)。"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from api.client import APIClient
from models.parsers import TypeParser
from utils.json_store import dump_json, endpoint_to_filename
from tasks.base_task import BaseTask, TaskContext
@dataclass(frozen=True)
class EndpointSpec:
endpoint: str
window_style: str # site | start_end | range | pay | none
data_path: tuple[str, ...] = ("data",)
list_key: str | None = None
class OdsJsonArchiveTask(BaseTask):
"""
抓取一组 ODS 所需接口并落盘为“简化 JSON”
{"code": 0, "data": [...records...]}
说明:
- 该输出格式与 tasks/manual_ingest_task.py 的解析逻辑兼容;
- 默认每页一个文件,避免单文件过大;
- 结算小票(/Order/GetOrderSettleTicketNew按 orderSettleId 分文件写入。
"""
ENDPOINTS: tuple[EndpointSpec, ...] = (
EndpointSpec("/MemberProfile/GetTenantMemberList", "site", list_key="tenantMemberInfos"),
EndpointSpec("/MemberProfile/GetTenantMemberCardList", "site", list_key="tenantMemberCards"),
EndpointSpec("/MemberProfile/GetMemberCardBalanceChange", "start_end"),
EndpointSpec("/PersonnelManagement/SearchAssistantInfo", "site", list_key="assistantInfos"),
EndpointSpec(
"/AssistantPerformance/GetOrderAssistantDetails",
"start_end",
list_key="orderAssistantDetails",
),
EndpointSpec(
"/AssistantPerformance/GetAbolitionAssistant",
"start_end",
list_key="abolitionAssistants",
),
EndpointSpec("/Table/GetSiteTables", "site", list_key="siteTables"),
EndpointSpec(
"/TenantGoodsCategory/QueryPrimarySecondaryCategory",
"site",
list_key="goodsCategoryList",
),
EndpointSpec("/TenantGoods/QueryTenantGoods", "site", list_key="tenantGoodsList"),
EndpointSpec("/TenantGoods/GetGoodsInventoryList", "site", list_key="orderGoodsList"),
EndpointSpec("/TenantGoods/GetGoodsStockReport", "site"),
EndpointSpec("/TenantGoods/GetGoodsSalesList", "start_end", list_key="orderGoodsLedgers"),
EndpointSpec(
"/PackageCoupon/QueryPackageCouponList",
"site",
list_key="packageCouponList",
),
EndpointSpec("/Site/GetSiteTableUseDetails", "start_end", list_key="siteTableUseDetailsList"),
EndpointSpec("/Site/GetSiteTableOrderDetails", "start_end", list_key="siteTableUseDetailsList"),
EndpointSpec("/Site/GetTaiFeeAdjustList", "start_end", list_key="taiFeeAdjustInfos"),
EndpointSpec(
"/GoodsStockManage/QueryGoodsOutboundReceipt",
"start_end",
list_key="queryDeliveryRecordsList",
),
EndpointSpec("/Promotion/GetOfflineCouponConsumePageList", "start_end"),
EndpointSpec("/Order/GetRefundPayLogList", "start_end"),
EndpointSpec("/Site/GetAllOrderSettleList", "range", list_key="settleList"),
EndpointSpec("/Site/GetRechargeSettleList", "range", list_key="settleList"),
EndpointSpec("/PayLog/GetPayLogListPage", "pay"),
)
TICKET_ENDPOINT = "/Order/GetOrderSettleTicketNew"
def get_task_code(self) -> str:
return "ODS_JSON_ARCHIVE"
def extract(self, context: TaskContext) -> dict:
base_client = getattr(self.api, "base", None) or self.api
if not isinstance(base_client, APIClient):
raise TypeError("ODS_JSON_ARCHIVE 需要 APIClient在线抓取")
output_dir = getattr(self.api, "output_dir", None)
if output_dir:
out = Path(output_dir)
else:
out = Path(self.config.get("pipeline.fetch_root") or self.config["pipeline"]["fetch_root"])
out.mkdir(parents=True, exist_ok=True)
write_pretty = bool(self.config.get("io.write_pretty_json", False))
page_size = int(self.config.get("api.page_size", 200) or 200)
store_id = int(context.store_id)
total_records = 0
ticket_ids: set[int] = set()
per_endpoint: list[dict] = []
self.logger.info(
"ODS_JSON_ARCHIVE: 开始抓取,窗口[%s ~ %s] 输出目录=%s",
context.window_start,
context.window_end,
out,
)
for spec in self.ENDPOINTS:
self.logger.info("ODS_JSON_ARCHIVE: 抓取 endpoint=%s", spec.endpoint)
built_params = self._build_params(
spec.window_style, store_id, context.window_start, context.window_end
)
# /TenantGoods/GetGoodsInventoryList 要求 siteId 为数组(标量会触发服务端异常,返回畸形状态行 HTTP/1.1 1400
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
built_params["siteId"] = [store_id]
params = self._merge_common_params(built_params)
base_filename = endpoint_to_filename(spec.endpoint)
stem = Path(base_filename).stem
suffix = Path(base_filename).suffix or ".json"
endpoint_records = 0
endpoint_pages = 0
endpoint_error: str | None = None
try:
for page_no, records, _, _ in base_client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
endpoint_pages += 1
total_records += len(records)
endpoint_records += len(records)
if spec.endpoint == "/PayLog/GetPayLogListPage":
for rec in records or []:
relate_id = TypeParser.parse_int(
(rec or {}).get("relateId")
or (rec or {}).get("orderSettleId")
or (rec or {}).get("order_settle_id")
)
if relate_id:
ticket_ids.add(relate_id)
out_path = out / f"{stem}__p{int(page_no):04d}{suffix}"
dump_json(out_path, {"code": 0, "data": records}, pretty=write_pretty)
except Exception as exc: # noqa: BLE001
endpoint_error = f"{type(exc).__name__}: {exc}"
self.logger.error("ODS_JSON_ARCHIVE: 接口抓取失败 endpoint=%s err=%s", spec.endpoint, endpoint_error)
per_endpoint.append(
{
"endpoint": spec.endpoint,
"file_stem": stem,
"pages": endpoint_pages,
"records": endpoint_records,
"error": endpoint_error,
}
)
if endpoint_error:
self.logger.warning(
"ODS_JSON_ARCHIVE: endpoint=%s 完成失败pages=%s records=%s err=%s",
spec.endpoint,
endpoint_pages,
endpoint_records,
endpoint_error,
)
else:
self.logger.info(
"ODS_JSON_ARCHIVE: endpoint=%s 完成 pages=%s records=%s",
spec.endpoint,
endpoint_pages,
endpoint_records,
)
# 小票详情:按 orderSettleId 获取
ticket_ids_sorted = sorted(ticket_ids)
self.logger.info("ODS_JSON_ARCHIVE: 小票候选数=%s", len(ticket_ids_sorted))
ticket_file_stem = Path(endpoint_to_filename(self.TICKET_ENDPOINT)).stem
ticket_file_suffix = Path(endpoint_to_filename(self.TICKET_ENDPOINT)).suffix or ".json"
ticket_records = 0
for order_settle_id in ticket_ids_sorted:
params = self._merge_common_params({"orderSettleId": int(order_settle_id)})
try:
records, _ = base_client.get_paginated(
endpoint=self.TICKET_ENDPOINT,
params=params,
page_size=None,
data_path=("data",),
list_key=None,
)
if not records:
continue
ticket_records += len(records)
out_path = out / f"{ticket_file_stem}__{int(order_settle_id)}{ticket_file_suffix}"
dump_json(out_path, {"code": 0, "data": records}, pretty=write_pretty)
except Exception as exc: # noqa: BLE001
self.logger.error(
"ODS_JSON_ARCHIVE: 小票抓取失败 orderSettleId=%s err=%s",
order_settle_id,
exc,
)
continue
total_records += ticket_records
manifest = {
"task": self.get_task_code(),
"store_id": store_id,
"window_start": context.window_start.isoformat(),
"window_end": context.window_end.isoformat(),
"page_size": page_size,
"total_records": total_records,
"ticket_ids": len(ticket_ids_sorted),
"ticket_records": ticket_records,
"endpoints": per_endpoint,
}
manifest_path = out / "manifest.json"
dump_json(manifest_path, manifest, pretty=True)
if hasattr(self.api, "last_dump"):
try:
self.api.last_dump = {"file": str(manifest_path), "records": total_records, "pages": None}
except Exception:
pass
self.logger.info("ODS_JSON_ARCHIVE: 抓取完成,总记录数=%s(含小票=%s", total_records, ticket_records)
return {"fetched": total_records, "ticket_ids": len(ticket_ids_sorted)}
def _build_params(self, window_style: str, store_id: int, window_start, window_end) -> dict:
if window_style == "none":
return {}
if window_style == "site":
return {"siteId": store_id}
if window_style == "range":
return {
"siteId": store_id,
"rangeStartTime": TypeParser.format_timestamp(window_start, self.tz),
"rangeEndTime": TypeParser.format_timestamp(window_end, self.tz),
}
if window_style == "pay":
return {
"siteId": store_id,
"StartPayTime": TypeParser.format_timestamp(window_start, self.tz),
"EndPayTime": TypeParser.format_timestamp(window_end, self.tz),
}
# 默认使用 startTime/endTime
return {
"siteId": store_id,
"startTime": TypeParser.format_timestamp(window_start, self.tz),
"endTime": TypeParser.format_timestamp(window_end, self.tz),
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""工具类任务Schema 初始化、手动入库、数据完整性检查等)"""

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
"""Task: report last successful cursor cutoff times from etl_admin."""
from __future__ import annotations
from typing import Any
from tasks.base_task import BaseTask
class CheckCutoffTask(BaseTask):
"""Report per-task cursor cutoff times (etl_admin.etl_cursor.last_end)."""
def get_task_code(self) -> str:
return "CHECK_CUTOFF"
def execute(self, cursor_data: dict | None = None) -> dict:
store_id = int(self.config.get("app.store_id"))
filter_codes = self.config.get("run.cutoff_task_codes") or None
if isinstance(filter_codes, str):
filter_codes = [c.strip().upper() for c in filter_codes.split(",") if c.strip()]
sql = """
SELECT
t.task_code,
c.last_start,
c.last_end,
c.last_id,
c.last_run_id,
c.updated_at
FROM etl_admin.etl_task t
LEFT JOIN etl_admin.etl_cursor c
ON c.task_id = t.task_id AND c.store_id = t.store_id
WHERE t.store_id = %s
AND t.enabled = TRUE
ORDER BY t.task_code
"""
rows = self.db.query(sql, (store_id,))
if filter_codes:
wanted = {str(c).upper() for c in filter_codes}
rows = [r for r in rows if str(r.get("task_code", "")).upper() in wanted]
def _ts(v: Any) -> str:
return "-" if not v else str(v)
self.logger.info("截止时间检查: 门店ID=%s 启用任务数=%s", store_id, len(rows))
for r in rows:
self.logger.info(
"截止时间检查: %-24s 结束时间=%s 开始时间=%s 运行ID=%s",
str(r.get("task_code") or ""),
_ts(r.get("last_end")),
_ts(r.get("last_start")),
_ts(r.get("last_run_id")),
)
cutoff_candidates = [
r.get("last_end")
for r in rows
if r.get("last_end") is not None and not str(r.get("task_code", "")).upper().startswith("INIT_")
]
cutoff = min(cutoff_candidates) if cutoff_candidates else None
self.logger.info("截止时间检查: 总体截止时间(最小结束时间,排除INIT_*)=%s", _ts(cutoff))
ods_fetched = self._probe_ods_fetched_at(store_id)
if ods_fetched:
non_null = [v["max_fetched_at"] for v in ods_fetched.values() if v.get("max_fetched_at") is not None]
ods_cutoff = min(non_null) if non_null else None
self.logger.info("截止时间检查: ODS截止时间(最小抓取时间)=%s", _ts(ods_cutoff))
worst = sorted(
((k, v.get("max_fetched_at")) for k, v in ods_fetched.items()),
key=lambda kv: (kv[1] is None, kv[1]),
)[:8]
for table, mx in worst:
self.logger.info("截止时间检查: ODS表=%s 最大抓取时间=%s", table, _ts(mx))
dw_checks = self._probe_dw_time_columns()
for name, value in dw_checks.items():
self.logger.info("截止时间检查: %s=%s", name, _ts(value))
return {
"status": "SUCCESS",
"counts": {"fetched": len(rows), "inserted": 0, "updated": 0, "skipped": 0, "errors": 0},
"window": None,
"request_params": {"store_id": store_id, "filter_task_codes": filter_codes or []},
"report": {
"rows": rows,
"overall_cutoff": cutoff,
"ods_fetched_at": ods_fetched,
"dw_max_times": dw_checks,
},
}
def _probe_ods_fetched_at(self, store_id: int) -> dict[str, dict[str, Any]]:
try:
from tasks.dwd.dwd_load_task import DwdLoadTask # local import to avoid circulars
except Exception:
return {}
ods_tables = sorted({str(t) for t in DwdLoadTask.TABLE_MAP.values() if str(t).startswith("billiards_ods.")})
results: dict[str, dict[str, Any]] = {}
for table in ods_tables:
try:
row = self.db.query(f"SELECT MAX(fetched_at) AS mx, COUNT(*) AS cnt FROM {table}")[0]
results[table] = {"max_fetched_at": row.get("mx"), "count": row.get("cnt")}
except Exception as exc: # noqa: BLE001
results[table] = {"max_fetched_at": None, "count": None, "error": str(exc)}
return results
def _probe_dw_time_columns(self) -> dict[str, Any]:
checks: dict[str, Any] = {}
probes = {
"DWD.max_settlement_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_settlement_head",
"DWD.max_payment_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_payment",
"DWD.max_refund_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_refund",
"DWS.max_order_date": "SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary",
"DWS.max_updated_at": "SELECT MAX(updated_at) AS mx FROM billiards_dws.dws_order_summary",
}
for name, sql2 in probes.items():
try:
row = self.db.query(sql2)[0]
checks[name] = row.get("mx")
except Exception as exc: # noqa: BLE001
checks[name] = f"ERROR: {exc}"
return checks

View File

@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
"""Data integrity task that checks API -> ODS -> DWD completeness."""
from __future__ import annotations
from datetime import datetime
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from utils.windowing import build_window_segments, calc_window_minutes
from tasks.base_task import BaseTask
from quality.integrity_service import run_history_flow, run_window_flow, write_report
class DataIntegrityTask(BaseTask):
"""Check data completeness across API -> ODS -> DWD."""
def get_task_code(self) -> str:
return "DATA_INTEGRITY_CHECK"
def execute(self, cursor_data: dict | None = None) -> dict:
tz = ZoneInfo(self.config.get("app.timezone", "Asia/Shanghai"))
mode = str(self.config.get("integrity.mode", "history") or "history").lower()
include_dimensions = bool(self.config.get("integrity.include_dimensions", False))
task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip()
auto_backfill = bool(self.config.get("integrity.auto_backfill", False))
compare_content = self.config.get("integrity.compare_content")
if compare_content is None:
compare_content = True
content_sample_limit = self.config.get("integrity.content_sample_limit")
backfill_mismatch = self.config.get("integrity.backfill_mismatch")
if backfill_mismatch is None:
backfill_mismatch = True
recheck_after_backfill = self.config.get("integrity.recheck_after_backfill")
if recheck_after_backfill is None:
recheck_after_backfill = True
# 当提供 CLI 覆盖参数时,切换到窗口模式。
window_override_start = self.config.get("run.window_override.start")
window_override_end = self.config.get("run.window_override.end")
if window_override_start or window_override_end:
self.logger.info(
"Detected CLI window override. Switching to window mode: %s ~ %s",
window_override_start,
window_override_end,
)
mode = "window"
if mode == "window":
base_start, base_end, _ = self._get_time_window(cursor_data)
segments = build_window_segments(
self.config,
base_start,
base_end,
tz=tz,
override_only=True,
)
if not segments:
segments = [(base_start, base_end)]
total_segments = len(segments)
if total_segments > 1:
self.logger.info("Data integrity check split into %s segments.", total_segments)
report, counts = run_window_flow(
cfg=self.config,
windows=segments,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
do_backfill=bool(auto_backfill),
include_mismatch=bool(backfill_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(self.config.get("api.page_size") or 200),
chunk_size=500,
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
report_path = write_report(report, prefix="data_integrity_window", tz=tz)
report["report_path"] = report_path
return {
"status": "SUCCESS",
"counts": counts,
"window": {
"start": overall_start,
"end": overall_end,
"minutes": calc_window_minutes(overall_start, overall_end),
},
"report_path": report_path,
"backfill_result": report.get("backfill_result"),
}
history_start = str(self.config.get("integrity.history_start", "2025-07-01") or "2025-07-01")
history_end = str(self.config.get("integrity.history_end", "") or "").strip()
start_dt = dtparser.parse(history_start)
if start_dt.tzinfo is None:
start_dt = start_dt.replace(tzinfo=tz)
else:
start_dt = start_dt.astimezone(tz)
end_dt = None
if history_end:
end_dt = dtparser.parse(history_end)
if end_dt.tzinfo is None:
end_dt = end_dt.replace(tzinfo=tz)
else:
end_dt = end_dt.astimezone(tz)
report, counts = run_history_flow(
cfg=self.config,
start_dt=start_dt,
end_dt=end_dt,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
do_backfill=bool(auto_backfill),
include_mismatch=bool(backfill_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(self.config.get("api.page_size") or 200),
chunk_size=500,
)
report_path = write_report(report, prefix="data_integrity_history", tz=tz)
report["report_path"] = report_path
end_dt_used = end_dt
if end_dt_used is None:
end_str = report.get("end")
if end_str:
parsed = dtparser.parse(end_str)
if parsed.tzinfo is None:
end_dt_used = parsed.replace(tzinfo=tz)
else:
end_dt_used = parsed.astimezone(tz)
if end_dt_used is None:
end_dt_used = start_dt
return {
"status": "SUCCESS",
"counts": counts,
"window": {
"start": start_dt,
"end": end_dt_used,
"minutes": int((end_dt_used - start_dt).total_seconds() // 60) if end_dt_used > start_dt else 0,
},
"report_path": report_path,
"backfill_result": report.get("backfill_result"),
}

View File

@@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
"""Build DWS order summary table from DWD fact tables."""
from __future__ import annotations
from datetime import date
from typing import Any
from tasks.base_task import BaseTask, TaskContext
from utils.windowing import build_window_segments, calc_window_minutes
# 原先从 scripts.rebuild.build_dws_order_summary 导入脚本已归档SQL 内联于此
SQL_BUILD_SUMMARY = r"""
WITH base AS (
SELECT
sh.site_id,
sh.order_settle_id,
sh.order_trade_no,
COALESCE(sh.pay_time, sh.create_time)::date AS order_date,
sh.tenant_id,
sh.member_id,
COALESCE(sh.is_bind_member, FALSE) AS member_flag,
(COALESCE(sh.consume_money, 0) = 0 AND COALESCE(sh.pay_amount, 0) > 0) AS recharge_order_flag,
COALESCE(sh.member_discount_amount, 0) AS member_discount_amount,
COALESCE(sh.adjust_amount, 0) AS manual_discount_amount,
COALESCE(sh.pay_amount, 0) AS total_paid_amount,
COALESCE(sh.balance_amount, 0) + COALESCE(sh.recharge_card_amount, 0) + COALESCE(sh.gift_card_amount, 0) AS stored_card_deduct,
COALESCE(sh.coupon_amount, 0) AS total_coupon_deduction,
COALESCE(sh.table_charge_money, 0) AS settle_table_fee_amount,
COALESCE(sh.assistant_pd_money, 0) + COALESCE(sh.assistant_cx_money, 0) AS settle_assistant_service_amount,
COALESCE(sh.real_goods_money, 0) AS settle_goods_amount
FROM billiards_dwd.dwd_settlement_head sh
WHERE (%(site_id)s IS NULL OR sh.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR COALESCE(sh.pay_time, sh.create_time)::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR COALESCE(sh.pay_time, sh.create_time)::date <= %(end_date)s)
),
table_fee AS (
SELECT
site_id,
order_settle_id,
SUM(COALESCE(real_table_charge_money, 0)) AS table_fee_amount
FROM billiards_dwd.dwd_table_fee_log
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_use_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_use_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
assistant_fee AS (
SELECT
site_id,
order_settle_id,
SUM(COALESCE(ledger_amount, 0)) AS assistant_service_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_use_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_use_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
goods_fee AS (
SELECT
site_id,
order_settle_id,
COUNT(*) AS item_count,
SUM(COALESCE(ledger_count, 0)) AS total_item_quantity,
SUM(COALESCE(real_goods_money, 0)) AS goods_amount
FROM billiards_dwd.dwd_store_goods_sale
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR create_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR create_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
group_fee AS (
SELECT
site_id,
order_settle_id,
SUM(COALESCE(ledger_amount, 0)) AS group_amount
FROM billiards_dwd.dwd_groupbuy_redemption
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR create_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR create_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
refunds AS (
SELECT
r.site_id,
r.relate_id AS order_settle_id,
SUM(COALESCE(rx.refund_amount, 0)) AS refund_amount
FROM billiards_dwd.dwd_refund r
LEFT JOIN billiards_dwd.dwd_refund_ex rx ON r.refund_id = rx.refund_id
WHERE (%(site_id)s IS NULL OR r.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR r.pay_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR r.pay_time::date <= %(end_date)s)
GROUP BY r.site_id, r.relate_id
)
INSERT INTO billiards_dws.dws_order_summary (
site_id, order_settle_id, order_trade_no, order_date, tenant_id,
member_id, member_flag, recharge_order_flag,
item_count, total_item_quantity,
table_fee_amount, assistant_service_amount, goods_amount, group_amount,
total_coupon_deduction, member_discount_amount, manual_discount_amount,
order_original_amount, order_final_amount,
stored_card_deduct, external_paid_amount, total_paid_amount,
book_table_flow, book_assistant_flow, book_goods_flow, book_group_flow, book_order_flow,
order_effective_consume_cash, order_effective_recharge_cash, order_effective_flow,
refund_amount, net_income, created_at, updated_at
)
SELECT
b.site_id, b.order_settle_id, b.order_trade_no::text, b.order_date, b.tenant_id,
b.member_id, b.member_flag, b.recharge_order_flag,
COALESCE(gf.item_count, 0),
COALESCE(gf.total_item_quantity, 0),
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount),
COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount),
COALESCE(gf.goods_amount, b.settle_goods_amount),
COALESCE(gr.group_amount, 0),
b.total_coupon_deduction, b.member_discount_amount, b.manual_discount_amount,
(b.total_paid_amount + b.total_coupon_deduction + b.member_discount_amount + b.manual_discount_amount),
b.total_paid_amount,
b.stored_card_deduct,
GREATEST(b.total_paid_amount - b.stored_card_deduct, 0),
b.total_paid_amount,
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount),
COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount),
COALESCE(gf.goods_amount, b.settle_goods_amount),
COALESCE(gr.group_amount, 0),
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount)
+ COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount)
+ COALESCE(gf.goods_amount, b.settle_goods_amount)
+ COALESCE(gr.group_amount, 0),
GREATEST(b.total_paid_amount - b.stored_card_deduct, 0),
0,
b.total_paid_amount,
COALESCE(rf.refund_amount, 0),
b.total_paid_amount - COALESCE(rf.refund_amount, 0),
now(), now()
FROM base b
LEFT JOIN table_fee tf ON b.site_id = tf.site_id AND b.order_settle_id = tf.order_settle_id
LEFT JOIN assistant_fee af ON b.site_id = af.site_id AND b.order_settle_id = af.order_settle_id
LEFT JOIN goods_fee gf ON b.site_id = gf.site_id AND b.order_settle_id = gf.order_settle_id
LEFT JOIN group_fee gr ON b.site_id = gr.site_id AND b.order_settle_id = gr.order_settle_id
LEFT JOIN refunds rf ON b.site_id = rf.site_id AND b.order_settle_id = rf.order_settle_id
ON CONFLICT (site_id, order_settle_id) DO UPDATE SET
order_trade_no = EXCLUDED.order_trade_no,
order_date = EXCLUDED.order_date,
tenant_id = EXCLUDED.tenant_id,
member_id = EXCLUDED.member_id,
member_flag = EXCLUDED.member_flag,
recharge_order_flag = EXCLUDED.recharge_order_flag,
item_count = EXCLUDED.item_count,
total_item_quantity = EXCLUDED.total_item_quantity,
table_fee_amount = EXCLUDED.table_fee_amount,
assistant_service_amount = EXCLUDED.assistant_service_amount,
goods_amount = EXCLUDED.goods_amount,
group_amount = EXCLUDED.group_amount,
total_coupon_deduction = EXCLUDED.total_coupon_deduction,
member_discount_amount = EXCLUDED.member_discount_amount,
manual_discount_amount = EXCLUDED.manual_discount_amount,
order_original_amount = EXCLUDED.order_original_amount,
order_final_amount = EXCLUDED.order_final_amount,
stored_card_deduct = EXCLUDED.stored_card_deduct,
external_paid_amount = EXCLUDED.external_paid_amount,
total_paid_amount = EXCLUDED.total_paid_amount,
book_table_flow = EXCLUDED.book_table_flow,
book_assistant_flow = EXCLUDED.book_assistant_flow,
book_goods_flow = EXCLUDED.book_goods_flow,
book_group_flow = EXCLUDED.book_group_flow,
book_order_flow = EXCLUDED.book_order_flow,
order_effective_consume_cash = EXCLUDED.order_effective_consume_cash,
order_effective_recharge_cash = EXCLUDED.order_effective_recharge_cash,
order_effective_flow = EXCLUDED.order_effective_flow,
refund_amount = EXCLUDED.refund_amount,
net_income = EXCLUDED.net_income,
updated_at = now();
"""
class DwsBuildOrderSummaryTask(BaseTask):
"""Recompute/refresh `billiards_dws.dws_order_summary` for a date window."""
def get_task_code(self) -> str:
return "DWS_BUILD_ORDER_SUMMARY"
def execute(self, cursor_data: dict | None = None) -> dict:
base_context = self._build_context(cursor_data)
task_code = self.get_task_code()
segments = build_window_segments(
self.config,
base_context.window_start,
base_context.window_end,
tz=self.tz,
override_only=True,
)
if not segments:
segments = [(base_context.window_start, base_context.window_end)]
total_segments = len(segments)
if total_segments > 1:
self.logger.info("%s: 分段执行 共%s", task_code, total_segments)
total_counts: dict = {}
segment_results: list[dict] = []
request_params_list: list[dict] = []
total_deleted = 0
for idx, (window_start, window_end) in enumerate(segments, start=1):
context = self._build_context_for_window(window_start, window_end, cursor_data)
self.logger.info(
"%s: 开始执行(%s/%s), 窗口[%s ~ %s]",
task_code,
idx,
total_segments,
context.window_start,
context.window_end,
)
try:
extracted = self.extract(context)
transformed = self.transform(extracted, context)
load_result = self.load(transformed, context) or {}
self.db.commit()
except Exception:
self.db.rollback()
self.logger.error("%s: 执行失败", task_code, exc_info=True)
raise
counts = load_result.get("counts") or {}
self._accumulate_counts(total_counts, counts)
extra = load_result.get("extra") or {}
deleted = int(extra.get("deleted") or 0)
total_deleted += deleted
request_params = load_result.get("request_params")
if request_params:
request_params_list.append(request_params)
if total_segments > 1:
segment_results.append(
{
"window": {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
},
"counts": counts,
"extra": extra,
}
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
result = {"status": "SUCCESS", "counts": total_counts}
result["window"] = {
"start": overall_start,
"end": overall_end,
"minutes": calc_window_minutes(overall_start, overall_end),
}
if segment_results:
result["segments"] = segment_results
if request_params_list:
result["request_params"] = request_params_list[0] if len(request_params_list) == 1 else request_params_list
if total_deleted:
result["extra"] = {"deleted": total_deleted}
self.logger.info("%s: 完成, 统计=%s", task_code, total_counts)
return result
def extract(self, context: TaskContext) -> dict[str, Any]:
store_id = int(self.config.get("app.store_id"))
full_refresh = bool(self.config.get("dws.order_summary.full_refresh", False))
site_id = self.config.get("dws.order_summary.site_id", store_id)
if site_id in ("", None, "null", "NULL"):
site_id = None
start_date = self.config.get("dws.order_summary.start_date")
end_date = self.config.get("dws.order_summary.end_date")
if not full_refresh:
if not start_date:
start_date = context.window_start.date()
if not end_date:
end_date = context.window_end.date()
else:
start_date = None
end_date = None
delete_before_insert = bool(self.config.get("dws.order_summary.delete_before_insert", True))
return {
"site_id": site_id,
"start_date": start_date,
"end_date": end_date,
"full_refresh": full_refresh,
"delete_before_insert": delete_before_insert,
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
sql_params = {
"site_id": extracted["site_id"],
"start_date": extracted["start_date"],
"end_date": extracted["end_date"],
}
request_params = {
"site_id": extracted["site_id"],
"start_date": _jsonable_date(extracted["start_date"]),
"end_date": _jsonable_date(extracted["end_date"]),
}
with self.db.conn.cursor() as cur:
cur.execute("SELECT to_regclass('billiards_dws.dws_order_summary') AS reg;")
row = cur.fetchone()
reg = row[0] if row else None
if not reg:
raise RuntimeError("DWS 表不存在:请先运行任务 INIT_DWS_SCHEMA")
deleted = 0
if extracted["delete_before_insert"]:
if extracted["full_refresh"] and extracted["site_id"] is None:
cur.execute("TRUNCATE TABLE billiards_dws.dws_order_summary;")
self.logger.info("DWS订单汇总: 已清空 billiards_dws.dws_order_summary")
else:
delete_sql = "DELETE FROM billiards_dws.dws_order_summary WHERE 1=1"
delete_args: list[Any] = []
if extracted["site_id"] is not None:
delete_sql += " AND site_id = %s"
delete_args.append(extracted["site_id"])
if extracted["start_date"] is not None:
delete_sql += " AND order_date >= %s"
delete_args.append(_as_date(extracted["start_date"]))
if extracted["end_date"] is not None:
delete_sql += " AND order_date <= %s"
delete_args.append(_as_date(extracted["end_date"]))
cur.execute(delete_sql, delete_args)
deleted = cur.rowcount
self.logger.info("DWS订单汇总: 删除=%s 语句=%s", deleted, delete_sql)
cur.execute(SQL_BUILD_SUMMARY, sql_params)
affected = cur.rowcount
return {
"counts": {"fetched": 0, "inserted": affected, "updated": 0, "skipped": 0, "errors": 0},
"request_params": request_params,
"extra": {"deleted": deleted},
}
def _as_date(v: Any) -> date:
if isinstance(v, date):
return v
return date.fromisoformat(str(v))
def _jsonable_date(v: Any):
if v is None:
return None
if isinstance(v, date):
return v.isoformat()
return str(v)

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
"""初始化 DWD Schema执行 schema_dwd_doc.sql可选先 DROP SCHEMA。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class InitDwdSchemaTask(BaseTask):
"""通过调度执行 DWD schema 初始化。"""
def get_task_code(self) -> str:
"""返回任务编码。"""
return "INIT_DWD_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""读取 DWD SQL 文件与参数。"""
base_dir = Path(__file__).resolve().parents[1] / "database"
dwd_path = Path(self.config.get("schema.dwd_file", base_dir / "schema_dwd_doc.sql"))
if not dwd_path.exists():
raise FileNotFoundError(f"未找到 DWD schema 文件: {dwd_path}")
drop_first = self.config.get("dwd.drop_schema_first", False)
return {"dwd_sql": dwd_path.read_text(encoding="utf-8"), "dwd_file": str(dwd_path), "drop_first": drop_first}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
"""可选 DROP schema再执行 DWD DDL。"""
with self.db.conn.cursor() as cur:
if extracted["drop_first"]:
cur.execute("DROP SCHEMA IF EXISTS billiards_dwd CASCADE;")
self.logger.info("已执行 DROP SCHEMA billiards_dwd CASCADE")
self.logger.info("执行 DWD schema 文件: %s", extracted["dwd_file"])
cur.execute(extracted["dwd_sql"])
return {"executed": 1, "files": [extracted["dwd_file"]]}

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""Initialize DWS schema (billiards_dws)."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class InitDwsSchemaTask(BaseTask):
"""Apply DWS schema SQL."""
def get_task_code(self) -> str:
return "INIT_DWS_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
base_dir = Path(__file__).resolve().parents[1] / "database"
dws_path = Path(self.config.get("schema.dws_file", base_dir / "schema_dws.sql"))
if not dws_path.exists():
raise FileNotFoundError(f"未找到 DWS schema 文件: {dws_path}")
drop_first = bool(self.config.get("dws.drop_schema_first", False))
return {"dws_sql": dws_path.read_text(encoding="utf-8"), "dws_file": str(dws_path), "drop_first": drop_first}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
with self.db.conn.cursor() as cur:
if extracted["drop_first"]:
cur.execute("DROP SCHEMA IF EXISTS billiards_dws CASCADE;")
self.logger.info("已执行 DROP SCHEMA billiards_dws CASCADE")
self.logger.info("执行 DWS schema 文件: %s", extracted["dws_file"])
cur.execute(extracted["dws_sql"])
return {"executed": 1, "files": [extracted["dws_file"]]}

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""任务:初始化运行环境,执行 ODS 与 etl_admin 的 DDL并准备日志/导出目录。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class InitOdsSchemaTask(BaseTask):
"""通过调度执行初始化:创建必要目录,执行 ODS 与 etl_admin 的 DDL。"""
def get_task_code(self) -> str:
"""返回任务编码。"""
return "INIT_ODS_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""读取 SQL 文件路径,收集需创建的目录。"""
base_dir = Path(__file__).resolve().parents[1] / "database"
ods_path = Path(self.config.get("schema.ods_file", base_dir / "schema_ODS_doc.sql"))
admin_path = Path(self.config.get("schema.etl_admin_file", base_dir / "schema_etl_admin.sql"))
if not ods_path.exists():
raise FileNotFoundError(f"找不到 ODS schema 文件: {ods_path}")
if not admin_path.exists():
raise FileNotFoundError(f"找不到 etl_admin schema 文件: {admin_path}")
log_root = Path(self.config.get("io.log_root") or self.config["io"]["log_root"])
export_root = Path(self.config.get("io.export_root") or self.config["io"]["export_root"])
fetch_root = Path(self.config.get("pipeline.fetch_root") or self.config["pipeline"]["fetch_root"])
ingest_dir = Path(self.config.get("pipeline.ingest_source_dir") or fetch_root)
return {
"ods_sql": ods_path.read_text(encoding="utf-8"),
"admin_sql": admin_path.read_text(encoding="utf-8"),
"ods_file": str(ods_path),
"admin_file": str(admin_path),
"dirs": [log_root, export_root, fetch_root, ingest_dir],
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
"""执行 DDL 并创建必要目录。
安全提示:
ODS DDL 文件可能携带头部说明或异常注释,为避免因非 SQL 文本导致执行失败,这里会做一次轻量清洗后再执行。
"""
for d in extracted["dirs"]:
Path(d).mkdir(parents=True, exist_ok=True)
self.logger.info("已确保目录存在: %s", d)
# 处理 ODS SQL去掉头部说明行以及易出错的 COMMENT ON 行(如 CamelCase 未加引号)
ods_sql_raw: str = extracted["ods_sql"]
drop_idx = ods_sql_raw.find("DROP SCHEMA")
if drop_idx > 0:
ods_sql_raw = ods_sql_raw[drop_idx:]
cleaned_lines: list[str] = []
for line in ods_sql_raw.splitlines():
if line.strip().upper().startswith("COMMENT ON "):
continue
cleaned_lines.append(line)
ods_sql = "\n".join(cleaned_lines)
with self.db.conn.cursor() as cur:
self.logger.info("执行 etl_admin schema 文件: %s", extracted["admin_file"])
cur.execute(extracted["admin_sql"])
self.logger.info("执行 ODS schema 文件: %s", extracted["ods_file"])
cur.execute(ods_sql)
return {
"executed": 2,
"files": [extracted["admin_file"], extracted["ods_file"]],
"dirs_prepared": [str(p) for p in extracted["dirs"]],
}

View File

@@ -0,0 +1,463 @@
# -*- coding: utf-8 -*-
"""手工示例数据灌入:按 schema_ODS_doc.sql 的表结构写入 ODS。"""
from __future__ import annotations
import hashlib
import json
import os
from datetime import datetime
from typing import Any, Iterable
from psycopg2.extras import Json, execute_values
from tasks.base_task import BaseTask
class ManualIngestTask(BaseTask):
"""本地示例 JSON 灌入 ODS确保表名/主键/插入列与 schema_ODS_doc.sql 对齐。"""
FILE_MAPPING: list[tuple[tuple[str, ...], str]] = [
(("member_profiles",), "billiards_ods.member_profiles"),
(("member_balance_changes",), "billiards_ods.member_balance_changes"),
(("member_stored_value_cards",), "billiards_ods.member_stored_value_cards"),
(("recharge_settlements",), "billiards_ods.recharge_settlements"),
(("settlement_records",), "billiards_ods.settlement_records"),
(("assistant_cancellation_records",), "billiards_ods.assistant_cancellation_records"),
(("assistant_accounts_master",), "billiards_ods.assistant_accounts_master"),
(("assistant_service_records",), "billiards_ods.assistant_service_records"),
(("site_tables_master",), "billiards_ods.site_tables_master"),
(("table_fee_discount_records",), "billiards_ods.table_fee_discount_records"),
(("table_fee_transactions",), "billiards_ods.table_fee_transactions"),
(("goods_stock_movements",), "billiards_ods.goods_stock_movements"),
(("stock_goods_category_tree",), "billiards_ods.stock_goods_category_tree"),
(("goods_stock_summary",), "billiards_ods.goods_stock_summary"),
(("payment_transactions",), "billiards_ods.payment_transactions"),
(("refund_transactions",), "billiards_ods.refund_transactions"),
(("platform_coupon_redemption_records",), "billiards_ods.platform_coupon_redemption_records"),
(("group_buy_redemption_records",), "billiards_ods.group_buy_redemption_records"),
(("group_buy_packages",), "billiards_ods.group_buy_packages"),
(("settlement_ticket_details",), "billiards_ods.settlement_ticket_details"),
(("store_goods_master",), "billiards_ods.store_goods_master"),
(("tenant_goods_master",), "billiards_ods.tenant_goods_master"),
(("store_goods_sales_records",), "billiards_ods.store_goods_sales_records"),
]
TABLE_SPECS: dict[str, dict[str, Any]] = {
"billiards_ods.member_profiles": {"pk": "id"},
"billiards_ods.member_balance_changes": {"pk": "id"},
"billiards_ods.member_stored_value_cards": {"pk": "id"},
"billiards_ods.recharge_settlements": {"pk": "id"},
"billiards_ods.settlement_records": {"pk": "id"},
"billiards_ods.assistant_cancellation_records": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.assistant_accounts_master": {"pk": "id"},
"billiards_ods.assistant_service_records": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.site_tables_master": {"pk": "id"},
"billiards_ods.table_fee_discount_records": {"pk": "id", "json_cols": ["siteProfile", "tableProfile"]},
"billiards_ods.table_fee_transactions": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.goods_stock_movements": {"pk": "siteGoodsStockId"},
"billiards_ods.stock_goods_category_tree": {"pk": "id", "json_cols": ["categoryBoxes"]},
"billiards_ods.goods_stock_summary": {"pk": "siteGoodsId"},
"billiards_ods.payment_transactions": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.refund_transactions": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.platform_coupon_redemption_records": {"pk": "id"},
"billiards_ods.tenant_goods_master": {"pk": "id"},
"billiards_ods.group_buy_packages": {"pk": "id"},
"billiards_ods.group_buy_redemption_records": {"pk": "id"},
"billiards_ods.settlement_ticket_details": {
"pk": "orderSettleId",
"json_cols": ["memberProfile", "orderItem", "tenantMemberCardLogs"],
},
"billiards_ods.store_goods_master": {"pk": "id"},
"billiards_ods.store_goods_sales_records": {"pk": "id"},
}
def get_task_code(self) -> str:
"""返回任务编码。"""
return "MANUAL_INGEST"
def execute(self, cursor_data: dict | None = None) -> dict:
"""从目录读取 JSON按表定义批量入库按文件提交事务避免长事务导致连接不稳定"""
data_dir = (
self.config.get("manual.data_dir")
or self.config.get("pipeline.ingest_source_dir")
or os.path.join("tests", "testdata_json")
)
if not os.path.exists(data_dir):
self.logger.error("Data directory not found: %s", data_dir)
return {"status": "error", "message": "Directory not found"}
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
include_files_cfg = self.config.get("manual.include_files") or []
include_files = {str(x).strip().lower() for x in include_files_cfg if str(x).strip()} if include_files_cfg else set()
for filename in sorted(os.listdir(data_dir)):
if not filename.endswith(".json"):
continue
stem = os.path.splitext(filename)[0].lower()
if include_files and stem not in include_files:
continue
filepath = os.path.join(data_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as fh:
raw_entries = json.load(fh)
except Exception:
counts["errors"] += 1
self.logger.exception("Failed to read %s", filename)
continue
entries = raw_entries if isinstance(raw_entries, list) else [raw_entries]
records = self._extract_records(entries)
if not records:
counts["skipped"] += 1
continue
target_table = self._match_by_filename(filename)
if not target_table:
self.logger.warning("No mapping found for file: %s", filename)
counts["skipped"] += 1
continue
self.logger.info("Ingesting %s into %s", filename, target_table)
try:
inserted, updated, row_errors = self._ingest_table(target_table, records, filename)
counts["inserted"] += inserted
counts["updated"] += updated
counts["fetched"] += len(records)
counts["errors"] += row_errors
# 每个文件一次提交:降低单次事务体积,避免长事务/连接异常导致整体回滚失败。
self.db.commit()
except Exception:
counts["errors"] += 1
self.logger.exception("Error processing %s", filename)
try:
self.db.rollback()
except Exception:
pass
# 若连接已断开,后续文件无法继续,直接抛出让上层处理(重连/重跑)。
if getattr(self.db.conn, "closed", 0):
raise
continue
return {"status": "SUCCESS", "counts": counts}
def _match_by_filename(self, filename: str) -> str | None:
"""根据文件名关键字匹配目标表。"""
for keywords, table in self.FILE_MAPPING:
if any(keyword and keyword in filename for keyword in keywords):
return table
return None
def _extract_records(self, raw_entries: Iterable[Any]) -> list[dict]:
"""兼容多层 data/list 包装,抽取记录列表。"""
records: list[dict] = []
for entry in raw_entries:
if isinstance(entry, dict):
preferred = entry
if "data" in entry and not any(k not in {"data", "code"} for k in entry.keys()):
preferred = entry["data"]
data = preferred
if isinstance(data, dict):
# 特殊处理 settleList充值、结算记录展开 data.settleList 下的 settleList抛弃上层 siteProfile
if "settleList" in data:
settle_list_val = data.get("settleList")
if isinstance(settle_list_val, dict):
settle_list_iter = [settle_list_val]
elif isinstance(settle_list_val, list):
settle_list_iter = settle_list_val
else:
settle_list_iter = []
handled = False
for item in settle_list_iter or []:
if not isinstance(item, dict):
continue
inner = item.get("settleList")
merged = dict(inner) if isinstance(inner, dict) else dict(item)
# 保留 siteProfile 供后续字段补充,但不落库
site_profile = data.get("siteProfile")
if isinstance(site_profile, dict):
merged.setdefault("siteProfile", site_profile)
records.append(merged)
handled = True
if handled:
continue
list_used = False
for v in data.values():
if isinstance(v, list) and v and isinstance(v[0], dict):
records.extend(v)
list_used = True
break
if list_used:
continue
if isinstance(data, list) and data and isinstance(data[0], dict):
records.extend(data)
elif isinstance(data, dict):
records.append(data)
elif isinstance(entry, list):
records.extend([item for item in entry if isinstance(item, dict)])
return records
def _get_table_columns(self, table: str) -> list[tuple[str, str, str]]:
"""查询 information_schema获取目标表列信息。"""
cache = getattr(self, "_table_columns_cache", {})
if table in cache:
return cache[table]
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
cache[table] = cols
self._table_columns_cache = cache
return cols
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int, int]:
"""
构建 INSERT/ON CONFLICT 语句并批量执行(优先向量化,小批次提交)。
设计目标:
- 控制单条 SQL 体积(避免一次性 VALUES 过大导致服务端 backend 被 OOM/异常终止);
- 发生异常时,可降级逐行并用 SAVEPOINT 跳过异常行;
- 统计口径偏“尽量可跑通”,插入/更新计数为近似值(不强依赖 RETURNING
"""
spec = self.TABLE_SPECS.get(table)
if not spec:
raise ValueError(f"No table spec for {table}")
pk_col = spec.get("pk")
json_cols = set(spec.get("json_cols", []))
json_cols_lower = {c.lower() for c in json_cols}
columns_info = self._get_table_columns(table)
columns = [c[0] for c in columns_info]
db_json_cols_lower = {
c[0].lower() for c in columns_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
pk_col_db = None
if pk_col:
pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col)
pk_index = None
if pk_col_db:
try:
pk_index = next(i for i, c in enumerate(columns_info) if c[0] == pk_col_db)
except Exception:
pk_index = None
has_content_hash = any(c[0].lower() == "content_hash" for c in columns_info)
col_list = ", ".join(f'"{c}"' for c in columns)
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
if pk_col_db:
if has_content_hash:
sql_prefix += f' ON CONFLICT ("{pk_col_db}", "content_hash") DO NOTHING'
else:
update_cols = [c for c in columns if c != pk_col_db]
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
params = []
now = datetime.now()
json_dump = lambda v: json.dumps(v, ensure_ascii=False) # noqa: E731
for rec in records:
merged_rec = rec if isinstance(rec, dict) else {}
data_part = merged_rec.get("data")
while isinstance(data_part, dict):
merged_rec = {**data_part, **merged_rec}
data_part = data_part.get("data")
# 针对充值/结算,补齐 siteProfile 中的店铺信息
if table in {
"billiards_ods.recharge_settlements",
"billiards_ods.settlement_records",
}:
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile")
if isinstance(site_profile, dict):
merged_rec.setdefault("tenantid", site_profile.get("tenant_id") or site_profile.get("tenantId"))
merged_rec.setdefault("siteid", site_profile.get("id") or site_profile.get("siteId"))
merged_rec.setdefault("sitename", site_profile.get("shop_name") or site_profile.get("siteName"))
pk_val = self._get_value_case_insensitive(merged_rec, pk_col) if pk_col else None
if pk_col and (pk_val is None or pk_val == ""):
continue
content_hash = None
if has_content_hash:
# Keep hash semantics aligned with ODS task ingestion:
# fetched_at is ETL metadata and should not create a new content version.
content_hash = self._compute_content_hash(merged_rec, include_fetched_at=False)
row_vals = []
for col_name, data_type, udt in columns_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append(source_file)
continue
if col_lower == "fetched_at":
row_vals.append(merged_rec.get(col_name, now))
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
if col_lower in json_cols_lower or col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
casted = self._cast_value(value, data_type)
row_vals.append(casted)
params.append(tuple(row_vals))
if not params:
return 0, 0, 0
# 先尝试向量化执行(速度快);若失败,再降级逐行并用 SAVEPOINT 跳过异常行。
try:
with self.db.conn.cursor() as cur:
# 分批提交:降低单次事务/单次 SQL 压力,避免服务端异常中断连接。
affected = 0
chunk_size = int(self.config.get("manual.execute_values_page_size", 50) or 50)
chunk_size = max(1, min(chunk_size, 500))
for i in range(0, len(params), chunk_size):
chunk = params[i : i + chunk_size]
execute_values(cur, sql_prefix, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
affected += int(cur.rowcount)
# 这里无法精确拆分 inserted/updated除非 RETURNING按“受影响行数≈插入”近似返回。
return int(affected), 0, 0
except Exception as exc:
self.logger.warning("批量入库失败准备降级逐行处理table=%s, err=%s", table, exc)
try:
self.db.rollback()
except Exception:
pass
inserted = 0
updated = 0
errors = 0
with self.db.conn.cursor() as cur:
for row in params:
cur.execute("SAVEPOINT sp_manual_ingest_row")
try:
cur.execute(sql_prefix.replace(" VALUES %s", f" VALUES ({', '.join(['%s'] * len(row))})"), row)
inserted += 1
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception as exc: # noqa: BLE001
errors += 1
try:
cur.execute("ROLLBACK TO SAVEPOINT sp_manual_ingest_row")
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception:
pass
pk_val = None
if pk_index is not None:
try:
pk_val = row[pk_index]
except Exception:
pk_val = None
self.logger.warning("跳过异常行table=%s pk=%s err=%s", table, pk_val, exc)
return inserted, updated, errors
@staticmethod
def _get_value_case_insensitive(record: dict, col: str | None):
"""忽略大小写获取值,兼容 information_schema 与 JSON 原始字段。"""
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
@staticmethod
def _normalize_scalar(value):
"""将空字符串/空 JSON 规范为 None避免类型转换错误。"""
if value == "" or value == "{}" or value == "[]":
return None
return value
@staticmethod
def _cast_value(value, data_type: str):
"""根据列类型做简单转换,保证批量插入兼容。"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, str) else None
return value
@staticmethod
def _hash_default(value):
if isinstance(value, datetime):
return value.isoformat()
return str(value)
@classmethod
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
exclude = {
"data",
"payload",
"source_file",
"source_endpoint",
"content_hash",
"record_index",
}
if not include_fetched_at:
exclude.add("fetched_at")
def _strip(value):
if isinstance(value, dict):
cleaned = {}
for k, v in value.items():
if isinstance(k, str) and k.lower() in exclude:
continue
cleaned[k] = _strip(v)
return cleaned
if isinstance(value, list):
return [_strip(v) for v in value]
return value
return _strip(record or {})
@classmethod
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
payload = json.dumps(
cleaned,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=cls._hash_default,
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""
DWS配置数据初始化任务
功能说明:
执行 seed_dws_config.sql向配置表插入初始数据
执行前提:
- billiards_dws schema 已创建INIT_DWS_SCHEMA
- 配置表已存在
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class SeedDwsConfigTask(BaseTask):
"""
DWS配置数据初始化任务
执行 seed_dws_config.sql 文件,向以下配置表插入初始数据:
- cfg_performance_tier: 绩效档位配置
- cfg_assistant_level_price: 助教等级定价
- cfg_bonus_rules: 奖金规则配置
- cfg_area_category: 台区分类映射
- cfg_skill_type: 技能课程类型映射
"""
def get_task_code(self) -> str:
return "SEED_DWS_CONFIG"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""
读取配置数据SQL文件
"""
base_dir = Path(__file__).resolve().parents[1] / "database"
seed_path = Path(self.config.get("schema.seed_dws_file", base_dir / "seed_dws_config.sql"))
if not seed_path.exists():
raise FileNotFoundError(f"未找到 DWS 配置数据文件: {seed_path}")
return {
"seed_sql": seed_path.read_text(encoding="utf-8"),
"seed_file": str(seed_path)
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
"""
执行配置数据SQL
"""
with self.db.conn.cursor() as cur:
self.logger.info("执行 DWS 配置数据文件: %s", extracted["seed_file"])
cur.execute(extracted["seed_sql"])
self.logger.info("DWS 配置数据初始化完成")
return {"executed": 1, "files": [extracted["seed_file"]]}

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""批量后置校验框架
提供各层数据的批量校验和补齐功能:
- ODS 层:主键 + content_hash 对比,批量 UPSERT
- DWD 层:维度 SCD2 / 事实主键对比,批量 UPSERT
- DWS 层:聚合对比,批量重算 UPSERT
- INDEX 层:实体覆盖对比,批量重算 UPSERT
"""
from .models import (
VerificationResult,
VerificationSummary,
VerificationStatus,
WindowSegment,
build_window_segments,
filter_verify_tables,
)
from .base_verifier import BaseVerifier
from .ods_verifier import OdsVerifier
from .dwd_verifier import DwdVerifier
from .dws_verifier import DwsVerifier
from .index_verifier import IndexVerifier
__all__ = [
# 模型
"VerificationResult",
"VerificationSummary",
"VerificationStatus",
"WindowSegment",
"build_window_segments",
"filter_verify_tables",
# 校验器
"BaseVerifier",
"OdsVerifier",
"DwdVerifier",
"DwsVerifier",
"IndexVerifier",
]
def get_verifier_for_layer(layer: str, db_connection, logger=None, **kwargs):
"""
根据层名获取对应的校验器实例
Args:
layer: 层名 ("ODS", "DWD", "DWS", "INDEX")
db_connection: 数据库连接
logger: 日志器
**kwargs: 额外参数
- api_client: API 客户端ODS 层需要)
- fetch_from_api: 是否从 API 获取源数据ODS 层需要)
- local_dump_dirs: 本地 JSON dump 目录映射ODS 层需要)
- use_local_json: 是否优先使用本地 JSONODS 层需要)
Returns:
对应的校验器实例
"""
verifier_map = {
"ODS": OdsVerifier,
"DWD": DwdVerifier,
"DWS": DwsVerifier,
"INDEX": IndexVerifier,
}
verifier_class = verifier_map.get(layer.upper())
if verifier_class is None:
raise ValueError(f"未知的数据层: {layer}")
# ODS 层支持额外参数
if layer.upper() == "ODS":
api_client = kwargs.pop("api_client", None)
fetch_from_api = kwargs.pop("fetch_from_api", False)
local_dump_dirs = kwargs.pop("local_dump_dirs", None)
use_local_json = kwargs.pop("use_local_json", False)
return verifier_class(
db_connection,
api_client=api_client,
logger=logger,
fetch_from_api=fetch_from_api,
local_dump_dirs=local_dump_dirs,
use_local_json=use_local_json,
**kwargs
)
return verifier_class(db_connection, logger=logger, **kwargs)

View File

@@ -0,0 +1,382 @@
# -*- coding: utf-8 -*-
"""批量校验基类"""
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from .models import (
VerificationResult,
VerificationSummary,
VerificationStatus,
WindowSegment,
build_window_segments,
)
class VerificationFetchError(RuntimeError):
"""校验数据获取失败(用于显式标记 ERROR"""
class BaseVerifier(ABC):
"""批量校验基类
提供统一的校验流程:
1. 切分时间窗口
2. 批量读取源数据
3. 批量读取目标数据
4. 内存对比
5. 批量补齐
"""
def __init__(
self,
db_connection: Any,
logger: Optional[logging.Logger] = None,
):
"""
初始化校验器
Args:
db_connection: 数据库连接
logger: 日志器
"""
self.db = db_connection
self.logger = logger or logging.getLogger(self.__class__.__name__)
@property
@abstractmethod
def layer_name(self) -> str:
"""数据层名称"""
pass
@abstractmethod
def get_tables(self) -> List[str]:
"""获取需要校验的表列表"""
pass
@abstractmethod
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
pass
@abstractmethod
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列(用于窗口过滤)"""
pass
@abstractmethod
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""批量获取源数据主键集合"""
pass
@abstractmethod
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""批量获取目标数据主键集合"""
pass
@abstractmethod
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""批量获取源数据主键->内容哈希映射"""
pass
@abstractmethod
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""批量获取目标数据主键->内容哈希映射"""
pass
@abstractmethod
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量补齐缺失数据,返回补齐的记录数"""
pass
@abstractmethod
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量更新不一致数据,返回更新的记录数"""
pass
def verify_table(
self,
table: str,
window_start: datetime,
window_end: datetime,
auto_backfill: bool = False,
compare_content: bool = True,
) -> VerificationResult:
"""
校验单表
Args:
table: 表名
window_start: 窗口开始
window_end: 窗口结束
auto_backfill: 是否自动补齐
compare_content: 是否对比内容True=对比hashFalse=仅对比主键)
Returns:
校验结果
"""
start_time = time.time()
result = VerificationResult(
layer=self.layer_name,
table=table,
window_start=window_start,
window_end=window_end,
)
try:
# 确保连接可用避免“connection already closed”导致误判 OK
self._ensure_connection()
self.logger.info(
"%s 校验开始: %s [%s ~ %s]",
self.layer_name, table,
window_start.strftime("%Y-%m-%d %H:%M"),
window_end.strftime("%Y-%m-%d %H:%M")
)
if compare_content:
# 对比内容哈希
source_hashes = self.fetch_source_hashes(table, window_start, window_end)
target_hashes = self.fetch_target_hashes(table, window_start, window_end)
result.source_count = len(source_hashes)
result.target_count = len(target_hashes)
source_keys = set(source_hashes.keys())
target_keys = set(target_hashes.keys())
# 计算缺失
missing_keys = source_keys - target_keys
result.missing_count = len(missing_keys)
# 计算不一致两边都有但hash不同
common_keys = source_keys & target_keys
mismatch_keys = {
k for k in common_keys
if source_hashes[k] != target_hashes[k]
}
result.mismatch_count = len(mismatch_keys)
else:
# 仅对比主键
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
result.source_count = len(source_keys)
result.target_count = len(target_keys)
missing_keys = source_keys - target_keys
result.missing_count = len(missing_keys)
mismatch_keys = set()
# 判断状态
if result.missing_count > 0:
result.status = VerificationStatus.MISSING
elif result.mismatch_count > 0:
result.status = VerificationStatus.MISMATCH
else:
result.status = VerificationStatus.OK
# 自动补齐
if auto_backfill and (missing_keys or mismatch_keys):
backfill_missing_count = 0
backfill_mismatch_count = 0
if missing_keys:
self.logger.info(
"%s 补齐缺失: %s, 数量=%d",
self.layer_name, table, len(missing_keys)
)
backfill_missing_count += self.backfill_missing(
table, missing_keys, window_start, window_end
)
if mismatch_keys:
self.logger.info(
"%s 更新不一致: %s, 数量=%d",
self.layer_name, table, len(mismatch_keys)
)
backfill_mismatch_count += self.backfill_mismatch(
table, mismatch_keys, window_start, window_end
)
result.backfilled_missing_count = backfill_missing_count
result.backfilled_mismatch_count = backfill_mismatch_count
result.backfilled_count = backfill_missing_count + backfill_mismatch_count
if result.backfilled_count > 0:
result.status = VerificationStatus.BACKFILLED
self.logger.info(
"%s 校验完成: %s, 源=%d, 目标=%d, 缺失=%d, 不一致=%d, 补齐=%d(缺失=%d, 不一致=%d)",
self.layer_name, table,
result.source_count, result.target_count,
result.missing_count, result.mismatch_count, result.backfilled_count,
result.backfilled_missing_count, result.backfilled_mismatch_count
)
except Exception as e:
result.status = VerificationStatus.ERROR
result.error_message = str(e)
if isinstance(e, VerificationFetchError):
# 连接不可用等致命错误,标记后续应中止
result.details["fatal"] = True
self.logger.exception("%s 校验失败: %s, error=%s", self.layer_name, table, e)
# 回滚事务,避免 PostgreSQL "当前事务被终止" 错误影响后续查询
try:
self.db.conn.rollback()
except Exception:
pass # 忽略回滚错误
result.elapsed_seconds = time.time() - start_time
return result
def verify_and_backfill(
self,
window_start: datetime,
window_end: datetime,
split_unit: str = "month",
tables: Optional[List[str]] = None,
auto_backfill: bool = True,
compare_content: bool = True,
) -> VerificationSummary:
"""
按时间窗口切分执行批量校验
Args:
window_start: 开始时间
window_end: 结束时间
split_unit: 切分单位 ("none", "day", "week", "month")
tables: 指定校验的表None 表示全部
auto_backfill: 是否自动补齐
compare_content: 是否对比内容
Returns:
校验汇总结果
"""
summary = VerificationSummary(
layer=self.layer_name,
window_start=window_start,
window_end=window_end,
)
# 获取要校验的表
all_tables = tables or self.get_tables()
# 切分时间窗口
segments = build_window_segments(window_start, window_end, split_unit)
self.logger.info(
"%s 批量校验开始: 表数=%d, 窗口切分=%d",
self.layer_name, len(all_tables), len(segments)
)
fatal_error = False
for segment in segments:
# 每段开始前检查连接状态,异常时立即终止,避免大量空跑
self._ensure_connection()
self.logger.info(
"%s 处理窗口 [%d/%d]: %s",
self.layer_name, segment.index + 1, segment.total, segment.label
)
for table in all_tables:
result = self.verify_table(
table=table,
window_start=segment.start,
window_end=segment.end,
auto_backfill=auto_backfill,
compare_content=compare_content,
)
summary.add_result(result)
if result.details.get("fatal"):
fatal_error = True
break
# 每段完成后提交
try:
self.db.commit()
except Exception as e:
self.logger.warning("提交失败: %s", e)
if fatal_error:
self.logger.warning("%s 校验中止:连接不可用或发生致命错误", self.layer_name)
break
self.logger.info(summary.format_summary())
return summary
def _ensure_connection(self):
"""确保数据库连接可用,必要时尝试重连。"""
if not hasattr(self.db, "conn"):
raise VerificationFetchError("校验器未绑定有效数据库连接")
if getattr(self.db.conn, "closed", 0):
# 优先使用连接对象的重连能力
if hasattr(self.db, "ensure_open"):
if not self.db.ensure_open():
raise VerificationFetchError("数据库连接已关闭,无法继续校验")
else:
raise VerificationFetchError("数据库连接已关闭,无法继续校验")
def quick_check(
self,
window_start: datetime,
window_end: datetime,
tables: Optional[List[str]] = None,
) -> Dict[str, dict]:
"""
快速检查(仅对比数量,不对比内容)
Args:
window_start: 开始时间
window_end: 结束时间
tables: 指定表None 表示全部
Returns:
{表名: {source_count, target_count, diff}}
"""
all_tables = tables or self.get_tables()
results = {}
for table in all_tables:
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
results[table] = {
"source_count": len(source_keys),
"target_count": len(target_keys),
"diff": len(source_keys) - len(target_keys),
}
return results

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,455 @@
# -*- coding: utf-8 -*-
"""DWS 汇总层批量校验器
校验逻辑:对比 DWD 聚合数据与 DWS 表数据
- 按日期/门店聚合对比
- 对比数值一致性
- 批量重算 UPSERT 补齐
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_verifier import BaseVerifier, VerificationFetchError
class DwsVerifier(BaseVerifier):
"""DWS 汇总层校验器"""
def __init__(
self,
db_connection: Any,
logger: Optional[logging.Logger] = None,
):
"""
初始化 DWS 校验器
Args:
db_connection: 数据库连接
logger: 日志器
"""
super().__init__(db_connection, logger)
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "DWS"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 DWS 汇总表配置"""
# DWS 汇总表通常有以下结构:
# - 主键site_id, stat_date 或类似组合
# - 数值列:各种统计值
# - 源表:对应的 DWD 事实表
return {
# 财务日度汇总表 - 包含结算、台费、商品、助教等汇总数据
# 注意:实际 DWS 表使用 gross_amount, table_fee_amount, goods_amount 等列
"dws_finance_daily_summary": {
"pk_columns": ["site_id", "stat_date"],
"time_column": "stat_date",
"source_table": "billiards_dwd.dwd_settlement_head",
"source_time_column": "pay_time",
"agg_sql": """
SELECT
site_id,
tenant_id,
DATE(pay_time) as stat_date,
COALESCE(SUM(pay_amount), 0) as cash_pay_amount,
COALESCE(SUM(table_charge_money), 0) as table_fee_amount,
COALESCE(SUM(goods_money), 0) as goods_amount,
COALESCE(SUM(table_charge_money) + SUM(goods_money) + COALESCE(SUM(assistant_pd_money), 0) + COALESCE(SUM(assistant_cx_money), 0), 0) as gross_amount
FROM billiards_dwd.dwd_settlement_head
WHERE pay_time >= %s AND pay_time < %s
GROUP BY site_id, tenant_id, DATE(pay_time)
""",
"compare_columns": ["cash_pay_amount", "table_fee_amount", "goods_amount", "gross_amount"],
},
# 助教日度明细表 - 按助教+日期汇总服务次数、时长、金额
# 注意DWD 表中使用 site_assistant_idDWS 表中使用 assistant_id
"dws_assistant_daily_detail": {
"pk_columns": ["site_id", "assistant_id", "stat_date"],
"time_column": "stat_date",
"source_table": "billiards_dwd.dwd_assistant_service_log",
"source_time_column": "start_use_time",
"agg_sql": """
SELECT
site_id,
tenant_id,
site_assistant_id as assistant_id,
DATE(start_use_time) as stat_date,
COUNT(*) as total_service_count,
COALESCE(SUM(income_seconds), 0) as total_seconds,
COALESCE(SUM(ledger_amount), 0) as total_ledger_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE start_use_time >= %s AND start_use_time < %s
AND is_delete = 0
GROUP BY site_id, tenant_id, site_assistant_id, DATE(start_use_time)
""",
"compare_columns": ["total_service_count", "total_seconds", "total_ledger_amount"],
},
# 会员来店明细表 - 按会员+订单记录每次来店消费
# 注意DWD 表主键是 order_settle_id不是 id
"dws_member_visit_detail": {
"pk_columns": ["site_id", "member_id", "order_settle_id"],
"time_column": "visit_date",
"source_table": "billiards_dwd.dwd_settlement_head",
"source_time_column": "pay_time",
"agg_sql": """
SELECT
site_id,
tenant_id,
member_id,
order_settle_id,
DATE(pay_time) as visit_date,
COALESCE(table_charge_money, 0) as table_fee,
COALESCE(goods_money, 0) as goods_amount,
COALESCE(pay_amount, 0) as actual_pay
FROM billiards_dwd.dwd_settlement_head
WHERE pay_time >= %s AND pay_time < %s
AND member_id > 0
""",
"compare_columns": ["table_fee", "goods_amount", "actual_pay"],
},
}
def get_tables(self) -> List[str]:
"""获取需要校验的 DWS 汇总表列表"""
if self._table_config:
return list(self._table_config.keys())
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_dws'
AND table_type = 'BASE TABLE'
AND table_name LIKE 'dws_%'
AND table_name NOT LIKE 'cfg_%'
ORDER BY table_name
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
except Exception as e:
self.logger.warning("获取 DWS 表列表失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return []
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
if table in self._table_config:
return self._table_config[table].get("pk_columns", ["site_id", "stat_date"])
return ["site_id", "stat_date"]
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列"""
if table in self._table_config:
return self._table_config[table].get("time_column", "stat_date")
return "stat_date"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 DWD 聚合获取源数据主键集合"""
config = self._table_config.get(table, {})
agg_sql = config.get("agg_sql")
if not agg_sql:
return set()
pk_cols = self.get_primary_keys(table)
try:
with self.db.conn.cursor() as cur:
cur.execute(agg_sql, (window_start, window_end))
columns = [desc[0] for desc in cur.description]
pk_indices = [columns.index(c) for c in pk_cols if c in columns]
return {tuple(row[i] for i in pk_indices) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 DWD 聚合主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWD 聚合主键失败: {table}") from e
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 DWS 表获取目标数据主键集合"""
pk_cols = self.get_primary_keys(table)
time_col = self.get_time_column(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}
FROM billiards_dws.{table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start.date(), window_end.date()))
return {tuple(row) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 DWS 主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWS 主键失败: {table}") from e
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 DWD 聚合获取数据,返回主键->聚合值字符串"""
config = self._table_config.get(table, {})
agg_sql = config.get("agg_sql")
compare_cols = config.get("compare_columns", [])
if not agg_sql:
return {}
pk_cols = self.get_primary_keys(table)
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(agg_sql, (window_start, window_end))
columns = [desc[0] for desc in cur.description]
pk_indices = [columns.index(c) for c in pk_cols if c in columns]
value_indices = [columns.index(c) for c in compare_cols if c in columns]
for row in cur.fetchall():
pk = tuple(row[i] for i in pk_indices)
values = tuple(row[i] for i in value_indices)
result[pk] = str(values)
except Exception as e:
self.logger.warning("获取 DWD 聚合数据失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWD 聚合数据失败: {table}") from e
return result
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 DWS 表获取数据,返回主键->值字符串"""
config = self._table_config.get(table, {})
compare_cols = config.get("compare_columns", [])
pk_cols = self.get_primary_keys(table)
time_col = self.get_time_column(table)
all_cols = pk_cols + compare_cols
col_select = ", ".join(all_cols)
sql = f"""
SELECT {col_select}
FROM billiards_dws.{table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start.date(), window_end.date()))
for row in cur.fetchall():
pk = tuple(row[:len(pk_cols)])
values = tuple(row[len(pk_cols):])
result[pk] = str(values)
except Exception as e:
self.logger.warning("获取 DWS 数据失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWS 数据失败: {table}") from e
return result
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量补齐缺失数据(重新计算并插入)"""
if not missing_keys:
return 0
self.logger.info(
"DWS 补齐缺失: 表=%s, 数量=%d",
table, len(missing_keys)
)
# 在执行之前确保事务状态干净
try:
self.db.conn.rollback()
except Exception:
pass
# 重新计算汇总数据
return self._recalculate_and_upsert(table, window_start, window_end, missing_keys)
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量更新不一致数据(重新计算并更新)"""
if not mismatch_keys:
return 0
self.logger.info(
"DWS 更新不一致: 表=%s, 数量=%d",
table, len(mismatch_keys)
)
# 在执行之前确保事务状态干净
try:
self.db.conn.rollback()
except Exception:
pass
# 重新计算汇总数据
return self._recalculate_and_upsert(table, window_start, window_end, mismatch_keys)
def _recalculate_and_upsert(
self,
table: str,
window_start: datetime,
window_end: datetime,
target_keys: Optional[Set[Tuple]] = None,
) -> int:
"""重新计算汇总数据并 UPSERT"""
config = self._table_config.get(table, {})
agg_sql = config.get("agg_sql")
if not agg_sql:
return 0
pk_cols = self.get_primary_keys(table)
# 执行聚合查询
try:
with self.db.conn.cursor() as cur:
cur.execute(agg_sql, (window_start, window_end))
columns = [desc[0] for desc in cur.description]
records = [dict(zip(columns, row)) for row in cur.fetchall()]
except Exception as e:
self.logger.error("聚合查询失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return 0
if not records:
return 0
# 如果指定了目标主键,只处理这些记录
if target_keys:
records = [
r for r in records
if tuple(r.get(c) for c in pk_cols) in target_keys
]
if not records:
return 0
# 构建 UPSERT SQL
col_list = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
pk_list = ", ".join(pk_cols)
update_cols = [c for c in columns if c not in pk_cols]
update_set = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols)
upsert_sql = f"""
INSERT INTO billiards_dws.{table} ({col_list})
VALUES ({placeholders})
ON CONFLICT ({pk_list}) DO UPDATE SET {update_set}
"""
count = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(c) for c in columns]
try:
cur.execute(upsert_sql, values)
count += 1
except Exception as e:
self.logger.warning("UPSERT 失败: %s", e)
self.db.commit()
return count
def verify_aggregation(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[str, Any]:
"""
详细校验聚合数据
返回源和目标的详细对比
"""
config = self._table_config.get(table, {})
compare_cols = config.get("compare_columns", [])
source_hashes = self.fetch_source_hashes(table, window_start, window_end)
target_hashes = self.fetch_target_hashes(table, window_start, window_end)
source_keys = set(source_hashes.keys())
target_keys = set(target_hashes.keys())
missing = source_keys - target_keys
extra = target_keys - source_keys
# 对比数值
mismatch_details = []
for key in source_keys & target_keys:
if source_hashes[key] != target_hashes[key]:
mismatch_details.append({
"key": key,
"source": source_hashes[key],
"target": target_hashes[key],
})
return {
"table": table,
"window": f"{window_start.date()} ~ {window_end.date()}",
"source_count": len(source_hashes),
"target_count": len(target_hashes),
"missing_count": len(missing),
"extra_count": len(extra),
"mismatch_count": len(mismatch_details),
"is_consistent": len(missing) == 0 and len(mismatch_details) == 0,
"missing_keys": list(missing)[:10], # 只返回前10个
"mismatch_details": mismatch_details[:10],
}

View File

@@ -0,0 +1,343 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 移除 recall/intimacy 表校验配置
"""INDEX 层批量校验器。"""
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_verifier import BaseVerifier, VerificationFetchError
class IndexVerifier(BaseVerifier):
"""INDEX 层校验器(覆盖率校验 + 重算补齐)。"""
def __init__(
self,
db_connection: Any,
logger: Optional[logging.Logger] = None,
lookback_days: int = 60,
config: Any = None,
):
super().__init__(db_connection, logger)
self.lookback_days = lookback_days
self.config = config
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "INDEX"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 INDEX 表配置。"""
return {
"v_member_recall_priority": {
"pk_columns": ["site_id", "member_id"],
"time_column": "calc_time",
"entity_sql": """
WITH params AS (
SELECT %s::timestamp AS start_time, %s::timestamp AS end_time
),
visit_members AS (
SELECT DISTINCT s.site_id, s.member_id
FROM billiards_dwd.dwd_settlement_head s
CROSS JOIN params p
WHERE s.pay_time >= p.start_time
AND s.pay_time < p.end_time
AND s.member_id > 0
AND (
s.settle_type = 1
OR (
s.settle_type = 3
AND EXISTS (
SELECT 1
FROM billiards_dwd.dwd_assistant_service_log asl
JOIN billiards_dws.cfg_skill_type st
ON asl.skill_id = st.skill_id
AND st.course_type_code = 'BONUS'
AND st.is_active = TRUE
WHERE asl.order_settle_id = s.order_settle_id
AND asl.site_id = s.site_id
AND asl.tenant_member_id = s.member_id
AND asl.is_delete = 0
)
)
)
),
recharge_members AS (
SELECT DISTINCT r.site_id, r.member_id
FROM billiards_dwd.dwd_recharge_order r
CROSS JOIN params p
WHERE r.pay_time >= p.start_time
AND r.pay_time < p.end_time
AND r.member_id > 0
AND r.settle_type = 5
)
SELECT site_id, member_id FROM visit_members
UNION
SELECT site_id, member_id FROM recharge_members
""",
# 该视图由 WBI + NCI 共同产出,缺失时需同时触发两类重算
"task_codes": ["DWS_WINBACK_INDEX", "DWS_NEWCONV_INDEX"],
"description": "客户召回/转化优先级视图",
},
"dws_member_assistant_relation_index": {
"pk_columns": ["site_id", "member_id", "assistant_id"],
"time_column": "calc_time",
"entity_sql": """
WITH params AS (
SELECT %s::timestamp AS start_time, %s::timestamp AS end_time
),
service_pairs AS (
SELECT DISTINCT
s.site_id,
s.tenant_member_id AS member_id,
d.assistant_id
FROM billiards_dwd.dwd_assistant_service_log s
JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id
AND d.scd2_is_current = 1
AND COALESCE(d.is_delete, 0) = 0
CROSS JOIN params p
WHERE s.last_use_time >= p.start_time
AND s.last_use_time < p.end_time
AND s.tenant_member_id > 0
AND s.user_id > 0
AND s.is_delete = 0
),
manual_pairs AS (
SELECT DISTINCT
m.site_id,
m.member_id,
m.assistant_id
FROM billiards_dws.dws_ml_manual_order_alloc m
CROSS JOIN params p
WHERE m.pay_time >= p.start_time
AND m.pay_time < p.end_time
AND m.member_id > 0
AND m.assistant_id > 0
)
SELECT site_id, member_id, assistant_id FROM service_pairs
UNION
SELECT site_id, member_id, assistant_id FROM manual_pairs
""",
"task_code": "DWS_RELATION_INDEX",
"description": "客户-助教关系指数",
},
}
def get_tables(self) -> List[str]:
return list(self._table_config.keys())
def get_primary_keys(self, table: str) -> List[str]:
if table in self._table_config:
return self._table_config[table].get("pk_columns", [])
self.logger.warning("%s 未在 INDEX 校验配置中定义,跳过", table)
return []
def get_time_column(self, table: str) -> Optional[str]:
if table in self._table_config:
return self._table_config[table].get("time_column", "calc_time")
return "calc_time"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
config = self._table_config.get(table, {})
entity_sql = config.get("entity_sql")
if not entity_sql:
return set()
actual_start = window_end - timedelta(days=self.lookback_days)
try:
with self.db.conn.cursor() as cur:
cur.execute(entity_sql, (actual_start, window_end))
return {tuple(row) for row in cur.fetchall()}
except Exception as exc:
self.logger.warning("获取源实体失败: table=%s error=%s", table, exc)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取源实体失败: {table}") from exc
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过目标读取", table)
return set()
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT DISTINCT {pk_select}
FROM billiards_dws.{table}
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return {tuple(row) for row in cur.fetchall()}
except Exception as exc:
self.logger.warning("获取目标实体失败: table=%s error=%s", table, exc)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取目标实体失败: {table}") from exc
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
keys = self.fetch_source_keys(table, window_start, window_end)
return {k: "1" for k in keys}
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
keys = self.fetch_target_keys(table, window_start, window_end)
return {k: "1" for k in keys}
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
if not missing_keys:
return 0
config = self._table_config.get(table, {})
task_codes = config.get("task_codes")
if not task_codes:
task_code = config.get("task_code")
task_codes = [task_code] if task_code else []
if not task_codes:
self.logger.warning("未找到补齐任务配置: table=%s", table)
return 0
self.logger.info(
"INDEX 补齐: table=%s missing=%d task_codes=%s",
table,
len(missing_keys),
",".join(task_codes),
)
try:
self.db.conn.rollback()
except Exception:
pass
try:
task_config = self.config
if task_config is None:
from config.settings import AppConfig
task_config = AppConfig.load()
inserted_total = 0
for task_code in task_codes:
if task_code == "DWS_WINBACK_INDEX":
from tasks.dws.index.winback_index_task import WinbackIndexTask
task = WinbackIndexTask(task_config, self.db, None, self.logger)
elif task_code == "DWS_NEWCONV_INDEX":
from tasks.dws.index.newconv_index_task import NewconvIndexTask
task = NewconvIndexTask(task_config, self.db, None, self.logger)
elif task_code == "DWS_RELATION_INDEX":
from tasks.dws.index.relation_index_task import RelationIndexTask
task = RelationIndexTask(task_config, self.db, None, self.logger)
else:
self.logger.warning("未知 INDEX 任务代码,跳过: %s", task_code)
continue
self.logger.info("执行 INDEX 补齐任务: %s", task_code)
result = task.execute(None)
inserted_total += result.get("records_inserted", 0) + result.get("records_updated", 0)
return inserted_total
except Exception as exc:
self.logger.error("INDEX 补齐失败: %s", exc)
try:
self.db.conn.rollback()
except Exception:
pass
return 0
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
return 0
def verify_coverage(
self,
table: str,
window_end: Optional[datetime] = None,
) -> Dict[str, Any]:
if window_end is None:
window_end = datetime.now()
window_start = window_end - timedelta(days=self.lookback_days)
config = self._table_config.get(table, {})
description = config.get("description", table)
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
missing = source_keys - target_keys
extra = target_keys - source_keys
coverage_rate = len(target_keys & source_keys) / len(source_keys) * 100 if source_keys else 100.0
return {
"table": table,
"description": description,
"lookback_days": self.lookback_days,
"window": f"{window_start.date()} ~ {window_end.date()}",
"source_entities": len(source_keys),
"indexed_entities": len(target_keys),
"missing_count": len(missing),
"extra_count": len(extra),
"coverage_rate": round(coverage_rate, 2),
"is_complete": len(missing) == 0,
"missing_sample": list(missing)[:10],
}
def verify_all_indices(
self,
window_end: Optional[datetime] = None,
) -> Dict[str, dict]:
results = {}
for table in self.get_tables():
results[table] = self.verify_coverage(table, window_end)
return results
def get_missing_entities(
self,
table: str,
limit: int = 100,
window_end: Optional[datetime] = None,
) -> List[Tuple]:
if window_end is None:
window_end = datetime.now()
window_start = window_end - timedelta(days=self.lookback_days)
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
missing = source_keys - target_keys
return list(missing)[:limit]

View File

@@ -0,0 +1,283 @@
# -*- coding: utf-8 -*-
"""校验结果数据模型"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import List, Optional, Dict, Any
class VerificationStatus(Enum):
"""校验状态"""
OK = "OK" # 数据一致
MISSING = "MISSING" # 有缺失数据
MISMATCH = "MISMATCH" # 有不一致数据
BACKFILLED = "BACKFILLED" # 已补齐
ERROR = "ERROR" # 校验出错
@dataclass
class VerificationResult:
"""单表校验结果"""
layer: str # 数据层: "ODS" / "DWD" / "DWS" / "INDEX"
table: str # 表名
window_start: datetime # 校验窗口开始
window_end: datetime # 校验窗口结束
source_count: int = 0 # 源数据量
target_count: int = 0 # 目标数据量
missing_count: int = 0 # 缺失记录数
mismatch_count: int = 0 # 不一致记录数
backfilled_count: int = 0 # 已补齐记录数(缺失 + 不一致)
backfilled_missing_count: int = 0 # 缺失补齐数
backfilled_mismatch_count: int = 0 # 不一致补齐数
status: VerificationStatus = VerificationStatus.OK
elapsed_seconds: float = 0.0 # 耗时(秒)
error_message: Optional[str] = None # 错误信息
details: Dict[str, Any] = field(default_factory=dict) # 额外详情
@property
def is_consistent(self) -> bool:
"""数据是否一致"""
return self.status == VerificationStatus.OK
@property
def needs_backfill(self) -> bool:
"""是否需要补齐"""
return self.missing_count > 0 or self.mismatch_count > 0
def to_dict(self) -> dict:
"""转换为字典"""
return {
"layer": self.layer,
"table": self.table,
"window_start": self.window_start.isoformat() if self.window_start else None,
"window_end": self.window_end.isoformat() if self.window_end else None,
"source_count": self.source_count,
"target_count": self.target_count,
"missing_count": self.missing_count,
"mismatch_count": self.mismatch_count,
"backfilled_count": self.backfilled_count,
"backfilled_missing_count": self.backfilled_missing_count,
"backfilled_mismatch_count": self.backfilled_mismatch_count,
"status": self.status.value,
"elapsed_seconds": self.elapsed_seconds,
"error_message": self.error_message,
"details": self.details,
}
def format_summary(self) -> str:
"""格式化摘要"""
lines = [
f"表: {self.table}",
f"层: {self.layer}",
f"窗口: {self.window_start.strftime('%Y-%m-%d %H:%M')} ~ {self.window_end.strftime('%Y-%m-%d %H:%M')}",
f"源数据量: {self.source_count:,}",
f"目标数据量: {self.target_count:,}",
f"缺失: {self.missing_count:,}",
f"不一致: {self.mismatch_count:,}",
f"缺失补齐: {self.backfilled_missing_count:,}",
f"不一致补齐: {self.backfilled_mismatch_count:,}",
f"已补齐: {self.backfilled_count:,}",
f"状态: {self.status.value}",
f"耗时: {self.elapsed_seconds:.2f}s",
]
if self.error_message:
lines.append(f"错误: {self.error_message}")
return "\n".join(lines)
@dataclass
class VerificationSummary:
"""校验汇总结果"""
layer: str # 数据层
window_start: datetime # 校验窗口开始
window_end: datetime # 校验窗口结束
total_tables: int = 0 # 总表数
consistent_tables: int = 0 # 一致的表数
inconsistent_tables: int = 0 # 不一致的表数
total_source_count: int = 0 # 总源数据量
total_target_count: int = 0 # 总目标数据量
total_missing: int = 0 # 总缺失数
total_mismatch: int = 0 # 总不一致数
total_backfilled: int = 0 # 总补齐数
total_backfilled_missing: int = 0 # 总缺失补齐数
total_backfilled_mismatch: int = 0 # 总不一致补齐数
error_tables: int = 0 # 发生错误的表数
elapsed_seconds: float = 0.0 # 总耗时
results: List[VerificationResult] = field(default_factory=list) # 各表结果
status: VerificationStatus = VerificationStatus.OK
def add_result(self, result: VerificationResult):
"""添加单表结果"""
self.results.append(result)
self.total_tables += 1
self.total_source_count += result.source_count
self.total_target_count += result.target_count
self.total_missing += result.missing_count
self.total_mismatch += result.mismatch_count
self.total_backfilled += result.backfilled_count
self.total_backfilled_missing += result.backfilled_missing_count
self.total_backfilled_mismatch += result.backfilled_mismatch_count
self.elapsed_seconds += result.elapsed_seconds
if result.status == VerificationStatus.ERROR:
self.error_tables += 1
self.inconsistent_tables += 1
# 错误优先级最高,直接覆盖汇总状态
self.status = VerificationStatus.ERROR
elif result.is_consistent:
self.consistent_tables += 1
else:
self.inconsistent_tables += 1
if self.status == VerificationStatus.OK:
self.status = result.status
@property
def is_all_consistent(self) -> bool:
"""是否全部一致"""
return self.inconsistent_tables == 0
def to_dict(self) -> dict:
"""转换为字典"""
return {
"layer": self.layer,
"window_start": self.window_start.isoformat() if self.window_start else None,
"window_end": self.window_end.isoformat() if self.window_end else None,
"total_tables": self.total_tables,
"consistent_tables": self.consistent_tables,
"inconsistent_tables": self.inconsistent_tables,
"total_source_count": self.total_source_count,
"total_target_count": self.total_target_count,
"total_missing": self.total_missing,
"total_mismatch": self.total_mismatch,
"total_backfilled": self.total_backfilled,
"total_backfilled_missing": self.total_backfilled_missing,
"total_backfilled_mismatch": self.total_backfilled_mismatch,
"error_tables": self.error_tables,
"elapsed_seconds": self.elapsed_seconds,
"status": self.status.value,
"results": [r.to_dict() for r in self.results],
}
def format_summary(self) -> str:
"""格式化汇总摘要"""
lines = [
f"{'=' * 60}",
f"校验汇总 - {self.layer}",
f"{'=' * 60}",
f"窗口: {self.window_start.strftime('%Y-%m-%d %H:%M')} ~ {self.window_end.strftime('%Y-%m-%d %H:%M')}",
f"表数: {self.total_tables} (一致: {self.consistent_tables}, 不一致: {self.inconsistent_tables})",
f"源数据量: {self.total_source_count:,}",
f"目标数据量: {self.total_target_count:,}",
f"总缺失: {self.total_missing:,}",
f"总不一致: {self.total_mismatch:,}",
f"总补齐: {self.total_backfilled:,} (缺失: {self.total_backfilled_missing:,}, 不一致: {self.total_backfilled_mismatch:,})",
f"错误表数: {self.error_tables}",
f"总耗时: {self.elapsed_seconds:.2f}s",
f"状态: {self.status.value}",
f"{'=' * 60}",
]
return "\n".join(lines)
@dataclass
class WindowSegment:
"""时间窗口片段"""
start: datetime
end: datetime
index: int = 0
total: int = 1
@property
def label(self) -> str:
"""片段标签"""
return f"{self.start.strftime('%Y-%m-%d')} ~ {self.end.strftime('%Y-%m-%d')}"
def build_window_segments(
window_start: datetime,
window_end: datetime,
split_unit: str = "month",
) -> List[WindowSegment]:
"""
按指定单位切分时间窗口
Args:
window_start: 开始时间
window_end: 结束时间
split_unit: 切分单位 ("none", "day", "week", "month")
Returns:
时间窗口片段列表
"""
if split_unit == "none" or not split_unit:
return [WindowSegment(start=window_start, end=window_end, index=0, total=1)]
segments = []
current = window_start
while current < window_end:
if split_unit == "day":
# 按天切分
next_boundary = current.replace(hour=0, minute=0, second=0, microsecond=0)
next_boundary = next_boundary + timedelta(days=1)
elif split_unit == "week":
# 按周切分(周一为起点)
days_until_monday = (7 - current.weekday()) % 7
if days_until_monday == 0:
days_until_monday = 7
next_boundary = current.replace(hour=0, minute=0, second=0, microsecond=0)
next_boundary = next_boundary + timedelta(days=days_until_monday)
elif split_unit == "month":
# 按月切分
if current.month == 12:
next_boundary = current.replace(year=current.year + 1, month=1, day=1,
hour=0, minute=0, second=0, microsecond=0)
else:
next_boundary = current.replace(month=current.month + 1, day=1,
hour=0, minute=0, second=0, microsecond=0)
else:
# 默认不切分
next_boundary = window_end
segment_end = min(next_boundary, window_end)
segments.append(WindowSegment(start=current, end=segment_end))
current = segment_end
# 更新索引
total = len(segments)
for i, seg in enumerate(segments):
seg.index = i
seg.total = total
return segments
def filter_verify_tables(layer: str, tables: list[str] | None) -> list[str] | None:
"""按层过滤校验表名,避免非目标层全量校验。
Args:
layer: 数据层名称("ODS" / "DWD" / "DWS" / "INDEX"
tables: 待过滤的表名列表,为 None 或空时直接返回 None
Returns:
过滤后的表名列表,或 None
"""
if not tables:
return None
layer_upper = layer.upper()
normalized = [t.strip().lower() for t in tables if t and t.strip()]
if layer_upper == "DWD":
return [t for t in normalized if t.startswith(("dwd_", "dim_", "fact_"))]
if layer_upper == "DWS":
return [t for t in normalized if t.startswith("dws_")]
if layer_upper == "INDEX":
return [t for t in normalized if t.startswith("v_") or t.endswith("_index")]
if layer_upper == "ODS":
return [t for t in normalized if t.startswith("ods_")]
return normalized
# 需要导入 timedelta
from datetime import timedelta

View File

@@ -0,0 +1,871 @@
# -*- coding: utf-8 -*-
"""ODS 层批量校验器
校验逻辑:对比 API 源数据与 ODS 表数据
- 主键 + content_hash 对比
- 批量 UPSERT 补齐缺失/不一致数据
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from psycopg2.extras import execute_values
from api.local_json_client import LocalJsonClient
from .base_verifier import BaseVerifier, VerificationFetchError
class OdsVerifier(BaseVerifier):
"""ODS 层校验器"""
def __init__(
self,
db_connection: Any,
api_client: Any = None,
logger: Optional[logging.Logger] = None,
fetch_from_api: bool = False,
local_dump_dirs: Optional[Dict[str, str]] = None,
use_local_json: bool = False,
):
"""
初始化 ODS 校验器
Args:
db_connection: 数据库连接
api_client: API 客户端(用于重新获取数据)
logger: 日志器
fetch_from_api: 是否从 API 获取源数据进行校验(默认 False仅校验 ODS 内部一致性)
local_dump_dirs: 本地 JSON dump 目录映射task_code -> 目录)
use_local_json: 是否优先使用本地 JSON 作为源数据
"""
super().__init__(db_connection, logger)
self.api_client = api_client
self.fetch_from_api = fetch_from_api
self.local_dump_dirs = local_dump_dirs or {}
self.use_local_json = bool(use_local_json or self.local_dump_dirs)
# 缓存从 API 获取的数据(避免重复调用)
self._api_data_cache: Dict[str, List[dict]] = {}
self._api_key_cache: Dict[str, Set[Tuple]] = {}
self._api_hash_cache: Dict[str, Dict[Tuple, str]] = {}
self._table_column_cache: Dict[Tuple[str, str], bool] = {}
self._table_pk_cache: Dict[str, List[str]] = {}
self._local_json_clients: Dict[str, LocalJsonClient] = {}
# ODS 表配置:{表名: {pk_columns, time_column, api_endpoint}}
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "ODS"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 ODS 表配置"""
# 从任务定义中动态获取配置
try:
from tasks.ods.ods_tasks import ODS_TASK_SPECS
config = {}
for spec in ODS_TASK_SPECS:
# time_fields 是一个元组 (start_field, end_field),取第一个作为时间列
# 或者使用 fetched_at 作为默认
time_column = "fetched_at"
# 使用 table_name 属性(不是 table
table_name = spec.table_name
# 提取不带 schema 前缀的表名作为 key
if "." in table_name:
table_key = table_name.split(".")[-1]
else:
table_key = table_name
# 从 sources 中提取 ODS 表的实际主键列名
# sources 格式如 ("settleList.id", "id"),最后一个简单名称是 ODS 列名
pk_columns = []
for col in spec.pk_columns:
ods_col_name = self._extract_ods_column_name(col)
pk_columns.append(ods_col_name)
# 如果 pk_columns 为空,尝试使用 conflict_columns_override 或跳过校验
# 一些特殊表(如 goods_stock_summary, settlement_ticket_details没有标准主键
if not pk_columns:
# 跳过没有明确主键定义的表
self.logger.debug("%s 没有定义主键列,跳过校验配置", table_key)
continue
config[table_key] = {
"full_table_name": table_name,
"pk_columns": pk_columns,
"time_column": time_column,
"api_endpoint": spec.endpoint,
"task_code": spec.code,
}
return config
except ImportError:
self.logger.warning("无法加载 ODS 任务定义,使用默认配置")
return {}
def _extract_ods_column_name(self, col) -> str:
"""
从 ColumnSpec 中提取 ODS 表的实际列名
ODS 表使用原始 JSON 字段名(小写),而 col.column 是 DWD 层的命名。
sources 中的最后一个简单字段名通常就是 ODS 表的列名。
"""
# 如果 sources 为空,使用 column假设 column 就是 ODS 列名)
if not col.sources:
return col.column
# 遍历 sources找到最简单的字段名不含点号的
for source in reversed(col.sources):
if "." not in source:
return source.lower() # ODS 列名通常是小写
# 如果都是复杂路径,取最后一个路径的最后一部分
last_source = col.sources[-1]
if "." in last_source:
return last_source.split(".")[-1].lower()
return last_source.lower()
def get_tables(self) -> List[str]:
"""获取需要校验的 ODS 表列表"""
if self._table_config:
return list(self._table_config.keys())
# 从数据库查询 ODS schema 中的表
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_ods'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
except Exception as e:
self.logger.warning("获取 ODS 表列表失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return []
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
if table in self._table_config:
return self._table_config[table].get("pk_columns", [])
# 表不在配置中,返回空列表表示无法校验
return []
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列"""
if table in self._table_config:
return self._table_config[table].get("time_column", "fetched_at")
return "fetched_at"
def _get_full_table_name(self, table: str) -> str:
"""获取完整的表名(包含 schema"""
if table in self._table_config:
return self._table_config[table].get("full_table_name", f"billiards_ods.{table}")
# 如果表名已经包含 schema直接返回
if "." in table:
return table
return f"billiards_ods.{table}"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""
从源获取主键集合
根据 fetch_from_api 参数决定数据来源:
- fetch_from_api=True: 从 API 获取数据(真正的源到目标校验)
- fetch_from_api=False: 从 ODS 表获取ODS 内部一致性校验)
"""
if self._has_external_source():
return self._fetch_keys_from_api(table, window_start, window_end)
else:
# ODS 内部校验:直接从 ODS 表获取
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 API 获取源数据主键集合"""
# 尝试获取缓存的 API 数据
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_key_cache:
return self._api_key_cache[cache_key]
if cache_key not in self._api_data_cache:
# 调用 API 获取数据
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
self.logger.debug("%s 从 API 未获取到数据", table)
return set()
# 获取主键列
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过 API 校验", table)
return set()
# 提取主键
keys = set()
for record in api_records:
pk_values = []
for col in pk_cols:
# API 返回的字段名可能是原始格式(如 id, Id, ID
# 尝试多种格式
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
keys.add(tuple(pk_values))
self.logger.info("%s 从源数据获取 %d 条记录,%d 个唯一主键", table, len(api_records), len(keys))
self._api_key_cache[cache_key] = keys
return keys
def _call_api_for_table(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> List[dict]:
"""调用源数据获取表对应的数据"""
config = self._table_config.get(table, {})
task_code = config.get("task_code")
endpoint = config.get("api_endpoint")
if not task_code or not endpoint:
self.logger.warning(
"%s 没有完整的任务配置task_code=%s, endpoint=%s),无法获取源数据",
table, task_code, endpoint
)
return []
source_client = self._get_source_client(task_code)
if not source_client:
self.logger.warning("%s 未找到可用源API/本地JSON跳过获取源数据", table)
return []
source_label = "本地 JSON" if self._is_using_local_json(task_code) else "API"
self.logger.info(
"%s 获取数据: 表=%s, 端点=%s, 时间窗口=%s ~ %s",
source_label, table, endpoint, window_start, window_end
)
try:
# 获取 ODS 任务规格以获取正确的参数配置
from tasks.ods.ods_tasks import ODS_TASK_SPECS
# 查找对应的任务规格
spec = None
for s in ODS_TASK_SPECS:
if s.code == task_code:
spec = s
break
if not spec:
self.logger.warning("未找到任务规格: %s", task_code)
return []
# 构建 API 参数
params = {}
if spec.include_site_id:
# 从 API 客户端获取 store_id如果可用
store_id = getattr(self.api_client, 'store_id', None)
if store_id:
params["siteId"] = store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
# 格式化时间戳
params[start_key] = window_start.strftime("%Y-%m-%d %H:%M:%S")
params[end_key] = window_end.strftime("%Y-%m-%d %H:%M:%S")
# 合并额外参数
params.update(spec.extra_params)
# 调用源数据
all_records = []
for _, page_records, _, _ in source_client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=200,
data_path=spec.data_path,
list_key=spec.list_key,
):
all_records.extend(page_records)
self.logger.info("源数据返回 %d 条原始记录", len(all_records))
return all_records
except Exception as e:
self.logger.warning("获取源数据失败: 表=%s, error=%s", table, e)
import traceback
self.logger.debug("调用栈: %s", traceback.format_exc())
raise VerificationFetchError(f"获取源数据失败: {table}") from e
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 ODS 表获取目标数据主键集合"""
if self._has_external_source():
cache_key = f"{table}_{window_start}_{window_end}"
api_keys = self._api_key_cache.get(cache_key)
if api_keys is None:
api_keys = self._fetch_keys_from_api(table, window_start, window_end)
return self._fetch_keys_from_db_by_keys(table, api_keys)
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从数据库获取主键集合"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取主键", table)
return set()
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
return {tuple(row) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 ODS 主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS 主键失败: {table}") from e
def _fetch_keys_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Set[Tuple]:
"""按主键集合反查 ODS 表是否存在记录(不依赖时间窗口)"""
if not keys:
return set()
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查", table)
return set()
full_table = self._get_full_table_name(table)
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
except Exception as e:
self.logger.warning("按主键反查 ODS 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS 失败: {table}") from e
return existing
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取源数据的主键->content_hash 映射"""
if self._has_external_source():
return self._fetch_hashes_from_api(table, window_start, window_end)
else:
# ODS 表自带 content_hash 列
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 API 数据计算哈希"""
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_hash_cache:
return self._api_hash_cache[cache_key]
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
# 尝试从 API 获取
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
if not api_records:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
return {}
result = {}
for record in api_records:
# 提取主键
pk_values = []
for col in pk_cols:
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
pk = tuple(pk_values)
# 计算内容哈希
content_hash = self._compute_hash(record)
result[pk] = content_hash
self._api_hash_cache[cache_key] = result
if cache_key not in self._api_key_cache:
self._api_key_cache[cache_key] = set(result.keys())
return result
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取目标数据的主键->content_hash 映射"""
if self.fetch_from_api and self.api_client:
cache_key = f"{table}_{window_start}_{window_end}"
api_hashes = self._api_hash_cache.get(cache_key)
if api_hashes is None:
api_hashes = self._fetch_hashes_from_api(table, window_start, window_end)
api_keys = set(api_hashes.keys())
return self._fetch_hashes_from_db_by_keys(table, api_keys)
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从数据库获取主键->hash 映射"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取哈希", table)
return {}
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}, content_hash
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
for row in cur.fetchall():
pk = tuple(row[:-1])
content_hash = row[-1]
result[pk] = content_hash or ""
except Exception as e:
# 查询失败时回滚事务,避免影响后续查询
self.logger.warning("获取 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS hash 失败: {table}") from e
return result
def _fetch_hashes_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Dict[Tuple, str]:
"""按主键集合反查 ODS 的对比哈希(不依赖时间窗口)"""
if not keys:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查 hash", table)
return {}
full_table = self._get_full_table_name(table)
has_payload = self._table_has_column(full_table, "payload")
select_tail = 't."payload"' if has_payload else 't."content_hash"'
select_cols = ", ".join([*(f't."{c}"' for c in pk_cols), select_tail])
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
result: Dict[Tuple, str] = {}
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
pk = tuple(row[:-1])
tail_value = row[-1]
if has_payload:
compare_hash = self._compute_compare_hash_from_payload(tail_value)
result[pk] = compare_hash or ""
else:
result[pk] = tail_value or ""
except Exception as e:
self.logger.warning("按主键反查 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS hash 失败: {table}") from e
return result
@staticmethod
def _chunked(items: List[Tuple], chunk_size: int) -> List[List[Tuple]]:
"""将列表按固定大小分块"""
if chunk_size <= 0:
return [items]
return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量补齐缺失数据
ODS 层补齐需要重新从 API 获取数据
"""
if not self._has_external_source():
self.logger.warning("未配置 API/本地JSON 源,无法补齐 ODS 缺失数据")
return 0
if not missing_keys:
return 0
# 获取表配置
config = self._table_config.get(table, {})
task_code = config.get("task_code")
if not task_code:
self.logger.warning("未找到表 %s 的任务配置,跳过补齐", table)
return 0
self.logger.info(
"ODS 补齐缺失: 表=%s, 数量=%d, 任务=%s",
table, len(missing_keys), task_code
)
# ODS 层的补齐实际上是重新执行 ODS 任务从 API 获取数据
# 但由于 ODS 任务已经在 "校验前先从 API 获取数据" 步骤执行过了,
# 这里补齐失败是预期的(数据已经在 ODS 表中,只是校验窗口可能不一致)
#
# 实际的 ODS 补齐应该在 verify_only 模式下启用 fetch_before_verify 选项,
# 这会先执行 ODS 任务获取 API 数据,然后再校验。
#
# 如果仍然有缺失,说明:
# 1. API 返回的数据时间窗口与校验窗口不完全匹配
# 2. 或者 ODS 任务的时间参数配置问题
self.logger.info(
"ODS 补齐提示: 表=%s%d 条缺失记录,建议使用 '校验前先从 API 获取数据' 选项获取完整数据",
table, len(missing_keys)
)
return 0
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量更新不一致数据
ODS 层更新也需要重新从 API 获取
"""
# 与 backfill_missing 类似,重新获取数据会自动 UPSERT
return self.backfill_missing(table, mismatch_keys, window_start, window_end)
def _has_external_source(self) -> bool:
return bool(self.fetch_from_api and (self.api_client or self.use_local_json))
def _is_using_local_json(self, task_code: str) -> bool:
return bool(self.use_local_json and task_code in self.local_dump_dirs)
def _get_local_json_client(self, task_code: str) -> Optional[LocalJsonClient]:
if task_code in self._local_json_clients:
return self._local_json_clients[task_code]
dump_dir = self.local_dump_dirs.get(task_code)
if not dump_dir:
return None
try:
client = LocalJsonClient(dump_dir)
except Exception as exc: # noqa: BLE001
self.logger.warning(
"本地 JSON 目录不可用: task=%s, dir=%s, error=%s",
task_code, dump_dir, exc,
)
return None
self._local_json_clients[task_code] = client
return client
def _get_source_client(self, task_code: str):
if self.use_local_json:
return self._get_local_json_client(task_code)
return self.api_client
def verify_against_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
auto_backfill: bool = False,
) -> Dict[str, Any]:
"""
与 API 源数据对比校验
这是更严格的校验,直接调用 API 获取数据进行对比
"""
if not self.api_client:
return {"error": "未配置 API 客户端"}
config = self._table_config.get(table, {})
endpoint = config.get("api_endpoint")
if not endpoint:
return {"error": f"未找到表 {table} 的 API 端点配置"}
self.logger.info("开始与 API 对比校验: %s", table)
# 1. 从 API 获取数据
try:
api_records = self.api_client.fetch_records(
endpoint=endpoint,
start_time=window_start,
end_time=window_end,
)
except Exception as e:
return {"error": f"API 调用失败: {e}"}
# 2. 从 ODS 获取数据
ods_hashes = self.fetch_target_hashes(table, window_start, window_end)
# 3. 计算 API 数据的 hash
pk_cols = self.get_primary_keys(table)
api_hashes = {}
for record in api_records:
pk = tuple(record.get(col) for col in pk_cols)
content_hash = self._compute_hash(record)
api_hashes[pk] = content_hash
# 4. 对比
api_keys = set(api_hashes.keys())
ods_keys = set(ods_hashes.keys())
missing = api_keys - ods_keys
extra = ods_keys - api_keys
mismatch = {
k for k in (api_keys & ods_keys)
if api_hashes[k] != ods_hashes[k]
}
result = {
"table": table,
"api_count": len(api_hashes),
"ods_count": len(ods_hashes),
"missing_count": len(missing),
"extra_count": len(extra),
"mismatch_count": len(mismatch),
"is_consistent": len(missing) == 0 and len(mismatch) == 0,
}
# 5. 自动补齐
if auto_backfill and (missing or mismatch):
# 需要重新获取的主键
keys_to_refetch = missing | mismatch
# 筛选需要重新插入的记录
records_to_upsert = [
r for r in api_records
if tuple(r.get(col) for col in pk_cols) in keys_to_refetch
]
if records_to_upsert:
backfilled = self._batch_upsert(table, records_to_upsert)
result["backfilled_count"] = backfilled
return result
def _table_has_column(self, full_table: str, column: str) -> bool:
"""检查表是否包含指定列(带缓存)"""
cache_key = (full_table, column)
if cache_key in self._table_column_cache:
return self._table_column_cache[cache_key]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT 1
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s AND column_name = %s
LIMIT 1
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table, column))
exists = cur.fetchone() is not None
except Exception:
exists = False
try:
self.db.conn.rollback()
except Exception:
pass
self._table_column_cache[cache_key] = exists
return exists
def _get_db_primary_keys(self, full_table: str) -> List[str]:
"""Read primary key columns from database metadata (ordered)."""
if full_table in self._table_pk_cache:
return self._table_pk_cache[full_table]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table))
rows = cur.fetchall()
cols = [r[0] if not isinstance(r, dict) else r.get("column_name") for r in rows]
result = [c for c in cols if c]
except Exception:
result = []
try:
self.db.conn.rollback()
except Exception:
pass
self._table_pk_cache[full_table] = result
return result
def _compute_compare_hash_from_payload(self, payload: Any) -> Optional[str]:
"""使用 ODS 任务的算法计算对比哈希"""
try:
from tasks.ods.ods_tasks import BaseOdsTask
return BaseOdsTask._compute_compare_hash_from_payload(payload)
except Exception:
return None
def _compute_hash(self, record: dict) -> str:
"""计算记录的对比哈希(与 ODS 入库一致,不包含 fetched_at"""
compare_hash = self._compute_compare_hash_from_payload(record)
return compare_hash or ""
def _batch_upsert(self, table: str, records: List[dict]) -> int:
"""Batch backfill in snapshot-safe mode (insert-only on PK conflict)."""
if not records:
return 0
full_table = self._get_full_table_name(table)
db_pk_cols = self._get_db_primary_keys(full_table)
if not db_pk_cols:
self.logger.warning("%s 未找到主键,跳过回填", full_table)
return 0
has_content_hash_col = self._table_has_column(full_table, "content_hash")
# 获取所有列(从第一条记录),并在存在 content_hash 列时补齐该列。
all_cols = list(records[0].keys())
if has_content_hash_col and "content_hash" not in all_cols:
all_cols.append("content_hash")
# Snapshot-safe strategy: never update historical rows; only insert new snapshots.
col_list = ", ".join(all_cols)
placeholders = ", ".join(["%s"] * len(all_cols))
pk_list = ", ".join(db_pk_cols)
sql = f"""
INSERT INTO {full_table} ({col_list})
VALUES ({placeholders})
ON CONFLICT ({pk_list}) DO NOTHING
"""
count = 0
with self.db.conn.cursor() as cur:
for record in records:
row = dict(record)
if has_content_hash_col:
row["content_hash"] = self._compute_hash(record)
values = [row.get(col) for col in all_cols]
try:
cur.execute(sql, values)
affected = int(cur.rowcount or 0)
if affected > 0:
count += affected
except Exception as e:
self.logger.warning("UPSERT 失败: %s, error=%s", record, e)
self.db.commit()
return count