Files
feiqiu-ETL/etl_billiards/quality/integrity_checker.py

745 lines
29 KiB
Python

# -*- coding: utf-8 -*-
"""Integrity checks across API -> ODS -> DWD."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple
from zoneinfo import ZoneInfo
import json
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.dwd_load_task import DwdLoadTask
from scripts.check_ods_gaps import run_gap_check
AMOUNT_KEYWORDS = ("amount", "money", "fee", "balance")
@dataclass(frozen=True)
class IntegrityWindow:
start: datetime
end: datetime
label: str
granularity: str
def _ensure_tz(dt: datetime, tz: ZoneInfo) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=tz)
return dt.astimezone(tz)
def _month_start(day: date) -> date:
return date(day.year, day.month, 1)
def _next_month(day: date) -> date:
if day.month == 12:
return date(day.year + 1, 1, 1)
return date(day.year, day.month + 1, 1)
def _date_to_start(dt: date, tz: ZoneInfo) -> datetime:
return datetime.combine(dt, time.min).replace(tzinfo=tz)
def _date_to_end_exclusive(dt: date, tz: ZoneInfo) -> datetime:
return datetime.combine(dt, time.min).replace(tzinfo=tz) + timedelta(days=1)
def build_history_windows(start_dt: datetime, end_dt: datetime, tz: ZoneInfo) -> List[IntegrityWindow]:
"""Build weekly windows for current month, monthly windows for earlier months."""
start_dt = _ensure_tz(start_dt, tz)
end_dt = _ensure_tz(end_dt, tz)
if end_dt <= start_dt:
return []
start_date = start_dt.date()
end_date = end_dt.date()
current_month_start = _month_start(end_date)
windows: List[IntegrityWindow] = []
cur = start_date
while cur <= end_date:
month_start = _month_start(cur)
month_end_exclusive = _next_month(cur)
range_start = max(cur, month_start)
range_end = min(end_date, month_end_exclusive - timedelta(days=1))
if month_start == current_month_start:
week_start = range_start
while week_start <= range_end:
week_end = min(week_start + timedelta(days=6), range_end)
w_start_dt = _date_to_start(week_start, tz)
w_end_dt = _date_to_end_exclusive(week_end, tz)
if w_start_dt < end_dt and w_end_dt > start_dt:
windows.append(
IntegrityWindow(
start=max(w_start_dt, start_dt),
end=min(w_end_dt, end_dt),
label=f"week_{week_start.isoformat()}",
granularity="week",
)
)
week_start = week_end + timedelta(days=1)
else:
m_start_dt = _date_to_start(range_start, tz)
m_end_dt = _date_to_end_exclusive(range_end, tz)
if m_start_dt < end_dt and m_end_dt > start_dt:
windows.append(
IntegrityWindow(
start=max(m_start_dt, start_dt),
end=min(m_end_dt, end_dt),
label=f"month_{month_start.isoformat()}",
granularity="month",
)
)
cur = month_end_exclusive
return windows
def _split_table(name: str, default_schema: str) -> Tuple[str, str]:
if "." in name:
schema, table = name.split(".", 1)
return schema, table
return default_schema, name
def _pick_time_column(dwd_cols: Iterable[str], ods_cols: Iterable[str]) -> str | None:
lower_cols = {c.lower() for c in dwd_cols} & {c.lower() for c in ods_cols}
for candidate in DwdLoadTask.FACT_ORDER_CANDIDATES:
if candidate.lower() in lower_cols:
return candidate.lower()
return None
def _fetch_columns(cur, schema: str, table: str) -> Tuple[List[str], Dict[str, str]]:
cur.execute(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(schema, table),
)
cols = []
types: Dict[str, str] = {}
for name, data_type in cur.fetchall():
cols.append(name)
types[name.lower()] = (data_type or "").lower()
return cols, types
def _amount_columns(cols: List[str], types: Dict[str, str]) -> List[str]:
numeric_types = {"numeric", "double precision", "integer", "bigint", "smallint", "real", "decimal"}
out = []
for col in cols:
lc = col.lower()
if types.get(lc) not in numeric_types:
continue
if any(key in lc for key in AMOUNT_KEYWORDS):
out.append(lc)
return out
def _build_hash_expr(alias: str, cols: list[str]) -> str:
if not cols:
return "NULL"
parts = ", ".join([f"COALESCE({alias}.\"{c}\"::text,'')" for c in cols])
return f"md5(concat_ws('||', {parts}))"
def _build_snapshot_subquery(
schema: str,
table: str,
cols: list[str],
key_cols: list[str],
order_col: str | None,
where_sql: str,
) -> str:
cols_sql = ", ".join([f'"{c}"' for c in cols])
if key_cols and order_col:
keys = ", ".join([f'"{c}"' for c in key_cols])
order_by = ", ".join([*(f'"{c}"' for c in key_cols), f'"{order_col}" DESC NULLS LAST'])
return (
f'SELECT DISTINCT ON ({keys}) {cols_sql} '
f'FROM "{schema}"."{table}" {where_sql} '
f"ORDER BY {order_by}"
)
return f'SELECT {cols_sql} FROM "{schema}"."{table}" {where_sql}'
def _build_snapshot_expr_subquery(
schema: str,
table: str,
select_exprs: list[str],
key_exprs: list[str],
order_col: str | None,
where_sql: str,
) -> str:
select_cols_sql = ", ".join(select_exprs)
table_sql = f'"{schema}"."{table}"'
if key_exprs and order_col:
distinct_on = ", ".join(key_exprs)
order_by = ", ".join([*key_exprs, f'"{order_col}" DESC NULLS LAST'])
return (
f"SELECT DISTINCT ON ({distinct_on}) {select_cols_sql} "
f"FROM {table_sql} {where_sql} "
f"ORDER BY {order_by}"
)
return f"SELECT {select_cols_sql} FROM {table_sql} {where_sql}"
def _cast_expr(col: str, cast_type: str | None) -> str:
if col.upper() == "NULL":
base = "NULL"
else:
is_expr = not col.isidentifier() or "->" in col or "#>>" in col or "::" in col or "'" in col
base = col if is_expr else f'"{col}"'
if cast_type:
cast_lower = cast_type.lower()
if cast_lower in {"bigint", "integer", "numeric", "decimal"}:
return f"CAST(NULLIF(CAST({base} AS text), '') AS numeric):: {cast_type}"
if cast_lower == "timestamptz":
return f"({base})::timestamptz"
return f"{base}::{cast_type}"
return base
def _fetch_pk_columns(cur, schema: str, table: str) -> List[str]:
cur.execute(
"""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
""",
(schema, table),
)
return [r[0] for r in cur.fetchall()]
def _pick_snapshot_order_column(cols: Iterable[str]) -> str | None:
lower = {c.lower() for c in cols}
for candidate in ("fetched_at", "update_time", "create_time"):
if candidate in lower:
return candidate
return None
def _count_table(
cur,
schema: str,
table: str,
time_col: str | None,
window: IntegrityWindow | None,
*,
pk_cols: List[str] | None = None,
snapshot_order_col: str | None = None,
current_only: bool = False,
) -> int:
where_parts: List[str] = []
params: List[Any] = []
if current_only:
where_parts.append("COALESCE(scd2_is_current,1)=1")
if time_col and window:
where_parts.append(f'"{time_col}" >= %s AND "{time_col}" < %s')
params.extend([window.start, window.end])
where = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
if pk_cols and snapshot_order_col:
keys = ", ".join(f'"{c}"' for c in pk_cols)
order_by = ", ".join([*(f'"{c}"' for c in pk_cols), f'"{snapshot_order_col}" DESC NULLS LAST'])
sql = (
f'SELECT COUNT(1) FROM ('
f'SELECT DISTINCT ON ({keys}) 1 FROM "{schema}"."{table}" {where} '
f'ORDER BY {order_by}'
f') t'
)
else:
sql = f'SELECT COUNT(1) FROM "{schema}"."{table}" {where}'
cur.execute(sql, params)
row = cur.fetchone()
return int(row[0] if row else 0)
def _sum_column(
cur,
schema: str,
table: str,
col: str,
time_col: str | None,
window: IntegrityWindow | None,
*,
pk_cols: List[str] | None = None,
snapshot_order_col: str | None = None,
current_only: bool = False,
) -> float:
where_parts: List[str] = []
params: List[Any] = []
if current_only:
where_parts.append("COALESCE(scd2_is_current,1)=1")
if time_col and window:
where_parts.append(f'"{time_col}" >= %s AND "{time_col}" < %s')
params.extend([window.start, window.end])
where = f"WHERE {' AND '.join(where_parts)}" if where_parts else ""
if pk_cols and snapshot_order_col:
keys = ", ".join(f'"{c}"' for c in pk_cols)
order_by = ", ".join([*(f'"{c}"' for c in pk_cols), f'"{snapshot_order_col}" DESC NULLS LAST'])
sql = (
f'SELECT COALESCE(SUM("{col}"), 0) FROM ('
f'SELECT DISTINCT ON ({keys}) "{col}" FROM "{schema}"."{table}" {where} '
f'ORDER BY {order_by}'
f') t'
)
else:
sql = f'SELECT COALESCE(SUM("{col}"), 0) FROM "{schema}"."{table}" {where}'
cur.execute(sql, params)
row = cur.fetchone()
return float(row[0] if row else 0)
def run_dwd_vs_ods_check(
*,
cfg: AppConfig,
window: IntegrityWindow | None,
include_dimensions: bool,
compare_content: bool | None = None,
content_sample_limit: int | None = None,
) -> Dict[str, Any]:
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
db_conn = DatabaseConnection(dsn=dsn, session=session)
if compare_content is None:
compare_content = bool(cfg.get("integrity.compare_content", True))
if content_sample_limit is None:
content_sample_limit = cfg.get("integrity.content_sample_limit") or 50
try:
with db_conn.conn.cursor() as cur:
results: List[Dict[str, Any]] = []
table_map = DwdLoadTask.TABLE_MAP
total_mismatch = 0
for dwd_table, ods_table in table_map.items():
if not include_dimensions and ".dim_" in dwd_table:
continue
schema_dwd, name_dwd = _split_table(dwd_table, "billiards_dwd")
schema_ods, name_ods = _split_table(ods_table, "billiards_ods")
try:
dwd_cols, dwd_types = _fetch_columns(cur, schema_dwd, name_dwd)
ods_cols, ods_types = _fetch_columns(cur, schema_ods, name_ods)
time_col = _pick_time_column(dwd_cols, ods_cols)
pk_dwd = _fetch_pk_columns(cur, schema_dwd, name_dwd)
pk_ods_raw = _fetch_pk_columns(cur, schema_ods, name_ods)
pk_ods = [c for c in pk_ods_raw if c.lower() != "content_hash"]
ods_has_snapshot = any(c.lower() == "content_hash" for c in ods_cols)
ods_snapshot_order = _pick_snapshot_order_column(ods_cols) if ods_has_snapshot else None
dwd_current_only = any(c.lower() == "scd2_is_current" for c in dwd_cols)
count_dwd = _count_table(
cur,
schema_dwd,
name_dwd,
time_col,
window,
current_only=dwd_current_only,
)
count_ods = _count_table(
cur,
schema_ods,
name_ods,
time_col,
window,
pk_cols=pk_ods if ods_has_snapshot else None,
snapshot_order_col=ods_snapshot_order if ods_has_snapshot else None,
)
dwd_amount_cols = _amount_columns(dwd_cols, dwd_types)
ods_amount_cols = _amount_columns(ods_cols, ods_types)
common_amount_cols = sorted(set(dwd_amount_cols) & set(ods_amount_cols))
amounts: List[Dict[str, Any]] = []
for col in common_amount_cols:
dwd_sum = _sum_column(
cur,
schema_dwd,
name_dwd,
col,
time_col,
window,
current_only=dwd_current_only,
)
ods_sum = _sum_column(
cur,
schema_ods,
name_ods,
col,
time_col,
window,
pk_cols=pk_ods if ods_has_snapshot else None,
snapshot_order_col=ods_snapshot_order if ods_has_snapshot else None,
)
amounts.append(
{
"column": col,
"dwd_sum": dwd_sum,
"ods_sum": ods_sum,
"diff": dwd_sum - ods_sum,
}
)
mismatch = None
mismatch_samples: list[dict] = []
mismatch_error = None
if compare_content:
dwd_cols_lower = [c.lower() for c in dwd_cols]
ods_cols_lower = [c.lower() for c in ods_cols]
dwd_col_set = set(dwd_cols_lower)
ods_col_set = set(ods_cols_lower)
scd_cols = {c.lower() for c in DwdLoadTask.SCD_COLS}
ods_exclude = {
"payload", "source_file", "source_endpoint", "fetched_at", "content_hash", "record_index"
}
numeric_types = {
"integer",
"bigint",
"smallint",
"numeric",
"double precision",
"real",
"decimal",
}
text_types = {"text", "character varying", "varchar"}
mapping = {
dst.lower(): (src, cast_type)
for dst, src, cast_type in (DwdLoadTask.FACT_MAPPINGS.get(dwd_table) or [])
}
business_keys = [c for c in pk_dwd if c.lower() not in scd_cols]
def resolve_ods_expr(col: str) -> str | None:
mapped = mapping.get(col)
if mapped:
src, cast_type = mapped
return _cast_expr(src, cast_type)
if col in ods_col_set:
d_type = dwd_types.get(col)
o_type = ods_types.get(col)
if d_type in numeric_types and o_type in text_types:
return _cast_expr(col, d_type)
return f'"{col}"'
if "id" in ods_col_set and col.endswith("_id"):
d_type = dwd_types.get(col)
o_type = ods_types.get("id")
if d_type in numeric_types and o_type in text_types:
return _cast_expr("id", d_type)
return '"id"'
return None
key_exprs: list[str] = []
join_keys: list[str] = []
for key in business_keys:
key_lower = key.lower()
expr = resolve_ods_expr(key_lower)
if expr is None:
key_exprs = []
join_keys = []
break
key_exprs.append(expr)
join_keys.append(key_lower)
compare_cols: list[str] = []
for col in dwd_col_set:
if col in ods_exclude or col in scd_cols:
continue
if col in {k.lower() for k in business_keys}:
continue
if dwd_types.get(col) in ("json", "jsonb"):
continue
if ods_types.get(col) in ("json", "jsonb"):
continue
if resolve_ods_expr(col) is None:
continue
compare_cols.append(col)
compare_cols = sorted(set(compare_cols))
if join_keys and compare_cols:
where_parts_dwd: list[str] = []
params_dwd: list[Any] = []
if dwd_current_only:
where_parts_dwd.append("COALESCE(scd2_is_current,1)=1")
if time_col and window:
where_parts_dwd.append(f"\"{time_col}\" >= %s AND \"{time_col}\" < %s")
params_dwd.extend([window.start, window.end])
where_dwd = f"WHERE {' AND '.join(where_parts_dwd)}" if where_parts_dwd else ""
where_parts_ods: list[str] = []
params_ods: list[Any] = []
if time_col and window:
where_parts_ods.append(f"\"{time_col}\" >= %s AND \"{time_col}\" < %s")
params_ods.extend([window.start, window.end])
where_ods = f"WHERE {' AND '.join(where_parts_ods)}" if where_parts_ods else ""
ods_select_exprs: list[str] = []
needed_cols = sorted(set(join_keys + compare_cols))
for col in needed_cols:
expr = resolve_ods_expr(col)
if expr is None:
continue
ods_select_exprs.append(f"{expr} AS \"{col}\"")
if not ods_select_exprs:
mismatch_error = "join_keys_or_compare_cols_unavailable"
else:
ods_sql = _build_snapshot_expr_subquery(
schema_ods,
name_ods,
ods_select_exprs,
key_exprs,
ods_snapshot_order,
where_ods,
)
dwd_cols_sql = ", ".join([f"\"{c}\"" for c in needed_cols])
dwd_sql = f"SELECT {dwd_cols_sql} FROM \"{schema_dwd}\".\"{name_dwd}\" {where_dwd}"
join_cond = " AND ".join([f"d.\"{k}\" = o.\"{k}\"" for k in join_keys])
hash_o = _build_hash_expr("o", compare_cols)
hash_d = _build_hash_expr("d", compare_cols)
mismatch_sql = (
f"WITH ods_latest AS ({ods_sql}), dwd_filtered AS ({dwd_sql}) "
f"SELECT COUNT(1) FROM ("
f"SELECT 1 FROM ods_latest o JOIN dwd_filtered d ON {join_cond} "
f"WHERE {hash_o} <> {hash_d}"
f") t"
)
params = params_ods + params_dwd
cur.execute(mismatch_sql, params)
row = cur.fetchone()
mismatch = int(row[0] if row and row[0] is not None else 0)
total_mismatch += mismatch
if content_sample_limit and mismatch > 0:
select_keys_sql = ", ".join([f"d.\"{k}\" AS \"{k}\"" for k in join_keys])
sample_sql = (
f"WITH ods_latest AS ({ods_sql}), dwd_filtered AS ({dwd_sql}) "
f"SELECT {select_keys_sql}, {hash_o} AS ods_hash, {hash_d} AS dwd_hash "
f"FROM ods_latest o JOIN dwd_filtered d ON {join_cond} "
f"WHERE {hash_o} <> {hash_d} LIMIT %s"
)
cur.execute(sample_sql, params + [int(content_sample_limit)])
rows = cur.fetchall() or []
if rows:
columns = [desc[0] for desc in (cur.description or [])]
mismatch_samples = [dict(zip(columns, r)) for r in rows]
else:
mismatch_error = "join_keys_or_compare_cols_unavailable"
results.append(
{
"dwd_table": dwd_table,
"ods_table": ods_table,
"windowed": bool(time_col and window),
"window_col": time_col,
"count": {"dwd": count_dwd, "ods": count_ods, "diff": count_dwd - count_ods},
"amounts": amounts,
"mismatch": mismatch,
"mismatch_samples": mismatch_samples,
"mismatch_error": mismatch_error,
}
)
except Exception as exc: # noqa: BLE001
results.append(
{
"dwd_table": dwd_table,
"ods_table": ods_table,
"windowed": bool(window),
"window_col": None,
"count": {"dwd": None, "ods": None, "diff": None},
"amounts": [],
"mismatch": None,
"mismatch_samples": [],
"error": f"{type(exc).__name__}: {exc}",
}
)
total_count_diff = sum(
int(item.get("count", {}).get("diff") or 0)
for item in results
if isinstance(item.get("count", {}).get("diff"), (int, float))
)
return {
"tables": results,
"total_count_diff": total_count_diff,
"total_mismatch": total_mismatch,
}
finally:
db_conn.close()
def _default_report_path(prefix: str) -> Path:
root = Path(__file__).resolve().parents[1]
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return root / "reports" / f"{prefix}_{stamp}.json"
def run_integrity_window(
*,
cfg: AppConfig,
window: IntegrityWindow,
include_dimensions: bool,
task_codes: str,
logger,
write_report: bool,
compare_content: bool | None = None,
content_sample_limit: int | None = None,
report_path: Path | None = None,
window_split_unit: str | None = None,
window_compensation_hours: int | None = None,
) -> Dict[str, Any]:
total_seconds = max(0, int((window.end - window.start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
if compare_content is None:
compare_content = bool(cfg.get("integrity.compare_content", True))
if content_sample_limit is None:
content_sample_limit = cfg.get("integrity.content_sample_limit")
ods_payload = run_gap_check(
cfg=cfg,
start=window.start,
end=window.end,
window_days=window_days,
window_hours=window_hours,
page_size=int(cfg.get("api.page_size") or 200),
chunk_size=500,
sample_limit=50,
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes,
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=logger,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
window_split_unit=window_split_unit,
window_compensation_hours=window_compensation_hours,
)
dwd_payload = run_dwd_vs_ods_check(
cfg=cfg,
window=window,
include_dimensions=include_dimensions,
compare_content=compare_content,
content_sample_limit=content_sample_limit,
)
report = {
"mode": "window",
"window": {
"start": window.start.isoformat(),
"end": window.end.isoformat(),
"label": window.label,
"granularity": window.granularity,
},
"api_to_ods": ods_payload,
"ods_to_dwd": dwd_payload,
"generated_at": datetime.now(ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))).isoformat(),
}
if write_report:
path = report_path or _default_report_path("data_integrity_window")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
report["report_path"] = str(path)
return report
def run_integrity_history(
*,
cfg: AppConfig,
start_dt: datetime,
end_dt: datetime,
include_dimensions: bool,
task_codes: str,
logger,
write_report: bool,
compare_content: bool | None = None,
content_sample_limit: int | None = None,
report_path: Path | None = None,
) -> Dict[str, Any]:
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
windows = build_history_windows(start_dt, end_dt, tz)
results: List[Dict[str, Any]] = []
total_missing = 0
total_mismatch = 0
total_errors = 0
for window in windows:
logger.info("校验窗口 起始=%s 结束=%s", window.start, window.end)
payload = run_integrity_window(
cfg=cfg,
window=window,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=logger,
write_report=False,
compare_content=compare_content,
content_sample_limit=content_sample_limit,
)
results.append(payload)
total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0)
total_mismatch += int(payload.get("api_to_ods", {}).get("total_mismatch") or 0)
total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0)
report = {
"mode": "history",
"start": _ensure_tz(start_dt, tz).isoformat(),
"end": _ensure_tz(end_dt, tz).isoformat(),
"windows": results,
"total_missing": total_missing,
"total_mismatch": total_mismatch,
"total_errors": total_errors,
"generated_at": datetime.now(tz).isoformat(),
}
if write_report:
path = report_path or _default_report_path("data_integrity_history")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
report["report_path"] = str(path)
return report
def compute_last_etl_end(cfg: AppConfig) -> datetime | None:
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
db_conn = DatabaseConnection(dsn=dsn, session=session)
try:
rows = db_conn.query(
"SELECT MAX(window_end) AS mx FROM etl_admin.etl_run WHERE store_id = %s",
(cfg.get("app.store_id"),),
)
mx = rows[0]["mx"] if rows else None
if isinstance(mx, datetime):
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
return _ensure_tz(mx, tz)
finally:
db_conn.close()
return None