Files
feiqiu-ETL/etl_billiards/scripts/migrate_snapshot_ods.py

325 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
迁移到“快照型 ODS + DWD SCD2”
1) 为所有 ODS 表补充 content_hash并以 (业务主键, content_hash) 作为新主键;
2) 基于 payload 计算 content_hash避免重复快照
3) 为所有 DWD 维度表补齐 SCD2 字段,并调整主键为 (业务主键, scd2_start_time)。
用法:
PYTHONPATH=. python -m etl_billiards.scripts.migrate_snapshot_ods --dsn "postgresql://..."
可选参数:
--only-ods / --only-dwd
--dry-run
--batch-size 500
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
from typing import Any, Iterable, List, Sequence
import psycopg2
from psycopg2.extras import execute_batch, RealDictCursor
def _hash_default(value):
return value.isoformat() if hasattr(value, "isoformat") else str(value)
def _sanitize_record_for_hash(record: Any) -> Any:
exclude = {
"data",
"payload",
"source_file",
"source_endpoint",
"fetched_at",
"content_hash",
"record_index",
}
def _strip(value):
if isinstance(value, dict):
cleaned = {}
for k, v in value.items():
if isinstance(k, str) and k.lower() in exclude:
continue
cleaned[k] = _strip(v)
return cleaned
if isinstance(value, list):
return [_strip(v) for v in value]
return value
return _strip(record or {})
def _compute_content_hash(record: Any) -> str:
cleaned = _sanitize_record_for_hash(record)
payload = json.dumps(
cleaned,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=_hash_default,
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _fetch_tables(cur, schema: str) -> List[str]:
cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
""",
(schema,),
)
return [r[0] for r in cur.fetchall()]
def _fetch_columns(cur, schema: str, table: str) -> List[str]:
cur.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
""",
(schema, table),
)
cols = []
for row in cur.fetchall():
if isinstance(row, dict):
cols.append(row.get("column_name"))
else:
cols.append(row[0])
return [c for c in cols if c]
def _fetch_pk_constraint(cur, schema: str, table: str) -> tuple[str | None, list[str]]:
cur.execute(
"""
SELECT tc.constraint_name, kcu.column_name, kcu.ordinal_position
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
""",
(schema, table),
)
rows = cur.fetchall()
if not rows:
return None, []
if isinstance(rows[0], dict):
name = rows[0].get("constraint_name")
cols = [r.get("column_name") for r in rows]
else:
name = rows[0][0]
cols = [r[1] for r in rows]
return name, [c for c in cols if c]
def _ensure_content_hash_column(cur, schema: str, table: str, dry_run: bool) -> None:
cols = _fetch_columns(cur, schema, table)
if any(c.lower() == "content_hash" for c in cols):
return
sql = f'ALTER TABLE "{schema}"."{table}" ADD COLUMN content_hash TEXT'
if dry_run:
print(f"[DRY] {sql}")
return
print(f"[ODS] 添加 content_hash: {schema}.{table}")
cur.execute(sql)
def _backfill_content_hash(conn, schema: str, table: str, batch_size: int, dry_run: bool) -> int:
updated = 0
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cols = _fetch_columns(cur, schema, table)
if "content_hash" not in [c.lower() for c in cols]:
return 0
pk_name, pk_cols = _fetch_pk_constraint(cur, schema, table)
if not pk_cols:
return 0
# 过滤 content_hash
pk_cols = [c for c in pk_cols if c.lower() != "content_hash"]
select_cols = [*pk_cols]
if any(c.lower() == "payload" for c in cols):
select_cols.append("payload")
else:
select_cols.extend([c for c in cols if c.lower() not in {"content_hash"}])
select_cols_sql = ", ".join(f'"{c}"' for c in select_cols)
sql = f'SELECT {select_cols_sql} FROM "{schema}"."{table}" WHERE content_hash IS NULL'
cur.execute(sql)
rows = cur.fetchall()
if not rows:
return 0
def build_row(row: dict) -> tuple:
payload = row.get("payload")
if payload is None:
payload = {k: v for k, v in row.items() if k.lower() not in {"content_hash", "payload"}}
content_hash = _compute_content_hash(payload)
key_vals = [row.get(k) for k in pk_cols]
return (content_hash, *key_vals)
updates = [build_row(r) for r in rows]
if dry_run:
print(f"[DRY] {schema}.{table}: 预计更新 {len(updates)} 行 content_hash")
return len(updates)
where_clause = " AND ".join([f'"{c}" = %s' for c in pk_cols])
update_sql = (
f'UPDATE "{schema}"."{table}" SET content_hash = %s '
f'WHERE {where_clause} AND content_hash IS NULL'
)
with conn.cursor() as cur2:
execute_batch(cur2, update_sql, updates, page_size=batch_size)
updated = cur2.rowcount or len(updates)
print(f"[ODS] {schema}.{table}: 更新 content_hash {updated}")
return updated
def _ensure_ods_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
if not pk_cols:
return
if any(c.lower() == "content_hash" for c in pk_cols):
return
new_pk = pk_cols + ["content_hash"]
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
if dry_run:
print(f"[DRY] {drop_sql}")
print(f"[DRY] {add_sql}")
return
print(f"[ODS] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
cur.execute(drop_sql)
cur.execute(add_sql)
def _migrate_ods(conn, schema: str, batch_size: int, dry_run: bool) -> None:
with conn.cursor() as cur:
tables = _fetch_tables(cur, schema)
for table in tables:
with conn.cursor() as cur:
_ensure_content_hash_column(cur, schema, table, dry_run)
conn.commit()
_backfill_content_hash(conn, schema, table, batch_size, dry_run)
with conn.cursor() as cur:
_ensure_ods_primary_key(cur, schema, table, dry_run)
conn.commit()
def _backfill_scd2_fields(cur, schema: str, table: str, columns: Sequence[str], dry_run: bool) -> None:
lower = {c.lower() for c in columns}
fallback_cols = [
"updated_at",
"update_time",
"created_at",
"create_time",
"fetched_at",
]
fallback = None
for col in fallback_cols:
if col in lower:
fallback = f'"{col}"'
break
if fallback is None:
fallback = "now()"
sql = (
f'UPDATE "{schema}"."{table}" '
f'SET scd2_start_time = COALESCE(scd2_start_time, {fallback}), '
f"scd2_end_time = COALESCE(scd2_end_time, TIMESTAMPTZ '9999-12-31'), "
f"scd2_is_current = COALESCE(scd2_is_current, 1), "
f"scd2_version = COALESCE(scd2_version, 1) "
f"WHERE scd2_start_time IS NULL OR scd2_end_time IS NULL OR scd2_is_current IS NULL OR scd2_version IS NULL"
)
if dry_run:
print(f"[DRY] {sql}")
return
cur.execute(sql)
def _ensure_dwd_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
if not pk_cols:
return
if any(c.lower() == "scd2_start_time" for c in pk_cols):
return
new_pk = pk_cols + ["scd2_start_time"]
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
if dry_run:
print(f"[DRY] {drop_sql}")
print(f"[DRY] {add_sql}")
return
print(f"[DWD] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
cur.execute(drop_sql)
cur.execute(add_sql)
def _migrate_dwd(conn, schema: str, dry_run: bool) -> None:
with conn.cursor() as cur:
cur.execute(
"""
SELECT DISTINCT table_name
FROM information_schema.columns
WHERE table_schema = %s AND column_name ILIKE 'scd2_start_time'
ORDER BY table_name
""",
(schema,),
)
tables = [r[0] for r in cur.fetchall()]
for table in tables:
with conn.cursor() as cur:
cols = _fetch_columns(cur, schema, table)
_backfill_scd2_fields(cur, schema, table, cols, dry_run)
conn.commit()
with conn.cursor() as cur:
_ensure_dwd_primary_key(cur, schema, table, dry_run)
conn.commit()
def main() -> int:
parser = argparse.ArgumentParser(description="迁移 ODS 快照 + DWD SCD2")
parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN也可用环境变量 PG_DSN")
parser.add_argument("--schema-ods", dest="schema_ods", default="billiards_ods")
parser.add_argument("--schema-dwd", dest="schema_dwd", default="billiards_dwd")
parser.add_argument("--batch-size", dest="batch_size", type=int, default=500)
parser.add_argument("--only-ods", dest="only_ods", action="store_true")
parser.add_argument("--only-dwd", dest="only_dwd", action="store_true")
parser.add_argument("--dry-run", dest="dry_run", action="store_true")
args = parser.parse_args()
dsn = args.dsn or os.environ.get("PG_DSN")
if not dsn:
print("缺少 DSN--dsn 或环境变量 PG_DSN")
return 2
conn = psycopg2.connect(dsn)
conn.autocommit = False
try:
if not args.only_dwd:
_migrate_ods(conn, args.schema_ods, args.batch_size, args.dry_run)
if not args.only_ods:
_migrate_dwd(conn, args.schema_dwd, args.dry_run)
return 0
finally:
conn.close()
if __name__ == "__main__":
raise SystemExit(main())