303 lines
10 KiB
Python
303 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Repair ODS content_hash values by recomputing from payload.
|
|
|
|
Usage:
|
|
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash
|
|
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema billiards_ods
|
|
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Iterable, Sequence
|
|
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from config.settings import AppConfig
|
|
from database.connection import DatabaseConnection
|
|
from tasks.ods.ods_tasks import BaseOdsTask
|
|
|
|
|
|
def _reconfigure_stdout_utf8() -> None:
|
|
if hasattr(sys.stdout, "reconfigure"):
|
|
try:
|
|
sys.stdout.reconfigure(encoding="utf-8")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _fetch_tables(conn, schema: str) -> list[str]:
|
|
sql = """
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
|
ORDER BY table_name
|
|
"""
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, (schema,))
|
|
return [r[0] for r in cur.fetchall()]
|
|
|
|
|
|
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
|
|
sql = """
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = %s AND table_name = %s
|
|
ORDER BY ordinal_position
|
|
"""
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, (schema, table))
|
|
cols = [r[0] for r in cur.fetchall()]
|
|
return [c for c in cols if c]
|
|
|
|
|
|
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
|
|
sql = """
|
|
SELECT kcu.column_name
|
|
FROM information_schema.table_constraints tc
|
|
JOIN information_schema.key_column_usage kcu
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
AND tc.table_schema = kcu.table_schema
|
|
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
AND tc.table_schema = %s
|
|
AND tc.table_name = %s
|
|
ORDER BY kcu.ordinal_position
|
|
"""
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, (schema, table))
|
|
cols = [r[0] for r in cur.fetchall()]
|
|
return [c for c in cols if c.lower() != "content_hash"]
|
|
|
|
|
|
def _fetch_row_count(conn, schema: str, table: str) -> int:
|
|
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql)
|
|
row = cur.fetchone()
|
|
return int(row[0] if row else 0)
|
|
|
|
|
|
def _iter_rows(
|
|
conn,
|
|
schema: str,
|
|
table: str,
|
|
select_cols: Sequence[str],
|
|
batch_size: int,
|
|
) -> Iterable[dict]:
|
|
cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols)
|
|
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
|
|
with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur:
|
|
cur.itersize = max(1, int(batch_size or 500))
|
|
cur.execute(sql)
|
|
for row in cur:
|
|
yield row
|
|
|
|
|
|
def _build_report_path(out_arg: str | None) -> Path:
|
|
if out_arg:
|
|
return Path(out_arg)
|
|
reports_dir = PROJECT_ROOT / "reports"
|
|
reports_dir.mkdir(parents=True, exist_ok=True)
|
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
return reports_dir / f"ods_content_hash_repair_{ts}.json"
|
|
|
|
|
|
def _print_progress(
|
|
table_label: str,
|
|
processed: int,
|
|
total: int,
|
|
updated: int,
|
|
skipped: int,
|
|
conflicts: int,
|
|
errors: int,
|
|
missing_hash: int,
|
|
invalid_payload: int,
|
|
) -> None:
|
|
if total:
|
|
msg = (
|
|
f"[{table_label}] checked {processed}/{total} "
|
|
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
|
|
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
|
)
|
|
else:
|
|
msg = (
|
|
f"[{table_label}] checked {processed} "
|
|
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
|
|
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
|
)
|
|
print(msg, flush=True)
|
|
|
|
|
|
def main() -> int:
|
|
_reconfigure_stdout_utf8()
|
|
ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload")
|
|
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
|
|
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
|
|
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
|
|
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
|
|
ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table")
|
|
ap.add_argument("--out", default="", help="output report JSON path")
|
|
ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update")
|
|
args = ap.parse_args()
|
|
|
|
cfg = AppConfig.load({})
|
|
db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
|
db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
|
try:
|
|
db_write.conn.rollback()
|
|
except Exception:
|
|
pass
|
|
db_write.conn.autocommit = True
|
|
|
|
tables = _fetch_tables(db_read.conn, args.schema)
|
|
if args.tables.strip():
|
|
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
|
|
tables = [t for t in tables if t in whitelist]
|
|
|
|
report = {
|
|
"schema": args.schema,
|
|
"tables": [],
|
|
"summary": {
|
|
"total_tables": len(tables),
|
|
"checked_tables": 0,
|
|
"total_rows": 0,
|
|
"checked_rows": 0,
|
|
"updated_rows": 0,
|
|
"skipped_rows": 0,
|
|
"conflict_rows": 0,
|
|
"error_rows": 0,
|
|
"missing_hash_rows": 0,
|
|
"invalid_payload_rows": 0,
|
|
},
|
|
}
|
|
|
|
for table in tables:
|
|
table_label = f"{args.schema}.{table}"
|
|
cols = _fetch_columns(db_read.conn, args.schema, table)
|
|
cols_lower = {c.lower() for c in cols}
|
|
if "payload" not in cols_lower or "content_hash" not in cols_lower:
|
|
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
|
|
continue
|
|
|
|
total = _fetch_row_count(db_read.conn, args.schema, table)
|
|
pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table)
|
|
select_cols = ["ctid", "content_hash", "payload", *pk_cols]
|
|
|
|
processed = 0
|
|
updated = 0
|
|
skipped = 0
|
|
conflicts = 0
|
|
errors = 0
|
|
missing_hash = 0
|
|
invalid_payload = 0
|
|
samples: list[dict[str, Any]] = []
|
|
|
|
print(f"[{table_label}] start: total_rows={total}", flush=True)
|
|
|
|
for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size):
|
|
processed += 1
|
|
content_hash = row.get("content_hash")
|
|
payload = row.get("payload")
|
|
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
|
|
row_ctid = row.get("ctid")
|
|
|
|
if not content_hash:
|
|
missing_hash += 1
|
|
if not recomputed:
|
|
invalid_payload += 1
|
|
|
|
if not recomputed:
|
|
skipped += 1
|
|
elif content_hash == recomputed:
|
|
skipped += 1
|
|
else:
|
|
if args.dry_run:
|
|
updated += 1
|
|
else:
|
|
try:
|
|
with db_write.conn.cursor() as cur:
|
|
cur.execute(
|
|
f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s',
|
|
(recomputed, row_ctid),
|
|
)
|
|
updated += 1
|
|
except psycopg2.errors.UniqueViolation:
|
|
conflicts += 1
|
|
if len(samples) < max(0, int(args.sample_limit or 0)):
|
|
sample = {k: row.get(k) for k in pk_cols}
|
|
sample["content_hash"] = content_hash
|
|
sample["recomputed_hash"] = recomputed
|
|
samples.append(sample)
|
|
except psycopg2.Error:
|
|
errors += 1
|
|
|
|
if args.progress_every and processed % int(args.progress_every) == 0:
|
|
_print_progress(
|
|
table_label,
|
|
processed,
|
|
total,
|
|
updated,
|
|
skipped,
|
|
conflicts,
|
|
errors,
|
|
missing_hash,
|
|
invalid_payload,
|
|
)
|
|
|
|
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
|
|
_print_progress(
|
|
table_label,
|
|
processed,
|
|
total,
|
|
updated,
|
|
skipped,
|
|
conflicts,
|
|
errors,
|
|
missing_hash,
|
|
invalid_payload,
|
|
)
|
|
|
|
report["tables"].append(
|
|
{
|
|
"table": table_label,
|
|
"total_rows": total,
|
|
"checked_rows": processed,
|
|
"updated_rows": updated,
|
|
"skipped_rows": skipped,
|
|
"conflict_rows": conflicts,
|
|
"error_rows": errors,
|
|
"missing_hash_rows": missing_hash,
|
|
"invalid_payload_rows": invalid_payload,
|
|
"conflict_samples": samples,
|
|
}
|
|
)
|
|
|
|
report["summary"]["checked_tables"] += 1
|
|
report["summary"]["total_rows"] += total
|
|
report["summary"]["checked_rows"] += processed
|
|
report["summary"]["updated_rows"] += updated
|
|
report["summary"]["skipped_rows"] += skipped
|
|
report["summary"]["conflict_rows"] += conflicts
|
|
report["summary"]["error_rows"] += errors
|
|
report["summary"]["missing_hash_rows"] += missing_hash
|
|
report["summary"]["invalid_payload_rows"] += invalid_payload
|
|
|
|
out_path = _build_report_path(args.out)
|
|
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
print(f"[REPORT] {out_path}", flush=True)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|