数据库 数据校验写入等逻辑更新。

This commit is contained in:
Neo
2026-02-01 03:46:16 +08:00
parent 9948000b71
commit 076f5755ca
128 changed files with 494310 additions and 2819 deletions

View File

@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""Data integrity task that checks API -> ODS -> DWD completeness."""
from __future__ import annotations
@@ -7,16 +7,9 @@ from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
import json
from pathlib import Path
from utils.windowing import build_window_segments, calc_window_minutes
from .base_task import BaseTask
from quality.integrity_checker import (
IntegrityWindow,
compute_last_etl_end,
run_integrity_history,
run_integrity_window,
)
from quality.integrity_service import run_history_flow, run_window_flow, write_report
class DataIntegrityTask(BaseTask):
@@ -31,15 +24,25 @@ class DataIntegrityTask(BaseTask):
include_dimensions = bool(self.config.get("integrity.include_dimensions", False))
task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip()
auto_backfill = bool(self.config.get("integrity.auto_backfill", False))
compare_content = self.config.get("integrity.compare_content")
if compare_content is None:
compare_content = True
content_sample_limit = self.config.get("integrity.content_sample_limit")
backfill_mismatch = self.config.get("integrity.backfill_mismatch")
if backfill_mismatch is None:
backfill_mismatch = True
recheck_after_backfill = self.config.get("integrity.recheck_after_backfill")
if recheck_after_backfill is None:
recheck_after_backfill = True
# 检测是否通过 CLI 传入了时间窗口参数window_override
# 如果有,自动切换到 window 模式
# Switch to window mode when CLI override is provided.
window_override_start = self.config.get("run.window_override.start")
window_override_end = self.config.get("run.window_override.end")
if window_override_start or window_override_end:
self.logger.info(
"检测到 CLI 时间窗口参数,自动切换到 window 模式: %s ~ %s",
window_override_start, window_override_end
"Detected CLI window override. Switching to window mode: %s ~ %s",
window_override_start,
window_override_end,
)
mode = "window"
@@ -57,65 +60,28 @@ class DataIntegrityTask(BaseTask):
total_segments = len(segments)
if total_segments > 1:
self.logger.info("数据完整性检查: 分段执行 共%s", total_segments)
self.logger.info("Data integrity check split into %s segments.", total_segments)
window_reports = []
total_missing = 0
total_errors = 0
for idx, (seg_start, seg_end) in enumerate(segments, start=1):
window = IntegrityWindow(
start=seg_start,
end=seg_end,
label=f"segment_{idx}",
granularity="window",
)
payload = run_integrity_window(
cfg=self.config,
window=window,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
write_report=False,
window_split_unit="none",
window_compensation_hours=0,
)
window_reports.append(payload)
total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0)
total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0)
report, counts = run_window_flow(
cfg=self.config,
windows=segments,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
do_backfill=bool(auto_backfill),
include_mismatch=bool(backfill_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(self.config.get("api.page_size") or 200),
chunk_size=500,
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
report = {
"mode": "window",
"window": {
"start": overall_start.isoformat(),
"end": overall_end.isoformat(),
"segments": total_segments,
},
"windows": window_reports,
"api_to_ods": {
"total_missing": total_missing,
"total_errors": total_errors,
},
"total_missing": total_missing,
"total_errors": total_errors,
"generated_at": datetime.now(tz).isoformat(),
}
report_path = self._write_report(report, "data_integrity_window")
report_path = write_report(report, prefix="data_integrity_window", tz=tz)
report["report_path"] = report_path
missing_count = int(total_missing or 0)
counts = {
"missing": missing_count,
"errors": int(total_errors or 0),
}
# ????
backfill_result = None
if auto_backfill and missing_count > 0:
backfill_result = self._run_backfill(base_start, base_end, task_codes)
counts["backfilled"] = backfill_result.get("backfilled", 0)
return {
"status": "SUCCESS",
"counts": counts,
@@ -125,7 +91,7 @@ class DataIntegrityTask(BaseTask):
"minutes": calc_window_minutes(overall_start, overall_end),
},
"report_path": report_path,
"backfill_result": backfill_result,
"backfill_result": report.get("backfill_result"),
}
history_start = str(self.config.get("integrity.history_start", "2025-07-01") or "2025-07-01")
@@ -136,77 +102,52 @@ class DataIntegrityTask(BaseTask):
else:
start_dt = start_dt.astimezone(tz)
end_dt = None
if history_end:
end_dt = dtparser.parse(history_end)
if end_dt.tzinfo is None:
end_dt = end_dt.replace(tzinfo=tz)
else:
end_dt = end_dt.astimezone(tz)
else:
end_dt = compute_last_etl_end(self.config) or datetime.now(tz)
report = run_integrity_history(
report, counts = run_history_flow(
cfg=self.config,
start_dt=start_dt,
end_dt=end_dt,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
write_report=True,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
do_backfill=bool(auto_backfill),
include_mismatch=bool(backfill_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(self.config.get("api.page_size") or 200),
chunk_size=500,
)
missing_count = int(report.get("total_missing") or 0)
counts = {
"missing": missing_count,
"errors": int(report.get("total_errors") or 0),
}
# 自动补全
backfill_result = None
if auto_backfill and missing_count > 0:
backfill_result = self._run_backfill(start_dt, end_dt, task_codes)
counts["backfilled"] = backfill_result.get("backfilled", 0)
report_path = write_report(report, prefix="data_integrity_history", tz=tz)
report["report_path"] = report_path
end_dt_used = end_dt
if end_dt_used is None:
end_str = report.get("end")
if end_str:
parsed = dtparser.parse(end_str)
if parsed.tzinfo is None:
end_dt_used = parsed.replace(tzinfo=tz)
else:
end_dt_used = parsed.astimezone(tz)
if end_dt_used is None:
end_dt_used = start_dt
return {
"status": "SUCCESS",
"counts": counts,
"window": {
"start": start_dt,
"end": end_dt,
"minutes": int((end_dt - start_dt).total_seconds() // 60) if end_dt > start_dt else 0,
"end": end_dt_used,
"minutes": int((end_dt_used - start_dt).total_seconds() // 60) if end_dt_used > start_dt else 0,
},
"report_path": report.get("report_path"),
"backfill_result": backfill_result,
"report_path": report_path,
"backfill_result": report.get("backfill_result"),
}
def _write_report(self, report: dict, prefix: str) -> str:
root = Path(__file__).resolve().parents[1]
stamp = datetime.now(self.tz).strftime("%Y%m%d_%H%M%S")
path = root / "reports" / f"{prefix}_{stamp}.json"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
return str(path)
def _run_backfill(self, start_dt: datetime, end_dt: datetime, task_codes: str) -> dict:
"""运行数据补全"""
self.logger.info("自动补全开始 起始=%s 结束=%s", start_dt, end_dt)
try:
from scripts.backfill_missing_data import run_backfill
result = run_backfill(
cfg=self.config,
start=start_dt,
end=end_dt,
task_codes=task_codes or None,
dry_run=False,
page_size=200,
chunk_size=500,
logger=self.logger,
)
self.logger.info(
"自动补全完成 已补全=%s 错误数=%s",
result.get("backfilled", 0),
result.get("errors", 0),
)
return result
except Exception as exc:
self.logger.exception("自动补全失败")
return {"backfilled": 0, "errors": 1, "error": str(exc)}

View File

@@ -2,8 +2,10 @@
"""DWD 装载任务:从 ODS 增量写入 DWD维度 SCD2事实按时间增量"""
from __future__ import annotations
import re
import time
from datetime import datetime
from datetime import date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any, Dict, Iterable, List, Sequence
from psycopg2.extras import RealDictCursor, execute_batch, execute_values
@@ -77,6 +79,37 @@ class DwdLoadTask(BaseTask):
"billiards_dwd.dwd_assistant_service_log",
}
_NUMERIC_RE = re.compile(r"^[+-]?\d+(?:\.\d+)?$")
_BOOL_STRINGS = {"true", "false", "1", "0", "yes", "no", "y", "n", "t", "f"}
def _strip_scd2_keys(self, pk_cols: Sequence[str]) -> list[str]:
return [c for c in pk_cols if c.lower() not in self.SCD_COLS]
@staticmethod
def _pick_snapshot_order_column(ods_cols: Sequence[str]) -> str | None:
lower_cols = {c.lower() for c in ods_cols}
for candidate in ("fetched_at", "update_time", "create_time"):
if candidate in lower_cols:
return candidate
return None
@staticmethod
def _latest_snapshot_select_sql(
select_cols_sql: str,
ods_table_sql: str,
key_exprs: Sequence[str],
order_col: str | None,
where_sql: str = "",
) -> str:
if key_exprs and order_col:
distinct_on = ", ".join(key_exprs)
order_by = ", ".join([*key_exprs, f'"{order_col}" DESC NULLS LAST'])
return (
f"SELECT DISTINCT ON ({distinct_on}) {select_cols_sql} "
f"FROM {ods_table_sql} {where_sql} ORDER BY {order_by}"
)
return f"SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}"
# 特殊列映射dwd 列名 -> 源列表达式(可选 CAST
FACT_MAPPINGS: dict[str, list[tuple[str, str, str | None]]] = {
# 维度表(补齐主键/字段差异)
@@ -652,9 +685,8 @@ class DwdLoadTask(BaseTask):
if not pk_cols:
raise ValueError(f"{dwd_table} 未配置主键,无法执行维表合并")
pk_has_scd = any(pk.lower() in self.SCD_COLS for pk in pk_cols)
scd_cols_present = any(c.lower() in self.SCD_COLS for c in dwd_cols)
if scd_cols_present and pk_has_scd:
if scd_cols_present:
return self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
return self._merge_dim_type1_upsert(cur, dwd_table, ods_table, dwd_cols, ods_cols, pk_cols, now)
@@ -701,12 +733,19 @@ class DwdLoadTask(BaseTask):
if not select_exprs:
return 0
# 对于 dim_site 和 dim_site_ex使用 DISTINCT ON 优化查询
# 避免从大表 table_fee_transactions 全表扫描,只获取每个 site_id 的最新记录
if dwd_table in ("billiards_dwd.dim_site", "billiards_dwd.dim_site_ex"):
sql = f"SELECT DISTINCT ON (site_id) {', '.join(select_exprs)} FROM {ods_table_sql} ORDER BY site_id, fetched_at DESC NULLS LAST"
else:
sql = f"SELECT {', '.join(select_exprs)} FROM {ods_table_sql}"
order_col = self._pick_snapshot_order_column(ods_cols)
business_keys = self._strip_scd2_keys(pk_cols)
key_exprs: list[str] = []
for key in business_keys:
lc = key.lower()
if lc in mapping:
src, cast_type = mapping[lc]
key_exprs.append(self._cast_expr(src, cast_type))
elif lc in ods_set:
key_exprs.append(f'"{lc}"')
select_cols_sql = ", ".join(select_exprs)
sql = self._latest_snapshot_select_sql(select_cols_sql, ods_table_sql, key_exprs, order_col)
cur.execute(sql)
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
@@ -784,7 +823,11 @@ class DwdLoadTask(BaseTask):
if not pk_cols:
raise ValueError(f"{dwd_table} 未配置主键,无法执行 SCD2 合并")
mapping = self._build_column_mapping(dwd_table, pk_cols, ods_cols)
business_keys = self._strip_scd2_keys(pk_cols)
if not business_keys:
raise ValueError(f"{dwd_table} primary key only contains SCD2 columns; cannot merge")
mapping = self._build_column_mapping(dwd_table, business_keys, ods_cols)
ods_set = {c.lower() for c in ods_cols}
table_sql = self._format_table(ods_table, "billiards_ods")
# 构造 SELECT 表达式,支持 JSON/expression 映射
@@ -806,7 +849,7 @@ class DwdLoadTask(BaseTask):
select_exprs.append('"categoryboxes" AS "categoryboxes"')
added.add("categoryboxes")
# 主键兜底确保被选出
for pk in pk_cols:
for pk in business_keys:
lc = pk.lower()
if lc not in added:
if lc in mapping:
@@ -819,7 +862,18 @@ class DwdLoadTask(BaseTask):
if not select_exprs:
return 0
sql = f"SELECT {', '.join(select_exprs)} FROM {table_sql}"
order_col = self._pick_snapshot_order_column(ods_cols)
key_exprs: list[str] = []
for key in business_keys:
lc = key.lower()
if lc in mapping:
src, cast_type = mapping[lc]
key_exprs.append(self._cast_expr(src, cast_type))
elif lc in ods_set:
key_exprs.append(f'"{lc}"')
select_cols_sql = ", ".join(select_exprs)
sql = self._latest_snapshot_select_sql(select_cols_sql, table_sql, key_exprs, order_col)
cur.execute(sql)
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
@@ -842,11 +896,11 @@ class DwdLoadTask(BaseTask):
value = row.get(src.lower())
mapped_row[lc] = value
pk_key = tuple(mapped_row.get(pk) for pk in pk_cols)
pk_key = tuple(mapped_row.get(pk) for pk in business_keys)
if pk_key in seen_pk:
continue
if any(v is None for v in pk_key):
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(business_keys, pk_key)))
continue
seen_pk.add(pk_key)
src_rows_by_pk[pk_key] = mapped_row
@@ -862,7 +916,7 @@ class DwdLoadTask(BaseTask):
current_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
for r in current_rows:
rr = {k.lower(): v for k, v in r.items()}
pk_key = tuple(rr.get(pk) for pk in pk_cols)
pk_key = tuple(rr.get(pk) for pk in business_keys)
current_by_pk[pk_key] = rr
# 计算需要关闭/插入的主键集合
@@ -881,7 +935,7 @@ class DwdLoadTask(BaseTask):
# 先关闭旧版本(同一批次统一 end_time
if to_close:
self._close_current_dim_bulk(cur, dwd_table, pk_cols, to_close, now)
self._close_current_dim_bulk(cur, dwd_table, business_keys, to_close, now)
# 批量插入新版本
if to_insert:
@@ -1031,10 +1085,105 @@ class DwdLoadTask(BaseTask):
lc = col.lower()
if lc in self.SCD_COLS:
continue
if current.get(lc) != incoming.get(lc):
if not self._values_equal(current.get(lc), incoming.get(lc)):
return True
return False
def _values_equal(self, current_val: Any, incoming_val: Any) -> bool:
"""Normalize common type mismatches (numeric/text, naive/aware datetime) before compare."""
current_val = self._normalize_empty(current_val)
incoming_val = self._normalize_empty(incoming_val)
if current_val is None and incoming_val is None:
return True
# Datetime normalization (naive vs aware)
if isinstance(current_val, (datetime, date)) or isinstance(incoming_val, (datetime, date)):
return self._normalize_datetime(current_val) == self._normalize_datetime(incoming_val)
# Boolean normalization
if self._looks_bool(current_val) or self._looks_bool(incoming_val):
cur_bool = self._coerce_bool(current_val)
inc_bool = self._coerce_bool(incoming_val)
if cur_bool is not None and inc_bool is not None:
return cur_bool == inc_bool
# Numeric normalization (string vs numeric)
if self._looks_numeric(current_val) or self._looks_numeric(incoming_val):
cur_num = self._coerce_numeric(current_val)
inc_num = self._coerce_numeric(incoming_val)
if cur_num is not None and inc_num is not None:
return cur_num == inc_num
return current_val == incoming_val
def _normalize_empty(self, value: Any) -> Any:
if isinstance(value, str):
stripped = value.strip()
return None if stripped == "" else stripped
return value
def _normalize_datetime(self, value: Any) -> Any:
if value is None:
return None
if isinstance(value, date) and not isinstance(value, datetime):
value = datetime.combine(value, datetime.min.time())
if not isinstance(value, datetime):
return value
if value.tzinfo is None:
return value.replace(tzinfo=self.tz)
return value.astimezone(self.tz)
def _looks_numeric(self, value: Any) -> bool:
if isinstance(value, (int, float, Decimal)) and not isinstance(value, bool):
return True
if isinstance(value, str):
return bool(self._NUMERIC_RE.match(value.strip()))
return False
def _coerce_numeric(self, value: Any) -> Decimal | None:
value = self._normalize_empty(value)
if value is None:
return None
if isinstance(value, bool):
return Decimal(int(value))
if isinstance(value, (int, float, Decimal)):
try:
return Decimal(str(value))
except InvalidOperation:
return None
if isinstance(value, str):
s = value.strip()
if not self._NUMERIC_RE.match(s):
return None
try:
return Decimal(s)
except InvalidOperation:
return None
return None
def _looks_bool(self, value: Any) -> bool:
if isinstance(value, bool):
return True
if isinstance(value, str):
return value.strip().lower() in self._BOOL_STRINGS
return False
def _coerce_bool(self, value: Any) -> bool | None:
value = self._normalize_empty(value)
if value is None:
return None
if isinstance(value, bool):
return value
if isinstance(value, (int, Decimal)) and not isinstance(value, bool):
return bool(int(value))
if isinstance(value, str):
s = value.strip().lower()
if s in {"true", "1", "yes", "y", "t"}:
return True
if s in {"false", "0", "no", "n", "f"}:
return False
return None
def _merge_fact_increment(
self,
cur,
@@ -1052,6 +1201,9 @@ class DwdLoadTask(BaseTask):
mapping: Dict[str, tuple[str, str | None]] = {
dst.lower(): (src, cast_type) for dst, src, cast_type in mapping_entries
}
ods_set = {c.lower() for c in ods_cols}
snapshot_mode = "content_hash" in ods_set
fact_upsert = bool(self.config.get("dwd.fact_upsert", True))
mapping_dest = [dst for dst, _, _ in mapping_entries]
insert_cols: List[str] = list(mapping_dest)
@@ -1064,7 +1216,6 @@ class DwdLoadTask(BaseTask):
insert_cols.append(col)
pk_cols = self._get_primary_keys(cur, dwd_table)
ods_set = {c.lower() for c in ods_cols}
existing_lower = [c.lower() for c in insert_cols]
for pk in pk_cols:
pk_lower = pk.lower()
@@ -1092,7 +1243,11 @@ class DwdLoadTask(BaseTask):
self.logger.warning("跳过 %s:未找到可插入的列", dwd_table)
return 0
order_col = self._pick_order_column(dwd_table, dwd_cols, ods_cols)
order_col = (
self._pick_snapshot_order_column(ods_cols)
if snapshot_mode
else self._pick_order_column(dwd_table, dwd_cols, ods_cols)
)
where_sql = ""
params: List[Any] = []
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
@@ -1122,12 +1277,40 @@ class DwdLoadTask(BaseTask):
select_cols_sql = ", ".join(select_exprs)
insert_cols_sql = ", ".join(f'"{c}"' for c in insert_cols)
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
if snapshot_mode and pk_cols:
key_exprs: list[str] = []
for pk in pk_cols:
pk_lower = pk.lower()
if pk_lower in mapping:
src, cast_type = mapping[pk_lower]
key_exprs.append(self._cast_expr(src, cast_type))
elif pk_lower in ods_set:
key_exprs.append(f'"{pk_lower}"')
elif "id" in ods_set:
key_exprs.append('"id"')
select_sql = self._latest_snapshot_select_sql(
select_cols_sql,
ods_table_sql,
key_exprs,
order_col,
where_sql,
)
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) {select_sql}'
else:
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
pk_cols = self._get_primary_keys(cur, dwd_table)
if pk_cols:
pk_sql = ", ".join(f'"{c}"' for c in pk_cols)
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
pk_lower = {c.lower() for c in pk_cols}
set_exprs = [f'"{c}" = EXCLUDED."{c}"' for c in insert_cols if c.lower() not in pk_lower]
if snapshot_mode or fact_upsert:
if set_exprs:
sql += f" ON CONFLICT ({pk_sql}) DO UPDATE SET {', '.join(set_exprs)}"
else:
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
else:
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
cur.execute(sql, params)
inserted = cur.rowcount

View File

@@ -2,6 +2,7 @@
"""手工示例数据灌入:按 schema_ODS_doc.sql 的表结构写入 ODS。"""
from __future__ import annotations
import hashlib
import json
import os
from datetime import datetime
@@ -252,12 +253,17 @@ class ManualIngestTask(BaseTask):
except Exception:
pk_index = None
has_content_hash = any(c[0].lower() == "content_hash" for c in columns_info)
col_list = ", ".join(f'"{c}"' for c in columns)
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
if pk_col_db:
update_cols = [c for c in columns if c != pk_col_db]
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
if has_content_hash:
sql_prefix += f' ON CONFLICT ("{pk_col_db}", "content_hash") DO NOTHING'
else:
update_cols = [c for c in columns if c != pk_col_db]
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
params = []
now = datetime.now()
@@ -284,6 +290,12 @@ class ManualIngestTask(BaseTask):
if pk_col and (pk_val is None or pk_val == ""):
continue
content_hash = None
if has_content_hash:
hash_record = dict(merged_rec)
hash_record["fetched_at"] = merged_rec.get("fetched_at", now)
content_hash = self._compute_content_hash(hash_record, include_fetched_at=True)
row_vals = []
for col_name, data_type, udt in columns_info:
col_lower = col_name.lower()
@@ -296,6 +308,9 @@ class ManualIngestTask(BaseTask):
if col_lower == "fetched_at":
row_vals.append(merged_rec.get(col_name, now))
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
@@ -401,3 +416,48 @@ class ManualIngestTask(BaseTask):
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, str) else None
return value
@staticmethod
def _hash_default(value):
if isinstance(value, datetime):
return value.isoformat()
return str(value)
@classmethod
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
exclude = {
"data",
"payload",
"source_file",
"source_endpoint",
"content_hash",
"record_index",
}
if not include_fetched_at:
exclude.add("fetched_at")
def _strip(value):
if isinstance(value, dict):
cleaned = {}
for k, v in value.items():
if isinstance(k, str) and k.lower() in exclude:
continue
cleaned[k] = _strip(v)
return cleaned
if isinstance(value, list):
return [_strip(v) for v in value]
return value
return _strip(record or {})
@classmethod
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
payload = json.dumps(
cleaned,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=cls._hash_default,
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()

View File

@@ -2,6 +2,7 @@
"""ODS ingestion tasks."""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -317,6 +318,7 @@ class BaseOdsTask(BaseTask):
db_json_cols_lower = {
c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
col_names = [c[0] for c in cols_info]
quoted_cols = ", ".join(f'\"{c}\"' for c in col_names)
@@ -330,6 +332,7 @@ class BaseOdsTask(BaseTask):
params: list[tuple] = []
skipped = 0
merged_records: list[dict] = []
root_site_profile = None
if isinstance(response_payload, dict):
@@ -345,6 +348,7 @@ class BaseOdsTask(BaseTask):
continue
merged_rec = self._merge_record_layers(rec)
merged_records.append({"raw": rec, "merged": merged_rec})
if table in {"billiards_ods.recharge_settlements", "billiards_ods.settlement_records"}:
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") or root_site_profile
if isinstance(site_profile, dict):
@@ -363,9 +367,42 @@ class BaseOdsTask(BaseTask):
_fill_missing("siteid", [site_profile.get("siteId"), site_profile.get("id")])
_fill_missing("sitename", [site_profile.get("shop_name"), site_profile.get("siteName")])
has_fetched_at = any(c[0].lower() == "fetched_at" for c in cols_info)
business_keys = [c for c in pk_cols if str(c).lower() != "content_hash"]
compare_latest = bool(needs_content_hash and has_fetched_at and business_keys)
latest_compare_hash: dict[tuple[Any, ...], str | None] = {}
if compare_latest:
key_values: list[tuple[Any, ...]] = []
for item in merged_records:
merged_rec = item["merged"]
key = tuple(self._get_value_case_insensitive(merged_rec, k) for k in business_keys)
if any(v is None or v == "" for v in key):
continue
key_values.append(key)
if key_values:
with self.db.conn.cursor() as cur:
latest_payloads = self._fetch_latest_payloads(cur, table, business_keys, key_values)
for key, payload in latest_payloads.items():
latest_compare_hash[key] = self._compute_compare_hash_from_payload(payload)
for item in merged_records:
rec = item["raw"]
merged_rec = item["merged"]
content_hash = None
compare_hash = None
if needs_content_hash:
compare_hash = self._compute_content_hash(merged_rec, include_fetched_at=False)
hash_record = dict(merged_rec)
hash_record["fetched_at"] = now
content_hash = self._compute_content_hash(hash_record, include_fetched_at=True)
if pk_cols:
missing_pk = False
for pk in pk_cols:
if str(pk).lower() == "content_hash":
continue
pk_val = self._get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
@@ -374,6 +411,16 @@ class BaseOdsTask(BaseTask):
skipped += 1
continue
if compare_latest and compare_hash is not None:
key = tuple(self._get_value_case_insensitive(merged_rec, k) for k in business_keys)
if any(v is None or v == "" for v in key):
skipped += 1
continue
last_hash = latest_compare_hash.get(key)
if last_hash is not None and last_hash == compare_hash:
skipped += 1
continue
row_vals: list[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
@@ -389,6 +436,9 @@ class BaseOdsTask(BaseTask):
if col_lower == "fetched_at":
row_vals.append(now)
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
@@ -472,6 +522,93 @@ class BaseOdsTask(BaseTask):
return resolver(spec.endpoint)
return None
@staticmethod
def _hash_default(value):
if isinstance(value, datetime):
return value.isoformat()
return str(value)
@classmethod
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
exclude = {
"data",
"payload",
"source_file",
"source_endpoint",
"content_hash",
"record_index",
}
if not include_fetched_at:
exclude.add("fetched_at")
def _strip(value):
if isinstance(value, dict):
cleaned = {}
for k, v in value.items():
if isinstance(k, str) and k.lower() in exclude:
continue
cleaned[k] = _strip(v)
return cleaned
if isinstance(value, list):
return [_strip(v) for v in value]
return value
return _strip(record or {})
@classmethod
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
payload = json.dumps(
cleaned,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=cls._hash_default,
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
@staticmethod
def _compute_compare_hash_from_payload(payload: Any) -> str | None:
if payload is None:
return None
if isinstance(payload, str):
try:
payload = json.loads(payload)
except Exception:
return None
if not isinstance(payload, dict):
return None
merged = BaseOdsTask._merge_record_layers(payload)
return BaseOdsTask._compute_content_hash(merged, include_fetched_at=False)
@staticmethod
def _fetch_latest_payloads(cur, table: str, business_keys: Sequence[str], key_values: Sequence[tuple]) -> dict:
if not business_keys or not key_values:
return {}
keys_sql = ", ".join(f'"{k}"' for k in business_keys)
sql = (
f"WITH keys({keys_sql}) AS (VALUES %s) "
f"SELECT DISTINCT ON ({keys_sql}) {keys_sql}, payload "
f"FROM {table} t JOIN keys k USING ({keys_sql}) "
f"ORDER BY {keys_sql}, fetched_at DESC NULLS LAST"
)
unique_keys = list({tuple(k) for k in key_values})
execute_values(cur, sql, unique_keys, page_size=500)
rows = cur.fetchall() or []
result = {}
if rows and isinstance(rows[0], dict):
for r in rows:
key = tuple(r[k] for k in business_keys)
result[key] = r.get("payload")
return result
key_len = len(business_keys)
for r in rows:
key = tuple(r[:key_len])
payload = r[key_len] if len(r) > key_len else None
result[key] = payload
return result
def _int_col(name: str, *sources: str, required: bool = False) -> ColumnSpec:
return ColumnSpec(