数据库 数据校验写入等逻辑更新。
This commit is contained in:
1
etl_billiards/scripts/Untitled
Normal file
1
etl_billiards/scripts/Untitled
Normal file
@@ -0,0 +1 @@
|
||||
check_data_integrity.py
|
||||
@@ -32,9 +32,15 @@ from api.client import APIClient
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from models.parsers import TypeParser
|
||||
from tasks.ods_tasks import ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
|
||||
from tasks.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
|
||||
from scripts.check_ods_gaps import run_gap_check
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
from utils.ods_record_utils import (
|
||||
get_value_case_insensitive,
|
||||
merge_record_layers,
|
||||
normalize_pk_value,
|
||||
pk_tuple_from_record,
|
||||
)
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
@@ -74,56 +80,26 @@ def _get_spec(code: str) -> Optional[OdsTaskSpec]:
|
||||
|
||||
|
||||
def _merge_record_layers(record: dict) -> dict:
|
||||
"""展开嵌套的 data 层"""
|
||||
merged = record
|
||||
data_part = merged.get("data")
|
||||
while isinstance(data_part, dict):
|
||||
merged = {**data_part, **merged}
|
||||
data_part = data_part.get("data")
|
||||
settle_inner = merged.get("settleList")
|
||||
if isinstance(settle_inner, dict):
|
||||
merged = {**settle_inner, **merged}
|
||||
return merged
|
||||
"""Flatten nested data layers into a single dict."""
|
||||
return merge_record_layers(record)
|
||||
|
||||
|
||||
def _get_value_case_insensitive(record: dict | None, col: str | None):
|
||||
"""不区分大小写地获取值"""
|
||||
if record is None or col is None:
|
||||
return None
|
||||
if col in record:
|
||||
return record.get(col)
|
||||
col_lower = col.lower()
|
||||
for k, v in record.items():
|
||||
if isinstance(k, str) and k.lower() == col_lower:
|
||||
return v
|
||||
return None
|
||||
"""Fetch value without case sensitivity."""
|
||||
return get_value_case_insensitive(record, col)
|
||||
|
||||
|
||||
def _normalize_pk_value(value):
|
||||
"""规范化 PK 值"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return value
|
||||
return value
|
||||
"""Normalize PK value."""
|
||||
return normalize_pk_value(value)
|
||||
|
||||
|
||||
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
|
||||
"""从记录中提取 PK 元组"""
|
||||
merged = _merge_record_layers(record)
|
||||
values = []
|
||||
for col in pk_cols:
|
||||
val = _normalize_pk_value(_get_value_case_insensitive(merged, col))
|
||||
if val is None or val == "":
|
||||
return None
|
||||
values.append(val)
|
||||
return tuple(values)
|
||||
"""Extract PK tuple from record."""
|
||||
return pk_tuple_from_record(record, pk_cols)
|
||||
|
||||
|
||||
def _get_table_pk_columns(conn, table: str) -> List[str]:
|
||||
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
|
||||
"""获取表的主键列"""
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
@@ -142,7 +118,10 @@ def _get_table_pk_columns(conn, table: str) -> List[str]:
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
if include_content_hash:
|
||||
return cols
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
|
||||
@@ -247,6 +226,13 @@ class MissingDataBackfiller:
|
||||
"""关闭连接"""
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
||||
def _ensure_db(self):
|
||||
"""确保数据库连接可用"""
|
||||
if self.db and getattr(self.db, "conn", None) is not None:
|
||||
if getattr(self.db.conn, "closed", 0) == 0:
|
||||
return
|
||||
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
|
||||
|
||||
def backfill_from_gap_check(
|
||||
self,
|
||||
@@ -254,8 +240,10 @@ class MissingDataBackfiller:
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
task_codes: Optional[str] = None,
|
||||
include_mismatch: bool = False,
|
||||
page_size: int = 200,
|
||||
chunk_size: int = 500,
|
||||
content_sample_limit: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行 gap check 并补全丢失数据
|
||||
@@ -292,16 +280,21 @@ class MissingDataBackfiller:
|
||||
cutoff_overlap_hours=24,
|
||||
allow_small_window=True,
|
||||
logger=self.logger,
|
||||
compare_content=include_mismatch,
|
||||
content_sample_limit=content_sample_limit or 10000,
|
||||
)
|
||||
|
||||
total_missing = gap_result.get("total_missing", 0)
|
||||
if total_missing == 0:
|
||||
self.logger.info("数据完整,无缺失记录")
|
||||
total_mismatch = gap_result.get("total_mismatch", 0)
|
||||
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
|
||||
self.logger.info("Data complete: no missing/mismatch records")
|
||||
return {"backfilled": 0, "errors": 0, "details": []}
|
||||
|
||||
self.logger.info("缺失检查完成 总缺失=%s", total_missing)
|
||||
if include_mismatch:
|
||||
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
|
||||
else:
|
||||
self.logger.info("Missing check done missing=%s", total_missing)
|
||||
|
||||
# 补全每个任务的丢失数据
|
||||
results = []
|
||||
total_backfilled = 0
|
||||
total_errors = 0
|
||||
@@ -310,13 +303,16 @@ class MissingDataBackfiller:
|
||||
task_code = task_result.get("task_code")
|
||||
missing = task_result.get("missing", 0)
|
||||
missing_samples = task_result.get("missing_samples", [])
|
||||
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
|
||||
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
|
||||
target_samples = list(missing_samples) + list(mismatch_samples)
|
||||
|
||||
if missing == 0:
|
||||
if missing == 0 and mismatch == 0:
|
||||
continue
|
||||
|
||||
self.logger.info(
|
||||
"开始补全任务 任务=%s 缺失=%s 样本数=%s",
|
||||
task_code, missing, len(missing_samples)
|
||||
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
|
||||
task_code, missing, mismatch, len(target_samples)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -324,7 +320,7 @@ class MissingDataBackfiller:
|
||||
task_code=task_code,
|
||||
table=task_result.get("table"),
|
||||
pk_columns=task_result.get("pk_columns", []),
|
||||
missing_samples=missing_samples,
|
||||
pk_samples=target_samples,
|
||||
start=start,
|
||||
end=end,
|
||||
page_size=page_size,
|
||||
@@ -333,6 +329,7 @@ class MissingDataBackfiller:
|
||||
results.append({
|
||||
"task_code": task_code,
|
||||
"missing": missing,
|
||||
"mismatch": mismatch,
|
||||
"backfilled": backfilled,
|
||||
"error": None,
|
||||
})
|
||||
@@ -342,6 +339,7 @@ class MissingDataBackfiller:
|
||||
results.append({
|
||||
"task_code": task_code,
|
||||
"missing": missing,
|
||||
"mismatch": mismatch,
|
||||
"backfilled": 0,
|
||||
"error": str(exc),
|
||||
})
|
||||
@@ -354,6 +352,7 @@ class MissingDataBackfiller:
|
||||
|
||||
return {
|
||||
"total_missing": total_missing,
|
||||
"total_mismatch": total_mismatch,
|
||||
"backfilled": total_backfilled,
|
||||
"errors": total_errors,
|
||||
"details": results,
|
||||
@@ -365,20 +364,25 @@ class MissingDataBackfiller:
|
||||
task_code: str,
|
||||
table: str,
|
||||
pk_columns: List[str],
|
||||
missing_samples: List[Dict],
|
||||
pk_samples: List[Dict],
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> int:
|
||||
"""补全单个任务的丢失数据"""
|
||||
self._ensure_db()
|
||||
spec = _get_spec(task_code)
|
||||
if not spec:
|
||||
self.logger.warning("未找到任务规格 任务=%s", task_code)
|
||||
return 0
|
||||
|
||||
if not pk_columns:
|
||||
pk_columns = _get_table_pk_columns(self.db.conn, table)
|
||||
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
|
||||
|
||||
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
|
||||
if not conflict_columns:
|
||||
conflict_columns = pk_columns
|
||||
|
||||
if not pk_columns:
|
||||
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
|
||||
@@ -386,7 +390,7 @@ class MissingDataBackfiller:
|
||||
|
||||
# 提取丢失的 PK 值
|
||||
missing_pks: Set[Tuple] = set()
|
||||
for sample in missing_samples:
|
||||
for sample in pk_samples:
|
||||
pk_tuple = tuple(sample.get(col) for col in pk_columns)
|
||||
if all(v is not None for v in pk_tuple):
|
||||
missing_pks.add(pk_tuple)
|
||||
@@ -410,6 +414,12 @@ class MissingDataBackfiller:
|
||||
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
|
||||
}
|
||||
col_names = [c[0] for c in cols_info]
|
||||
|
||||
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
|
||||
try:
|
||||
self.db.conn.commit()
|
||||
except Exception:
|
||||
self.db.conn.rollback()
|
||||
|
||||
try:
|
||||
for page_no, records, _, response_payload in self.api.iter_paginated(
|
||||
@@ -444,9 +454,12 @@ class MissingDataBackfiller:
|
||||
records=records_to_insert,
|
||||
cols_info=cols_info,
|
||||
pk_columns=pk_columns,
|
||||
conflict_columns=conflict_columns,
|
||||
db_json_cols_lower=db_json_cols_lower,
|
||||
)
|
||||
backfilled += inserted
|
||||
# 避免长事务阻塞与 idle_in_tx 超时
|
||||
self.db.conn.commit()
|
||||
self.logger.info(
|
||||
"已插入 任务=%s 页=%s 数量=%s",
|
||||
task_code, page_no, inserted
|
||||
@@ -498,6 +511,7 @@ class MissingDataBackfiller:
|
||||
records: List[Dict],
|
||||
cols_info: List[Tuple[str, str, str]],
|
||||
pk_columns: List[str],
|
||||
conflict_columns: List[str],
|
||||
db_json_cols_lower: Set[str],
|
||||
) -> int:
|
||||
"""插入记录到数据库"""
|
||||
@@ -505,10 +519,12 @@ class MissingDataBackfiller:
|
||||
return 0
|
||||
|
||||
col_names = [c[0] for c in cols_info]
|
||||
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
|
||||
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
|
||||
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
|
||||
if pk_columns:
|
||||
pk_clause = ", ".join(f'"{c}"' for c in pk_columns)
|
||||
conflict_cols = conflict_columns or pk_columns
|
||||
if conflict_cols:
|
||||
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
|
||||
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
|
||||
|
||||
now = datetime.now(self.tz)
|
||||
@@ -522,12 +538,20 @@ class MissingDataBackfiller:
|
||||
if pk_columns:
|
||||
missing_pk = False
|
||||
for pk in pk_columns:
|
||||
if str(pk).lower() == "content_hash":
|
||||
continue
|
||||
pk_val = _get_value_case_insensitive(merged_rec, pk)
|
||||
if pk_val is None or pk_val == "":
|
||||
missing_pk = True
|
||||
break
|
||||
if missing_pk:
|
||||
continue
|
||||
|
||||
content_hash = None
|
||||
if needs_content_hash:
|
||||
hash_record = dict(merged_rec)
|
||||
hash_record["fetched_at"] = now
|
||||
content_hash = BaseOdsTask._compute_content_hash(hash_record, include_fetched_at=True)
|
||||
|
||||
row_vals: List[Any] = []
|
||||
for (col_name, data_type, _udt) in cols_info:
|
||||
@@ -544,6 +568,9 @@ class MissingDataBackfiller:
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(now)
|
||||
continue
|
||||
if col_lower == "content_hash":
|
||||
row_vals.append(content_hash)
|
||||
continue
|
||||
|
||||
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
|
||||
if col_lower in db_json_cols_lower:
|
||||
@@ -574,9 +601,11 @@ def run_backfill(
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
task_codes: Optional[str] = None,
|
||||
include_mismatch: bool = False,
|
||||
dry_run: bool = False,
|
||||
page_size: int = 200,
|
||||
chunk_size: int = 500,
|
||||
content_sample_limit: int | None = None,
|
||||
logger: logging.Logger,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -601,8 +630,10 @@ def run_backfill(
|
||||
start=start,
|
||||
end=end,
|
||||
task_codes=task_codes,
|
||||
include_mismatch=include_mismatch,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
content_sample_limit=content_sample_limit,
|
||||
)
|
||||
finally:
|
||||
backfiller.close()
|
||||
@@ -615,6 +646,8 @@ def main() -> int:
|
||||
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
|
||||
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
|
||||
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
|
||||
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
|
||||
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
|
||||
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
|
||||
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
|
||||
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
|
||||
@@ -646,15 +679,19 @@ def main() -> int:
|
||||
start=start,
|
||||
end=end,
|
||||
task_codes=args.task_codes or None,
|
||||
include_mismatch=args.include_mismatch,
|
||||
dry_run=args.dry_run,
|
||||
page_size=args.page_size,
|
||||
chunk_size=args.chunk_size,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("补全完成!")
|
||||
logger.info(" 总丢失: %s", result.get("total_missing", 0))
|
||||
if args.include_mismatch:
|
||||
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
|
||||
logger.info(" 已补全: %s", result.get("backfilled", 0))
|
||||
logger.info(" 错误数: %s", result.get("errors", 0))
|
||||
logger.info("=" * 60)
|
||||
@@ -663,17 +700,19 @@ def main() -> int:
|
||||
for detail in result.get("details", []):
|
||||
if detail.get("error"):
|
||||
logger.error(
|
||||
" %s: 丢失=%s 补全=%s 错误=%s",
|
||||
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
|
||||
detail.get("task_code"),
|
||||
detail.get("missing"),
|
||||
detail.get("mismatch", 0),
|
||||
detail.get("backfilled"),
|
||||
detail.get("error"),
|
||||
)
|
||||
elif detail.get("backfilled", 0) > 0:
|
||||
logger.info(
|
||||
" %s: 丢失=%s 补全=%s",
|
||||
" %s: 丢失=%s 不一致=%s 补全=%s",
|
||||
detail.get("task_code"),
|
||||
detail.get("missing"),
|
||||
detail.get("mismatch", 0),
|
||||
detail.get("backfilled"),
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -12,12 +11,7 @@ from zoneinfo import ZoneInfo
|
||||
from dateutil import parser as dtparser
|
||||
|
||||
from config.settings import AppConfig
|
||||
from quality.integrity_checker import (
|
||||
IntegrityWindow,
|
||||
compute_last_etl_end,
|
||||
run_integrity_history,
|
||||
run_integrity_window,
|
||||
)
|
||||
from quality.integrity_service import run_history_flow, run_window_flow, write_report
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
from utils.windowing import split_window
|
||||
|
||||
@@ -38,14 +32,37 @@ def main() -> int:
|
||||
|
||||
ap = argparse.ArgumentParser(description="Data integrity checks (API -> ODS -> DWD)")
|
||||
ap.add_argument("--mode", choices=["history", "window"], default="history")
|
||||
ap.add_argument(
|
||||
"--flow",
|
||||
choices=["verify", "update_and_verify"],
|
||||
default="verify",
|
||||
help="verify only or update+verify (auto backfill then optional recheck)",
|
||||
)
|
||||
ap.add_argument("--start", default="2025-07-01", help="history start date (default: 2025-07-01)")
|
||||
ap.add_argument("--end", default="", help="history end datetime (default: last ETL end)")
|
||||
ap.add_argument("--window-start", default="", help="window start datetime (mode=window)")
|
||||
ap.add_argument("--window-end", default="", help="window end datetime (mode=window)")
|
||||
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
|
||||
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
|
||||
ap.add_argument("--include-dimensions", action="store_true", help="include dimension tables in ODS->DWD checks")
|
||||
ap.add_argument(
|
||||
"--include-dimensions",
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="include dimension tables in ODS->DWD checks",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-include-dimensions",
|
||||
action="store_true",
|
||||
help="exclude dimension tables in ODS->DWD checks",
|
||||
)
|
||||
ap.add_argument("--ods-task-codes", default="", help="comma-separated ODS task codes for API checks")
|
||||
ap.add_argument("--compare-content", action="store_true", help="compare API vs ODS content hash")
|
||||
ap.add_argument("--no-compare-content", action="store_true", help="disable content comparison even if enabled in config")
|
||||
ap.add_argument("--include-mismatch", action="store_true", help="backfill mismatch records as well")
|
||||
ap.add_argument("--no-include-mismatch", action="store_true", help="disable mismatch backfill")
|
||||
ap.add_argument("--recheck", action="store_true", help="re-run checks after backfill")
|
||||
ap.add_argument("--no-recheck", action="store_true", help="skip recheck after backfill")
|
||||
ap.add_argument("--content-sample-limit", type=int, default=None, help="max mismatch samples per table")
|
||||
ap.add_argument("--out", default="", help="output JSON path")
|
||||
ap.add_argument("--log-file", default="", help="log file path")
|
||||
ap.add_argument("--log-dir", default="", help="log directory")
|
||||
@@ -68,6 +85,39 @@ def main() -> int:
|
||||
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
|
||||
report_path = Path(args.out) if args.out else None
|
||||
|
||||
if args.recheck and args.no_recheck:
|
||||
raise SystemExit("cannot set both --recheck and --no-recheck")
|
||||
if args.include_mismatch and args.no_include_mismatch:
|
||||
raise SystemExit("cannot set both --include-mismatch and --no-include-mismatch")
|
||||
if args.include_dimensions and args.no_include_dimensions:
|
||||
raise SystemExit("cannot set both --include-dimensions and --no-include-dimensions")
|
||||
|
||||
compare_content = None
|
||||
if args.compare_content and args.no_compare_content:
|
||||
raise SystemExit("cannot set both --compare-content and --no-compare-content")
|
||||
if args.compare_content:
|
||||
compare_content = True
|
||||
elif args.no_compare_content:
|
||||
compare_content = False
|
||||
|
||||
include_mismatch = cfg.get("integrity.backfill_mismatch", True)
|
||||
if args.include_mismatch:
|
||||
include_mismatch = True
|
||||
elif args.no_include_mismatch:
|
||||
include_mismatch = False
|
||||
|
||||
recheck_after_backfill = cfg.get("integrity.recheck_after_backfill", True)
|
||||
if args.recheck:
|
||||
recheck_after_backfill = True
|
||||
elif args.no_recheck:
|
||||
recheck_after_backfill = False
|
||||
|
||||
include_dimensions = cfg.get("integrity.include_dimensions", True)
|
||||
if args.include_dimensions:
|
||||
include_dimensions = True
|
||||
elif args.no_include_dimensions:
|
||||
include_dimensions = False
|
||||
|
||||
if args.mode == "window":
|
||||
if not args.window_start or not args.window_end:
|
||||
raise SystemExit("window-start and window-end are required for mode=window")
|
||||
@@ -88,78 +138,52 @@ def main() -> int:
|
||||
if not windows:
|
||||
windows = [(start_dt, end_dt)]
|
||||
|
||||
window_reports = []
|
||||
total_missing = 0
|
||||
total_errors = 0
|
||||
for idx, (seg_start, seg_end) in enumerate(windows, start=1):
|
||||
window = IntegrityWindow(
|
||||
start=seg_start,
|
||||
end=seg_end,
|
||||
label=f"segment_{idx}",
|
||||
granularity="window",
|
||||
)
|
||||
payload = run_integrity_window(
|
||||
cfg=cfg,
|
||||
window=window,
|
||||
include_dimensions=args.include_dimensions,
|
||||
task_codes=args.ods_task_codes,
|
||||
logger=logger,
|
||||
write_report=False,
|
||||
report_path=None,
|
||||
window_split_unit="none",
|
||||
window_compensation_hours=0,
|
||||
)
|
||||
window_reports.append(payload)
|
||||
total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0)
|
||||
total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0)
|
||||
|
||||
overall_start = windows[0][0]
|
||||
overall_end = windows[-1][1]
|
||||
report = {
|
||||
"mode": "window",
|
||||
"window": {
|
||||
"start": overall_start.isoformat(),
|
||||
"end": overall_end.isoformat(),
|
||||
"segments": len(windows),
|
||||
},
|
||||
"windows": window_reports,
|
||||
"api_to_ods": {
|
||||
"total_missing": total_missing,
|
||||
"total_errors": total_errors,
|
||||
},
|
||||
"total_missing": total_missing,
|
||||
"total_errors": total_errors,
|
||||
"generated_at": datetime.now(tz).isoformat(),
|
||||
}
|
||||
if report_path is None:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
stamp = datetime.now(tz).strftime("%Y%m%d_%H%M%S")
|
||||
report_path = root / "reports" / f"data_integrity_window_{stamp}.json"
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
report["report_path"] = str(report_path)
|
||||
report, counts = run_window_flow(
|
||||
cfg=cfg,
|
||||
windows=windows,
|
||||
include_dimensions=bool(include_dimensions),
|
||||
task_codes=args.ods_task_codes,
|
||||
logger=logger,
|
||||
compare_content=compare_content,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
do_backfill=args.flow == "update_and_verify",
|
||||
include_mismatch=bool(include_mismatch),
|
||||
recheck_after_backfill=bool(recheck_after_backfill),
|
||||
page_size=int(cfg.get("api.page_size") or 200),
|
||||
chunk_size=500,
|
||||
)
|
||||
report_path = write_report(report, prefix="data_integrity_window", tz=tz, report_path=report_path)
|
||||
report["report_path"] = report_path
|
||||
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
|
||||
else:
|
||||
start_dt = _parse_dt(args.start, tz)
|
||||
if args.end:
|
||||
end_dt = _parse_dt(args.end, tz)
|
||||
else:
|
||||
end_dt = compute_last_etl_end(cfg) or datetime.now(tz)
|
||||
report = run_integrity_history(
|
||||
end_dt = None
|
||||
report, counts = run_history_flow(
|
||||
cfg=cfg,
|
||||
start_dt=start_dt,
|
||||
end_dt=end_dt,
|
||||
include_dimensions=args.include_dimensions,
|
||||
include_dimensions=bool(include_dimensions),
|
||||
task_codes=args.ods_task_codes,
|
||||
logger=logger,
|
||||
write_report=True,
|
||||
report_path=report_path,
|
||||
compare_content=compare_content,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
do_backfill=args.flow == "update_and_verify",
|
||||
include_mismatch=bool(include_mismatch),
|
||||
recheck_after_backfill=bool(recheck_after_backfill),
|
||||
page_size=int(cfg.get("api.page_size") or 200),
|
||||
chunk_size=500,
|
||||
)
|
||||
report_path = write_report(report, prefix="data_integrity_history", tz=tz, report_path=report_path)
|
||||
report["report_path"] = report_path
|
||||
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
|
||||
logger.info(
|
||||
"SUMMARY missing=%s errors=%s",
|
||||
report.get("total_missing"),
|
||||
report.get("total_errors"),
|
||||
"SUMMARY missing=%s mismatch=%s errors=%s",
|
||||
counts.get("missing"),
|
||||
counts.get("mismatch"),
|
||||
counts.get("errors"),
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Iterable, Sequence
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from dateutil import parser as dtparser
|
||||
from psycopg2 import InterfaceError, OperationalError
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
@@ -32,8 +33,14 @@ from api.client import APIClient
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from models.parsers import TypeParser
|
||||
from tasks.ods_tasks import ENABLED_ODS_CODES, ODS_TASK_SPECS
|
||||
from tasks.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
from utils.ods_record_utils import (
|
||||
get_value_case_insensitive,
|
||||
merge_record_layers,
|
||||
normalize_pk_value,
|
||||
pk_tuple_from_record,
|
||||
)
|
||||
from utils.windowing import split_window
|
||||
|
||||
DEFAULT_START = "2025-07-01"
|
||||
@@ -74,38 +81,7 @@ def _iter_windows(start: datetime, end: datetime, window_size: timedelta) -> Ite
|
||||
|
||||
|
||||
def _merge_record_layers(record: dict) -> dict:
|
||||
merged = record
|
||||
data_part = merged.get("data")
|
||||
while isinstance(data_part, dict):
|
||||
merged = {**data_part, **merged}
|
||||
data_part = data_part.get("data")
|
||||
settle_inner = merged.get("settleList")
|
||||
if isinstance(settle_inner, dict):
|
||||
merged = {**settle_inner, **merged}
|
||||
return merged
|
||||
|
||||
|
||||
def _get_value_case_insensitive(record: dict | None, col: str | None):
|
||||
if record is None or col is None:
|
||||
return None
|
||||
if col in record:
|
||||
return record.get(col)
|
||||
col_lower = col.lower()
|
||||
for k, v in record.items():
|
||||
if isinstance(k, str) and k.lower() == col_lower:
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_pk_value(value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return value
|
||||
return value
|
||||
return merge_record_layers(record)
|
||||
|
||||
|
||||
def _chunked(seq: Sequence, size: int) -> Iterable[Sequence]:
|
||||
@@ -133,7 +109,24 @@ def _get_table_pk_columns(conn, table: str) -> list[str]:
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _table_has_column(conn, table: str, column: str) -> bool:
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s AND column_name = %s
|
||||
LIMIT 1
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name, column))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
|
||||
def _fetch_existing_pk_set(conn, table: str, pk_cols: Sequence[str], pk_values: list[tuple], chunk_size: int) -> set[tuple]:
|
||||
@@ -155,6 +148,54 @@ def _fetch_existing_pk_set(conn, table: str, pk_cols: Sequence[str], pk_values:
|
||||
return existing
|
||||
|
||||
|
||||
def _fetch_existing_pk_hash_set(
|
||||
conn, table: str, pk_cols: Sequence[str], pk_hash_values: list[tuple], chunk_size: int
|
||||
) -> set[tuple]:
|
||||
if not pk_hash_values:
|
||||
return set()
|
||||
select_cols = ", ".join([*(f't.\"{c}\"' for c in pk_cols), 't.\"content_hash\"'])
|
||||
value_cols = ", ".join([*(f'\"{c}\"' for c in pk_cols), '\"content_hash\"'])
|
||||
join_cond = " AND ".join([*(f't.\"{c}\" = v.\"{c}\"' for c in pk_cols), 't.\"content_hash\" = v.\"content_hash\"'])
|
||||
sql = (
|
||||
f"SELECT {select_cols} FROM {table} t "
|
||||
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
|
||||
)
|
||||
existing: set[tuple] = set()
|
||||
with conn.cursor() as cur:
|
||||
for chunk in _chunked(pk_hash_values, chunk_size):
|
||||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||||
for row in cur.fetchall():
|
||||
existing.add(tuple(row))
|
||||
return existing
|
||||
|
||||
|
||||
def _init_db_state(cfg: AppConfig) -> dict:
|
||||
db_conn = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
try:
|
||||
db_conn.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
db_conn.conn.autocommit = True
|
||||
return {"db": db_conn, "conn": db_conn.conn}
|
||||
|
||||
|
||||
def _reconnect_db(db_state: dict, cfg: AppConfig, logger: logging.Logger):
|
||||
try:
|
||||
db_state.get("db").close()
|
||||
except Exception:
|
||||
pass
|
||||
db_state.update(_init_db_state(cfg))
|
||||
logger.warning("DB connection reset/reconnected")
|
||||
return db_state["conn"]
|
||||
|
||||
|
||||
def _ensure_db_conn(db_state: dict, cfg: AppConfig, logger: logging.Logger):
|
||||
conn = db_state.get("conn")
|
||||
if conn is None or getattr(conn, "closed", 0):
|
||||
return _reconnect_db(db_state, cfg, logger)
|
||||
return conn
|
||||
|
||||
|
||||
def _merge_common_params(cfg: AppConfig, task_code: str, base: dict) -> dict:
|
||||
merged: dict = {}
|
||||
common = cfg.get("api.params", {}) or {}
|
||||
@@ -182,19 +223,22 @@ def _build_params(cfg: AppConfig, spec, store_id: int, window_start: datetime |
|
||||
return _merge_common_params(cfg, spec.code, base)
|
||||
|
||||
|
||||
def _pk_tuple_from_record(record: dict, pk_cols: Sequence[str]) -> tuple | None:
|
||||
merged = _merge_record_layers(record)
|
||||
def _pk_tuple_from_merged(merged: dict, pk_cols: Sequence[str]) -> tuple | None:
|
||||
values = []
|
||||
for col in pk_cols:
|
||||
val = _normalize_pk_value(_get_value_case_insensitive(merged, col))
|
||||
val = normalize_pk_value(get_value_case_insensitive(merged, col))
|
||||
if val is None or val == "":
|
||||
return None
|
||||
values.append(val)
|
||||
return tuple(values)
|
||||
|
||||
|
||||
def _pk_tuple_from_record(record: dict, pk_cols: Sequence[str]) -> tuple | None:
|
||||
return pk_tuple_from_record(record, pk_cols)
|
||||
|
||||
|
||||
def _pk_tuple_from_ticket_candidate(value) -> tuple | None:
|
||||
val = _normalize_pk_value(value)
|
||||
val = normalize_pk_value(value)
|
||||
if val is None or val == "":
|
||||
return None
|
||||
return (val,)
|
||||
@@ -204,10 +248,17 @@ def _format_missing_sample(pk_cols: Sequence[str], pk_tuple: tuple) -> dict:
|
||||
return {col: pk_tuple[idx] for idx, col in enumerate(pk_cols)}
|
||||
|
||||
|
||||
def _format_mismatch_sample(pk_cols: Sequence[str], pk_tuple: tuple, content_hash: str | None) -> dict:
|
||||
sample = _format_missing_sample(pk_cols, pk_tuple)
|
||||
if content_hash:
|
||||
sample["content_hash"] = content_hash
|
||||
return sample
|
||||
|
||||
|
||||
def _check_spec(
|
||||
*,
|
||||
client: APIClient,
|
||||
db_conn,
|
||||
db_state: dict,
|
||||
cfg: AppConfig,
|
||||
tz: ZoneInfo,
|
||||
logger: logging.Logger,
|
||||
@@ -219,6 +270,8 @@ def _check_spec(
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
sample_limit: int,
|
||||
compare_content: bool,
|
||||
content_sample_limit: int,
|
||||
sleep_per_window: float,
|
||||
sleep_per_page: float,
|
||||
) -> dict:
|
||||
@@ -231,19 +284,34 @@ def _check_spec(
|
||||
"records_with_pk": 0,
|
||||
"missing": 0,
|
||||
"missing_samples": [],
|
||||
"mismatch": 0,
|
||||
"mismatch_samples": [],
|
||||
"pages": 0,
|
||||
"skipped_missing_pk": 0,
|
||||
"errors": 0,
|
||||
"error_detail": None,
|
||||
}
|
||||
|
||||
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
|
||||
db_conn = _ensure_db_conn(db_state, cfg, logger)
|
||||
try:
|
||||
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
|
||||
except (OperationalError, InterfaceError):
|
||||
db_conn = _reconnect_db(db_state, cfg, logger)
|
||||
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
|
||||
result["pk_columns"] = pk_cols
|
||||
if not pk_cols:
|
||||
result["errors"] = 1
|
||||
result["error_detail"] = "no primary key columns found"
|
||||
return result
|
||||
|
||||
try:
|
||||
has_content_hash = bool(compare_content and _table_has_column(db_conn, spec.table_name, "content_hash"))
|
||||
except (OperationalError, InterfaceError):
|
||||
db_conn = _reconnect_db(db_state, cfg, logger)
|
||||
has_content_hash = bool(compare_content and _table_has_column(db_conn, spec.table_name, "content_hash"))
|
||||
result["compare_content"] = bool(compare_content)
|
||||
result["content_hash_supported"] = has_content_hash
|
||||
|
||||
if spec.requires_window and spec.time_fields:
|
||||
if not start or not end:
|
||||
result["errors"] = 1
|
||||
@@ -293,24 +361,33 @@ def _check_spec(
|
||||
result["pages"] += 1
|
||||
result["records"] += len(records)
|
||||
pk_tuples: list[tuple] = []
|
||||
pk_hash_tuples: list[tuple] = []
|
||||
for rec in records:
|
||||
if not isinstance(rec, dict):
|
||||
result["skipped_missing_pk"] += 1
|
||||
window_skipped += 1
|
||||
continue
|
||||
pk_tuple = _pk_tuple_from_record(rec, pk_cols)
|
||||
merged = _merge_record_layers(rec)
|
||||
pk_tuple = _pk_tuple_from_merged(merged, pk_cols)
|
||||
if not pk_tuple:
|
||||
result["skipped_missing_pk"] += 1
|
||||
window_skipped += 1
|
||||
continue
|
||||
pk_tuples.append(pk_tuple)
|
||||
if has_content_hash:
|
||||
content_hash = BaseOdsTask._compute_content_hash(merged, include_fetched_at=False)
|
||||
pk_hash_tuples.append((*pk_tuple, content_hash))
|
||||
|
||||
if not pk_tuples:
|
||||
continue
|
||||
|
||||
result["records_with_pk"] += len(pk_tuples)
|
||||
pk_unique = list(dict.fromkeys(pk_tuples))
|
||||
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
|
||||
try:
|
||||
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
|
||||
except (OperationalError, InterfaceError):
|
||||
db_conn = _reconnect_db(db_state, cfg, logger)
|
||||
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
|
||||
for pk_tuple in pk_unique:
|
||||
if pk_tuple in existing:
|
||||
continue
|
||||
@@ -321,6 +398,29 @@ def _check_spec(
|
||||
window_missing += 1
|
||||
if len(result["missing_samples"]) < sample_limit:
|
||||
result["missing_samples"].append(_format_missing_sample(pk_cols, pk_tuple))
|
||||
|
||||
if has_content_hash and pk_hash_tuples:
|
||||
pk_hash_unique = list(dict.fromkeys(pk_hash_tuples))
|
||||
try:
|
||||
existing_hash = _fetch_existing_pk_hash_set(
|
||||
db_conn, spec.table_name, pk_cols, pk_hash_unique, chunk_size
|
||||
)
|
||||
except (OperationalError, InterfaceError):
|
||||
db_conn = _reconnect_db(db_state, cfg, logger)
|
||||
existing_hash = _fetch_existing_pk_hash_set(
|
||||
db_conn, spec.table_name, pk_cols, pk_hash_unique, chunk_size
|
||||
)
|
||||
for pk_hash_tuple in pk_hash_unique:
|
||||
pk_tuple = pk_hash_tuple[:-1]
|
||||
if pk_tuple not in existing:
|
||||
continue
|
||||
if pk_hash_tuple in existing_hash:
|
||||
continue
|
||||
result["mismatch"] += 1
|
||||
if len(result["mismatch_samples"]) < content_sample_limit:
|
||||
result["mismatch_samples"].append(
|
||||
_format_mismatch_sample(pk_cols, pk_tuple, pk_hash_tuple[-1])
|
||||
)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"PAGE task=%s idx=%s page=%s records=%s missing=%s skipped=%s",
|
||||
@@ -369,7 +469,7 @@ def _check_spec(
|
||||
def _check_settlement_tickets(
|
||||
*,
|
||||
client: APIClient,
|
||||
db_conn,
|
||||
db_state: dict,
|
||||
cfg: AppConfig,
|
||||
tz: ZoneInfo,
|
||||
logger: logging.Logger,
|
||||
@@ -380,11 +480,18 @@ def _check_settlement_tickets(
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
sample_limit: int,
|
||||
compare_content: bool,
|
||||
content_sample_limit: int,
|
||||
sleep_per_window: float,
|
||||
sleep_per_page: float,
|
||||
) -> dict:
|
||||
table_name = "billiards_ods.settlement_ticket_details"
|
||||
pk_cols = _get_table_pk_columns(db_conn, table_name)
|
||||
db_conn = _ensure_db_conn(db_state, cfg, logger)
|
||||
try:
|
||||
pk_cols = _get_table_pk_columns(db_conn, table_name)
|
||||
except (OperationalError, InterfaceError):
|
||||
db_conn = _reconnect_db(db_state, cfg, logger)
|
||||
pk_cols = _get_table_pk_columns(db_conn, table_name)
|
||||
result = {
|
||||
"task_code": "ODS_SETTLEMENT_TICKET",
|
||||
"table": table_name,
|
||||
@@ -394,6 +501,8 @@ def _check_settlement_tickets(
|
||||
"records_with_pk": 0,
|
||||
"missing": 0,
|
||||
"missing_samples": [],
|
||||
"mismatch": 0,
|
||||
"mismatch_samples": [],
|
||||
"pages": 0,
|
||||
"skipped_missing_pk": 0,
|
||||
"errors": 0,
|
||||
@@ -476,7 +585,11 @@ def _check_settlement_tickets(
|
||||
|
||||
result["records_with_pk"] += len(pk_tuples)
|
||||
pk_unique = list(dict.fromkeys(pk_tuples))
|
||||
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
|
||||
try:
|
||||
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
|
||||
except (OperationalError, InterfaceError):
|
||||
db_conn = _reconnect_db(db_state, cfg, logger)
|
||||
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
|
||||
for pk_tuple in pk_unique:
|
||||
if pk_tuple in existing:
|
||||
continue
|
||||
@@ -585,6 +698,8 @@ def run_gap_check(
|
||||
cutoff_overlap_hours: int,
|
||||
allow_small_window: bool,
|
||||
logger: logging.Logger,
|
||||
compare_content: bool = False,
|
||||
content_sample_limit: int | None = None,
|
||||
window_split_unit: str | None = None,
|
||||
window_compensation_hours: int | None = None,
|
||||
) -> dict:
|
||||
@@ -668,6 +783,9 @@ def run_gap_check(
|
||||
if windows:
|
||||
start, end = windows[0][0], windows[-1][1]
|
||||
|
||||
if content_sample_limit is None:
|
||||
content_sample_limit = sample_limit
|
||||
|
||||
logger.info(
|
||||
"START range=%s~%s window_days=%s window_hours=%s split_unit=%s comp_hours=%s page_size=%s chunk_size=%s",
|
||||
start.isoformat() if isinstance(start, datetime) else None,
|
||||
@@ -690,12 +808,7 @@ def run_gap_check(
|
||||
headers_extra=cfg["api"].get("headers_extra") or {},
|
||||
)
|
||||
|
||||
db_conn = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
try:
|
||||
db_conn.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
db_conn.conn.autocommit = True
|
||||
db_state = _init_db_state(cfg)
|
||||
try:
|
||||
task_filter = {t.strip().upper() for t in (task_codes or "").split(",") if t.strip()}
|
||||
specs = [s for s in ODS_TASK_SPECS if s.code in ENABLED_ODS_CODES]
|
||||
@@ -708,7 +821,7 @@ def run_gap_check(
|
||||
continue
|
||||
result = _check_spec(
|
||||
client=client,
|
||||
db_conn=db_conn.conn,
|
||||
db_state=db_state,
|
||||
cfg=cfg,
|
||||
tz=tz,
|
||||
logger=logger,
|
||||
@@ -720,6 +833,8 @@ def run_gap_check(
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
sample_limit=sample_limit,
|
||||
compare_content=compare_content,
|
||||
content_sample_limit=content_sample_limit,
|
||||
sleep_per_window=sleep_per_window,
|
||||
sleep_per_page=sleep_per_page,
|
||||
)
|
||||
@@ -735,7 +850,7 @@ def run_gap_check(
|
||||
if (not task_filter) or ("ODS_SETTLEMENT_TICKET" in task_filter):
|
||||
ticket_result = _check_settlement_tickets(
|
||||
client=client,
|
||||
db_conn=db_conn.conn,
|
||||
db_state=db_state,
|
||||
cfg=cfg,
|
||||
tz=tz,
|
||||
logger=logger,
|
||||
@@ -746,6 +861,8 @@ def run_gap_check(
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
sample_limit=sample_limit,
|
||||
compare_content=compare_content,
|
||||
content_sample_limit=content_sample_limit,
|
||||
sleep_per_window=sleep_per_window,
|
||||
sleep_per_page=sleep_per_page,
|
||||
)
|
||||
@@ -759,6 +876,7 @@ def run_gap_check(
|
||||
)
|
||||
|
||||
total_missing = sum(int(r.get("missing") or 0) for r in results)
|
||||
total_mismatch = sum(int(r.get("mismatch") or 0) for r in results)
|
||||
total_errors = sum(int(r.get("errors") or 0) for r in results)
|
||||
|
||||
payload = {
|
||||
@@ -772,16 +890,22 @@ def run_gap_check(
|
||||
"page_size": page_size,
|
||||
"chunk_size": chunk_size,
|
||||
"sample_limit": sample_limit,
|
||||
"compare_content": compare_content,
|
||||
"content_sample_limit": content_sample_limit,
|
||||
"store_id": store_id,
|
||||
"base_url": cfg.get("api.base_url"),
|
||||
"results": results,
|
||||
"total_missing": total_missing,
|
||||
"total_mismatch": total_mismatch,
|
||||
"total_errors": total_errors,
|
||||
"generated_at": datetime.now(tz).isoformat(),
|
||||
}
|
||||
return payload
|
||||
finally:
|
||||
db_conn.close()
|
||||
try:
|
||||
db_state.get("db").close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main() -> int:
|
||||
@@ -796,6 +920,13 @@ def main() -> int:
|
||||
ap.add_argument("--page-size", type=int, default=200, help="API page size (default: 200)")
|
||||
ap.add_argument("--chunk-size", type=int, default=500, help="DB query chunk size (default: 500)")
|
||||
ap.add_argument("--sample-limit", type=int, default=50, help="max missing PK samples per table")
|
||||
ap.add_argument("--compare-content", action="store_true", help="compare record content hash (mismatch detection)")
|
||||
ap.add_argument(
|
||||
"--content-sample-limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max mismatch samples per table (default: same as --sample-limit)",
|
||||
)
|
||||
ap.add_argument("--sleep-per-window-seconds", type=float, default=0, help="sleep seconds after each window")
|
||||
ap.add_argument("--sleep-per-page-seconds", type=float, default=0, help="sleep seconds after each page")
|
||||
ap.add_argument("--task-codes", default="", help="comma-separated task codes to check (optional)")
|
||||
@@ -847,6 +978,8 @@ def main() -> int:
|
||||
cutoff_overlap_hours=args.cutoff_overlap_hours,
|
||||
allow_small_window=args.allow_small_window,
|
||||
logger=logger,
|
||||
compare_content=args.compare_content,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
window_split_unit=args.window_split_unit or None,
|
||||
window_compensation_hours=args.window_compensation_hours,
|
||||
)
|
||||
@@ -862,8 +995,9 @@ def main() -> int:
|
||||
out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
logger.info("REPORT_WRITTEN path=%s", out_path)
|
||||
logger.info(
|
||||
"SUMMARY missing=%s errors=%s",
|
||||
"SUMMARY missing=%s mismatch=%s errors=%s",
|
||||
payload.get("total_missing"),
|
||||
payload.get("total_mismatch"),
|
||||
payload.get("total_errors"),
|
||||
)
|
||||
|
||||
|
||||
324
etl_billiards/scripts/migrate_snapshot_ods.py
Normal file
324
etl_billiards/scripts/migrate_snapshot_ods.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
迁移到“快照型 ODS + DWD SCD2”:
|
||||
1) 为所有 ODS 表补充 content_hash,并以 (业务主键, content_hash) 作为新主键;
|
||||
2) 基于 payload 计算 content_hash,避免重复快照;
|
||||
3) 为所有 DWD 维度表补齐 SCD2 字段,并调整主键为 (业务主键, scd2_start_time)。
|
||||
|
||||
用法:
|
||||
PYTHONPATH=. python -m etl_billiards.scripts.migrate_snapshot_ods --dsn "postgresql://..."
|
||||
|
||||
可选参数:
|
||||
--only-ods / --only-dwd
|
||||
--dry-run
|
||||
--batch-size 500
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Iterable, List, Sequence
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_batch, RealDictCursor
|
||||
|
||||
|
||||
def _hash_default(value):
|
||||
return value.isoformat() if hasattr(value, "isoformat") else str(value)
|
||||
|
||||
|
||||
def _sanitize_record_for_hash(record: Any) -> Any:
|
||||
exclude = {
|
||||
"data",
|
||||
"payload",
|
||||
"source_file",
|
||||
"source_endpoint",
|
||||
"fetched_at",
|
||||
"content_hash",
|
||||
"record_index",
|
||||
}
|
||||
|
||||
def _strip(value):
|
||||
if isinstance(value, dict):
|
||||
cleaned = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(k, str) and k.lower() in exclude:
|
||||
continue
|
||||
cleaned[k] = _strip(v)
|
||||
return cleaned
|
||||
if isinstance(value, list):
|
||||
return [_strip(v) for v in value]
|
||||
return value
|
||||
|
||||
return _strip(record or {})
|
||||
|
||||
|
||||
def _compute_content_hash(record: Any) -> str:
|
||||
cleaned = _sanitize_record_for_hash(record)
|
||||
payload = json.dumps(
|
||||
cleaned,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=_hash_default,
|
||||
)
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _fetch_tables(cur, schema: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
""",
|
||||
(schema,),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(cur, schema: str, table: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
cols = []
|
||||
for row in cur.fetchall():
|
||||
if isinstance(row, dict):
|
||||
cols.append(row.get("column_name"))
|
||||
else:
|
||||
cols.append(row[0])
|
||||
return [c for c in cols if c]
|
||||
|
||||
|
||||
def _fetch_pk_constraint(cur, schema: str, table: str) -> tuple[str | None, list[str]]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tc.constraint_name, kcu.column_name, kcu.ordinal_position
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
if not rows:
|
||||
return None, []
|
||||
if isinstance(rows[0], dict):
|
||||
name = rows[0].get("constraint_name")
|
||||
cols = [r.get("column_name") for r in rows]
|
||||
else:
|
||||
name = rows[0][0]
|
||||
cols = [r[1] for r in rows]
|
||||
return name, [c for c in cols if c]
|
||||
|
||||
|
||||
def _ensure_content_hash_column(cur, schema: str, table: str, dry_run: bool) -> None:
|
||||
cols = _fetch_columns(cur, schema, table)
|
||||
if any(c.lower() == "content_hash" for c in cols):
|
||||
return
|
||||
sql = f'ALTER TABLE "{schema}"."{table}" ADD COLUMN content_hash TEXT'
|
||||
if dry_run:
|
||||
print(f"[DRY] {sql}")
|
||||
return
|
||||
print(f"[ODS] 添加 content_hash: {schema}.{table}")
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
def _backfill_content_hash(conn, schema: str, table: str, batch_size: int, dry_run: bool) -> int:
|
||||
updated = 0
|
||||
with conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
cols = _fetch_columns(cur, schema, table)
|
||||
if "content_hash" not in [c.lower() for c in cols]:
|
||||
return 0
|
||||
pk_name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||||
if not pk_cols:
|
||||
return 0
|
||||
# 过滤 content_hash
|
||||
pk_cols = [c for c in pk_cols if c.lower() != "content_hash"]
|
||||
select_cols = [*pk_cols]
|
||||
if any(c.lower() == "payload" for c in cols):
|
||||
select_cols.append("payload")
|
||||
else:
|
||||
select_cols.extend([c for c in cols if c.lower() not in {"content_hash"}])
|
||||
select_cols_sql = ", ".join(f'"{c}"' for c in select_cols)
|
||||
sql = f'SELECT {select_cols_sql} FROM "{schema}"."{table}" WHERE content_hash IS NULL'
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
def build_row(row: dict) -> tuple:
|
||||
payload = row.get("payload")
|
||||
if payload is None:
|
||||
payload = {k: v for k, v in row.items() if k.lower() not in {"content_hash", "payload"}}
|
||||
content_hash = _compute_content_hash(payload)
|
||||
key_vals = [row.get(k) for k in pk_cols]
|
||||
return (content_hash, *key_vals)
|
||||
|
||||
updates = [build_row(r) for r in rows]
|
||||
if dry_run:
|
||||
print(f"[DRY] {schema}.{table}: 预计更新 {len(updates)} 行 content_hash")
|
||||
return len(updates)
|
||||
|
||||
where_clause = " AND ".join([f'"{c}" = %s' for c in pk_cols])
|
||||
update_sql = (
|
||||
f'UPDATE "{schema}"."{table}" SET content_hash = %s '
|
||||
f'WHERE {where_clause} AND content_hash IS NULL'
|
||||
)
|
||||
with conn.cursor() as cur2:
|
||||
execute_batch(cur2, update_sql, updates, page_size=batch_size)
|
||||
updated = cur2.rowcount or len(updates)
|
||||
print(f"[ODS] {schema}.{table}: 更新 content_hash {updated} 行")
|
||||
return updated
|
||||
|
||||
|
||||
def _ensure_ods_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
|
||||
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||||
if not pk_cols:
|
||||
return
|
||||
if any(c.lower() == "content_hash" for c in pk_cols):
|
||||
return
|
||||
new_pk = pk_cols + ["content_hash"]
|
||||
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
|
||||
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
|
||||
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
|
||||
if dry_run:
|
||||
print(f"[DRY] {drop_sql}")
|
||||
print(f"[DRY] {add_sql}")
|
||||
return
|
||||
print(f"[ODS] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
|
||||
cur.execute(drop_sql)
|
||||
cur.execute(add_sql)
|
||||
|
||||
|
||||
def _migrate_ods(conn, schema: str, batch_size: int, dry_run: bool) -> None:
|
||||
with conn.cursor() as cur:
|
||||
tables = _fetch_tables(cur, schema)
|
||||
for table in tables:
|
||||
with conn.cursor() as cur:
|
||||
_ensure_content_hash_column(cur, schema, table, dry_run)
|
||||
conn.commit()
|
||||
_backfill_content_hash(conn, schema, table, batch_size, dry_run)
|
||||
with conn.cursor() as cur:
|
||||
_ensure_ods_primary_key(cur, schema, table, dry_run)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _backfill_scd2_fields(cur, schema: str, table: str, columns: Sequence[str], dry_run: bool) -> None:
|
||||
lower = {c.lower() for c in columns}
|
||||
fallback_cols = [
|
||||
"updated_at",
|
||||
"update_time",
|
||||
"created_at",
|
||||
"create_time",
|
||||
"fetched_at",
|
||||
]
|
||||
fallback = None
|
||||
for col in fallback_cols:
|
||||
if col in lower:
|
||||
fallback = f'"{col}"'
|
||||
break
|
||||
if fallback is None:
|
||||
fallback = "now()"
|
||||
|
||||
sql = (
|
||||
f'UPDATE "{schema}"."{table}" '
|
||||
f'SET scd2_start_time = COALESCE(scd2_start_time, {fallback}), '
|
||||
f"scd2_end_time = COALESCE(scd2_end_time, TIMESTAMPTZ '9999-12-31'), "
|
||||
f"scd2_is_current = COALESCE(scd2_is_current, 1), "
|
||||
f"scd2_version = COALESCE(scd2_version, 1) "
|
||||
f"WHERE scd2_start_time IS NULL OR scd2_end_time IS NULL OR scd2_is_current IS NULL OR scd2_version IS NULL"
|
||||
)
|
||||
if dry_run:
|
||||
print(f"[DRY] {sql}")
|
||||
return
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
def _ensure_dwd_primary_key(cur, schema: str, table: str, dry_run: bool) -> None:
|
||||
name, pk_cols = _fetch_pk_constraint(cur, schema, table)
|
||||
if not pk_cols:
|
||||
return
|
||||
if any(c.lower() == "scd2_start_time" for c in pk_cols):
|
||||
return
|
||||
new_pk = pk_cols + ["scd2_start_time"]
|
||||
drop_sql = f'ALTER TABLE "{schema}"."{table}" DROP CONSTRAINT "{name}"'
|
||||
cols_sql = ", ".join([f'"{c}"' for c in new_pk])
|
||||
add_sql = f'ALTER TABLE "{schema}"."{table}" ADD PRIMARY KEY ({cols_sql})'
|
||||
if dry_run:
|
||||
print(f"[DRY] {drop_sql}")
|
||||
print(f"[DRY] {add_sql}")
|
||||
return
|
||||
print(f"[DWD] 变更主键: {schema}.{table} -> ({', '.join(new_pk)})")
|
||||
cur.execute(drop_sql)
|
||||
cur.execute(add_sql)
|
||||
|
||||
|
||||
def _migrate_dwd(conn, schema: str, dry_run: bool) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT table_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND column_name ILIKE 'scd2_start_time'
|
||||
ORDER BY table_name
|
||||
""",
|
||||
(schema,),
|
||||
)
|
||||
tables = [r[0] for r in cur.fetchall()]
|
||||
|
||||
for table in tables:
|
||||
with conn.cursor() as cur:
|
||||
cols = _fetch_columns(cur, schema, table)
|
||||
_backfill_scd2_fields(cur, schema, table, cols, dry_run)
|
||||
conn.commit()
|
||||
with conn.cursor() as cur:
|
||||
_ensure_dwd_primary_key(cur, schema, table, dry_run)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="迁移 ODS 快照 + DWD SCD2")
|
||||
parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN(也可用环境变量 PG_DSN)")
|
||||
parser.add_argument("--schema-ods", dest="schema_ods", default="billiards_ods")
|
||||
parser.add_argument("--schema-dwd", dest="schema_dwd", default="billiards_dwd")
|
||||
parser.add_argument("--batch-size", dest="batch_size", type=int, default=500)
|
||||
parser.add_argument("--only-ods", dest="only_ods", action="store_true")
|
||||
parser.add_argument("--only-dwd", dest="only_dwd", action="store_true")
|
||||
parser.add_argument("--dry-run", dest="dry_run", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
dsn = args.dsn or os.environ.get("PG_DSN")
|
||||
if not dsn:
|
||||
print("缺少 DSN(--dsn 或环境变量 PG_DSN)")
|
||||
return 2
|
||||
|
||||
conn = psycopg2.connect(dsn)
|
||||
conn.autocommit = False
|
||||
try:
|
||||
if not args.only_dwd:
|
||||
_migrate_ods(conn, args.schema_ods, args.batch_size, args.dry_run)
|
||||
if not args.only_ods:
|
||||
_migrate_dwd(conn, args.schema_dwd, args.dry_run)
|
||||
return 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user