325 lines
11 KiB
Python
325 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
迁移到“快照型 ODS + DWD SCD2”:
|
||
1) 为所有 ODS 表补充 content_hash,并以 (业务主键, content_hash) 作为新主键;
|
||
2) 基于 payload 计算 content_hash,避免重复快照;
|
||
3) 为所有 DWD 维度表补齐 SCD2 字段,并调整主键为 (业务主键, scd2_start_time)。
|
||
|
||
用法:
|
||
PYTHONPATH=. python -m etl_billiards.scripts.migrate_snapshot_ods --dsn "postgresql://..."
|
||
|
||
可选参数:
|
||
--only-ods / --only-dwd
|
||
--dry-run
|
||
--batch-size 500
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import hashlib
|
||
import json
|
||
import os
|
||
from typing import Any, Iterable, List, Sequence
|
||
|
||
import psycopg2
|
||
from psycopg2.extras import execute_batch, RealDictCursor
|
||
|
||
|
||
def _hash_default(value):
|
||
return value.isoformat() if hasattr(value, "isoformat") else str(value)
|
||
|
||
|
||
def _sanitize_record_for_hash(record: Any) -> Any:
|
||
exclude = {
|
||
"data",
|
||
"payload",
|
||
"source_file",
|
||
"source_endpoint",
|
||
"fetched_at",
|
||
"content_hash",
|
||
"record_index",
|
||
}
|
||
|
||
def _strip(value):
|
||
if isinstance(value, dict):
|
||
cleaned = {}
|
||
for k, v in value.items():
|
||
if isinstance(k, str) and k.lower() in exclude:
|
||
continue
|
||
cleaned[k] = _strip(v)
|
||
return cleaned
|
||
if isinstance(value, list):
|
||
return [_strip(v) for v in value]
|
||
return value
|
||
|
||
return _strip(record or {})
|
||
|
||
|
||
def _compute_content_hash(record: Any) -> str:
|
||
cleaned = _sanitize_record_for_hash(record)
|
||
payload = json.dumps(
|
||
cleaned,
|
||
ensure_ascii=False,
|
||
sort_keys=True,
|
||
separators=(",", ":"),
|
||
default=_hash_default,
|
||
)
|
||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||
|
||
|
||
def _fetch_tables(cur, schema: str) -> List[str]:
|
||
cur.execute(
|
||
"""
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||
ORDER BY table_name
|
||
""",
|
||
(schema,),
|
||
)
|
||
return [r[0] for r in cur.fetchall()]
|
||
|
||
|
||
def _fetch_columns(cur, schema: str, table: str) -> List[str]:
|
||
cur.execute(
|
||
"""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_schema = %s AND table_name = %s
|
||
ORDER BY ordinal_position
|
||
""",
|
||
(schema, table),
|
||
)
|
||
cols = []
|
||
for row in cur.fetchall():
|
||
if isinstance(row, dict):
|
||
cols.append(row.get("column_name"))
|
||
else:
|
||
cols.append(row[0])
|
||
return [c for c in cols if c]
|
||
|
||
|
||
def _fetch_pk_constraint(cur, schema: str, table: str) -> tuple[str | None, list[str]]:
|
||
cur.execute(
|
||
"""
|
||
SELECT tc.constraint_name, kcu.column_name, kcu.ordinal_position
|
||
FROM information_schema.table_constraints tc
|
||
JOIN information_schema.key_column_usage kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||
AND tc.table_schema = %s
|
||
AND tc.table_name = %s
|
||
ORDER BY kcu.ordinal_position
|
||
""",
|
||
(schema, table),
|
||
)
|
||
rows = cur.fetchall()
|
||
if not rows:
|
||
return None, []
|
||
if isinstance(rows[0], dict):
|
||
name = rows[0].get("constraint_name")
|
||
cols = [r.get("column_name") for r in rows]
|
||
else:
|
||
name = rows[0][0]
|
||
cols = [r[1] for r in rows]
|
||
return name, [c for c in cols if c]
|
||
|
||
|
||
def _ensure_content_hash_column(cur, schema: str, table: str, dry_run: bool) -> None:
|
||
cols = _fetch_columns(cur, schema, table)
|
||
if any(c.lower() == "content_hash" for c in cols):
|
||
return
|
||
sql = f'ALTER TABLE "{schema}"."{table}" ADD COLUMN content_hash TEXT'
|
||
if dry_run:
|
||
print(f"[DRY] {sql}")
|
||
return
|
||
print(f"[ODS] 添加 content_hash: {schema}.{table}")
|
||
cur.execute(sql)
|
||
|
||
|
||
def _backfill_content_hash(conn, schema: str, table: str, batch_size: int, dry_run: bool) -> int:
|
||
updated = 0
|
||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||
cols = _fetch_columns(cur, schema, table)
|
||
if "content_hash" not in [c.lower() for c in cols]:
|
||
return 0
|
||
pk_name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||
if not pk_cols:
|
||
return 0
|
||
# 过滤 content_hash
|
||
pk_cols = [c for c in pk_cols if c.lower() != "content_hash"]
|
||
select_cols = [*pk_cols]
|
||
if any(c.lower() == "payload" for c in cols):
|
||
select_cols.append("payload")
|
||
else:
|
||
select_cols.extend([c for c in cols if c.lower() not in {"content_hash"}])
|
||
select_cols_sql = ", ".join(f'"{c}"' for c in select_cols)
|
||
sql = f'SELECT {select_cols_sql} FROM "{schema}"."{table}" WHERE content_hash IS NULL'
|
||
cur.execute(sql)
|
||
rows = cur.fetchall()
|
||
|
||
if not rows:
|
||
return 0
|
||
|
||
def build_row(row: dict) -> tuple:
|
||
payload = row.get("payload")
|
||
if payload is None:
|
||
payload = {k: v for k, v in row.items() if k.lower() not in {"content_hash", "payload"}}
|
||
content_hash = _compute_content_hash(payload)
|
||
key_vals = [row.get(k) for k in pk_cols]
|
||
return (content_hash, *key_vals)
|
||
|
||
updates = [build_row(r) for r in rows]
|
||
if dry_run:
|
||
print(f"[DRY] {schema}.{table}: 预计更新 {len(updates)} 行 content_hash")
|
||
return len(updates)
|
||
|
||
where_clause = " AND ".join([f'"{c}" = %s' for c in pk_cols])
|
||
update_sql = (
|
||
f'UPDATE "{schema}"."{table}" SET content_hash = %s '
|
||
f'WHERE {where_clause} AND content_hash IS NULL'
|
||
)
|
||
with conn.cursor() as cur2:
|
||
execute_batch(cur2, update_sql, updates, page_size=batch_size)
|
||
updated = cur2.rowcount or len(updates)
|
||
print(f"[ODS] {schema}.{table}: 更新 content_hash {updated} 行")
|
||
return updated
|
||
|
||
|
||
def _ensure_ods_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
|
||
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||
if not pk_cols:
|
||
return
|
||
if any(c.lower() == "content_hash" for c in pk_cols):
|
||
return
|
||
new_pk = pk_cols + ["content_hash"]
|
||
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
|
||
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
|
||
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
|
||
if dry_run:
|
||
print(f"[DRY] {drop_sql}")
|
||
print(f"[DRY] {add_sql}")
|
||
return
|
||
print(f"[ODS] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
|
||
cur.execute(drop_sql)
|
||
cur.execute(add_sql)
|
||
|
||
|
||
def _migrate_ods(conn, schema: str, batch_size: int, dry_run: bool) -> None:
|
||
with conn.cursor() as cur:
|
||
tables = _fetch_tables(cur, schema)
|
||
for table in tables:
|
||
with conn.cursor() as cur:
|
||
_ensure_content_hash_column(cur, schema, table, dry_run)
|
||
conn.commit()
|
||
_backfill_content_hash(conn, schema, table, batch_size, dry_run)
|
||
with conn.cursor() as cur:
|
||
_ensure_ods_primary_key(cur, schema, table, dry_run)
|
||
conn.commit()
|
||
|
||
|
||
def _backfill_scd2_fields(cur, schema: str, table: str, columns: Sequence[str], dry_run: bool) -> None:
|
||
lower = {c.lower() for c in columns}
|
||
fallback_cols = [
|
||
"updated_at",
|
||
"update_time",
|
||
"created_at",
|
||
"create_time",
|
||
"fetched_at",
|
||
]
|
||
fallback = None
|
||
for col in fallback_cols:
|
||
if col in lower:
|
||
fallback = f'"{col}"'
|
||
break
|
||
if fallback is None:
|
||
fallback = "now()"
|
||
|
||
sql = (
|
||
f'UPDATE "{schema}"."{table}" '
|
||
f'SET scd2_start_time = COALESCE(scd2_start_time, {fallback}), '
|
||
f"scd2_end_time = COALESCE(scd2_end_time, TIMESTAMPTZ '9999-12-31'), "
|
||
f"scd2_is_current = COALESCE(scd2_is_current, 1), "
|
||
f"scd2_version = COALESCE(scd2_version, 1) "
|
||
f"WHERE scd2_start_time IS NULL OR scd2_end_time IS NULL OR scd2_is_current IS NULL OR scd2_version IS NULL"
|
||
)
|
||
if dry_run:
|
||
print(f"[DRY] {sql}")
|
||
return
|
||
cur.execute(sql)
|
||
|
||
|
||
def _ensure_dwd_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
|
||
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||
if not pk_cols:
|
||
return
|
||
if any(c.lower() == "scd2_start_time" for c in pk_cols):
|
||
return
|
||
new_pk = pk_cols + ["scd2_start_time"]
|
||
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
|
||
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
|
||
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
|
||
if dry_run:
|
||
print(f"[DRY] {drop_sql}")
|
||
print(f"[DRY] {add_sql}")
|
||
return
|
||
print(f"[DWD] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
|
||
cur.execute(drop_sql)
|
||
cur.execute(add_sql)
|
||
|
||
|
||
def _migrate_dwd(conn, schema: str, dry_run: bool) -> None:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT DISTINCT table_name
|
||
FROM information_schema.columns
|
||
WHERE table_schema = %s AND column_name ILIKE 'scd2_start_time'
|
||
ORDER BY table_name
|
||
""",
|
||
(schema,),
|
||
)
|
||
tables = [r[0] for r in cur.fetchall()]
|
||
|
||
for table in tables:
|
||
with conn.cursor() as cur:
|
||
cols = _fetch_columns(cur, schema, table)
|
||
_backfill_scd2_fields(cur, schema, table, cols, dry_run)
|
||
conn.commit()
|
||
with conn.cursor() as cur:
|
||
_ensure_dwd_primary_key(cur, schema, table, dry_run)
|
||
conn.commit()
|
||
|
||
|
||
def main() -> int:
|
||
parser = argparse.ArgumentParser(description="迁移 ODS 快照 + DWD SCD2")
|
||
parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN(也可用环境变量 PG_DSN)")
|
||
parser.add_argument("--schema-ods", dest="schema_ods", default="billiards_ods")
|
||
parser.add_argument("--schema-dwd", dest="schema_dwd", default="billiards_dwd")
|
||
parser.add_argument("--batch-size", dest="batch_size", type=int, default=500)
|
||
parser.add_argument("--only-ods", dest="only_ods", action="store_true")
|
||
parser.add_argument("--only-dwd", dest="only_dwd", action="store_true")
|
||
parser.add_argument("--dry-run", dest="dry_run", action="store_true")
|
||
args = parser.parse_args()
|
||
|
||
dsn = args.dsn or os.environ.get("PG_DSN")
|
||
if not dsn:
|
||
print("缺少 DSN(--dsn 或环境变量 PG_DSN)")
|
||
return 2
|
||
|
||
conn = psycopg2.connect(dsn)
|
||
conn.autocommit = False
|
||
try:
|
||
if not args.only_dwd:
|
||
_migrate_ods(conn, args.schema_ods, args.batch_size, args.dry_run)
|
||
if not args.only_ods:
|
||
_migrate_dwd(conn, args.schema_dwd, args.dry_run)
|
||
return 0
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|