数据库 数据校验写入等逻辑更新。
This commit is contained in:
324
etl_billiards/scripts/migrate_snapshot_ods.py
Normal file
324
etl_billiards/scripts/migrate_snapshot_ods.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
迁移到“快照型 ODS + DWD SCD2”:
|
||||
1) 为所有 ODS 表补充 content_hash,并以 (业务主键, content_hash) 作为新主键;
|
||||
2) 基于 payload 计算 content_hash,避免重复快照;
|
||||
3) 为所有 DWD 维度表补齐 SCD2 字段,并调整主键为 (业务主键, scd2_start_time)。
|
||||
|
||||
用法:
|
||||
PYTHONPATH=. python -m etl_billiards.scripts.migrate_snapshot_ods --dsn "postgresql://..."
|
||||
|
||||
可选参数:
|
||||
--only-ods / --only-dwd
|
||||
--dry-run
|
||||
--batch-size 500
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Iterable, List, Sequence
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_batch, RealDictCursor
|
||||
|
||||
|
||||
def _hash_default(value):
|
||||
return value.isoformat() if hasattr(value, "isoformat") else str(value)
|
||||
|
||||
|
||||
def _sanitize_record_for_hash(record: Any) -> Any:
|
||||
exclude = {
|
||||
"data",
|
||||
"payload",
|
||||
"source_file",
|
||||
"source_endpoint",
|
||||
"fetched_at",
|
||||
"content_hash",
|
||||
"record_index",
|
||||
}
|
||||
|
||||
def _strip(value):
|
||||
if isinstance(value, dict):
|
||||
cleaned = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(k, str) and k.lower() in exclude:
|
||||
continue
|
||||
cleaned[k] = _strip(v)
|
||||
return cleaned
|
||||
if isinstance(value, list):
|
||||
return [_strip(v) for v in value]
|
||||
return value
|
||||
|
||||
return _strip(record or {})
|
||||
|
||||
|
||||
def _compute_content_hash(record: Any) -> str:
|
||||
cleaned = _sanitize_record_for_hash(record)
|
||||
payload = json.dumps(
|
||||
cleaned,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=_hash_default,
|
||||
)
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _fetch_tables(cur, schema: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
""",
|
||||
(schema,),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(cur, schema: str, table: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
cols = []
|
||||
for row in cur.fetchall():
|
||||
if isinstance(row, dict):
|
||||
cols.append(row.get("column_name"))
|
||||
else:
|
||||
cols.append(row[0])
|
||||
return [c for c in cols if c]
|
||||
|
||||
|
||||
def _fetch_pk_constraint(cur, schema: str, table: str) -> tuple[str | None, list[str]]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tc.constraint_name, kcu.column_name, kcu.ordinal_position
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
if not rows:
|
||||
return None, []
|
||||
if isinstance(rows[0], dict):
|
||||
name = rows[0].get("constraint_name")
|
||||
cols = [r.get("column_name") for r in rows]
|
||||
else:
|
||||
name = rows[0][0]
|
||||
cols = [r[1] for r in rows]
|
||||
return name, [c for c in cols if c]
|
||||
|
||||
|
||||
def _ensure_content_hash_column(cur, schema: str, table: str, dry_run: bool) -> None:
|
||||
cols = _fetch_columns(cur, schema, table)
|
||||
if any(c.lower() == "content_hash" for c in cols):
|
||||
return
|
||||
sql = f'ALTER TABLE "{schema}"."{table}" ADD COLUMN content_hash TEXT'
|
||||
if dry_run:
|
||||
print(f"[DRY] {sql}")
|
||||
return
|
||||
print(f"[ODS] 添加 content_hash: {schema}.{table}")
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
def _backfill_content_hash(conn, schema: str, table: str, batch_size: int, dry_run: bool) -> int:
|
||||
updated = 0
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
cols = _fetch_columns(cur, schema, table)
|
||||
if "content_hash" not in [c.lower() for c in cols]:
|
||||
return 0
|
||||
pk_name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||||
if not pk_cols:
|
||||
return 0
|
||||
# 过滤 content_hash
|
||||
pk_cols = [c for c in pk_cols if c.lower() != "content_hash"]
|
||||
select_cols = [*pk_cols]
|
||||
if any(c.lower() == "payload" for c in cols):
|
||||
select_cols.append("payload")
|
||||
else:
|
||||
select_cols.extend([c for c in cols if c.lower() not in {"content_hash"}])
|
||||
select_cols_sql = ", ".join(f'"{c}"' for c in select_cols)
|
||||
sql = f'SELECT {select_cols_sql} FROM "{schema}"."{table}" WHERE content_hash IS NULL'
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
def build_row(row: dict) -> tuple:
|
||||
payload = row.get("payload")
|
||||
if payload is None:
|
||||
payload = {k: v for k, v in row.items() if k.lower() not in {"content_hash", "payload"}}
|
||||
content_hash = _compute_content_hash(payload)
|
||||
key_vals = [row.get(k) for k in pk_cols]
|
||||
return (content_hash, *key_vals)
|
||||
|
||||
updates = [build_row(r) for r in rows]
|
||||
if dry_run:
|
||||
print(f"[DRY] {schema}.{table}: 预计更新 {len(updates)} 行 content_hash")
|
||||
return len(updates)
|
||||
|
||||
where_clause = " AND ".join([f'"{c}" = %s' for c in pk_cols])
|
||||
update_sql = (
|
||||
f'UPDATE "{schema}"."{table}" SET content_hash = %s '
|
||||
f'WHERE {where_clause} AND content_hash IS NULL'
|
||||
)
|
||||
with conn.cursor() as cur2:
|
||||
execute_batch(cur2, update_sql, updates, page_size=batch_size)
|
||||
updated = cur2.rowcount or len(updates)
|
||||
print(f"[ODS] {schema}.{table}: 更新 content_hash {updated} 行")
|
||||
return updated
|
||||
|
||||
|
||||
def _ensure_ods_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
|
||||
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||||
if not pk_cols:
|
||||
return
|
||||
if any(c.lower() == "content_hash" for c in pk_cols):
|
||||
return
|
||||
new_pk = pk_cols + ["content_hash"]
|
||||
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
|
||||
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
|
||||
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
|
||||
if dry_run:
|
||||
print(f"[DRY] {drop_sql}")
|
||||
print(f"[DRY] {add_sql}")
|
||||
return
|
||||
print(f"[ODS] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
|
||||
cur.execute(drop_sql)
|
||||
cur.execute(add_sql)
|
||||
|
||||
|
||||
def _migrate_ods(conn, schema: str, batch_size: int, dry_run: bool) -> None:
|
||||
with conn.cursor() as cur:
|
||||
tables = _fetch_tables(cur, schema)
|
||||
for table in tables:
|
||||
with conn.cursor() as cur:
|
||||
_ensure_content_hash_column(cur, schema, table, dry_run)
|
||||
conn.commit()
|
||||
_backfill_content_hash(conn, schema, table, batch_size, dry_run)
|
||||
with conn.cursor() as cur:
|
||||
_ensure_ods_primary_key(cur, schema, table, dry_run)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _backfill_scd2_fields(cur, schema: str, table: str, columns: Sequence[str], dry_run: bool) -> None:
|
||||
lower = {c.lower() for c in columns}
|
||||
fallback_cols = [
|
||||
"updated_at",
|
||||
"update_time",
|
||||
"created_at",
|
||||
"create_time",
|
||||
"fetched_at",
|
||||
]
|
||||
fallback = None
|
||||
for col in fallback_cols:
|
||||
if col in lower:
|
||||
fallback = f'"{col}"'
|
||||
break
|
||||
if fallback is None:
|
||||
fallback = "now()"
|
||||
|
||||
sql = (
|
||||
f'UPDATE "{schema}"."{table}" '
|
||||
f'SET scd2_start_time = COALESCE(scd2_start_time, {fallback}), '
|
||||
f"scd2_end_time = COALESCE(scd2_end_time, TIMESTAMPTZ '9999-12-31'), "
|
||||
f"scd2_is_current = COALESCE(scd2_is_current, 1), "
|
||||
f"scd2_version = COALESCE(scd2_version, 1) "
|
||||
f"WHERE scd2_start_time IS NULL OR scd2_end_time IS NULL OR scd2_is_current IS NULL OR scd2_version IS NULL"
|
||||
)
|
||||
if dry_run:
|
||||
print(f"[DRY] {sql}")
|
||||
return
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
def _ensure_dwd_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
|
||||
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||||
if not pk_cols:
|
||||
return
|
||||
if any(c.lower() == "scd2_start_time" for c in pk_cols):
|
||||
return
|
||||
new_pk = pk_cols + ["scd2_start_time"]
|
||||
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
|
||||
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
|
||||
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
|
||||
if dry_run:
|
||||
print(f"[DRY] {drop_sql}")
|
||||
print(f"[DRY] {add_sql}")
|
||||
return
|
||||
print(f"[DWD] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
|
||||
cur.execute(drop_sql)
|
||||
cur.execute(add_sql)
|
||||
|
||||
|
||||
def _migrate_dwd(conn, schema: str, dry_run: bool) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT table_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND column_name ILIKE 'scd2_start_time'
|
||||
ORDER BY table_name
|
||||
""",
|
||||
(schema,),
|
||||
)
|
||||
tables = [r[0] for r in cur.fetchall()]
|
||||
|
||||
for table in tables:
|
||||
with conn.cursor() as cur:
|
||||
cols = _fetch_columns(cur, schema, table)
|
||||
_backfill_scd2_fields(cur, schema, table, cols, dry_run)
|
||||
conn.commit()
|
||||
with conn.cursor() as cur:
|
||||
_ensure_dwd_primary_key(cur, schema, table, dry_run)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="迁移 ODS 快照 + DWD SCD2")
|
||||
parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN(也可用环境变量 PG_DSN)")
|
||||
parser.add_argument("--schema-ods", dest="schema_ods", default="billiards_ods")
|
||||
parser.add_argument("--schema-dwd", dest="schema_dwd", default="billiards_dwd")
|
||||
parser.add_argument("--batch-size", dest="batch_size", type=int, default=500)
|
||||
parser.add_argument("--only-ods", dest="only_ods", action="store_true")
|
||||
parser.add_argument("--only-dwd", dest="only_dwd", action="store_true")
|
||||
parser.add_argument("--dry-run", dest="dry_run", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
dsn = args.dsn or os.environ.get("PG_DSN")
|
||||
if not dsn:
|
||||
print("缺少 DSN(--dsn 或环境变量 PG_DSN)")
|
||||
return 2
|
||||
|
||||
conn = psycopg2.connect(dsn)
|
||||
conn.autocommit = False
|
||||
try:
|
||||
if not args.only_dwd:
|
||||
_migrate_ods(conn, args.schema_ods, args.batch_size, args.dry_run)
|
||||
if not args.only_ods:
|
||||
_migrate_dwd(conn, args.schema_dwd, args.dry_run)
|
||||
return 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user