419 lines
15 KiB
Python
419 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""ODS 层逐任务调试脚本。
|
||
|
||
连接真实 API 和数据库,逐个执行 23 个 ODS 任务(小窗口),
|
||
验证返回结果和 ODS 表实际写入行数的一致性。
|
||
|
||
用法:
|
||
cd apps/etl/connectors/feiqiu
|
||
python -m scripts.debug.debug_ods [--hours 2] [--tasks ODS_MEMBER,ODS_PAYMENT]
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import logging
|
||
import sys
|
||
import time
|
||
import traceback
|
||
from dataclasses import asdict, dataclass, field
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from zoneinfo import ZoneInfo
|
||
|
||
# ── 确保项目根目录在 sys.path ──
|
||
_FEIQIU_ROOT = Path(__file__).resolve().parents[2]
|
||
if str(_FEIQIU_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(_FEIQIU_ROOT))
|
||
|
||
from config.settings import AppConfig
|
||
from database.connection import DatabaseConnection
|
||
from database.operations import DatabaseOperations
|
||
from api.client import APIClient
|
||
from orchestration.task_registry import default_registry
|
||
from orchestration.cursor_manager import CursorManager
|
||
from orchestration.run_tracker import RunTracker
|
||
from orchestration.task_executor import TaskExecutor
|
||
|
||
|
||
@dataclass
|
||
class DebugResult:
|
||
"""单个 ODS 任务的调试结果"""
|
||
layer: str = "ODS"
|
||
task_code: str = ""
|
||
status: str = "" # PASS / FAIL / WARN / ERROR
|
||
message: str = ""
|
||
counts: dict = field(default_factory=dict)
|
||
db_row_count: int | None = None
|
||
count_match: bool | None = None
|
||
duration_sec: float = 0.0
|
||
error_detail: str | None = None
|
||
table_name: str = ""
|
||
fix_applied: str | None = None
|
||
|
||
|
||
# ── 工具函数 ──────────────────────────────────────────────────
|
||
|
||
def _setup_logging() -> logging.Logger:
|
||
logger = logging.getLogger("debug_ods")
|
||
logger.setLevel(logging.INFO)
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler(sys.stdout)
|
||
handler.setFormatter(logging.Formatter(
|
||
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"
|
||
))
|
||
logger.addHandler(handler)
|
||
return logger
|
||
|
||
|
||
def _get_ods_table_name(task_code: str) -> str | None:
|
||
"""从 TaskRegistry 获取 ODS 任务对应的表名。"""
|
||
meta = default_registry.get_metadata(task_code)
|
||
if meta is None:
|
||
return None
|
||
# 通过临时实例获取 SPEC.table_name(所有 ODS 任务类都有 SPEC 属性)
|
||
task_cls = meta.task_class
|
||
spec = getattr(task_cls, "SPEC", None)
|
||
if spec and hasattr(spec, "table_name"):
|
||
return spec.table_name
|
||
return None
|
||
|
||
|
||
def _query_table_count(db_conn: DatabaseConnection, table_name: str,
|
||
window_start: datetime, window_end: datetime) -> int:
|
||
"""查询 ODS 表在指定时间窗口内的行数。
|
||
|
||
优先用 fetched_at 列过滤;若该列不存在则回退到全表 COUNT。
|
||
"""
|
||
# 先检查 fetched_at 列是否存在
|
||
check_sql = """
|
||
SELECT 1 FROM information_schema.columns
|
||
WHERE table_schema || '.' || table_name = %s
|
||
AND column_name = 'fetched_at'
|
||
LIMIT 1
|
||
"""
|
||
schema_table = table_name # 格式: ods.xxx
|
||
rows = db_conn.query(check_sql, (schema_table,))
|
||
|
||
if rows:
|
||
count_sql = f"SELECT COUNT(*) AS cnt FROM {table_name} WHERE fetched_at >= %s AND fetched_at < %s"
|
||
result = db_conn.query(count_sql, (window_start, window_end))
|
||
else:
|
||
count_sql = f"SELECT COUNT(*) AS cnt FROM {table_name}"
|
||
result = db_conn.query(count_sql)
|
||
|
||
return int(result[0]["cnt"]) if result else 0
|
||
|
||
|
||
def _build_components(config: AppConfig, logger: logging.Logger):
|
||
"""构建 DB / API / TaskExecutor 等组件,与 CLI main() 保持一致。"""
|
||
db_conn = DatabaseConnection(
|
||
dsn=config["db"]["dsn"],
|
||
session=config["db"].get("session"),
|
||
connect_timeout=config["db"].get("connect_timeout_sec"),
|
||
)
|
||
api_client = APIClient(
|
||
base_url=config["api"]["base_url"],
|
||
token=config["api"]["token"],
|
||
timeout=config["api"].get("timeout_sec", 20),
|
||
retry_max=config["api"].get("retries", {}).get("max_attempts", 3),
|
||
headers_extra=config["api"].get("headers_extra"),
|
||
)
|
||
db_ops = DatabaseOperations(db_conn)
|
||
cursor_mgr = CursorManager(db_conn)
|
||
run_tracker = RunTracker(db_conn)
|
||
|
||
executor = TaskExecutor(
|
||
config, db_ops, api_client,
|
||
cursor_mgr, run_tracker, default_registry, logger,
|
||
)
|
||
return db_conn, api_client, db_ops, executor
|
||
|
||
|
||
# ── 核心调试逻辑 ──────────────────────────────────────────────
|
||
|
||
def debug_single_ods_task(
|
||
task_code: str,
|
||
executor: TaskExecutor,
|
||
db_conn: DatabaseConnection,
|
||
config: AppConfig,
|
||
logger: logging.Logger,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> DebugResult:
|
||
"""执行单个 ODS 任务并验证结果。"""
|
||
result = DebugResult(task_code=task_code)
|
||
table_name = _get_ods_table_name(task_code)
|
||
result.table_name = table_name or ""
|
||
|
||
store_id = int(config.get("app.store_id"))
|
||
run_uuid = f"debug-ods-{task_code.lower()}-{int(time.time())}"
|
||
|
||
logger.info("━" * 60)
|
||
logger.info("▶ 开始调试: %s (表: %s)", task_code, table_name or "未知")
|
||
|
||
# 执行前查询表行数(用于对比增量)
|
||
pre_count = None
|
||
if table_name:
|
||
try:
|
||
pre_count = _query_table_count(db_conn, table_name, window_start, window_end)
|
||
logger.info(" 执行前表行数 (窗口内): %d", pre_count)
|
||
except Exception as exc:
|
||
logger.warning(" 查询执行前行数失败: %s", exc)
|
||
|
||
# 执行任务
|
||
t0 = time.monotonic()
|
||
try:
|
||
task_result = executor.run_single_task(
|
||
task_code=task_code,
|
||
run_uuid=run_uuid,
|
||
store_id=store_id,
|
||
data_source="online",
|
||
)
|
||
result.duration_sec = round(time.monotonic() - t0, 2)
|
||
except Exception as exc:
|
||
result.duration_sec = round(time.monotonic() - t0, 2)
|
||
result.status = "ERROR"
|
||
result.message = f"任务执行异常: {exc}"
|
||
result.error_detail = traceback.format_exc()
|
||
logger.error(" ✗ 执行异常: %s", exc)
|
||
return result
|
||
|
||
# 解析返回结果
|
||
task_status = (task_result.get("status") or "").upper()
|
||
counts = task_result.get("counts") or {}
|
||
result.counts = counts
|
||
|
||
logger.info(" 返回状态: %s", task_status)
|
||
logger.info(" counts: fetched=%s inserted=%s updated=%s skipped=%s errors=%s",
|
||
counts.get("fetched", 0), counts.get("inserted", 0),
|
||
counts.get("updated", 0), counts.get("skipped", 0),
|
||
counts.get("errors", 0))
|
||
|
||
# 验证 counts 合理性
|
||
fetched = counts.get("fetched", 0)
|
||
inserted = counts.get("inserted", 0)
|
||
updated = counts.get("updated", 0)
|
||
skipped = counts.get("skipped", 0)
|
||
errors = counts.get("errors", 0)
|
||
|
||
# 基本校验: fetched >= inserted + updated + skipped
|
||
accounted = inserted + updated + skipped
|
||
if fetched > 0 and accounted > fetched:
|
||
result.status = "WARN"
|
||
result.message = f"counts 异常: accounted({accounted}) > fetched({fetched})"
|
||
logger.warning(" ⚠ %s", result.message)
|
||
|
||
# 执行后查询表行数
|
||
if table_name:
|
||
try:
|
||
post_count = _query_table_count(db_conn, table_name, window_start, window_end)
|
||
result.db_row_count = post_count
|
||
logger.info(" 执行后表行数 (窗口内): %d", post_count)
|
||
|
||
# 对比增量: 新增行数应约等于 inserted
|
||
if pre_count is not None:
|
||
actual_delta = post_count - pre_count
|
||
# inserted 是本次新插入的行数
|
||
if inserted > 0 and actual_delta == 0:
|
||
# 可能是冲突处理导致无新增(DO NOTHING / update)
|
||
logger.info(" ℹ 无新增行(可能是冲突处理: DO NOTHING / update)")
|
||
result.count_match = True # 标记已完成对比
|
||
|
||
logger.info(" 实际新增行数: %d, counts.inserted: %d", actual_delta, inserted)
|
||
except Exception as exc:
|
||
logger.warning(" 查询执行后行数失败: %s", exc)
|
||
|
||
# 最终状态判定
|
||
if result.status == "":
|
||
if errors > 0:
|
||
result.status = "WARN"
|
||
result.message = f"执行完成但有 {errors} 个错误"
|
||
elif task_status in ("SUCCESS", "PARTIAL"):
|
||
result.status = "PASS"
|
||
result.message = f"执行成功, fetched={fetched}"
|
||
elif task_status == "SKIP":
|
||
result.status = "WARN"
|
||
result.message = "任务被跳过(未启用或不存在)"
|
||
else:
|
||
result.status = "WARN"
|
||
result.message = f"未知状态: {task_status}"
|
||
|
||
icon = {"PASS": "✓", "WARN": "⚠", "ERROR": "✗", "FAIL": "✗"}.get(result.status, "?")
|
||
logger.info(" %s 结果: %s - %s (耗时 %.1fs)", icon, result.status, result.message, result.duration_sec)
|
||
return result
|
||
|
||
|
||
# ── 主流程 ────────────────────────────────────────────────────
|
||
|
||
def run_ods_debug(
|
||
hours: float = 2.0,
|
||
task_filter: list[str] | None = None,
|
||
) -> list[DebugResult]:
|
||
"""执行 ODS 层全量调试。
|
||
|
||
Args:
|
||
hours: 回溯窗口小时数(默认 2 小时)
|
||
task_filter: 仅调试指定的任务代码列表,None 表示全部
|
||
Returns:
|
||
所有任务的 DebugResult 列表
|
||
"""
|
||
logger = _setup_logging()
|
||
logger.info("=" * 60)
|
||
logger.info("ODS 层调试开始")
|
||
logger.info("=" * 60)
|
||
|
||
# 加载配置(从 .env)
|
||
config = AppConfig.load()
|
||
tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
|
||
window_end = datetime.now(tz)
|
||
window_start = window_end - timedelta(hours=hours)
|
||
|
||
logger.info("门店 ID: %s", config.get("app.store_id"))
|
||
logger.info("数据库: %s", config.get("db.name", ""))
|
||
logger.info("API: %s", config.get("api.base_url", ""))
|
||
logger.info("时间窗口: %s ~ %s (%.1f 小时)", window_start, window_end, hours)
|
||
|
||
# 设置 window_override 让所有任务使用统一的小窗口
|
||
config.config.setdefault("run", {}).setdefault("window_override", {})
|
||
config.config["run"]["window_override"]["start"] = window_start
|
||
config.config["run"]["window_override"]["end"] = window_end
|
||
|
||
# 构建组件
|
||
db_conn, api_client, db_ops, executor = _build_components(config, logger)
|
||
|
||
# 获取所有 ODS 层任务
|
||
all_ods_codes = sorted(default_registry.get_tasks_by_layer("ODS"))
|
||
if task_filter:
|
||
filter_set = {t.upper() for t in task_filter}
|
||
ods_codes = [c for c in all_ods_codes if c in filter_set]
|
||
skipped = filter_set - set(ods_codes)
|
||
if skipped:
|
||
logger.warning("以下任务不在 ODS 层注册表中,已跳过: %s", skipped)
|
||
else:
|
||
ods_codes = all_ods_codes
|
||
|
||
logger.info("待调试 ODS 任务: %d 个", len(ods_codes))
|
||
logger.info("任务列表: %s", ", ".join(ods_codes))
|
||
logger.info("")
|
||
|
||
# 逐个执行
|
||
results: list[DebugResult] = []
|
||
for idx, task_code in enumerate(ods_codes, start=1):
|
||
logger.info("[%d/%d] %s", idx, len(ods_codes), task_code)
|
||
try:
|
||
r = debug_single_ods_task(
|
||
task_code=task_code,
|
||
executor=executor,
|
||
db_conn=db_conn,
|
||
config=config,
|
||
logger=logger,
|
||
window_start=window_start,
|
||
window_end=window_end,
|
||
)
|
||
except Exception as exc:
|
||
r = DebugResult(
|
||
task_code=task_code,
|
||
status="ERROR",
|
||
message=f"未捕获异常: {exc}",
|
||
error_detail=traceback.format_exc(),
|
||
)
|
||
logger.error(" ✗ 未捕获异常: %s", exc)
|
||
results.append(r)
|
||
|
||
# 确保连接可用(防止长时间运行后断连)
|
||
db_conn.ensure_open()
|
||
|
||
# 汇总
|
||
_print_summary(results, logger)
|
||
|
||
# 输出 JSON 结果
|
||
output_dir = _FEIQIU_ROOT / "scripts" / "debug" / "output"
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
ts = datetime.now(tz).strftime("%Y%m%d_%H%M%S")
|
||
output_file = output_dir / f"debug_ods_{ts}.json"
|
||
_save_results(results, output_file)
|
||
logger.info("结果已保存: %s", output_file)
|
||
|
||
# 清理
|
||
db_conn.close()
|
||
return results
|
||
|
||
|
||
def _print_summary(results: list[DebugResult], logger: logging.Logger):
|
||
"""打印调试汇总。"""
|
||
logger.info("")
|
||
logger.info("=" * 60)
|
||
logger.info("ODS 层调试汇总")
|
||
logger.info("=" * 60)
|
||
|
||
pass_count = sum(1 for r in results if r.status == "PASS")
|
||
warn_count = sum(1 for r in results if r.status == "WARN")
|
||
error_count = sum(1 for r in results if r.status in ("ERROR", "FAIL"))
|
||
total_duration = sum(r.duration_sec for r in results)
|
||
|
||
logger.info("总计: %d 个任务", len(results))
|
||
logger.info(" ✓ PASS: %d", pass_count)
|
||
logger.info(" ⚠ WARN: %d", warn_count)
|
||
logger.info(" ✗ ERROR: %d", error_count)
|
||
logger.info(" 总耗时: %.1f 秒", total_duration)
|
||
logger.info("")
|
||
|
||
# 列出非 PASS 的任务
|
||
non_pass = [r for r in results if r.status != "PASS"]
|
||
if non_pass:
|
||
logger.info("需关注的任务:")
|
||
for r in non_pass:
|
||
logger.info(" [%s] %s: %s", r.status, r.task_code, r.message)
|
||
else:
|
||
logger.info("所有任务均通过 ✓")
|
||
|
||
|
||
def _save_results(results: list[DebugResult], path: Path):
|
||
"""将结果序列化为 JSON。"""
|
||
data = []
|
||
for r in results:
|
||
d = asdict(r)
|
||
# datetime 不可直接序列化,counts 中可能有 datetime
|
||
data.append(_sanitize_for_json(d))
|
||
path.write_text(json.dumps(data, ensure_ascii=False, indent=2, default=str), encoding="utf-8")
|
||
|
||
|
||
def _sanitize_for_json(obj):
|
||
"""递归处理不可序列化的值。"""
|
||
if isinstance(obj, dict):
|
||
return {k: _sanitize_for_json(v) for k, v in obj.items()}
|
||
if isinstance(obj, (list, tuple)):
|
||
return [_sanitize_for_json(v) for v in obj]
|
||
if isinstance(obj, datetime):
|
||
return obj.isoformat()
|
||
return obj
|
||
|
||
|
||
# ── CLI 入口 ──────────────────────────────────────────────────
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="ODS 层逐任务调试")
|
||
parser.add_argument("--hours", type=float, default=2.0,
|
||
help="回溯窗口小时数(默认 2)")
|
||
parser.add_argument("--tasks", type=str, default=None,
|
||
help="仅调试指定任务,逗号分隔(如 ODS_MEMBER,ODS_PAYMENT)")
|
||
return parser.parse_args()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
task_filter = None
|
||
if args.tasks:
|
||
task_filter = [t.strip().upper() for t in args.tasks.split(",") if t.strip()]
|
||
|
||
results = run_ods_debug(hours=args.hours, task_filter=task_filter)
|
||
|
||
# 退出码: 有 ERROR 则非零
|
||
has_error = any(r.status in ("ERROR", "FAIL") for r in results)
|
||
sys.exit(1 if has_error else 0)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|