# -*- coding: utf-8 -*- """ 迁移到“快照型 ODS + DWD SCD2”: 1) 为所有 ODS 表补充 content_hash,并以 (业务主键, content_hash) 作为新主键; 2) 基于 payload 计算 content_hash,避免重复快照; 3) 为所有 DWD 维度表补齐 SCD2 字段,并调整主键为 (业务主键, scd2_start_time)。 用法: PYTHONPATH=. python -m etl_billiards.scripts.migrate_snapshot_ods --dsn "postgresql://..." 可选参数: --only-ods / --only-dwd --dry-run --batch-size 500 """ from __future__ import annotations import argparse import hashlib import json import os from typing import Any, Iterable, List, Sequence import psycopg2 from psycopg2.extras import execute_batch, RealDictCursor def _hash_default(value): return value.isoformat() if hasattr(value, "isoformat") else str(value) def _sanitize_record_for_hash(record: Any) -> Any: exclude = { "data", "payload", "source_file", "source_endpoint", "fetched_at", "content_hash", "record_index", } def _strip(value): if isinstance(value, dict): cleaned = {} for k, v in value.items(): if isinstance(k, str) and k.lower() in exclude: continue cleaned[k] = _strip(v) return cleaned if isinstance(value, list): return [_strip(v) for v in value] return value return _strip(record or {}) def _compute_content_hash(record: Any) -> str: cleaned = _sanitize_record_for_hash(record) payload = json.dumps( cleaned, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=_hash_default, ) return hashlib.sha256(payload.encode("utf-8")).hexdigest() def _fetch_tables(cur, schema: str) -> List[str]: cur.execute( """ SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE' ORDER BY table_name """, (schema,), ) return [r[0] for r in cur.fetchall()] def _fetch_columns(cur, schema: str, table: str) -> List[str]: cur.execute( """ SELECT column_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """, (schema, table), ) cols = [] for row in cur.fetchall(): if isinstance(row, dict): cols.append(row.get("column_name")) else: cols.append(row[0]) return [c for c in cols if c] def _fetch_pk_constraint(cur, schema: str, table: str) -> tuple[str | None, list[str]]: cur.execute( """ SELECT tc.constraint_name, kcu.column_name, kcu.ordinal_position FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s ORDER BY kcu.ordinal_position """, (schema, table), ) rows = cur.fetchall() if not rows: return None, [] if isinstance(rows[0], dict): name = rows[0].get("constraint_name") cols = [r.get("column_name") for r in rows] else: name = rows[0][0] cols = [r[1] for r in rows] return name, [c for c in cols if c] def _ensure_content_hash_column(cur, schema: str, table: str, dry_run: bool) -> None: cols = _fetch_columns(cur, schema, table) if any(c.lower() == "content_hash" for c in cols): return sql = f'ALTER TABLE "{schema}"."{table}" ADD COLUMN content_hash TEXT' if dry_run: print(f"[DRY] {sql}") return print(f"[ODS] 添加 content_hash: {schema}.{table}") cur.execute(sql) def _backfill_content_hash(conn, schema: str, table: str, batch_size: int, dry_run: bool) -> int: updated = 0 with conn.cursor(cursor_factory=RealDictCursor) as cur: cols = _fetch_columns(cur, schema, table) if "content_hash" not in [c.lower() for c in cols]: return 0 pk_name, pk_cols = _fetch_pk_constraint(cur, schema, table) if not pk_cols: return 0 # 过滤 content_hash pk_cols = [c for c in pk_cols if c.lower() != "content_hash"] select_cols = [*pk_cols] if any(c.lower() == "payload" for c in cols): select_cols.append("payload") else: select_cols.extend([c for c in cols if c.lower() not in {"content_hash"}]) select_cols_sql = ", ".join(f'"{c}"' for c in select_cols) sql = f'SELECT {select_cols_sql} FROM "{schema}"."{table}" WHERE content_hash IS NULL' cur.execute(sql) rows = cur.fetchall() if not rows: return 0 def build_row(row: dict) -> tuple: payload = row.get("payload") if payload is None: payload = {k: v for k, v in row.items() if k.lower() not in {"content_hash", "payload"}} content_hash = _compute_content_hash(payload) key_vals = [row.get(k) for k in pk_cols] return (content_hash, *key_vals) updates = [build_row(r) for r in rows] if dry_run: print(f"[DRY] {schema}.{table}: 预计更新 {len(updates)} 行 content_hash") return len(updates) where_clause = " AND ".join([f'"{c}" = %s' for c in pk_cols]) update_sql = ( f'UPDATE "{schema}"."{table}" SET content_hash = %s ' f'WHERE {where_clause} AND content_hash IS NULL' ) with conn.cursor() as cur2: execute_batch(cur2, update_sql, updates, page_size=batch_size) updated = cur2.rowcount or len(updates) print(f"[ODS] {schema}.{table}: 更新 content_hash {updated} 行") return updated def _ensure_ods_primary_key(cur, schema: str, table: str, dry_run: bool) -> None: name, pk_cols = _fetch_pk_constraint(cur, schema, table) if not pk_cols: return if any(c.lower() == "content_hash" for c in pk_cols): return new_pk = pk_cols + ["content_hash"] drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"' cols_sql = ", ".join([f'"{c}"' for c in new_pk]) add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})' if dry_run: print(f"[DRY] {drop_sql}") print(f"[DRY] {add_sql}") return print(f"[ODS] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})") cur.execute(drop_sql) cur.execute(add_sql) def _migrate_ods(conn, schema: str, batch_size: int, dry_run: bool) -> None: with conn.cursor() as cur: tables = _fetch_tables(cur, schema) for table in tables: with conn.cursor() as cur: _ensure_content_hash_column(cur, schema, table, dry_run) conn.commit() _backfill_content_hash(conn, schema, table, batch_size, dry_run) with conn.cursor() as cur: _ensure_ods_primary_key(cur, schema, table, dry_run) conn.commit() def _backfill_scd2_fields(cur, schema: str, table: str, columns: Sequence[str], dry_run: bool) -> None: lower = {c.lower() for c in columns} fallback_cols = [ "updated_at", "update_time", "created_at", "create_time", "fetched_at", ] fallback = None for col in fallback_cols: if col in lower: fallback = f'"{col}"' break if fallback is None: fallback = "now()" sql = ( f'UPDATE "{schema}"."{table}" ' f'SET scd2_start_time = COALESCE(scd2_start_time, {fallback}), ' f"scd2_end_time = COALESCE(scd2_end_time, TIMESTAMPTZ '9999-12-31'), " f"scd2_is_current = COALESCE(scd2_is_current, 1), " f"scd2_version = COALESCE(scd2_version, 1) " f"WHERE scd2_start_time IS NULL OR scd2_end_time IS NULL OR scd2_is_current IS NULL OR scd2_version IS NULL" ) if dry_run: print(f"[DRY] {sql}") return cur.execute(sql) def _ensure_dwd_primary_key(cur, schema: str, table: str, dry_run: bool) -> None: name, pk_cols = _fetch_pk_constraint(cur, schema, table) if not pk_cols: return if any(c.lower() == "scd2_start_time" for c in pk_cols): return new_pk = pk_cols + ["scd2_start_time"] drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"' cols_sql = ", ".join([f'"{c}"' for c in new_pk]) add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})' if dry_run: print(f"[DRY] {drop_sql}") print(f"[DRY] {add_sql}") return print(f"[DWD] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})") cur.execute(drop_sql) cur.execute(add_sql) def _migrate_dwd(conn, schema: str, dry_run: bool) -> None: with conn.cursor() as cur: cur.execute( """ SELECT DISTINCT table_name FROM information_schema.columns WHERE table_schema = %s AND column_name ILIKE 'scd2_start_time' ORDER BY table_name """, (schema,), ) tables = [r[0] for r in cur.fetchall()] for table in tables: with conn.cursor() as cur: cols = _fetch_columns(cur, schema, table) _backfill_scd2_fields(cur, schema, table, cols, dry_run) conn.commit() with conn.cursor() as cur: _ensure_dwd_primary_key(cur, schema, table, dry_run) conn.commit() def main() -> int: parser = argparse.ArgumentParser(description="迁移 ODS 快照 + DWD SCD2") parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN(也可用环境变量 PG_DSN)") parser.add_argument("--schema-ods", dest="schema_ods", default="billiards_ods") parser.add_argument("--schema-dwd", dest="schema_dwd", default="billiards_dwd") parser.add_argument("--batch-size", dest="batch_size", type=int, default=500) parser.add_argument("--only-ods", dest="only_ods", action="store_true") parser.add_argument("--only-dwd", dest="only_dwd", action="store_true") parser.add_argument("--dry-run", dest="dry_run", action="store_true") args = parser.parse_args() dsn = args.dsn or os.environ.get("PG_DSN") if not dsn: print("缺少 DSN(--dsn 或环境变量 PG_DSN)") return 2 conn = psycopg2.connect(dsn) conn.autocommit = False try: if not args.only_dwd: _migrate_ods(conn, args.schema_ods, args.batch_size, args.dry_run) if not args.only_ods: _migrate_dwd(conn, args.schema_dwd, args.dry_run) return 0 finally: conn.close() if __name__ == "__main__": raise SystemExit(main())