Files
ZQYY.FQ-ETL/scripts/check/check_ods_content_hash.py

249 lines
8.4 KiB
Python

# -*- coding: utf-8 -*-
"""
Validate that ODS payload content matches stored content_hash.
Usage:
PYTHONPATH=. python -m scripts.check.check_ods_content_hash
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --schema billiards_ods
PYTHONPATH=. python -m scripts.check.check_ods_content_hash --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Sequence
from psycopg2.extras import RealDictCursor
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.ods.ods_tasks import BaseOdsTask
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _fetch_row_count(conn, schema: str, table: str) -> int:
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _iter_rows(
conn,
schema: str,
table: str,
select_cols: Sequence[str],
batch_size: int,
) -> Iterable[dict]:
cols_sql = ", ".join(f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
with conn.cursor(name=f"ods_hash_{table}", cursor_factory=RealDictCursor) as cur:
cur.itersize = max(1, int(batch_size or 500))
cur.execute(sql)
for row in cur:
yield row
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_content_hash_check_{ts}.json"
def _print_progress(
table_label: str,
processed: int,
total: int,
mismatched: int,
missing_hash: int,
invalid_payload: int,
) -> None:
if total:
msg = (
f"[{table_label}] checked {processed}/{total} "
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
else:
msg = (
f"[{table_label}] checked {processed} "
f"mismatch={mismatched} missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
print(msg, flush=True)
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Validate ODS payload vs content_hash consistency")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
ap.add_argument("--sample-limit", type=int, default=5, help="sample mismatch rows per table")
ap.add_argument("--out", default="", help="output report JSON path")
args = ap.parse_args()
cfg = AppConfig.load({})
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
conn = db.conn
tables = _fetch_tables(conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": 0,
"checked_tables": 0,
"total_rows": 0,
"checked_rows": 0,
"mismatch_rows": 0,
"missing_hash_rows": 0,
"invalid_payload_rows": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "payload" not in cols_lower or "content_hash" not in cols_lower:
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
continue
total = _fetch_row_count(conn, args.schema, table)
pk_cols = _fetch_pk_columns(conn, args.schema, table)
select_cols = ["content_hash", "payload", *pk_cols]
processed = 0
mismatched = 0
missing_hash = 0
invalid_payload = 0
samples: list[dict[str, Any]] = []
print(f"[{table_label}] start: total_rows={total}", flush=True)
for row in _iter_rows(conn, args.schema, table, select_cols, args.batch_size):
processed += 1
content_hash = row.get("content_hash")
payload = row.get("payload")
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
row_mismatch = False
if not content_hash:
missing_hash += 1
mismatched += 1
row_mismatch = True
elif not recomputed:
invalid_payload += 1
mismatched += 1
row_mismatch = True
elif content_hash != recomputed:
mismatched += 1
row_mismatch = True
if row_mismatch and len(samples) < max(0, int(args.sample_limit or 0)):
sample = {k: row.get(k) for k in pk_cols}
sample["content_hash"] = content_hash
sample["recomputed_hash"] = recomputed
samples.append(sample)
if args.progress_every and processed % int(args.progress_every) == 0:
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
_print_progress(table_label, processed, total, mismatched, missing_hash, invalid_payload)
report["tables"].append(
{
"table": table_label,
"total_rows": total,
"checked_rows": processed,
"mismatch_rows": mismatched,
"missing_hash_rows": missing_hash,
"invalid_payload_rows": invalid_payload,
"sample_mismatches": samples,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_rows"] += total
report["summary"]["checked_rows"] += processed
report["summary"]["mismatch_rows"] += mismatched
report["summary"]["missing_hash_rows"] += missing_hash
report["summary"]["invalid_payload_rows"] += invalid_payload
report["summary"]["total_tables"] = len(tables)
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())