# -*- coding: utf-8 -*- """ Repair ODS content_hash values by recomputing from payload. Usage: PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema billiards_ods PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders """ from __future__ import annotations import argparse import json import sys from datetime import datetime from pathlib import Path from typing import Any, Iterable, Sequence import psycopg2 from psycopg2.extras import RealDictCursor PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from config.settings import AppConfig from database.connection import DatabaseConnection from tasks.ods.ods_tasks import BaseOdsTask def _reconfigure_stdout_utf8() -> None: if hasattr(sys.stdout, "reconfigure"): try: sys.stdout.reconfigure(encoding="utf-8") except Exception: pass def _fetch_tables(conn, schema: str) -> list[str]: sql = """ SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE' ORDER BY table_name """ with conn.cursor() as cur: cur.execute(sql, (schema,)) return [r[0] for r in cur.fetchall()] def _fetch_columns(conn, schema: str, table: str) -> list[str]: sql = """ SELECT column_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, table)) cols = [r[0] for r in cur.fetchall()] return [c for c in cols if c] def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]: sql = """ SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s ORDER BY kcu.ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, table)) cols = [r[0] for r in cur.fetchall()] return [c for c in cols if c.lower() != "content_hash"] def _fetch_row_count(conn, schema: str, table: str) -> int: sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"' with conn.cursor() as cur: cur.execute(sql) row = cur.fetchone() return int(row[0] if row else 0) def _iter_rows( conn, schema: str, table: str, select_cols: Sequence[str], batch_size: int, ) -> Iterable[dict]: cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols) sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"' with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur: cur.itersize = max(1, int(batch_size or 500)) cur.execute(sql) for row in cur: yield row def _build_report_path(out_arg: str | None) -> Path: if out_arg: return Path(out_arg) reports_dir = PROJECT_ROOT / "reports" reports_dir.mkdir(parents=True, exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") return reports_dir / f"ods_content_hash_repair_{ts}.json" def _print_progress( table_label: str, processed: int, total: int, updated: int, skipped: int, conflicts: int, errors: int, missing_hash: int, invalid_payload: int, ) -> None: if total: msg = ( f"[{table_label}] checked {processed}/{total} " f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} " f"missing_hash={missing_hash} invalid_payload={invalid_payload}" ) else: msg = ( f"[{table_label}] checked {processed} " f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} " f"missing_hash={missing_hash} invalid_payload={invalid_payload}" ) print(msg, flush=True) def main() -> int: _reconfigure_stdout_utf8() ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload") ap.add_argument("--schema", default="billiards_ods", help="ODS schema name") ap.add_argument("--tables", default="", help="comma-separated table names (optional)") ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size") ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows") ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table") ap.add_argument("--out", default="", help="output report JSON path") ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update") args = ap.parse_args() cfg = AppConfig.load({}) db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session")) db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session")) try: db_write.conn.rollback() except Exception: pass db_write.conn.autocommit = True tables = _fetch_tables(db_read.conn, args.schema) if args.tables.strip(): whitelist = {t.strip() for t in args.tables.split(",") if t.strip()} tables = [t for t in tables if t in whitelist] report = { "schema": args.schema, "tables": [], "summary": { "total_tables": len(tables), "checked_tables": 0, "total_rows": 0, "checked_rows": 0, "updated_rows": 0, "skipped_rows": 0, "conflict_rows": 0, "error_rows": 0, "missing_hash_rows": 0, "invalid_payload_rows": 0, }, } for table in tables: table_label = f"{args.schema}.{table}" cols = _fetch_columns(db_read.conn, args.schema, table) cols_lower = {c.lower() for c in cols} if "payload" not in cols_lower or "content_hash" not in cols_lower: print(f"[{table_label}] skip: missing payload/content_hash", flush=True) continue total = _fetch_row_count(db_read.conn, args.schema, table) pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table) select_cols = ["ctid", "content_hash", "payload", *pk_cols] processed = 0 updated = 0 skipped = 0 conflicts = 0 errors = 0 missing_hash = 0 invalid_payload = 0 samples: list[dict[str, Any]] = [] print(f"[{table_label}] start: total_rows={total}", flush=True) for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size): processed += 1 content_hash = row.get("content_hash") payload = row.get("payload") recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload) row_ctid = row.get("ctid") if not content_hash: missing_hash += 1 if not recomputed: invalid_payload += 1 if not recomputed: skipped += 1 elif content_hash == recomputed: skipped += 1 else: if args.dry_run: updated += 1 else: try: with db_write.conn.cursor() as cur: cur.execute( f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s', (recomputed, row_ctid), ) updated += 1 except psycopg2.errors.UniqueViolation: conflicts += 1 if len(samples) < max(0, int(args.sample_limit or 0)): sample = {k: row.get(k) for k in pk_cols} sample["content_hash"] = content_hash sample["recomputed_hash"] = recomputed samples.append(sample) except psycopg2.Error: errors += 1 if args.progress_every and processed % int(args.progress_every) == 0: _print_progress( table_label, processed, total, updated, skipped, conflicts, errors, missing_hash, invalid_payload, ) if processed and (not args.progress_every or processed % int(args.progress_every) != 0): _print_progress( table_label, processed, total, updated, skipped, conflicts, errors, missing_hash, invalid_payload, ) report["tables"].append( { "table": table_label, "total_rows": total, "checked_rows": processed, "updated_rows": updated, "skipped_rows": skipped, "conflict_rows": conflicts, "error_rows": errors, "missing_hash_rows": missing_hash, "invalid_payload_rows": invalid_payload, "conflict_samples": samples, } ) report["summary"]["checked_tables"] += 1 report["summary"]["total_rows"] += total report["summary"]["checked_rows"] += processed report["summary"]["updated_rows"] += updated report["summary"]["skipped_rows"] += skipped report["summary"]["conflict_rows"] += conflicts report["summary"]["error_rows"] += errors report["summary"]["missing_hash_rows"] += missing_hash report["summary"]["invalid_payload_rows"] += invalid_payload out_path = _build_report_path(args.out) out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") print(f"[REPORT] {out_path}", flush=True) return 0 if __name__ == "__main__": raise SystemExit(main())