# -*- coding: utf-8 -*- """Integrity checks across API -> ODS -> DWD.""" from __future__ import annotations from dataclasses import dataclass from datetime import date, datetime, time, timedelta from pathlib import Path from typing import Any, Dict, Iterable, List, Tuple from zoneinfo import ZoneInfo import json from config.settings import AppConfig from database.connection import DatabaseConnection from tasks.dwd_load_task import DwdLoadTask from scripts.check_ods_gaps import run_gap_check AMOUNT_KEYWORDS = ("amount", "money", "fee", "balance") @dataclass(frozen=True) class IntegrityWindow: start: datetime end: datetime label: str granularity: str def _ensure_tz(dt: datetime, tz: ZoneInfo) -> datetime: if dt.tzinfo is None: return dt.replace(tzinfo=tz) return dt.astimezone(tz) def _month_start(day: date) -> date: return date(day.year, day.month, 1) def _next_month(day: date) -> date: if day.month == 12: return date(day.year + 1, 1, 1) return date(day.year, day.month + 1, 1) def _date_to_start(dt: date, tz: ZoneInfo) -> datetime: return datetime.combine(dt, time.min).replace(tzinfo=tz) def _date_to_end_exclusive(dt: date, tz: ZoneInfo) -> datetime: return datetime.combine(dt, time.min).replace(tzinfo=tz) + timedelta(days=1) def build_history_windows(start_dt: datetime, end_dt: datetime, tz: ZoneInfo) -> List[IntegrityWindow]: """Build weekly windows for current month, monthly windows for earlier months.""" start_dt = _ensure_tz(start_dt, tz) end_dt = _ensure_tz(end_dt, tz) if end_dt <= start_dt: return [] start_date = start_dt.date() end_date = end_dt.date() current_month_start = _month_start(end_date) windows: List[IntegrityWindow] = [] cur = start_date while cur <= end_date: month_start = _month_start(cur) month_end_exclusive = _next_month(cur) range_start = max(cur, month_start) range_end = min(end_date, month_end_exclusive - timedelta(days=1)) if month_start == current_month_start: week_start = range_start while week_start <= range_end: week_end = min(week_start + timedelta(days=6), range_end) w_start_dt = _date_to_start(week_start, tz) w_end_dt = _date_to_end_exclusive(week_end, tz) if w_start_dt < end_dt and w_end_dt > start_dt: windows.append( IntegrityWindow( start=max(w_start_dt, start_dt), end=min(w_end_dt, end_dt), label=f"week_{week_start.isoformat()}", granularity="week", ) ) week_start = week_end + timedelta(days=1) else: m_start_dt = _date_to_start(range_start, tz) m_end_dt = _date_to_end_exclusive(range_end, tz) if m_start_dt < end_dt and m_end_dt > start_dt: windows.append( IntegrityWindow( start=max(m_start_dt, start_dt), end=min(m_end_dt, end_dt), label=f"month_{month_start.isoformat()}", granularity="month", ) ) cur = month_end_exclusive return windows def _split_table(name: str, default_schema: str) -> Tuple[str, str]: if "." in name: schema, table = name.split(".", 1) return schema, table return default_schema, name def _pick_time_column(dwd_cols: Iterable[str], ods_cols: Iterable[str]) -> str | None: lower_cols = {c.lower() for c in dwd_cols} & {c.lower() for c in ods_cols} for candidate in DwdLoadTask.FACT_ORDER_CANDIDATES: if candidate.lower() in lower_cols: return candidate.lower() return None def _fetch_columns(cur, schema: str, table: str) -> Tuple[List[str], Dict[str, str]]: cur.execute( """ SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """, (schema, table), ) cols = [] types: Dict[str, str] = {} for name, data_type in cur.fetchall(): cols.append(name) types[name.lower()] = (data_type or "").lower() return cols, types def _amount_columns(cols: List[str], types: Dict[str, str]) -> List[str]: numeric_types = {"numeric", "double precision", "integer", "bigint", "smallint", "real", "decimal"} out = [] for col in cols: lc = col.lower() if types.get(lc) not in numeric_types: continue if any(key in lc for key in AMOUNT_KEYWORDS): out.append(lc) return out def _build_hash_expr(alias: str, cols: list[str]) -> str: if not cols: return "NULL" parts = ", ".join([f"COALESCE({alias}.\"{c}\"::text,'')" for c in cols]) return f"md5(concat_ws('||', {parts}))" def _build_snapshot_subquery( schema: str, table: str, cols: list[str], key_cols: list[str], order_col: str | None, where_sql: str, ) -> str: cols_sql = ", ".join([f'"{c}"' for c in cols]) if key_cols and order_col: keys = ", ".join([f'"{c}"' for c in key_cols]) order_by = ", ".join([*(f'"{c}"' for c in key_cols), f'"{order_col}" DESC NULLS LAST']) return ( f'SELECT DISTINCT ON ({keys}) {cols_sql} ' f'FROM "{schema}"."{table}" {where_sql} ' f"ORDER BY {order_by}" ) return f'SELECT {cols_sql} FROM "{schema}"."{table}" {where_sql}' def _build_snapshot_expr_subquery( schema: str, table: str, select_exprs: list[str], key_exprs: list[str], order_col: str | None, where_sql: str, ) -> str: select_cols_sql = ", ".join(select_exprs) table_sql = f'"{schema}"."{table}"' if key_exprs and order_col: distinct_on = ", ".join(key_exprs) order_by = ", ".join([*key_exprs, f'"{order_col}" DESC NULLS LAST']) return ( f"SELECT DISTINCT ON ({distinct_on}) {select_cols_sql} " f"FROM {table_sql} {where_sql} " f"ORDER BY {order_by}" ) return f"SELECT {select_cols_sql} FROM {table_sql} {where_sql}" def _cast_expr(col: str, cast_type: str | None) -> str: if col.upper() == "NULL": base = "NULL" else: is_expr = not col.isidentifier() or "->" in col or "#>>" in col or "::" in col or "'" in col base = col if is_expr else f'"{col}"' if cast_type: cast_lower = cast_type.lower() if cast_lower in {"bigint", "integer", "numeric", "decimal"}: return f"CAST(NULLIF(CAST({base} AS text), '') AS numeric):: {cast_type}" if cast_lower == "timestamptz": return f"({base})::timestamptz" return f"{base}::{cast_type}" return base def _fetch_pk_columns(cur, schema: str, table: str) -> List[str]: cur.execute( """ SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s ORDER BY kcu.ordinal_position """, (schema, table), ) return [r[0] for r in cur.fetchall()] def _pick_snapshot_order_column(cols: Iterable[str]) -> str | None: lower = {c.lower() for c in cols} for candidate in ("fetched_at", "update_time", "create_time"): if candidate in lower: return candidate return None def _count_table( cur, schema: str, table: str, time_col: str | None, window: IntegrityWindow | None, *, pk_cols: List[str] | None = None, snapshot_order_col: str | None = None, current_only: bool = False, ) -> int: where_parts: List[str] = [] params: List[Any] = [] if current_only: where_parts.append("COALESCE(scd2_is_current,1)=1") if time_col and window: where_parts.append(f'"{time_col}" >= %s AND "{time_col}" < %s') params.extend([window.start, window.end]) where = f"WHERE {' AND '.join(where_parts)}" if where_parts else "" if pk_cols and snapshot_order_col: keys = ", ".join(f'"{c}"' for c in pk_cols) order_by = ", ".join([*(f'"{c}"' for c in pk_cols), f'"{snapshot_order_col}" DESC NULLS LAST']) sql = ( f'SELECT COUNT(1) FROM (' f'SELECT DISTINCT ON ({keys}) 1 FROM "{schema}"."{table}" {where} ' f'ORDER BY {order_by}' f') t' ) else: sql = f'SELECT COUNT(1) FROM "{schema}"."{table}" {where}' cur.execute(sql, params) row = cur.fetchone() return int(row[0] if row else 0) def _sum_column( cur, schema: str, table: str, col: str, time_col: str | None, window: IntegrityWindow | None, *, pk_cols: List[str] | None = None, snapshot_order_col: str | None = None, current_only: bool = False, ) -> float: where_parts: List[str] = [] params: List[Any] = [] if current_only: where_parts.append("COALESCE(scd2_is_current,1)=1") if time_col and window: where_parts.append(f'"{time_col}" >= %s AND "{time_col}" < %s') params.extend([window.start, window.end]) where = f"WHERE {' AND '.join(where_parts)}" if where_parts else "" if pk_cols and snapshot_order_col: keys = ", ".join(f'"{c}"' for c in pk_cols) order_by = ", ".join([*(f'"{c}"' for c in pk_cols), f'"{snapshot_order_col}" DESC NULLS LAST']) sql = ( f'SELECT COALESCE(SUM("{col}"), 0) FROM (' f'SELECT DISTINCT ON ({keys}) "{col}" FROM "{schema}"."{table}" {where} ' f'ORDER BY {order_by}' f') t' ) else: sql = f'SELECT COALESCE(SUM("{col}"), 0) FROM "{schema}"."{table}" {where}' cur.execute(sql, params) row = cur.fetchone() return float(row[0] if row else 0) def run_dwd_vs_ods_check( *, cfg: AppConfig, window: IntegrityWindow | None, include_dimensions: bool, compare_content: bool | None = None, content_sample_limit: int | None = None, ) -> Dict[str, Any]: dsn = cfg["db"]["dsn"] session = cfg["db"].get("session") db_conn = DatabaseConnection(dsn=dsn, session=session) if compare_content is None: compare_content = bool(cfg.get("integrity.compare_content", True)) if content_sample_limit is None: content_sample_limit = cfg.get("integrity.content_sample_limit") or 50 try: with db_conn.conn.cursor() as cur: results: List[Dict[str, Any]] = [] table_map = DwdLoadTask.TABLE_MAP total_mismatch = 0 for dwd_table, ods_table in table_map.items(): if not include_dimensions and ".dim_" in dwd_table: continue schema_dwd, name_dwd = _split_table(dwd_table, "billiards_dwd") schema_ods, name_ods = _split_table(ods_table, "billiards_ods") try: dwd_cols, dwd_types = _fetch_columns(cur, schema_dwd, name_dwd) ods_cols, ods_types = _fetch_columns(cur, schema_ods, name_ods) time_col = _pick_time_column(dwd_cols, ods_cols) pk_dwd = _fetch_pk_columns(cur, schema_dwd, name_dwd) pk_ods_raw = _fetch_pk_columns(cur, schema_ods, name_ods) pk_ods = [c for c in pk_ods_raw if c.lower() != "content_hash"] ods_has_snapshot = any(c.lower() == "content_hash" for c in ods_cols) ods_snapshot_order = _pick_snapshot_order_column(ods_cols) if ods_has_snapshot else None dwd_current_only = any(c.lower() == "scd2_is_current" for c in dwd_cols) count_dwd = _count_table( cur, schema_dwd, name_dwd, time_col, window, current_only=dwd_current_only, ) count_ods = _count_table( cur, schema_ods, name_ods, time_col, window, pk_cols=pk_ods if ods_has_snapshot else None, snapshot_order_col=ods_snapshot_order if ods_has_snapshot else None, ) dwd_amount_cols = _amount_columns(dwd_cols, dwd_types) ods_amount_cols = _amount_columns(ods_cols, ods_types) common_amount_cols = sorted(set(dwd_amount_cols) & set(ods_amount_cols)) amounts: List[Dict[str, Any]] = [] for col in common_amount_cols: dwd_sum = _sum_column( cur, schema_dwd, name_dwd, col, time_col, window, current_only=dwd_current_only, ) ods_sum = _sum_column( cur, schema_ods, name_ods, col, time_col, window, pk_cols=pk_ods if ods_has_snapshot else None, snapshot_order_col=ods_snapshot_order if ods_has_snapshot else None, ) amounts.append( { "column": col, "dwd_sum": dwd_sum, "ods_sum": ods_sum, "diff": dwd_sum - ods_sum, } ) mismatch = None mismatch_samples: list[dict] = [] mismatch_error = None if compare_content: dwd_cols_lower = [c.lower() for c in dwd_cols] ods_cols_lower = [c.lower() for c in ods_cols] dwd_col_set = set(dwd_cols_lower) ods_col_set = set(ods_cols_lower) scd_cols = {c.lower() for c in DwdLoadTask.SCD_COLS} ods_exclude = { "payload", "source_file", "source_endpoint", "fetched_at", "content_hash", "record_index" } numeric_types = { "integer", "bigint", "smallint", "numeric", "double precision", "real", "decimal", } text_types = {"text", "character varying", "varchar"} mapping = { dst.lower(): (src, cast_type) for dst, src, cast_type in (DwdLoadTask.FACT_MAPPINGS.get(dwd_table) or []) } business_keys = [c for c in pk_dwd if c.lower() not in scd_cols] def resolve_ods_expr(col: str) -> str | None: mapped = mapping.get(col) if mapped: src, cast_type = mapped return _cast_expr(src, cast_type) if col in ods_col_set: d_type = dwd_types.get(col) o_type = ods_types.get(col) if d_type in numeric_types and o_type in text_types: return _cast_expr(col, d_type) return f'"{col}"' if "id" in ods_col_set and col.endswith("_id"): d_type = dwd_types.get(col) o_type = ods_types.get("id") if d_type in numeric_types and o_type in text_types: return _cast_expr("id", d_type) return '"id"' return None key_exprs: list[str] = [] join_keys: list[str] = [] for key in business_keys: key_lower = key.lower() expr = resolve_ods_expr(key_lower) if expr is None: key_exprs = [] join_keys = [] break key_exprs.append(expr) join_keys.append(key_lower) compare_cols: list[str] = [] for col in dwd_col_set: if col in ods_exclude or col in scd_cols: continue if col in {k.lower() for k in business_keys}: continue if dwd_types.get(col) in ("json", "jsonb"): continue if ods_types.get(col) in ("json", "jsonb"): continue if resolve_ods_expr(col) is None: continue compare_cols.append(col) compare_cols = sorted(set(compare_cols)) if join_keys and compare_cols: where_parts_dwd: list[str] = [] params_dwd: list[Any] = [] if dwd_current_only: where_parts_dwd.append("COALESCE(scd2_is_current,1)=1") if time_col and window: where_parts_dwd.append(f"\"{time_col}\" >= %s AND \"{time_col}\" < %s") params_dwd.extend([window.start, window.end]) where_dwd = f"WHERE {' AND '.join(where_parts_dwd)}" if where_parts_dwd else "" where_parts_ods: list[str] = [] params_ods: list[Any] = [] if time_col and window: where_parts_ods.append(f"\"{time_col}\" >= %s AND \"{time_col}\" < %s") params_ods.extend([window.start, window.end]) where_ods = f"WHERE {' AND '.join(where_parts_ods)}" if where_parts_ods else "" ods_select_exprs: list[str] = [] needed_cols = sorted(set(join_keys + compare_cols)) for col in needed_cols: expr = resolve_ods_expr(col) if expr is None: continue ods_select_exprs.append(f"{expr} AS \"{col}\"") if not ods_select_exprs: mismatch_error = "join_keys_or_compare_cols_unavailable" else: ods_sql = _build_snapshot_expr_subquery( schema_ods, name_ods, ods_select_exprs, key_exprs, ods_snapshot_order, where_ods, ) dwd_cols_sql = ", ".join([f"\"{c}\"" for c in needed_cols]) dwd_sql = f"SELECT {dwd_cols_sql} FROM \"{schema_dwd}\".\"{name_dwd}\" {where_dwd}" join_cond = " AND ".join([f"d.\"{k}\" = o.\"{k}\"" for k in join_keys]) hash_o = _build_hash_expr("o", compare_cols) hash_d = _build_hash_expr("d", compare_cols) mismatch_sql = ( f"WITH ods_latest AS ({ods_sql}), dwd_filtered AS ({dwd_sql}) " f"SELECT COUNT(1) FROM (" f"SELECT 1 FROM ods_latest o JOIN dwd_filtered d ON {join_cond} " f"WHERE {hash_o} <> {hash_d}" f") t" ) params = params_ods + params_dwd cur.execute(mismatch_sql, params) row = cur.fetchone() mismatch = int(row[0] if row and row[0] is not None else 0) total_mismatch += mismatch if content_sample_limit and mismatch > 0: select_keys_sql = ", ".join([f"d.\"{k}\" AS \"{k}\"" for k in join_keys]) sample_sql = ( f"WITH ods_latest AS ({ods_sql}), dwd_filtered AS ({dwd_sql}) " f"SELECT {select_keys_sql}, {hash_o} AS ods_hash, {hash_d} AS dwd_hash " f"FROM ods_latest o JOIN dwd_filtered d ON {join_cond} " f"WHERE {hash_o} <> {hash_d} LIMIT %s" ) cur.execute(sample_sql, params + [int(content_sample_limit)]) rows = cur.fetchall() or [] if rows: columns = [desc[0] for desc in (cur.description or [])] mismatch_samples = [dict(zip(columns, r)) for r in rows] else: mismatch_error = "join_keys_or_compare_cols_unavailable" results.append( { "dwd_table": dwd_table, "ods_table": ods_table, "windowed": bool(time_col and window), "window_col": time_col, "count": {"dwd": count_dwd, "ods": count_ods, "diff": count_dwd - count_ods}, "amounts": amounts, "mismatch": mismatch, "mismatch_samples": mismatch_samples, "mismatch_error": mismatch_error, } ) except Exception as exc: # noqa: BLE001 results.append( { "dwd_table": dwd_table, "ods_table": ods_table, "windowed": bool(window), "window_col": None, "count": {"dwd": None, "ods": None, "diff": None}, "amounts": [], "mismatch": None, "mismatch_samples": [], "error": f"{type(exc).__name__}: {exc}", } ) total_count_diff = sum( int(item.get("count", {}).get("diff") or 0) for item in results if isinstance(item.get("count", {}).get("diff"), (int, float)) ) return { "tables": results, "total_count_diff": total_count_diff, "total_mismatch": total_mismatch, } finally: db_conn.close() def _default_report_path(prefix: str) -> Path: root = Path(__file__).resolve().parents[1] stamp = datetime.now().strftime("%Y%m%d_%H%M%S") return root / "reports" / f"{prefix}_{stamp}.json" def run_integrity_window( *, cfg: AppConfig, window: IntegrityWindow, include_dimensions: bool, task_codes: str, logger, write_report: bool, compare_content: bool | None = None, content_sample_limit: int | None = None, report_path: Path | None = None, window_split_unit: str | None = None, window_compensation_hours: int | None = None, ) -> Dict[str, Any]: total_seconds = max(0, int((window.end - window.start).total_seconds())) if total_seconds >= 86400: window_days = max(1, total_seconds // 86400) window_hours = 0 else: window_days = 0 window_hours = max(1, total_seconds // 3600 or 1) if compare_content is None: compare_content = bool(cfg.get("integrity.compare_content", True)) if content_sample_limit is None: content_sample_limit = cfg.get("integrity.content_sample_limit") ods_payload = run_gap_check( cfg=cfg, start=window.start, end=window.end, window_days=window_days, window_hours=window_hours, page_size=int(cfg.get("api.page_size") or 200), chunk_size=500, sample_limit=50, sleep_per_window=0, sleep_per_page=0, task_codes=task_codes, from_cutoff=False, cutoff_overlap_hours=24, allow_small_window=True, logger=logger, compare_content=bool(compare_content), content_sample_limit=content_sample_limit, window_split_unit=window_split_unit, window_compensation_hours=window_compensation_hours, ) dwd_payload = run_dwd_vs_ods_check( cfg=cfg, window=window, include_dimensions=include_dimensions, compare_content=compare_content, content_sample_limit=content_sample_limit, ) report = { "mode": "window", "window": { "start": window.start.isoformat(), "end": window.end.isoformat(), "label": window.label, "granularity": window.granularity, }, "api_to_ods": ods_payload, "ods_to_dwd": dwd_payload, "generated_at": datetime.now(ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))).isoformat(), } if write_report: path = report_path or _default_report_path("data_integrity_window") path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") report["report_path"] = str(path) return report def run_integrity_history( *, cfg: AppConfig, start_dt: datetime, end_dt: datetime, include_dimensions: bool, task_codes: str, logger, write_report: bool, compare_content: bool | None = None, content_sample_limit: int | None = None, report_path: Path | None = None, ) -> Dict[str, Any]: tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) windows = build_history_windows(start_dt, end_dt, tz) results: List[Dict[str, Any]] = [] total_missing = 0 total_mismatch = 0 total_errors = 0 for window in windows: logger.info("校验窗口 起始=%s 结束=%s", window.start, window.end) payload = run_integrity_window( cfg=cfg, window=window, include_dimensions=include_dimensions, task_codes=task_codes, logger=logger, write_report=False, compare_content=compare_content, content_sample_limit=content_sample_limit, ) results.append(payload) total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0) total_mismatch += int(payload.get("api_to_ods", {}).get("total_mismatch") or 0) total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0) report = { "mode": "history", "start": _ensure_tz(start_dt, tz).isoformat(), "end": _ensure_tz(end_dt, tz).isoformat(), "windows": results, "total_missing": total_missing, "total_mismatch": total_mismatch, "total_errors": total_errors, "generated_at": datetime.now(tz).isoformat(), } if write_report: path = report_path or _default_report_path("data_integrity_history") path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") report["report_path"] = str(path) return report def compute_last_etl_end(cfg: AppConfig) -> datetime | None: dsn = cfg["db"]["dsn"] session = cfg["db"].get("session") db_conn = DatabaseConnection(dsn=dsn, session=session) try: rows = db_conn.query( "SELECT MAX(window_end) AS mx FROM etl_admin.etl_run WHERE store_id = %s", (cfg.get("app.store_id"),), ) mx = rows[0]["mx"] if rows else None if isinstance(mx, datetime): tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) return _ensure_tz(mx, tz) finally: db_conn.close() return None