Files

419 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ODS 层逐任务调试脚本。
连接真实 API 和数据库,逐个执行 23 个 ODS 任务(小窗口),
验证返回结果和 ODS 表实际写入行数的一致性。
用法:
cd apps/etl/connectors/feiqiu
python -m scripts.debug.debug_ods [--hours 2] [--tasks ODS_MEMBER,ODS_PAYMENT]
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time
import traceback
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
# ── 确保项目根目录在 sys.path ──
_FEIQIU_ROOT = Path(__file__).resolve().parents[2]
if str(_FEIQIU_ROOT) not in sys.path:
sys.path.insert(0, str(_FEIQIU_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from api.client import APIClient
from orchestration.task_registry import default_registry
from orchestration.cursor_manager import CursorManager
from orchestration.run_tracker import RunTracker
from orchestration.task_executor import TaskExecutor
@dataclass
class DebugResult:
"""单个 ODS 任务的调试结果"""
layer: str = "ODS"
task_code: str = ""
status: str = "" # PASS / FAIL / WARN / ERROR
message: str = ""
counts: dict = field(default_factory=dict)
db_row_count: int | None = None
count_match: bool | None = None
duration_sec: float = 0.0
error_detail: str | None = None
table_name: str = ""
fix_applied: str | None = None
# ── 工具函数 ──────────────────────────────────────────────────
def _setup_logging() -> logging.Logger:
logger = logging.getLogger("debug_ods")
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S"
))
logger.addHandler(handler)
return logger
def _get_ods_table_name(task_code: str) -> str | None:
"""从 TaskRegistry 获取 ODS 任务对应的表名。"""
meta = default_registry.get_metadata(task_code)
if meta is None:
return None
# 通过临时实例获取 SPEC.table_name所有 ODS 任务类都有 SPEC 属性)
task_cls = meta.task_class
spec = getattr(task_cls, "SPEC", None)
if spec and hasattr(spec, "table_name"):
return spec.table_name
return None
def _query_table_count(db_conn: DatabaseConnection, table_name: str,
window_start: datetime, window_end: datetime) -> int:
"""查询 ODS 表在指定时间窗口内的行数。
优先用 fetched_at 列过滤;若该列不存在则回退到全表 COUNT。
"""
# 先检查 fetched_at 列是否存在
check_sql = """
SELECT 1 FROM information_schema.columns
WHERE table_schema || '.' || table_name = %s
AND column_name = 'fetched_at'
LIMIT 1
"""
schema_table = table_name # 格式: ods.xxx
rows = db_conn.query(check_sql, (schema_table,))
if rows:
count_sql = f"SELECT COUNT(*) AS cnt FROM {table_name} WHERE fetched_at >= %s AND fetched_at < %s"
result = db_conn.query(count_sql, (window_start, window_end))
else:
count_sql = f"SELECT COUNT(*) AS cnt FROM {table_name}"
result = db_conn.query(count_sql)
return int(result[0]["cnt"]) if result else 0
def _build_components(config: AppConfig, logger: logging.Logger):
"""构建 DB / API / TaskExecutor 等组件,与 CLI main() 保持一致。"""
db_conn = DatabaseConnection(
dsn=config["db"]["dsn"],
session=config["db"].get("session"),
connect_timeout=config["db"].get("connect_timeout_sec"),
)
api_client = APIClient(
base_url=config["api"]["base_url"],
token=config["api"]["token"],
timeout=config["api"].get("timeout_sec", 20),
retry_max=config["api"].get("retries", {}).get("max_attempts", 3),
headers_extra=config["api"].get("headers_extra"),
)
db_ops = DatabaseOperations(db_conn)
cursor_mgr = CursorManager(db_conn)
run_tracker = RunTracker(db_conn)
executor = TaskExecutor(
config, db_ops, api_client,
cursor_mgr, run_tracker, default_registry, logger,
)
return db_conn, api_client, db_ops, executor
# ── 核心调试逻辑 ──────────────────────────────────────────────
def debug_single_ods_task(
task_code: str,
executor: TaskExecutor,
db_conn: DatabaseConnection,
config: AppConfig,
logger: logging.Logger,
window_start: datetime,
window_end: datetime,
) -> DebugResult:
"""执行单个 ODS 任务并验证结果。"""
result = DebugResult(task_code=task_code)
table_name = _get_ods_table_name(task_code)
result.table_name = table_name or ""
store_id = int(config.get("app.store_id"))
run_uuid = f"debug-ods-{task_code.lower()}-{int(time.time())}"
logger.info("" * 60)
logger.info("▶ 开始调试: %s (表: %s)", task_code, table_name or "未知")
# 执行前查询表行数(用于对比增量)
pre_count = None
if table_name:
try:
pre_count = _query_table_count(db_conn, table_name, window_start, window_end)
logger.info(" 执行前表行数 (窗口内): %d", pre_count)
except Exception as exc:
logger.warning(" 查询执行前行数失败: %s", exc)
# 执行任务
t0 = time.monotonic()
try:
task_result = executor.run_single_task(
task_code=task_code,
run_uuid=run_uuid,
store_id=store_id,
data_source="online",
)
result.duration_sec = round(time.monotonic() - t0, 2)
except Exception as exc:
result.duration_sec = round(time.monotonic() - t0, 2)
result.status = "ERROR"
result.message = f"任务执行异常: {exc}"
result.error_detail = traceback.format_exc()
logger.error(" ✗ 执行异常: %s", exc)
return result
# 解析返回结果
task_status = (task_result.get("status") or "").upper()
counts = task_result.get("counts") or {}
result.counts = counts
logger.info(" 返回状态: %s", task_status)
logger.info(" counts: fetched=%s inserted=%s updated=%s skipped=%s errors=%s",
counts.get("fetched", 0), counts.get("inserted", 0),
counts.get("updated", 0), counts.get("skipped", 0),
counts.get("errors", 0))
# 验证 counts 合理性
fetched = counts.get("fetched", 0)
inserted = counts.get("inserted", 0)
updated = counts.get("updated", 0)
skipped = counts.get("skipped", 0)
errors = counts.get("errors", 0)
# 基本校验: fetched >= inserted + updated + skipped
accounted = inserted + updated + skipped
if fetched > 0 and accounted > fetched:
result.status = "WARN"
result.message = f"counts 异常: accounted({accounted}) > fetched({fetched})"
logger.warning("%s", result.message)
# 执行后查询表行数
if table_name:
try:
post_count = _query_table_count(db_conn, table_name, window_start, window_end)
result.db_row_count = post_count
logger.info(" 执行后表行数 (窗口内): %d", post_count)
# 对比增量: 新增行数应约等于 inserted
if pre_count is not None:
actual_delta = post_count - pre_count
# inserted 是本次新插入的行数
if inserted > 0 and actual_delta == 0:
# 可能是冲突处理导致无新增DO NOTHING / update
logger.info(" 无新增行(可能是冲突处理: DO NOTHING / update")
result.count_match = True # 标记已完成对比
logger.info(" 实际新增行数: %d, counts.inserted: %d", actual_delta, inserted)
except Exception as exc:
logger.warning(" 查询执行后行数失败: %s", exc)
# 最终状态判定
if result.status == "":
if errors > 0:
result.status = "WARN"
result.message = f"执行完成但有 {errors} 个错误"
elif task_status in ("SUCCESS", "PARTIAL"):
result.status = "PASS"
result.message = f"执行成功, fetched={fetched}"
elif task_status == "SKIP":
result.status = "WARN"
result.message = "任务被跳过(未启用或不存在)"
else:
result.status = "WARN"
result.message = f"未知状态: {task_status}"
icon = {"PASS": "", "WARN": "", "ERROR": "", "FAIL": ""}.get(result.status, "?")
logger.info(" %s 结果: %s - %s (耗时 %.1fs)", icon, result.status, result.message, result.duration_sec)
return result
# ── 主流程 ────────────────────────────────────────────────────
def run_ods_debug(
hours: float = 2.0,
task_filter: list[str] | None = None,
) -> list[DebugResult]:
"""执行 ODS 层全量调试。
Args:
hours: 回溯窗口小时数(默认 2 小时)
task_filter: 仅调试指定的任务代码列表None 表示全部
Returns:
所有任务的 DebugResult 列表
"""
logger = _setup_logging()
logger.info("=" * 60)
logger.info("ODS 层调试开始")
logger.info("=" * 60)
# 加载配置(从 .env
config = AppConfig.load()
tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
window_end = datetime.now(tz)
window_start = window_end - timedelta(hours=hours)
logger.info("门店 ID: %s", config.get("app.store_id"))
logger.info("数据库: %s", config.get("db.name", ""))
logger.info("API: %s", config.get("api.base_url", ""))
logger.info("时间窗口: %s ~ %s (%.1f 小时)", window_start, window_end, hours)
# 设置 window_override 让所有任务使用统一的小窗口
config.config.setdefault("run", {}).setdefault("window_override", {})
config.config["run"]["window_override"]["start"] = window_start
config.config["run"]["window_override"]["end"] = window_end
# 构建组件
db_conn, api_client, db_ops, executor = _build_components(config, logger)
# 获取所有 ODS 层任务
all_ods_codes = sorted(default_registry.get_tasks_by_layer("ODS"))
if task_filter:
filter_set = {t.upper() for t in task_filter}
ods_codes = [c for c in all_ods_codes if c in filter_set]
skipped = filter_set - set(ods_codes)
if skipped:
logger.warning("以下任务不在 ODS 层注册表中,已跳过: %s", skipped)
else:
ods_codes = all_ods_codes
logger.info("待调试 ODS 任务: %d", len(ods_codes))
logger.info("任务列表: %s", ", ".join(ods_codes))
logger.info("")
# 逐个执行
results: list[DebugResult] = []
for idx, task_code in enumerate(ods_codes, start=1):
logger.info("[%d/%d] %s", idx, len(ods_codes), task_code)
try:
r = debug_single_ods_task(
task_code=task_code,
executor=executor,
db_conn=db_conn,
config=config,
logger=logger,
window_start=window_start,
window_end=window_end,
)
except Exception as exc:
r = DebugResult(
task_code=task_code,
status="ERROR",
message=f"未捕获异常: {exc}",
error_detail=traceback.format_exc(),
)
logger.error(" ✗ 未捕获异常: %s", exc)
results.append(r)
# 确保连接可用(防止长时间运行后断连)
db_conn.ensure_open()
# 汇总
_print_summary(results, logger)
# 输出 JSON 结果
output_dir = _FEIQIU_ROOT / "scripts" / "debug" / "output"
output_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(tz).strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"debug_ods_{ts}.json"
_save_results(results, output_file)
logger.info("结果已保存: %s", output_file)
# 清理
db_conn.close()
return results
def _print_summary(results: list[DebugResult], logger: logging.Logger):
"""打印调试汇总。"""
logger.info("")
logger.info("=" * 60)
logger.info("ODS 层调试汇总")
logger.info("=" * 60)
pass_count = sum(1 for r in results if r.status == "PASS")
warn_count = sum(1 for r in results if r.status == "WARN")
error_count = sum(1 for r in results if r.status in ("ERROR", "FAIL"))
total_duration = sum(r.duration_sec for r in results)
logger.info("总计: %d 个任务", len(results))
logger.info(" ✓ PASS: %d", pass_count)
logger.info(" ⚠ WARN: %d", warn_count)
logger.info(" ✗ ERROR: %d", error_count)
logger.info(" 总耗时: %.1f", total_duration)
logger.info("")
# 列出非 PASS 的任务
non_pass = [r for r in results if r.status != "PASS"]
if non_pass:
logger.info("需关注的任务:")
for r in non_pass:
logger.info(" [%s] %s: %s", r.status, r.task_code, r.message)
else:
logger.info("所有任务均通过 ✓")
def _save_results(results: list[DebugResult], path: Path):
"""将结果序列化为 JSON。"""
data = []
for r in results:
d = asdict(r)
# datetime 不可直接序列化counts 中可能有 datetime
data.append(_sanitize_for_json(d))
path.write_text(json.dumps(data, ensure_ascii=False, indent=2, default=str), encoding="utf-8")
def _sanitize_for_json(obj):
"""递归处理不可序列化的值。"""
if isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_sanitize_for_json(v) for v in obj]
if isinstance(obj, datetime):
return obj.isoformat()
return obj
# ── CLI 入口 ──────────────────────────────────────────────────
def parse_args():
parser = argparse.ArgumentParser(description="ODS 层逐任务调试")
parser.add_argument("--hours", type=float, default=2.0,
help="回溯窗口小时数(默认 2")
parser.add_argument("--tasks", type=str, default=None,
help="仅调试指定任务,逗号分隔(如 ODS_MEMBER,ODS_PAYMENT")
return parser.parse_args()
def main():
args = parse_args()
task_filter = None
if args.tasks:
task_filter = [t.strip().upper() for t in args.tasks.split(",") if t.strip()]
results = run_ods_debug(hours=args.hours, task_filter=task_filter)
# 退出码: 有 ERROR 则非零
has_error = any(r.status in ("ERROR", "FAIL") for r in results)
sys.exit(1 if has_error else 0)
if __name__ == "__main__":
main()