# -*- coding: utf-8 -*- """Integrity checks across API -> ODS -> DWD.""" from __future__ import annotations from dataclasses import dataclass from datetime import date, datetime, time, timedelta from pathlib import Path from typing import Any, Dict, Iterable, List, Tuple from zoneinfo import ZoneInfo import json from config.settings import AppConfig from database.connection import DatabaseConnection from tasks.dwd_load_task import DwdLoadTask from scripts.check_ods_gaps import run_gap_check AMOUNT_KEYWORDS = ("amount", "money", "fee", "balance") @dataclass(frozen=True) class IntegrityWindow: start: datetime end: datetime label: str granularity: str def _ensure_tz(dt: datetime, tz: ZoneInfo) -> datetime: if dt.tzinfo is None: return dt.replace(tzinfo=tz) return dt.astimezone(tz) def _month_start(day: date) -> date: return date(day.year, day.month, 1) def _next_month(day: date) -> date: if day.month == 12: return date(day.year + 1, 1, 1) return date(day.year, day.month + 1, 1) def _date_to_start(dt: date, tz: ZoneInfo) -> datetime: return datetime.combine(dt, time.min).replace(tzinfo=tz) def _date_to_end_exclusive(dt: date, tz: ZoneInfo) -> datetime: return datetime.combine(dt, time.min).replace(tzinfo=tz) + timedelta(days=1) def build_history_windows(start_dt: datetime, end_dt: datetime, tz: ZoneInfo) -> List[IntegrityWindow]: """Build weekly windows for current month, monthly windows for earlier months.""" start_dt = _ensure_tz(start_dt, tz) end_dt = _ensure_tz(end_dt, tz) if end_dt <= start_dt: return [] start_date = start_dt.date() end_date = end_dt.date() current_month_start = _month_start(end_date) windows: List[IntegrityWindow] = [] cur = start_date while cur <= end_date: month_start = _month_start(cur) month_end_exclusive = _next_month(cur) range_start = max(cur, month_start) range_end = min(end_date, month_end_exclusive - timedelta(days=1)) if month_start == current_month_start: week_start = range_start while week_start <= range_end: week_end = min(week_start + timedelta(days=6), range_end) w_start_dt = _date_to_start(week_start, tz) w_end_dt = _date_to_end_exclusive(week_end, tz) if w_start_dt < end_dt and w_end_dt > start_dt: windows.append( IntegrityWindow( start=max(w_start_dt, start_dt), end=min(w_end_dt, end_dt), label=f"week_{week_start.isoformat()}", granularity="week", ) ) week_start = week_end + timedelta(days=1) else: m_start_dt = _date_to_start(range_start, tz) m_end_dt = _date_to_end_exclusive(range_end, tz) if m_start_dt < end_dt and m_end_dt > start_dt: windows.append( IntegrityWindow( start=max(m_start_dt, start_dt), end=min(m_end_dt, end_dt), label=f"month_{month_start.isoformat()}", granularity="month", ) ) cur = month_end_exclusive return windows def _split_table(name: str, default_schema: str) -> Tuple[str, str]: if "." in name: schema, table = name.split(".", 1) return schema, table return default_schema, name def _pick_time_column(dwd_cols: Iterable[str], ods_cols: Iterable[str]) -> str | None: lower_cols = {c.lower() for c in dwd_cols} & {c.lower() for c in ods_cols} for candidate in DwdLoadTask.FACT_ORDER_CANDIDATES: if candidate.lower() in lower_cols: return candidate.lower() return None def _fetch_columns(cur, schema: str, table: str) -> Tuple[List[str], Dict[str, str]]: cur.execute( """ SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """, (schema, table), ) cols = [] types: Dict[str, str] = {} for name, data_type in cur.fetchall(): cols.append(name) types[name.lower()] = (data_type or "").lower() return cols, types def _amount_columns(cols: List[str], types: Dict[str, str]) -> List[str]: numeric_types = {"numeric", "double precision", "integer", "bigint", "smallint", "real", "decimal"} out = [] for col in cols: lc = col.lower() if types.get(lc) not in numeric_types: continue if any(key in lc for key in AMOUNT_KEYWORDS): out.append(lc) return out def _count_table(cur, schema: str, table: str, time_col: str | None, window: IntegrityWindow | None) -> int: where = "" params: List[Any] = [] if time_col and window: where = f'WHERE "{time_col}" >= %s AND "{time_col}" < %s' params = [window.start, window.end] sql = f'SELECT COUNT(1) FROM "{schema}"."{table}" {where}' cur.execute(sql, params) row = cur.fetchone() return int(row[0] if row else 0) def _sum_column(cur, schema: str, table: str, col: str, time_col: str | None, window: IntegrityWindow | None) -> float: where = "" params: List[Any] = [] if time_col and window: where = f'WHERE "{time_col}" >= %s AND "{time_col}" < %s' params = [window.start, window.end] sql = f'SELECT COALESCE(SUM("{col}"), 0) FROM "{schema}"."{table}" {where}' cur.execute(sql, params) row = cur.fetchone() return float(row[0] if row else 0) def run_dwd_vs_ods_check( *, cfg: AppConfig, window: IntegrityWindow | None, include_dimensions: bool, ) -> Dict[str, Any]: dsn = cfg["db"]["dsn"] session = cfg["db"].get("session") db_conn = DatabaseConnection(dsn=dsn, session=session) try: with db_conn.conn.cursor() as cur: results: List[Dict[str, Any]] = [] table_map = DwdLoadTask.TABLE_MAP for dwd_table, ods_table in table_map.items(): if not include_dimensions and ".dim_" in dwd_table: continue schema_dwd, name_dwd = _split_table(dwd_table, "billiards_dwd") schema_ods, name_ods = _split_table(ods_table, "billiards_ods") try: dwd_cols, dwd_types = _fetch_columns(cur, schema_dwd, name_dwd) ods_cols, ods_types = _fetch_columns(cur, schema_ods, name_ods) time_col = _pick_time_column(dwd_cols, ods_cols) count_dwd = _count_table(cur, schema_dwd, name_dwd, time_col, window) count_ods = _count_table(cur, schema_ods, name_ods, time_col, window) dwd_amount_cols = _amount_columns(dwd_cols, dwd_types) ods_amount_cols = _amount_columns(ods_cols, ods_types) common_amount_cols = sorted(set(dwd_amount_cols) & set(ods_amount_cols)) amounts: List[Dict[str, Any]] = [] for col in common_amount_cols: dwd_sum = _sum_column(cur, schema_dwd, name_dwd, col, time_col, window) ods_sum = _sum_column(cur, schema_ods, name_ods, col, time_col, window) amounts.append( { "column": col, "dwd_sum": dwd_sum, "ods_sum": ods_sum, "diff": dwd_sum - ods_sum, } ) results.append( { "dwd_table": dwd_table, "ods_table": ods_table, "windowed": bool(time_col and window), "window_col": time_col, "count": {"dwd": count_dwd, "ods": count_ods, "diff": count_dwd - count_ods}, "amounts": amounts, } ) except Exception as exc: # noqa: BLE001 results.append( { "dwd_table": dwd_table, "ods_table": ods_table, "windowed": bool(window), "window_col": None, "count": {"dwd": None, "ods": None, "diff": None}, "amounts": [], "error": f"{type(exc).__name__}: {exc}", } ) total_count_diff = sum( int(item.get("count", {}).get("diff") or 0) for item in results if isinstance(item.get("count", {}).get("diff"), (int, float)) ) return { "tables": results, "total_count_diff": total_count_diff, } finally: db_conn.close() def _default_report_path(prefix: str) -> Path: root = Path(__file__).resolve().parents[1] stamp = datetime.now().strftime("%Y%m%d_%H%M%S") return root / "reports" / f"{prefix}_{stamp}.json" def run_integrity_window( *, cfg: AppConfig, window: IntegrityWindow, include_dimensions: bool, task_codes: str, logger, write_report: bool, report_path: Path | None = None, window_split_unit: str | None = None, window_compensation_hours: int | None = None, ) -> Dict[str, Any]: total_seconds = max(0, int((window.end - window.start).total_seconds())) if total_seconds >= 86400: window_days = max(1, total_seconds // 86400) window_hours = 0 else: window_days = 0 window_hours = max(1, total_seconds // 3600 or 1) ods_payload = run_gap_check( cfg=cfg, start=window.start, end=window.end, window_days=window_days, window_hours=window_hours, page_size=int(cfg.get("api.page_size") or 200), chunk_size=500, sample_limit=50, sleep_per_window=0, sleep_per_page=0, task_codes=task_codes, from_cutoff=False, cutoff_overlap_hours=24, allow_small_window=True, logger=logger, window_split_unit=window_split_unit, window_compensation_hours=window_compensation_hours, ) dwd_payload = run_dwd_vs_ods_check( cfg=cfg, window=window, include_dimensions=include_dimensions, ) report = { "mode": "window", "window": { "start": window.start.isoformat(), "end": window.end.isoformat(), "label": window.label, "granularity": window.granularity, }, "api_to_ods": ods_payload, "ods_to_dwd": dwd_payload, "generated_at": datetime.now(ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))).isoformat(), } if write_report: path = report_path or _default_report_path("data_integrity_window") path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") report["report_path"] = str(path) return report def run_integrity_history( *, cfg: AppConfig, start_dt: datetime, end_dt: datetime, include_dimensions: bool, task_codes: str, logger, write_report: bool, report_path: Path | None = None, ) -> Dict[str, Any]: tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) windows = build_history_windows(start_dt, end_dt, tz) results: List[Dict[str, Any]] = [] total_missing = 0 total_errors = 0 for window in windows: logger.info("校验窗口 起始=%s 结束=%s", window.start, window.end) payload = run_integrity_window( cfg=cfg, window=window, include_dimensions=include_dimensions, task_codes=task_codes, logger=logger, write_report=False, ) results.append(payload) total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0) total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0) report = { "mode": "history", "start": _ensure_tz(start_dt, tz).isoformat(), "end": _ensure_tz(end_dt, tz).isoformat(), "windows": results, "total_missing": total_missing, "total_errors": total_errors, "generated_at": datetime.now(tz).isoformat(), } if write_report: path = report_path or _default_report_path("data_integrity_history") path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") report["report_path"] = str(path) return report def compute_last_etl_end(cfg: AppConfig) -> datetime | None: dsn = cfg["db"]["dsn"] session = cfg["db"].get("session") db_conn = DatabaseConnection(dsn=dsn, session=session) try: rows = db_conn.query( "SELECT MAX(window_end) AS mx FROM etl_admin.etl_run WHERE store_id = %s", (cfg.get("app.store_id"),), ) mx = rows[0]["mx"] if rows else None if isinstance(mx, datetime): tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) return _ensure_tz(mx, tz) finally: db_conn.close() return None