# -*- coding: utf-8 -*- """ODS 层逐任务调试脚本。 连接真实 API 和数据库,逐个执行 23 个 ODS 任务(小窗口), 验证返回结果和 ODS 表实际写入行数的一致性。 用法: cd apps/etl/connectors/feiqiu python -m scripts.debug.debug_ods [--hours 2] [--tasks ODS_MEMBER,ODS_PAYMENT] """ from __future__ import annotations import argparse import json import logging import sys import time import traceback from dataclasses import asdict, dataclass, field from datetime import datetime, timedelta from pathlib import Path from zoneinfo import ZoneInfo # ── 确保项目根目录在 sys.path ── _FEIQIU_ROOT = Path(__file__).resolve().parents[2] if str(_FEIQIU_ROOT) not in sys.path: sys.path.insert(0, str(_FEIQIU_ROOT)) from config.settings import AppConfig from database.connection import DatabaseConnection from database.operations import DatabaseOperations from api.client import APIClient from orchestration.task_registry import default_registry from orchestration.cursor_manager import CursorManager from orchestration.run_tracker import RunTracker from orchestration.task_executor import TaskExecutor @dataclass class DebugResult: """单个 ODS 任务的调试结果""" layer: str = "ODS" task_code: str = "" status: str = "" # PASS / FAIL / WARN / ERROR message: str = "" counts: dict = field(default_factory=dict) db_row_count: int | None = None count_match: bool | None = None duration_sec: float = 0.0 error_detail: str | None = None table_name: str = "" fix_applied: str | None = None # ── 工具函数 ────────────────────────────────────────────────── def _setup_logging() -> logging.Logger: logger = logging.getLogger("debug_ods") logger.setLevel(logging.INFO) if not logger.handlers: handler = logging.StreamHandler(sys.stdout) handler.setFormatter(logging.Formatter( "%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S" )) logger.addHandler(handler) return logger def _get_ods_table_name(task_code: str) -> str | None: """从 TaskRegistry 获取 ODS 任务对应的表名。""" meta = default_registry.get_metadata(task_code) if meta is None: return None # 通过临时实例获取 SPEC.table_name(所有 ODS 任务类都有 SPEC 属性) task_cls = meta.task_class spec = getattr(task_cls, "SPEC", None) if spec and hasattr(spec, "table_name"): return spec.table_name return None def _query_table_count(db_conn: DatabaseConnection, table_name: str, window_start: datetime, window_end: datetime) -> int: """查询 ODS 表在指定时间窗口内的行数。 优先用 fetched_at 列过滤;若该列不存在则回退到全表 COUNT。 """ # 先检查 fetched_at 列是否存在 check_sql = """ SELECT 1 FROM information_schema.columns WHERE table_schema || '.' || table_name = %s AND column_name = 'fetched_at' LIMIT 1 """ schema_table = table_name # 格式: ods.xxx rows = db_conn.query(check_sql, (schema_table,)) if rows: count_sql = f"SELECT COUNT(*) AS cnt FROM {table_name} WHERE fetched_at >= %s AND fetched_at < %s" result = db_conn.query(count_sql, (window_start, window_end)) else: count_sql = f"SELECT COUNT(*) AS cnt FROM {table_name}" result = db_conn.query(count_sql) return int(result[0]["cnt"]) if result else 0 def _build_components(config: AppConfig, logger: logging.Logger): """构建 DB / API / TaskExecutor 等组件,与 CLI main() 保持一致。""" db_conn = DatabaseConnection( dsn=config["db"]["dsn"], session=config["db"].get("session"), connect_timeout=config["db"].get("connect_timeout_sec"), ) api_client = APIClient( base_url=config["api"]["base_url"], token=config["api"]["token"], timeout=config["api"].get("timeout_sec", 20), retry_max=config["api"].get("retries", {}).get("max_attempts", 3), headers_extra=config["api"].get("headers_extra"), ) db_ops = DatabaseOperations(db_conn) cursor_mgr = CursorManager(db_conn) run_tracker = RunTracker(db_conn) executor = TaskExecutor( config, db_ops, api_client, cursor_mgr, run_tracker, default_registry, logger, ) return db_conn, api_client, db_ops, executor # ── 核心调试逻辑 ────────────────────────────────────────────── def debug_single_ods_task( task_code: str, executor: TaskExecutor, db_conn: DatabaseConnection, config: AppConfig, logger: logging.Logger, window_start: datetime, window_end: datetime, ) -> DebugResult: """执行单个 ODS 任务并验证结果。""" result = DebugResult(task_code=task_code) table_name = _get_ods_table_name(task_code) result.table_name = table_name or "" store_id = int(config.get("app.store_id")) run_uuid = f"debug-ods-{task_code.lower()}-{int(time.time())}" logger.info("━" * 60) logger.info("▶ 开始调试: %s (表: %s)", task_code, table_name or "未知") # 执行前查询表行数(用于对比增量) pre_count = None if table_name: try: pre_count = _query_table_count(db_conn, table_name, window_start, window_end) logger.info(" 执行前表行数 (窗口内): %d", pre_count) except Exception as exc: logger.warning(" 查询执行前行数失败: %s", exc) # 执行任务 t0 = time.monotonic() try: task_result = executor.run_single_task( task_code=task_code, run_uuid=run_uuid, store_id=store_id, data_source="online", ) result.duration_sec = round(time.monotonic() - t0, 2) except Exception as exc: result.duration_sec = round(time.monotonic() - t0, 2) result.status = "ERROR" result.message = f"任务执行异常: {exc}" result.error_detail = traceback.format_exc() logger.error(" ✗ 执行异常: %s", exc) return result # 解析返回结果 task_status = (task_result.get("status") or "").upper() counts = task_result.get("counts") or {} result.counts = counts logger.info(" 返回状态: %s", task_status) logger.info(" counts: fetched=%s inserted=%s updated=%s skipped=%s errors=%s", counts.get("fetched", 0), counts.get("inserted", 0), counts.get("updated", 0), counts.get("skipped", 0), counts.get("errors", 0)) # 验证 counts 合理性 fetched = counts.get("fetched", 0) inserted = counts.get("inserted", 0) updated = counts.get("updated", 0) skipped = counts.get("skipped", 0) errors = counts.get("errors", 0) # 基本校验: fetched >= inserted + updated + skipped accounted = inserted + updated + skipped if fetched > 0 and accounted > fetched: result.status = "WARN" result.message = f"counts 异常: accounted({accounted}) > fetched({fetched})" logger.warning(" ⚠ %s", result.message) # 执行后查询表行数 if table_name: try: post_count = _query_table_count(db_conn, table_name, window_start, window_end) result.db_row_count = post_count logger.info(" 执行后表行数 (窗口内): %d", post_count) # 对比增量: 新增行数应约等于 inserted if pre_count is not None: actual_delta = post_count - pre_count # inserted 是本次新插入的行数 if inserted > 0 and actual_delta == 0: # 可能是冲突处理导致无新增(DO NOTHING / update) logger.info(" ℹ 无新增行(可能是冲突处理: DO NOTHING / update)") result.count_match = True # 标记已完成对比 logger.info(" 实际新增行数: %d, counts.inserted: %d", actual_delta, inserted) except Exception as exc: logger.warning(" 查询执行后行数失败: %s", exc) # 最终状态判定 if result.status == "": if errors > 0: result.status = "WARN" result.message = f"执行完成但有 {errors} 个错误" elif task_status in ("SUCCESS", "PARTIAL"): result.status = "PASS" result.message = f"执行成功, fetched={fetched}" elif task_status == "SKIP": result.status = "WARN" result.message = "任务被跳过(未启用或不存在)" else: result.status = "WARN" result.message = f"未知状态: {task_status}" icon = {"PASS": "✓", "WARN": "⚠", "ERROR": "✗", "FAIL": "✗"}.get(result.status, "?") logger.info(" %s 结果: %s - %s (耗时 %.1fs)", icon, result.status, result.message, result.duration_sec) return result # ── 主流程 ──────────────────────────────────────────────────── def run_ods_debug( hours: float = 2.0, task_filter: list[str] | None = None, ) -> list[DebugResult]: """执行 ODS 层全量调试。 Args: hours: 回溯窗口小时数(默认 2 小时) task_filter: 仅调试指定的任务代码列表,None 表示全部 Returns: 所有任务的 DebugResult 列表 """ logger = _setup_logging() logger.info("=" * 60) logger.info("ODS 层调试开始") logger.info("=" * 60) # 加载配置(从 .env) config = AppConfig.load() tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai")) window_end = datetime.now(tz) window_start = window_end - timedelta(hours=hours) logger.info("门店 ID: %s", config.get("app.store_id")) logger.info("数据库: %s", config.get("db.name", "")) logger.info("API: %s", config.get("api.base_url", "")) logger.info("时间窗口: %s ~ %s (%.1f 小时)", window_start, window_end, hours) # 设置 window_override 让所有任务使用统一的小窗口 config.config.setdefault("run", {}).setdefault("window_override", {}) config.config["run"]["window_override"]["start"] = window_start config.config["run"]["window_override"]["end"] = window_end # 构建组件 db_conn, api_client, db_ops, executor = _build_components(config, logger) # 获取所有 ODS 层任务 all_ods_codes = sorted(default_registry.get_tasks_by_layer("ODS")) if task_filter: filter_set = {t.upper() for t in task_filter} ods_codes = [c for c in all_ods_codes if c in filter_set] skipped = filter_set - set(ods_codes) if skipped: logger.warning("以下任务不在 ODS 层注册表中,已跳过: %s", skipped) else: ods_codes = all_ods_codes logger.info("待调试 ODS 任务: %d 个", len(ods_codes)) logger.info("任务列表: %s", ", ".join(ods_codes)) logger.info("") # 逐个执行 results: list[DebugResult] = [] for idx, task_code in enumerate(ods_codes, start=1): logger.info("[%d/%d] %s", idx, len(ods_codes), task_code) try: r = debug_single_ods_task( task_code=task_code, executor=executor, db_conn=db_conn, config=config, logger=logger, window_start=window_start, window_end=window_end, ) except Exception as exc: r = DebugResult( task_code=task_code, status="ERROR", message=f"未捕获异常: {exc}", error_detail=traceback.format_exc(), ) logger.error(" ✗ 未捕获异常: %s", exc) results.append(r) # 确保连接可用(防止长时间运行后断连) db_conn.ensure_open() # 汇总 _print_summary(results, logger) # 输出 JSON 结果 output_dir = _FEIQIU_ROOT / "scripts" / "debug" / "output" output_dir.mkdir(parents=True, exist_ok=True) ts = datetime.now(tz).strftime("%Y%m%d_%H%M%S") output_file = output_dir / f"debug_ods_{ts}.json" _save_results(results, output_file) logger.info("结果已保存: %s", output_file) # 清理 db_conn.close() return results def _print_summary(results: list[DebugResult], logger: logging.Logger): """打印调试汇总。""" logger.info("") logger.info("=" * 60) logger.info("ODS 层调试汇总") logger.info("=" * 60) pass_count = sum(1 for r in results if r.status == "PASS") warn_count = sum(1 for r in results if r.status == "WARN") error_count = sum(1 for r in results if r.status in ("ERROR", "FAIL")) total_duration = sum(r.duration_sec for r in results) logger.info("总计: %d 个任务", len(results)) logger.info(" ✓ PASS: %d", pass_count) logger.info(" ⚠ WARN: %d", warn_count) logger.info(" ✗ ERROR: %d", error_count) logger.info(" 总耗时: %.1f 秒", total_duration) logger.info("") # 列出非 PASS 的任务 non_pass = [r for r in results if r.status != "PASS"] if non_pass: logger.info("需关注的任务:") for r in non_pass: logger.info(" [%s] %s: %s", r.status, r.task_code, r.message) else: logger.info("所有任务均通过 ✓") def _save_results(results: list[DebugResult], path: Path): """将结果序列化为 JSON。""" data = [] for r in results: d = asdict(r) # datetime 不可直接序列化,counts 中可能有 datetime data.append(_sanitize_for_json(d)) path.write_text(json.dumps(data, ensure_ascii=False, indent=2, default=str), encoding="utf-8") def _sanitize_for_json(obj): """递归处理不可序列化的值。""" if isinstance(obj, dict): return {k: _sanitize_for_json(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [_sanitize_for_json(v) for v in obj] if isinstance(obj, datetime): return obj.isoformat() return obj # ── CLI 入口 ────────────────────────────────────────────────── def parse_args(): parser = argparse.ArgumentParser(description="ODS 层逐任务调试") parser.add_argument("--hours", type=float, default=2.0, help="回溯窗口小时数(默认 2)") parser.add_argument("--tasks", type=str, default=None, help="仅调试指定任务,逗号分隔(如 ODS_MEMBER,ODS_PAYMENT)") return parser.parse_args() def main(): args = parse_args() task_filter = None if args.tasks: task_filter = [t.strip().upper() for t in args.tasks.split(",") if t.strip()] results = run_ods_debug(hours=args.hours, task_filter=task_filter) # 退出码: 有 ERROR 则非零 has_error = any(r.status in ("ERROR", "FAIL") for r in results) sys.exit(1 if has_error else 0) if __name__ == "__main__": main()