106 lines
4.9 KiB
Python
106 lines
4.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""DWD 质量核对任务:按 dwd_quality_check.md 输出行数/金额对照报表。"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
|
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
from .base_task import BaseTask, TaskContext
|
|
from .dwd_load_task import DwdLoadTask
|
|
|
|
|
|
class DwdQualityTask(BaseTask):
|
|
"""对 ODS 与 DWD 进行行数、金额对照核查,生成 JSON 报表。"""
|
|
|
|
REPORT_PATH = Path("etl_billiards/reports/dwd_quality_report.json")
|
|
AMOUNT_KEYWORDS = ("amount", "money", "fee", "balance")
|
|
|
|
def get_task_code(self) -> str:
|
|
"""返回任务编码。"""
|
|
return "DWD_QUALITY_CHECK"
|
|
|
|
def extract(self, context: TaskContext) -> dict[str, Any]:
|
|
"""准备运行时上下文。"""
|
|
return {"now": datetime.now()}
|
|
|
|
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict[str, Any]:
|
|
"""输出行数/金额差异报表到本地文件。"""
|
|
report: Dict[str, Any] = {
|
|
"generated_at": extracted["now"].isoformat(),
|
|
"tables": [],
|
|
"note": "行数/金额核对,金额字段基于列名包含 amount/money/fee/balance 的数值列自动扫描。",
|
|
}
|
|
|
|
with self.db.conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|
for dwd_table, ods_table in DwdLoadTask.TABLE_MAP.items():
|
|
count_info = self._compare_counts(cur, dwd_table, ods_table)
|
|
amount_info = self._compare_amounts(cur, dwd_table, ods_table)
|
|
report["tables"].append(
|
|
{
|
|
"dwd_table": dwd_table,
|
|
"ods_table": ods_table,
|
|
"count": count_info,
|
|
"amounts": amount_info,
|
|
}
|
|
)
|
|
|
|
self.REPORT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
self.REPORT_PATH.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
self.logger.info("DWD 质检报表已生成:%s", self.REPORT_PATH)
|
|
return {"report_path": str(self.REPORT_PATH)}
|
|
|
|
# ---------------------- helpers ----------------------
|
|
def _compare_counts(self, cur, dwd_table: str, ods_table: str) -> Dict[str, Any]:
|
|
"""统计两端行数并返回差异。"""
|
|
dwd_schema, dwd_name = self._split_table_name(dwd_table, default_schema="billiards_dwd")
|
|
ods_schema, ods_name = self._split_table_name(ods_table, default_schema="billiards_ods")
|
|
cur.execute(f'SELECT COUNT(1) AS cnt FROM "{dwd_schema}"."{dwd_name}"')
|
|
dwd_cnt = cur.fetchone()["cnt"]
|
|
cur.execute(f'SELECT COUNT(1) AS cnt FROM "{ods_schema}"."{ods_name}"')
|
|
ods_cnt = cur.fetchone()["cnt"]
|
|
return {"dwd": dwd_cnt, "ods": ods_cnt, "diff": dwd_cnt - ods_cnt}
|
|
|
|
def _compare_amounts(self, cur, dwd_table: str, ods_table: str) -> List[Dict[str, Any]]:
|
|
"""扫描金额相关列,生成 ODS 与 DWD 的汇总对照。"""
|
|
dwd_schema, dwd_name = self._split_table_name(dwd_table, default_schema="billiards_dwd")
|
|
ods_schema, ods_name = self._split_table_name(ods_table, default_schema="billiards_ods")
|
|
|
|
dwd_amount_cols = self._get_numeric_amount_columns(cur, dwd_schema, dwd_name)
|
|
ods_amount_cols = self._get_numeric_amount_columns(cur, ods_schema, ods_name)
|
|
common_amount_cols = sorted(set(dwd_amount_cols) & set(ods_amount_cols))
|
|
|
|
results: List[Dict[str, Any]] = []
|
|
for col in common_amount_cols:
|
|
cur.execute(f'SELECT COALESCE(SUM("{col}"),0) AS val FROM "{dwd_schema}"."{dwd_name}"')
|
|
dwd_sum = cur.fetchone()["val"]
|
|
cur.execute(f'SELECT COALESCE(SUM("{col}"),0) AS val FROM "{ods_schema}"."{ods_name}"')
|
|
ods_sum = cur.fetchone()["val"]
|
|
results.append({"column": col, "dwd_sum": float(dwd_sum or 0), "ods_sum": float(ods_sum or 0), "diff": float(dwd_sum or 0) - float(ods_sum or 0)})
|
|
return results
|
|
|
|
def _get_numeric_amount_columns(self, cur, schema: str, table: str) -> List[str]:
|
|
"""获取列名包含金额关键词的数值型字段。"""
|
|
cur.execute(
|
|
"""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = %s
|
|
AND table_name = %s
|
|
AND data_type IN ('numeric','double precision','integer','bigint','smallint','real','decimal')
|
|
""",
|
|
(schema, table),
|
|
)
|
|
cols = [r["column_name"].lower() for r in cur.fetchall()]
|
|
return [c for c in cols if any(key in c for key in self.AMOUNT_KEYWORDS)]
|
|
|
|
def _split_table_name(self, name: str, default_schema: str) -> Tuple[str, str]:
|
|
"""拆分 schema 与表名,缺省使用 default_schema。"""
|
|
parts = name.split(".")
|
|
if len(parts) == 2:
|
|
return parts[0], parts[1]
|
|
return default_schema, name
|