# -*- coding: utf-8 -*- """ Deduplicate ODS snapshots by (business PK, content_hash). Keep the latest row by fetched_at (tie-breaker: ctid desc). Usage: PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --schema billiards_ods PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --tables member_profiles,orders """ from __future__ import annotations import argparse import json import sys from datetime import datetime from pathlib import Path from typing import Iterable, Sequence import psycopg2 PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from config.settings import AppConfig from database.connection import DatabaseConnection def _reconfigure_stdout_utf8() -> None: if hasattr(sys.stdout, "reconfigure"): try: sys.stdout.reconfigure(encoding="utf-8") except Exception: pass def _quote_ident(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' def _fetch_tables(conn, schema: str) -> list[str]: sql = """ SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE' ORDER BY table_name """ with conn.cursor() as cur: cur.execute(sql, (schema,)) return [r[0] for r in cur.fetchall()] def _fetch_columns(conn, schema: str, table: str) -> list[str]: sql = """ SELECT column_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, table)) return [r[0] for r in cur.fetchall()] def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]: sql = """ SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s ORDER BY kcu.ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, table)) cols = [r[0] for r in cur.fetchall()] return [c for c in cols if c.lower() != "content_hash"] def _build_report_path(out_arg: str | None) -> Path: if out_arg: return Path(out_arg) reports_dir = PROJECT_ROOT / "reports" reports_dir.mkdir(parents=True, exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") return reports_dir / f"ods_snapshot_dedupe_{ts}.json" def _print_progress( table_label: str, deleted: int, total: int, errors: int, ) -> None: if total: msg = f"[{table_label}] deleted {deleted}/{total} errors={errors}" else: msg = f"[{table_label}] deleted {deleted} errors={errors}" print(msg, flush=True) def _count_duplicates(conn, schema: str, table: str, key_cols: Sequence[str]) -> int: keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"]) table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}" sql = f""" SELECT COUNT(*) FROM ( SELECT 1 FROM ( SELECT ROW_NUMBER() OVER ( PARTITION BY {keys_sql} ORDER BY fetched_at DESC NULLS LAST, ctid DESC ) AS rn FROM {table_sql} ) t WHERE rn > 1 ) s """ with conn.cursor() as cur: cur.execute(sql) row = cur.fetchone() return int(row[0] if row else 0) def _delete_duplicate_batch( conn, schema: str, table: str, key_cols: Sequence[str], batch_size: int, ) -> int: keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"]) table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}" sql = f""" WITH dupes AS ( SELECT ctid FROM ( SELECT ctid, ROW_NUMBER() OVER ( PARTITION BY {keys_sql} ORDER BY fetched_at DESC NULLS LAST, ctid DESC ) AS rn FROM {table_sql} ) s WHERE rn > 1 LIMIT %s ) DELETE FROM {table_sql} t USING dupes d WHERE t.ctid = d.ctid RETURNING 1 """ with conn.cursor() as cur: cur.execute(sql, (int(batch_size),)) rows = cur.fetchall() return len(rows or []) def main() -> int: _reconfigure_stdout_utf8() ap = argparse.ArgumentParser(description="Deduplicate ODS snapshot rows by PK+content_hash") ap.add_argument("--schema", default="billiards_ods", help="ODS schema name") ap.add_argument("--tables", default="", help="comma-separated table names (optional)") ap.add_argument("--batch-size", type=int, default=1000, help="delete batch size") ap.add_argument("--progress-every", type=int, default=100, help="print progress every N deletions") ap.add_argument("--out", default="", help="output report JSON path") ap.add_argument("--dry-run", action="store_true", help="only compute duplicate counts") args = ap.parse_args() cfg = AppConfig.load({}) db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session")) try: db.conn.rollback() except Exception: pass db.conn.autocommit = True tables = _fetch_tables(db.conn, args.schema) if args.tables.strip(): whitelist = {t.strip() for t in args.tables.split(",") if t.strip()} tables = [t for t in tables if t in whitelist] report = { "schema": args.schema, "tables": [], "summary": { "total_tables": len(tables), "checked_tables": 0, "total_duplicates": 0, "deleted_rows": 0, "error_rows": 0, "skipped_tables": 0, }, } for table in tables: table_label = f"{args.schema}.{table}" cols = _fetch_columns(db.conn, args.schema, table) cols_lower = {c.lower() for c in cols} if "content_hash" not in cols_lower or "fetched_at" not in cols_lower: print(f"[{table_label}] skip: missing content_hash/fetched_at", flush=True) report["summary"]["skipped_tables"] += 1 continue key_cols = _fetch_pk_columns(db.conn, args.schema, table) if not key_cols: print(f"[{table_label}] skip: missing primary key", flush=True) report["summary"]["skipped_tables"] += 1 continue total_dupes = _count_duplicates(db.conn, args.schema, table, key_cols) print(f"[{table_label}] duplicates={total_dupes}", flush=True) deleted = 0 errors = 0 if not args.dry_run and total_dupes: while True: try: batch_deleted = _delete_duplicate_batch( db.conn, args.schema, table, key_cols, args.batch_size, ) except psycopg2.Error: errors += 1 break if batch_deleted <= 0: break deleted += batch_deleted if args.progress_every and deleted % int(args.progress_every) == 0: _print_progress(table_label, deleted, total_dupes, errors) if deleted and (not args.progress_every or deleted % int(args.progress_every) != 0): _print_progress(table_label, deleted, total_dupes, errors) report["tables"].append( { "table": table_label, "duplicate_rows": total_dupes, "deleted_rows": deleted, "error_rows": errors, } ) report["summary"]["checked_tables"] += 1 report["summary"]["total_duplicates"] += total_dupes report["summary"]["deleted_rows"] += deleted report["summary"]["error_rows"] += errors out_path = _build_report_path(args.out) out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") print(f"[REPORT] {out_path}", flush=True) return 0 if __name__ == "__main__": raise SystemExit(main())