数据库 数据校验写入等逻辑更新。
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Data integrity task that checks API -> ODS -> DWD completeness."""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,16 +7,9 @@ from zoneinfo import ZoneInfo
|
||||
|
||||
from dateutil import parser as dtparser
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from utils.windowing import build_window_segments, calc_window_minutes
|
||||
from .base_task import BaseTask
|
||||
from quality.integrity_checker import (
|
||||
IntegrityWindow,
|
||||
compute_last_etl_end,
|
||||
run_integrity_history,
|
||||
run_integrity_window,
|
||||
)
|
||||
from quality.integrity_service import run_history_flow, run_window_flow, write_report
|
||||
|
||||
|
||||
class DataIntegrityTask(BaseTask):
|
||||
@@ -31,15 +24,25 @@ class DataIntegrityTask(BaseTask):
|
||||
include_dimensions = bool(self.config.get("integrity.include_dimensions", False))
|
||||
task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip()
|
||||
auto_backfill = bool(self.config.get("integrity.auto_backfill", False))
|
||||
compare_content = self.config.get("integrity.compare_content")
|
||||
if compare_content is None:
|
||||
compare_content = True
|
||||
content_sample_limit = self.config.get("integrity.content_sample_limit")
|
||||
backfill_mismatch = self.config.get("integrity.backfill_mismatch")
|
||||
if backfill_mismatch is None:
|
||||
backfill_mismatch = True
|
||||
recheck_after_backfill = self.config.get("integrity.recheck_after_backfill")
|
||||
if recheck_after_backfill is None:
|
||||
recheck_after_backfill = True
|
||||
|
||||
# 检测是否通过 CLI 传入了时间窗口参数(window_override)
|
||||
# 如果有,自动切换到 window 模式
|
||||
# Switch to window mode when CLI override is provided.
|
||||
window_override_start = self.config.get("run.window_override.start")
|
||||
window_override_end = self.config.get("run.window_override.end")
|
||||
if window_override_start or window_override_end:
|
||||
self.logger.info(
|
||||
"检测到 CLI 时间窗口参数,自动切换到 window 模式: %s ~ %s",
|
||||
window_override_start, window_override_end
|
||||
"Detected CLI window override. Switching to window mode: %s ~ %s",
|
||||
window_override_start,
|
||||
window_override_end,
|
||||
)
|
||||
mode = "window"
|
||||
|
||||
@@ -57,65 +60,28 @@ class DataIntegrityTask(BaseTask):
|
||||
|
||||
total_segments = len(segments)
|
||||
if total_segments > 1:
|
||||
self.logger.info("数据完整性检查: 分段执行 共%s段", total_segments)
|
||||
self.logger.info("Data integrity check split into %s segments.", total_segments)
|
||||
|
||||
window_reports = []
|
||||
total_missing = 0
|
||||
total_errors = 0
|
||||
for idx, (seg_start, seg_end) in enumerate(segments, start=1):
|
||||
window = IntegrityWindow(
|
||||
start=seg_start,
|
||||
end=seg_end,
|
||||
label=f"segment_{idx}",
|
||||
granularity="window",
|
||||
)
|
||||
payload = run_integrity_window(
|
||||
cfg=self.config,
|
||||
window=window,
|
||||
include_dimensions=include_dimensions,
|
||||
task_codes=task_codes,
|
||||
logger=self.logger,
|
||||
write_report=False,
|
||||
window_split_unit="none",
|
||||
window_compensation_hours=0,
|
||||
)
|
||||
window_reports.append(payload)
|
||||
total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0)
|
||||
total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0)
|
||||
report, counts = run_window_flow(
|
||||
cfg=self.config,
|
||||
windows=segments,
|
||||
include_dimensions=include_dimensions,
|
||||
task_codes=task_codes,
|
||||
logger=self.logger,
|
||||
compare_content=bool(compare_content),
|
||||
content_sample_limit=content_sample_limit,
|
||||
do_backfill=bool(auto_backfill),
|
||||
include_mismatch=bool(backfill_mismatch),
|
||||
recheck_after_backfill=bool(recheck_after_backfill),
|
||||
page_size=int(self.config.get("api.page_size") or 200),
|
||||
chunk_size=500,
|
||||
)
|
||||
|
||||
overall_start = segments[0][0]
|
||||
overall_end = segments[-1][1]
|
||||
report = {
|
||||
"mode": "window",
|
||||
"window": {
|
||||
"start": overall_start.isoformat(),
|
||||
"end": overall_end.isoformat(),
|
||||
"segments": total_segments,
|
||||
},
|
||||
"windows": window_reports,
|
||||
"api_to_ods": {
|
||||
"total_missing": total_missing,
|
||||
"total_errors": total_errors,
|
||||
},
|
||||
"total_missing": total_missing,
|
||||
"total_errors": total_errors,
|
||||
"generated_at": datetime.now(tz).isoformat(),
|
||||
}
|
||||
report_path = self._write_report(report, "data_integrity_window")
|
||||
report_path = write_report(report, prefix="data_integrity_window", tz=tz)
|
||||
report["report_path"] = report_path
|
||||
|
||||
missing_count = int(total_missing or 0)
|
||||
counts = {
|
||||
"missing": missing_count,
|
||||
"errors": int(total_errors or 0),
|
||||
}
|
||||
|
||||
# ????
|
||||
backfill_result = None
|
||||
if auto_backfill and missing_count > 0:
|
||||
backfill_result = self._run_backfill(base_start, base_end, task_codes)
|
||||
counts["backfilled"] = backfill_result.get("backfilled", 0)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"counts": counts,
|
||||
@@ -125,7 +91,7 @@ class DataIntegrityTask(BaseTask):
|
||||
"minutes": calc_window_minutes(overall_start, overall_end),
|
||||
},
|
||||
"report_path": report_path,
|
||||
"backfill_result": backfill_result,
|
||||
"backfill_result": report.get("backfill_result"),
|
||||
}
|
||||
|
||||
history_start = str(self.config.get("integrity.history_start", "2025-07-01") or "2025-07-01")
|
||||
@@ -136,77 +102,52 @@ class DataIntegrityTask(BaseTask):
|
||||
else:
|
||||
start_dt = start_dt.astimezone(tz)
|
||||
|
||||
end_dt = None
|
||||
if history_end:
|
||||
end_dt = dtparser.parse(history_end)
|
||||
if end_dt.tzinfo is None:
|
||||
end_dt = end_dt.replace(tzinfo=tz)
|
||||
else:
|
||||
end_dt = end_dt.astimezone(tz)
|
||||
else:
|
||||
end_dt = compute_last_etl_end(self.config) or datetime.now(tz)
|
||||
|
||||
report = run_integrity_history(
|
||||
report, counts = run_history_flow(
|
||||
cfg=self.config,
|
||||
start_dt=start_dt,
|
||||
end_dt=end_dt,
|
||||
include_dimensions=include_dimensions,
|
||||
task_codes=task_codes,
|
||||
logger=self.logger,
|
||||
write_report=True,
|
||||
compare_content=bool(compare_content),
|
||||
content_sample_limit=content_sample_limit,
|
||||
do_backfill=bool(auto_backfill),
|
||||
include_mismatch=bool(backfill_mismatch),
|
||||
recheck_after_backfill=bool(recheck_after_backfill),
|
||||
page_size=int(self.config.get("api.page_size") or 200),
|
||||
chunk_size=500,
|
||||
)
|
||||
missing_count = int(report.get("total_missing") or 0)
|
||||
counts = {
|
||||
"missing": missing_count,
|
||||
"errors": int(report.get("total_errors") or 0),
|
||||
}
|
||||
|
||||
# 自动补全
|
||||
backfill_result = None
|
||||
if auto_backfill and missing_count > 0:
|
||||
backfill_result = self._run_backfill(start_dt, end_dt, task_codes)
|
||||
counts["backfilled"] = backfill_result.get("backfilled", 0)
|
||||
|
||||
report_path = write_report(report, prefix="data_integrity_history", tz=tz)
|
||||
report["report_path"] = report_path
|
||||
|
||||
end_dt_used = end_dt
|
||||
if end_dt_used is None:
|
||||
end_str = report.get("end")
|
||||
if end_str:
|
||||
parsed = dtparser.parse(end_str)
|
||||
if parsed.tzinfo is None:
|
||||
end_dt_used = parsed.replace(tzinfo=tz)
|
||||
else:
|
||||
end_dt_used = parsed.astimezone(tz)
|
||||
if end_dt_used is None:
|
||||
end_dt_used = start_dt
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"counts": counts,
|
||||
"window": {
|
||||
"start": start_dt,
|
||||
"end": end_dt,
|
||||
"minutes": int((end_dt - start_dt).total_seconds() // 60) if end_dt > start_dt else 0,
|
||||
"end": end_dt_used,
|
||||
"minutes": int((end_dt_used - start_dt).total_seconds() // 60) if end_dt_used > start_dt else 0,
|
||||
},
|
||||
"report_path": report.get("report_path"),
|
||||
"backfill_result": backfill_result,
|
||||
"report_path": report_path,
|
||||
"backfill_result": report.get("backfill_result"),
|
||||
}
|
||||
|
||||
def _write_report(self, report: dict, prefix: str) -> str:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
stamp = datetime.now(self.tz).strftime("%Y%m%d_%H%M%S")
|
||||
path = root / "reports" / f"{prefix}_{stamp}.json"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
return str(path)
|
||||
|
||||
def _run_backfill(self, start_dt: datetime, end_dt: datetime, task_codes: str) -> dict:
|
||||
"""运行数据补全"""
|
||||
self.logger.info("自动补全开始 起始=%s 结束=%s", start_dt, end_dt)
|
||||
try:
|
||||
from scripts.backfill_missing_data import run_backfill
|
||||
result = run_backfill(
|
||||
cfg=self.config,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
task_codes=task_codes or None,
|
||||
dry_run=False,
|
||||
page_size=200,
|
||||
chunk_size=500,
|
||||
logger=self.logger,
|
||||
)
|
||||
self.logger.info(
|
||||
"自动补全完成 已补全=%s 错误数=%s",
|
||||
result.get("backfilled", 0),
|
||||
result.get("errors", 0),
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
self.logger.exception("自动补全失败")
|
||||
return {"backfilled": 0, "errors": 1, "error": str(exc)}
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""DWD 装载任务:从 ODS 增量写入 DWD(维度 SCD2,事实按时间增量)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Dict, Iterable, List, Sequence
|
||||
|
||||
from psycopg2.extras import RealDictCursor, execute_batch, execute_values
|
||||
@@ -77,6 +79,37 @@ class DwdLoadTask(BaseTask):
|
||||
"billiards_dwd.dwd_assistant_service_log",
|
||||
}
|
||||
|
||||
_NUMERIC_RE = re.compile(r"^[+-]?\d+(?:\.\d+)?$")
|
||||
_BOOL_STRINGS = {"true", "false", "1", "0", "yes", "no", "y", "n", "t", "f"}
|
||||
|
||||
def _strip_scd2_keys(self, pk_cols: Sequence[str]) -> list[str]:
|
||||
return [c for c in pk_cols if c.lower() not in self.SCD_COLS]
|
||||
|
||||
@staticmethod
|
||||
def _pick_snapshot_order_column(ods_cols: Sequence[str]) -> str | None:
|
||||
lower_cols = {c.lower() for c in ods_cols}
|
||||
for candidate in ("fetched_at", "update_time", "create_time"):
|
||||
if candidate in lower_cols:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _latest_snapshot_select_sql(
|
||||
select_cols_sql: str,
|
||||
ods_table_sql: str,
|
||||
key_exprs: Sequence[str],
|
||||
order_col: str | None,
|
||||
where_sql: str = "",
|
||||
) -> str:
|
||||
if key_exprs and order_col:
|
||||
distinct_on = ", ".join(key_exprs)
|
||||
order_by = ", ".join([*key_exprs, f'"{order_col}" DESC NULLS LAST'])
|
||||
return (
|
||||
f"SELECT DISTINCT ON ({distinct_on}) {select_cols_sql} "
|
||||
f"FROM {ods_table_sql} {where_sql} ORDER BY {order_by}"
|
||||
)
|
||||
return f"SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}"
|
||||
|
||||
# 特殊列映射:dwd 列名 -> 源列表达式(可选 CAST)
|
||||
FACT_MAPPINGS: dict[str, list[tuple[str, str, str | None]]] = {
|
||||
# 维度表(补齐主键/字段差异)
|
||||
@@ -652,9 +685,8 @@ class DwdLoadTask(BaseTask):
|
||||
if not pk_cols:
|
||||
raise ValueError(f"{dwd_table} 未配置主键,无法执行维表合并")
|
||||
|
||||
pk_has_scd = any(pk.lower() in self.SCD_COLS for pk in pk_cols)
|
||||
scd_cols_present = any(c.lower() in self.SCD_COLS for c in dwd_cols)
|
||||
if scd_cols_present and pk_has_scd:
|
||||
if scd_cols_present:
|
||||
return self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
|
||||
return self._merge_dim_type1_upsert(cur, dwd_table, ods_table, dwd_cols, ods_cols, pk_cols, now)
|
||||
|
||||
@@ -701,12 +733,19 @@ class DwdLoadTask(BaseTask):
|
||||
if not select_exprs:
|
||||
return 0
|
||||
|
||||
# 对于 dim_site 和 dim_site_ex,使用 DISTINCT ON 优化查询
|
||||
# 避免从大表 table_fee_transactions 全表扫描,只获取每个 site_id 的最新记录
|
||||
if dwd_table in ("billiards_dwd.dim_site", "billiards_dwd.dim_site_ex"):
|
||||
sql = f"SELECT DISTINCT ON (site_id) {', '.join(select_exprs)} FROM {ods_table_sql} ORDER BY site_id, fetched_at DESC NULLS LAST"
|
||||
else:
|
||||
sql = f"SELECT {', '.join(select_exprs)} FROM {ods_table_sql}"
|
||||
order_col = self._pick_snapshot_order_column(ods_cols)
|
||||
business_keys = self._strip_scd2_keys(pk_cols)
|
||||
key_exprs: list[str] = []
|
||||
for key in business_keys:
|
||||
lc = key.lower()
|
||||
if lc in mapping:
|
||||
src, cast_type = mapping[lc]
|
||||
key_exprs.append(self._cast_expr(src, cast_type))
|
||||
elif lc in ods_set:
|
||||
key_exprs.append(f'"{lc}"')
|
||||
|
||||
select_cols_sql = ", ".join(select_exprs)
|
||||
sql = self._latest_snapshot_select_sql(select_cols_sql, ods_table_sql, key_exprs, order_col)
|
||||
|
||||
cur.execute(sql)
|
||||
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
|
||||
@@ -784,7 +823,11 @@ class DwdLoadTask(BaseTask):
|
||||
if not pk_cols:
|
||||
raise ValueError(f"{dwd_table} 未配置主键,无法执行 SCD2 合并")
|
||||
|
||||
mapping = self._build_column_mapping(dwd_table, pk_cols, ods_cols)
|
||||
business_keys = self._strip_scd2_keys(pk_cols)
|
||||
if not business_keys:
|
||||
raise ValueError(f"{dwd_table} primary key only contains SCD2 columns; cannot merge")
|
||||
|
||||
mapping = self._build_column_mapping(dwd_table, business_keys, ods_cols)
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
table_sql = self._format_table(ods_table, "billiards_ods")
|
||||
# 构造 SELECT 表达式,支持 JSON/expression 映射
|
||||
@@ -806,7 +849,7 @@ class DwdLoadTask(BaseTask):
|
||||
select_exprs.append('"categoryboxes" AS "categoryboxes"')
|
||||
added.add("categoryboxes")
|
||||
# 主键兜底确保被选出
|
||||
for pk in pk_cols:
|
||||
for pk in business_keys:
|
||||
lc = pk.lower()
|
||||
if lc not in added:
|
||||
if lc in mapping:
|
||||
@@ -819,7 +862,18 @@ class DwdLoadTask(BaseTask):
|
||||
if not select_exprs:
|
||||
return 0
|
||||
|
||||
sql = f"SELECT {', '.join(select_exprs)} FROM {table_sql}"
|
||||
order_col = self._pick_snapshot_order_column(ods_cols)
|
||||
key_exprs: list[str] = []
|
||||
for key in business_keys:
|
||||
lc = key.lower()
|
||||
if lc in mapping:
|
||||
src, cast_type = mapping[lc]
|
||||
key_exprs.append(self._cast_expr(src, cast_type))
|
||||
elif lc in ods_set:
|
||||
key_exprs.append(f'"{lc}"')
|
||||
|
||||
select_cols_sql = ", ".join(select_exprs)
|
||||
sql = self._latest_snapshot_select_sql(select_cols_sql, table_sql, key_exprs, order_col)
|
||||
cur.execute(sql)
|
||||
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
|
||||
|
||||
@@ -842,11 +896,11 @@ class DwdLoadTask(BaseTask):
|
||||
value = row.get(src.lower())
|
||||
mapped_row[lc] = value
|
||||
|
||||
pk_key = tuple(mapped_row.get(pk) for pk in pk_cols)
|
||||
pk_key = tuple(mapped_row.get(pk) for pk in business_keys)
|
||||
if pk_key in seen_pk:
|
||||
continue
|
||||
if any(v is None for v in pk_key):
|
||||
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
|
||||
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(business_keys, pk_key)))
|
||||
continue
|
||||
seen_pk.add(pk_key)
|
||||
src_rows_by_pk[pk_key] = mapped_row
|
||||
@@ -862,7 +916,7 @@ class DwdLoadTask(BaseTask):
|
||||
current_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
|
||||
for r in current_rows:
|
||||
rr = {k.lower(): v for k, v in r.items()}
|
||||
pk_key = tuple(rr.get(pk) for pk in pk_cols)
|
||||
pk_key = tuple(rr.get(pk) for pk in business_keys)
|
||||
current_by_pk[pk_key] = rr
|
||||
|
||||
# 计算需要关闭/插入的主键集合
|
||||
@@ -881,7 +935,7 @@ class DwdLoadTask(BaseTask):
|
||||
|
||||
# 先关闭旧版本(同一批次统一 end_time)
|
||||
if to_close:
|
||||
self._close_current_dim_bulk(cur, dwd_table, pk_cols, to_close, now)
|
||||
self._close_current_dim_bulk(cur, dwd_table, business_keys, to_close, now)
|
||||
|
||||
# 批量插入新版本
|
||||
if to_insert:
|
||||
@@ -1031,10 +1085,105 @@ class DwdLoadTask(BaseTask):
|
||||
lc = col.lower()
|
||||
if lc in self.SCD_COLS:
|
||||
continue
|
||||
if current.get(lc) != incoming.get(lc):
|
||||
if not self._values_equal(current.get(lc), incoming.get(lc)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _values_equal(self, current_val: Any, incoming_val: Any) -> bool:
|
||||
"""Normalize common type mismatches (numeric/text, naive/aware datetime) before compare."""
|
||||
current_val = self._normalize_empty(current_val)
|
||||
incoming_val = self._normalize_empty(incoming_val)
|
||||
if current_val is None and incoming_val is None:
|
||||
return True
|
||||
|
||||
# Datetime normalization (naive vs aware)
|
||||
if isinstance(current_val, (datetime, date)) or isinstance(incoming_val, (datetime, date)):
|
||||
return self._normalize_datetime(current_val) == self._normalize_datetime(incoming_val)
|
||||
|
||||
# Boolean normalization
|
||||
if self._looks_bool(current_val) or self._looks_bool(incoming_val):
|
||||
cur_bool = self._coerce_bool(current_val)
|
||||
inc_bool = self._coerce_bool(incoming_val)
|
||||
if cur_bool is not None and inc_bool is not None:
|
||||
return cur_bool == inc_bool
|
||||
|
||||
# Numeric normalization (string vs numeric)
|
||||
if self._looks_numeric(current_val) or self._looks_numeric(incoming_val):
|
||||
cur_num = self._coerce_numeric(current_val)
|
||||
inc_num = self._coerce_numeric(incoming_val)
|
||||
if cur_num is not None and inc_num is not None:
|
||||
return cur_num == inc_num
|
||||
|
||||
return current_val == incoming_val
|
||||
|
||||
def _normalize_empty(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return None if stripped == "" else stripped
|
||||
return value
|
||||
|
||||
def _normalize_datetime(self, value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, date) and not isinstance(value, datetime):
|
||||
value = datetime.combine(value, datetime.min.time())
|
||||
if not isinstance(value, datetime):
|
||||
return value
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=self.tz)
|
||||
return value.astimezone(self.tz)
|
||||
|
||||
def _looks_numeric(self, value: Any) -> bool:
|
||||
if isinstance(value, (int, float, Decimal)) and not isinstance(value, bool):
|
||||
return True
|
||||
if isinstance(value, str):
|
||||
return bool(self._NUMERIC_RE.match(value.strip()))
|
||||
return False
|
||||
|
||||
def _coerce_numeric(self, value: Any) -> Decimal | None:
|
||||
value = self._normalize_empty(value)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return Decimal(int(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except InvalidOperation:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
s = value.strip()
|
||||
if not self._NUMERIC_RE.match(s):
|
||||
return None
|
||||
try:
|
||||
return Decimal(s)
|
||||
except InvalidOperation:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _looks_bool(self, value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return True
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in self._BOOL_STRINGS
|
||||
return False
|
||||
|
||||
def _coerce_bool(self, value: Any) -> bool | None:
|
||||
value = self._normalize_empty(value)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, Decimal)) and not isinstance(value, bool):
|
||||
return bool(int(value))
|
||||
if isinstance(value, str):
|
||||
s = value.strip().lower()
|
||||
if s in {"true", "1", "yes", "y", "t"}:
|
||||
return True
|
||||
if s in {"false", "0", "no", "n", "f"}:
|
||||
return False
|
||||
return None
|
||||
|
||||
def _merge_fact_increment(
|
||||
self,
|
||||
cur,
|
||||
@@ -1052,6 +1201,9 @@ class DwdLoadTask(BaseTask):
|
||||
mapping: Dict[str, tuple[str, str | None]] = {
|
||||
dst.lower(): (src, cast_type) for dst, src, cast_type in mapping_entries
|
||||
}
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
snapshot_mode = "content_hash" in ods_set
|
||||
fact_upsert = bool(self.config.get("dwd.fact_upsert", True))
|
||||
|
||||
mapping_dest = [dst for dst, _, _ in mapping_entries]
|
||||
insert_cols: List[str] = list(mapping_dest)
|
||||
@@ -1064,7 +1216,6 @@ class DwdLoadTask(BaseTask):
|
||||
insert_cols.append(col)
|
||||
|
||||
pk_cols = self._get_primary_keys(cur, dwd_table)
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
existing_lower = [c.lower() for c in insert_cols]
|
||||
for pk in pk_cols:
|
||||
pk_lower = pk.lower()
|
||||
@@ -1092,7 +1243,11 @@ class DwdLoadTask(BaseTask):
|
||||
self.logger.warning("跳过 %s:未找到可插入的列", dwd_table)
|
||||
return 0
|
||||
|
||||
order_col = self._pick_order_column(dwd_table, dwd_cols, ods_cols)
|
||||
order_col = (
|
||||
self._pick_snapshot_order_column(ods_cols)
|
||||
if snapshot_mode
|
||||
else self._pick_order_column(dwd_table, dwd_cols, ods_cols)
|
||||
)
|
||||
where_sql = ""
|
||||
params: List[Any] = []
|
||||
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
|
||||
@@ -1122,12 +1277,40 @@ class DwdLoadTask(BaseTask):
|
||||
|
||||
select_cols_sql = ", ".join(select_exprs)
|
||||
insert_cols_sql = ", ".join(f'"{c}"' for c in insert_cols)
|
||||
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
|
||||
if snapshot_mode and pk_cols:
|
||||
key_exprs: list[str] = []
|
||||
for pk in pk_cols:
|
||||
pk_lower = pk.lower()
|
||||
if pk_lower in mapping:
|
||||
src, cast_type = mapping[pk_lower]
|
||||
key_exprs.append(self._cast_expr(src, cast_type))
|
||||
elif pk_lower in ods_set:
|
||||
key_exprs.append(f'"{pk_lower}"')
|
||||
elif "id" in ods_set:
|
||||
key_exprs.append('"id"')
|
||||
select_sql = self._latest_snapshot_select_sql(
|
||||
select_cols_sql,
|
||||
ods_table_sql,
|
||||
key_exprs,
|
||||
order_col,
|
||||
where_sql,
|
||||
)
|
||||
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) {select_sql}'
|
||||
else:
|
||||
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
|
||||
|
||||
pk_cols = self._get_primary_keys(cur, dwd_table)
|
||||
if pk_cols:
|
||||
pk_sql = ", ".join(f'"{c}"' for c in pk_cols)
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
|
||||
pk_lower = {c.lower() for c in pk_cols}
|
||||
set_exprs = [f'"{c}" = EXCLUDED."{c}"' for c in insert_cols if c.lower() not in pk_lower]
|
||||
if snapshot_mode or fact_upsert:
|
||||
if set_exprs:
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO UPDATE SET {', '.join(set_exprs)}"
|
||||
else:
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
|
||||
else:
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
|
||||
|
||||
cur.execute(sql, params)
|
||||
inserted = cur.rowcount
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""手工示例数据灌入:按 schema_ODS_doc.sql 的表结构写入 ODS。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
@@ -252,12 +253,17 @@ class ManualIngestTask(BaseTask):
|
||||
except Exception:
|
||||
pk_index = None
|
||||
|
||||
has_content_hash = any(c[0].lower() == "content_hash" for c in columns_info)
|
||||
|
||||
col_list = ", ".join(f'"{c}"' for c in columns)
|
||||
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
|
||||
if pk_col_db:
|
||||
update_cols = [c for c in columns if c != pk_col_db]
|
||||
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
|
||||
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
|
||||
if has_content_hash:
|
||||
sql_prefix += f' ON CONFLICT ("{pk_col_db}", "content_hash") DO NOTHING'
|
||||
else:
|
||||
update_cols = [c for c in columns if c != pk_col_db]
|
||||
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
|
||||
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
|
||||
|
||||
params = []
|
||||
now = datetime.now()
|
||||
@@ -284,6 +290,12 @@ class ManualIngestTask(BaseTask):
|
||||
if pk_col and (pk_val is None or pk_val == ""):
|
||||
continue
|
||||
|
||||
content_hash = None
|
||||
if has_content_hash:
|
||||
hash_record = dict(merged_rec)
|
||||
hash_record["fetched_at"] = merged_rec.get("fetched_at", now)
|
||||
content_hash = self._compute_content_hash(hash_record, include_fetched_at=True)
|
||||
|
||||
row_vals = []
|
||||
for col_name, data_type, udt in columns_info:
|
||||
col_lower = col_name.lower()
|
||||
@@ -296,6 +308,9 @@ class ManualIngestTask(BaseTask):
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(merged_rec.get(col_name, now))
|
||||
continue
|
||||
if col_lower == "content_hash":
|
||||
row_vals.append(content_hash)
|
||||
continue
|
||||
|
||||
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
|
||||
|
||||
@@ -401,3 +416,48 @@ class ManualIngestTask(BaseTask):
|
||||
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
|
||||
return value if isinstance(value, str) else None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _hash_default(value):
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
@classmethod
|
||||
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
|
||||
exclude = {
|
||||
"data",
|
||||
"payload",
|
||||
"source_file",
|
||||
"source_endpoint",
|
||||
"content_hash",
|
||||
"record_index",
|
||||
}
|
||||
if not include_fetched_at:
|
||||
exclude.add("fetched_at")
|
||||
|
||||
def _strip(value):
|
||||
if isinstance(value, dict):
|
||||
cleaned = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(k, str) and k.lower() in exclude:
|
||||
continue
|
||||
cleaned[k] = _strip(v)
|
||||
return cleaned
|
||||
if isinstance(value, list):
|
||||
return [_strip(v) for v in value]
|
||||
return value
|
||||
|
||||
return _strip(record or {})
|
||||
|
||||
@classmethod
|
||||
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
|
||||
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
|
||||
payload = json.dumps(
|
||||
cleaned,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=cls._hash_default,
|
||||
)
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""ODS ingestion tasks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
@@ -317,6 +318,7 @@ class BaseOdsTask(BaseTask):
|
||||
db_json_cols_lower = {
|
||||
c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
|
||||
}
|
||||
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
|
||||
|
||||
col_names = [c[0] for c in cols_info]
|
||||
quoted_cols = ", ".join(f'\"{c}\"' for c in col_names)
|
||||
@@ -330,6 +332,7 @@ class BaseOdsTask(BaseTask):
|
||||
|
||||
params: list[tuple] = []
|
||||
skipped = 0
|
||||
merged_records: list[dict] = []
|
||||
|
||||
root_site_profile = None
|
||||
if isinstance(response_payload, dict):
|
||||
@@ -345,6 +348,7 @@ class BaseOdsTask(BaseTask):
|
||||
continue
|
||||
|
||||
merged_rec = self._merge_record_layers(rec)
|
||||
merged_records.append({"raw": rec, "merged": merged_rec})
|
||||
if table in {"billiards_ods.recharge_settlements", "billiards_ods.settlement_records"}:
|
||||
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") or root_site_profile
|
||||
if isinstance(site_profile, dict):
|
||||
@@ -363,9 +367,42 @@ class BaseOdsTask(BaseTask):
|
||||
_fill_missing("siteid", [site_profile.get("siteId"), site_profile.get("id")])
|
||||
_fill_missing("sitename", [site_profile.get("shop_name"), site_profile.get("siteName")])
|
||||
|
||||
has_fetched_at = any(c[0].lower() == "fetched_at" for c in cols_info)
|
||||
business_keys = [c for c in pk_cols if str(c).lower() != "content_hash"]
|
||||
compare_latest = bool(needs_content_hash and has_fetched_at and business_keys)
|
||||
latest_compare_hash: dict[tuple[Any, ...], str | None] = {}
|
||||
if compare_latest:
|
||||
key_values: list[tuple[Any, ...]] = []
|
||||
for item in merged_records:
|
||||
merged_rec = item["merged"]
|
||||
key = tuple(self._get_value_case_insensitive(merged_rec, k) for k in business_keys)
|
||||
if any(v is None or v == "" for v in key):
|
||||
continue
|
||||
key_values.append(key)
|
||||
|
||||
if key_values:
|
||||
with self.db.conn.cursor() as cur:
|
||||
latest_payloads = self._fetch_latest_payloads(cur, table, business_keys, key_values)
|
||||
for key, payload in latest_payloads.items():
|
||||
latest_compare_hash[key] = self._compute_compare_hash_from_payload(payload)
|
||||
|
||||
for item in merged_records:
|
||||
rec = item["raw"]
|
||||
merged_rec = item["merged"]
|
||||
|
||||
content_hash = None
|
||||
compare_hash = None
|
||||
if needs_content_hash:
|
||||
compare_hash = self._compute_content_hash(merged_rec, include_fetched_at=False)
|
||||
hash_record = dict(merged_rec)
|
||||
hash_record["fetched_at"] = now
|
||||
content_hash = self._compute_content_hash(hash_record, include_fetched_at=True)
|
||||
|
||||
if pk_cols:
|
||||
missing_pk = False
|
||||
for pk in pk_cols:
|
||||
if str(pk).lower() == "content_hash":
|
||||
continue
|
||||
pk_val = self._get_value_case_insensitive(merged_rec, pk)
|
||||
if pk_val is None or pk_val == "":
|
||||
missing_pk = True
|
||||
@@ -374,6 +411,16 @@ class BaseOdsTask(BaseTask):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
if compare_latest and compare_hash is not None:
|
||||
key = tuple(self._get_value_case_insensitive(merged_rec, k) for k in business_keys)
|
||||
if any(v is None or v == "" for v in key):
|
||||
skipped += 1
|
||||
continue
|
||||
last_hash = latest_compare_hash.get(key)
|
||||
if last_hash is not None and last_hash == compare_hash:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
row_vals: list[Any] = []
|
||||
for (col_name, data_type, _udt) in cols_info:
|
||||
col_lower = col_name.lower()
|
||||
@@ -389,6 +436,9 @@ class BaseOdsTask(BaseTask):
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(now)
|
||||
continue
|
||||
if col_lower == "content_hash":
|
||||
row_vals.append(content_hash)
|
||||
continue
|
||||
|
||||
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
|
||||
if col_lower in db_json_cols_lower:
|
||||
@@ -472,6 +522,93 @@ class BaseOdsTask(BaseTask):
|
||||
return resolver(spec.endpoint)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hash_default(value):
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
@classmethod
|
||||
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
|
||||
exclude = {
|
||||
"data",
|
||||
"payload",
|
||||
"source_file",
|
||||
"source_endpoint",
|
||||
"content_hash",
|
||||
"record_index",
|
||||
}
|
||||
if not include_fetched_at:
|
||||
exclude.add("fetched_at")
|
||||
|
||||
def _strip(value):
|
||||
if isinstance(value, dict):
|
||||
cleaned = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(k, str) and k.lower() in exclude:
|
||||
continue
|
||||
cleaned[k] = _strip(v)
|
||||
return cleaned
|
||||
if isinstance(value, list):
|
||||
return [_strip(v) for v in value]
|
||||
return value
|
||||
|
||||
return _strip(record or {})
|
||||
|
||||
@classmethod
|
||||
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
|
||||
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
|
||||
payload = json.dumps(
|
||||
cleaned,
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
default=cls._hash_default,
|
||||
)
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_compare_hash_from_payload(payload: Any) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
merged = BaseOdsTask._merge_record_layers(payload)
|
||||
return BaseOdsTask._compute_content_hash(merged, include_fetched_at=False)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_latest_payloads(cur, table: str, business_keys: Sequence[str], key_values: Sequence[tuple]) -> dict:
|
||||
if not business_keys or not key_values:
|
||||
return {}
|
||||
keys_sql = ", ".join(f'"{k}"' for k in business_keys)
|
||||
sql = (
|
||||
f"WITH keys({keys_sql}) AS (VALUES %s) "
|
||||
f"SELECT DISTINCT ON ({keys_sql}) {keys_sql}, payload "
|
||||
f"FROM {table} t JOIN keys k USING ({keys_sql}) "
|
||||
f"ORDER BY {keys_sql}, fetched_at DESC NULLS LAST"
|
||||
)
|
||||
unique_keys = list({tuple(k) for k in key_values})
|
||||
execute_values(cur, sql, unique_keys, page_size=500)
|
||||
rows = cur.fetchall() or []
|
||||
result = {}
|
||||
if rows and isinstance(rows[0], dict):
|
||||
for r in rows:
|
||||
key = tuple(r[k] for k in business_keys)
|
||||
result[key] = r.get("payload")
|
||||
return result
|
||||
|
||||
key_len = len(business_keys)
|
||||
for r in rows:
|
||||
key = tuple(r[:key_len])
|
||||
payload = r[key_len] if len(r) > key_len else None
|
||||
result[key] = payload
|
||||
return result
|
||||
|
||||
|
||||
def _int_col(name: str, *sources: str, required: bool = False) -> ColumnSpec:
|
||||
return ColumnSpec(
|
||||
|
||||
Reference in New Issue
Block a user