Files
feiqiu-ETL/etl_billiards/quality/integrity_checker.py
2026-01-27 22:47:05 +08:00

391 lines
13 KiB
Python

# -*- coding: utf-8 -*-
"""Integrity checks across API -> ODS -> DWD."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple
from zoneinfo import ZoneInfo
import json
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.dwd_load_task import DwdLoadTask
from scripts.check_ods_gaps import run_gap_check
AMOUNT_KEYWORDS = ("amount", "money", "fee", "balance")
@dataclass(frozen=True)
class IntegrityWindow:
start: datetime
end: datetime
label: str
granularity: str
def _ensure_tz(dt: datetime, tz: ZoneInfo) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=tz)
return dt.astimezone(tz)
def _month_start(day: date) -> date:
return date(day.year, day.month, 1)
def _next_month(day: date) -> date:
if day.month == 12:
return date(day.year + 1, 1, 1)
return date(day.year, day.month + 1, 1)
def _date_to_start(dt: date, tz: ZoneInfo) -> datetime:
return datetime.combine(dt, time.min).replace(tzinfo=tz)
def _date_to_end_exclusive(dt: date, tz: ZoneInfo) -> datetime:
return datetime.combine(dt, time.min).replace(tzinfo=tz) + timedelta(days=1)
def build_history_windows(start_dt: datetime, end_dt: datetime, tz: ZoneInfo) -> List[IntegrityWindow]:
"""Build weekly windows for current month, monthly windows for earlier months."""
start_dt = _ensure_tz(start_dt, tz)
end_dt = _ensure_tz(end_dt, tz)
if end_dt <= start_dt:
return []
start_date = start_dt.date()
end_date = end_dt.date()
current_month_start = _month_start(end_date)
windows: List[IntegrityWindow] = []
cur = start_date
while cur <= end_date:
month_start = _month_start(cur)
month_end_exclusive = _next_month(cur)
range_start = max(cur, month_start)
range_end = min(end_date, month_end_exclusive - timedelta(days=1))
if month_start == current_month_start:
week_start = range_start
while week_start <= range_end:
week_end = min(week_start + timedelta(days=6), range_end)
w_start_dt = _date_to_start(week_start, tz)
w_end_dt = _date_to_end_exclusive(week_end, tz)
if w_start_dt < end_dt and w_end_dt > start_dt:
windows.append(
IntegrityWindow(
start=max(w_start_dt, start_dt),
end=min(w_end_dt, end_dt),
label=f"week_{week_start.isoformat()}",
granularity="week",
)
)
week_start = week_end + timedelta(days=1)
else:
m_start_dt = _date_to_start(range_start, tz)
m_end_dt = _date_to_end_exclusive(range_end, tz)
if m_start_dt < end_dt and m_end_dt > start_dt:
windows.append(
IntegrityWindow(
start=max(m_start_dt, start_dt),
end=min(m_end_dt, end_dt),
label=f"month_{month_start.isoformat()}",
granularity="month",
)
)
cur = month_end_exclusive
return windows
def _split_table(name: str, default_schema: str) -> Tuple[str, str]:
if "." in name:
schema, table = name.split(".", 1)
return schema, table
return default_schema, name
def _pick_time_column(dwd_cols: Iterable[str], ods_cols: Iterable[str]) -> str | None:
lower_cols = {c.lower() for c in dwd_cols} & {c.lower() for c in ods_cols}
for candidate in DwdLoadTask.FACT_ORDER_CANDIDATES:
if candidate.lower() in lower_cols:
return candidate.lower()
return None
def _fetch_columns(cur, schema: str, table: str) -> Tuple[List[str], Dict[str, str]]:
cur.execute(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(schema, table),
)
cols = []
types: Dict[str, str] = {}
for name, data_type in cur.fetchall():
cols.append(name)
types[name.lower()] = (data_type or "").lower()
return cols, types
def _amount_columns(cols: List[str], types: Dict[str, str]) -> List[str]:
numeric_types = {"numeric", "double precision", "integer", "bigint", "smallint", "real", "decimal"}
out = []
for col in cols:
lc = col.lower()
if types.get(lc) not in numeric_types:
continue
if any(key in lc for key in AMOUNT_KEYWORDS):
out.append(lc)
return out
def _count_table(cur, schema: str, table: str, time_col: str | None, window: IntegrityWindow | None) -> int:
where = ""
params: List[Any] = []
if time_col and window:
where = f'WHERE "{time_col}" >= %s AND "{time_col}" < %s'
params = [window.start, window.end]
sql = f'SELECT COUNT(1) FROM "{schema}"."{table}" {where}'
cur.execute(sql, params)
row = cur.fetchone()
return int(row[0] if row else 0)
def _sum_column(cur, schema: str, table: str, col: str, time_col: str | None, window: IntegrityWindow | None) -> float:
where = ""
params: List[Any] = []
if time_col and window:
where = f'WHERE "{time_col}" >= %s AND "{time_col}" < %s'
params = [window.start, window.end]
sql = f'SELECT COALESCE(SUM("{col}"), 0) FROM "{schema}"."{table}" {where}'
cur.execute(sql, params)
row = cur.fetchone()
return float(row[0] if row else 0)
def run_dwd_vs_ods_check(
*,
cfg: AppConfig,
window: IntegrityWindow | None,
include_dimensions: bool,
) -> Dict[str, Any]:
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
db_conn = DatabaseConnection(dsn=dsn, session=session)
try:
with db_conn.conn.cursor() as cur:
results: List[Dict[str, Any]] = []
table_map = DwdLoadTask.TABLE_MAP
for dwd_table, ods_table in table_map.items():
if not include_dimensions and ".dim_" in dwd_table:
continue
schema_dwd, name_dwd = _split_table(dwd_table, "billiards_dwd")
schema_ods, name_ods = _split_table(ods_table, "billiards_ods")
try:
dwd_cols, dwd_types = _fetch_columns(cur, schema_dwd, name_dwd)
ods_cols, ods_types = _fetch_columns(cur, schema_ods, name_ods)
time_col = _pick_time_column(dwd_cols, ods_cols)
count_dwd = _count_table(cur, schema_dwd, name_dwd, time_col, window)
count_ods = _count_table(cur, schema_ods, name_ods, time_col, window)
dwd_amount_cols = _amount_columns(dwd_cols, dwd_types)
ods_amount_cols = _amount_columns(ods_cols, ods_types)
common_amount_cols = sorted(set(dwd_amount_cols) & set(ods_amount_cols))
amounts: List[Dict[str, Any]] = []
for col in common_amount_cols:
dwd_sum = _sum_column(cur, schema_dwd, name_dwd, col, time_col, window)
ods_sum = _sum_column(cur, schema_ods, name_ods, col, time_col, window)
amounts.append(
{
"column": col,
"dwd_sum": dwd_sum,
"ods_sum": ods_sum,
"diff": dwd_sum - ods_sum,
}
)
results.append(
{
"dwd_table": dwd_table,
"ods_table": ods_table,
"windowed": bool(time_col and window),
"window_col": time_col,
"count": {"dwd": count_dwd, "ods": count_ods, "diff": count_dwd - count_ods},
"amounts": amounts,
}
)
except Exception as exc: # noqa: BLE001
results.append(
{
"dwd_table": dwd_table,
"ods_table": ods_table,
"windowed": bool(window),
"window_col": None,
"count": {"dwd": None, "ods": None, "diff": None},
"amounts": [],
"error": f"{type(exc).__name__}: {exc}",
}
)
total_count_diff = sum(
int(item.get("count", {}).get("diff") or 0)
for item in results
if isinstance(item.get("count", {}).get("diff"), (int, float))
)
return {
"tables": results,
"total_count_diff": total_count_diff,
}
finally:
db_conn.close()
def _default_report_path(prefix: str) -> Path:
root = Path(__file__).resolve().parents[1]
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return root / "reports" / f"{prefix}_{stamp}.json"
def run_integrity_window(
*,
cfg: AppConfig,
window: IntegrityWindow,
include_dimensions: bool,
task_codes: str,
logger,
write_report: bool,
report_path: Path | None = None,
window_split_unit: str | None = None,
window_compensation_hours: int | None = None,
) -> Dict[str, Any]:
total_seconds = max(0, int((window.end - window.start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
ods_payload = run_gap_check(
cfg=cfg,
start=window.start,
end=window.end,
window_days=window_days,
window_hours=window_hours,
page_size=int(cfg.get("api.page_size") or 200),
chunk_size=500,
sample_limit=50,
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes,
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=logger,
window_split_unit=window_split_unit,
window_compensation_hours=window_compensation_hours,
)
dwd_payload = run_dwd_vs_ods_check(
cfg=cfg,
window=window,
include_dimensions=include_dimensions,
)
report = {
"mode": "window",
"window": {
"start": window.start.isoformat(),
"end": window.end.isoformat(),
"label": window.label,
"granularity": window.granularity,
},
"api_to_ods": ods_payload,
"ods_to_dwd": dwd_payload,
"generated_at": datetime.now(ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))).isoformat(),
}
if write_report:
path = report_path or _default_report_path("data_integrity_window")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
report["report_path"] = str(path)
return report
def run_integrity_history(
*,
cfg: AppConfig,
start_dt: datetime,
end_dt: datetime,
include_dimensions: bool,
task_codes: str,
logger,
write_report: bool,
report_path: Path | None = None,
) -> Dict[str, Any]:
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
windows = build_history_windows(start_dt, end_dt, tz)
results: List[Dict[str, Any]] = []
total_missing = 0
total_errors = 0
for window in windows:
logger.info("校验窗口 起始=%s 结束=%s", window.start, window.end)
payload = run_integrity_window(
cfg=cfg,
window=window,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=logger,
write_report=False,
)
results.append(payload)
total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0)
total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0)
report = {
"mode": "history",
"start": _ensure_tz(start_dt, tz).isoformat(),
"end": _ensure_tz(end_dt, tz).isoformat(),
"windows": results,
"total_missing": total_missing,
"total_errors": total_errors,
"generated_at": datetime.now(tz).isoformat(),
}
if write_report:
path = report_path or _default_report_path("data_integrity_history")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
report["report_path"] = str(path)
return report
def compute_last_etl_end(cfg: AppConfig) -> datetime | None:
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
db_conn = DatabaseConnection(dsn=dsn, session=session)
try:
rows = db_conn.query(
"SELECT MAX(window_end) AS mx FROM etl_admin.etl_run WHERE store_id = %s",
(cfg.get("app.store_id"),),
)
mx = rows[0]["mx"] if rows else None
if isinstance(mx, datetime):
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
return _ensure_tz(mx, tz)
finally:
db_conn.close()
return None