数据库 数据校验写入等逻辑更新。

This commit is contained in:
Neo
2026-02-01 03:46:16 +08:00
parent 9948000b71
commit 076f5755ca
128 changed files with 494310 additions and 2819 deletions

View File

@@ -2,8 +2,10 @@
"""DWD 装载任务:从 ODS 增量写入 DWD维度 SCD2事实按时间增量"""
from __future__ import annotations
import re
import time
from datetime import datetime
from datetime import date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any, Dict, Iterable, List, Sequence
from psycopg2.extras import RealDictCursor, execute_batch, execute_values
@@ -77,6 +79,37 @@ class DwdLoadTask(BaseTask):
"billiards_dwd.dwd_assistant_service_log",
}
_NUMERIC_RE = re.compile(r"^[+-]?\d+(?:\.\d+)?$")
_BOOL_STRINGS = {"true", "false", "1", "0", "yes", "no", "y", "n", "t", "f"}
def _strip_scd2_keys(self, pk_cols: Sequence[str]) -> list[str]:
return [c for c in pk_cols if c.lower() not in self.SCD_COLS]
@staticmethod
def _pick_snapshot_order_column(ods_cols: Sequence[str]) -> str | None:
lower_cols = {c.lower() for c in ods_cols}
for candidate in ("fetched_at", "update_time", "create_time"):
if candidate in lower_cols:
return candidate
return None
@staticmethod
def _latest_snapshot_select_sql(
select_cols_sql: str,
ods_table_sql: str,
key_exprs: Sequence[str],
order_col: str | None,
where_sql: str = "",
) -> str:
if key_exprs and order_col:
distinct_on = ", ".join(key_exprs)
order_by = ", ".join([*key_exprs, f'"{order_col}" DESC NULLS LAST'])
return (
f"SELECT DISTINCT ON ({distinct_on}) {select_cols_sql} "
f"FROM {ods_table_sql} {where_sql} ORDER BY {order_by}"
)
return f"SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}"
# 特殊列映射dwd 列名 -> 源列表达式(可选 CAST
FACT_MAPPINGS: dict[str, list[tuple[str, str, str | None]]] = {
# 维度表(补齐主键/字段差异)
@@ -652,9 +685,8 @@ class DwdLoadTask(BaseTask):
if not pk_cols:
raise ValueError(f"{dwd_table} 未配置主键,无法执行维表合并")
pk_has_scd = any(pk.lower() in self.SCD_COLS for pk in pk_cols)
scd_cols_present = any(c.lower() in self.SCD_COLS for c in dwd_cols)
if scd_cols_present and pk_has_scd:
if scd_cols_present:
return self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
return self._merge_dim_type1_upsert(cur, dwd_table, ods_table, dwd_cols, ods_cols, pk_cols, now)
@@ -701,12 +733,19 @@ class DwdLoadTask(BaseTask):
if not select_exprs:
return 0
# 对于 dim_site 和 dim_site_ex使用 DISTINCT ON 优化查询
# 避免从大表 table_fee_transactions 全表扫描,只获取每个 site_id 的最新记录
if dwd_table in ("billiards_dwd.dim_site", "billiards_dwd.dim_site_ex"):
sql = f"SELECT DISTINCT ON (site_id) {', '.join(select_exprs)} FROM {ods_table_sql} ORDER BY site_id, fetched_at DESC NULLS LAST"
else:
sql = f"SELECT {', '.join(select_exprs)} FROM {ods_table_sql}"
order_col = self._pick_snapshot_order_column(ods_cols)
business_keys = self._strip_scd2_keys(pk_cols)
key_exprs: list[str] = []
for key in business_keys:
lc = key.lower()
if lc in mapping:
src, cast_type = mapping[lc]
key_exprs.append(self._cast_expr(src, cast_type))
elif lc in ods_set:
key_exprs.append(f'"{lc}"')
select_cols_sql = ", ".join(select_exprs)
sql = self._latest_snapshot_select_sql(select_cols_sql, ods_table_sql, key_exprs, order_col)
cur.execute(sql)
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
@@ -784,7 +823,11 @@ class DwdLoadTask(BaseTask):
if not pk_cols:
raise ValueError(f"{dwd_table} 未配置主键,无法执行 SCD2 合并")
mapping = self._build_column_mapping(dwd_table, pk_cols, ods_cols)
business_keys = self._strip_scd2_keys(pk_cols)
if not business_keys:
raise ValueError(f"{dwd_table} primary key only contains SCD2 columns; cannot merge")
mapping = self._build_column_mapping(dwd_table, business_keys, ods_cols)
ods_set = {c.lower() for c in ods_cols}
table_sql = self._format_table(ods_table, "billiards_ods")
# 构造 SELECT 表达式,支持 JSON/expression 映射
@@ -806,7 +849,7 @@ class DwdLoadTask(BaseTask):
select_exprs.append('"categoryboxes" AS "categoryboxes"')
added.add("categoryboxes")
# 主键兜底确保被选出
for pk in pk_cols:
for pk in business_keys:
lc = pk.lower()
if lc not in added:
if lc in mapping:
@@ -819,7 +862,18 @@ class DwdLoadTask(BaseTask):
if not select_exprs:
return 0
sql = f"SELECT {', '.join(select_exprs)} FROM {table_sql}"
order_col = self._pick_snapshot_order_column(ods_cols)
key_exprs: list[str] = []
for key in business_keys:
lc = key.lower()
if lc in mapping:
src, cast_type = mapping[lc]
key_exprs.append(self._cast_expr(src, cast_type))
elif lc in ods_set:
key_exprs.append(f'"{lc}"')
select_cols_sql = ", ".join(select_exprs)
sql = self._latest_snapshot_select_sql(select_cols_sql, table_sql, key_exprs, order_col)
cur.execute(sql)
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
@@ -842,11 +896,11 @@ class DwdLoadTask(BaseTask):
value = row.get(src.lower())
mapped_row[lc] = value
pk_key = tuple(mapped_row.get(pk) for pk in pk_cols)
pk_key = tuple(mapped_row.get(pk) for pk in business_keys)
if pk_key in seen_pk:
continue
if any(v is None for v in pk_key):
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(business_keys, pk_key)))
continue
seen_pk.add(pk_key)
src_rows_by_pk[pk_key] = mapped_row
@@ -862,7 +916,7 @@ class DwdLoadTask(BaseTask):
current_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
for r in current_rows:
rr = {k.lower(): v for k, v in r.items()}
pk_key = tuple(rr.get(pk) for pk in pk_cols)
pk_key = tuple(rr.get(pk) for pk in business_keys)
current_by_pk[pk_key] = rr
# 计算需要关闭/插入的主键集合
@@ -881,7 +935,7 @@ class DwdLoadTask(BaseTask):
# 先关闭旧版本(同一批次统一 end_time
if to_close:
self._close_current_dim_bulk(cur, dwd_table, pk_cols, to_close, now)
self._close_current_dim_bulk(cur, dwd_table, business_keys, to_close, now)
# 批量插入新版本
if to_insert:
@@ -1031,10 +1085,105 @@ class DwdLoadTask(BaseTask):
lc = col.lower()
if lc in self.SCD_COLS:
continue
if current.get(lc) != incoming.get(lc):
if not self._values_equal(current.get(lc), incoming.get(lc)):
return True
return False
def _values_equal(self, current_val: Any, incoming_val: Any) -> bool:
"""Normalize common type mismatches (numeric/text, naive/aware datetime) before compare."""
current_val = self._normalize_empty(current_val)
incoming_val = self._normalize_empty(incoming_val)
if current_val is None and incoming_val is None:
return True
# Datetime normalization (naive vs aware)
if isinstance(current_val, (datetime, date)) or isinstance(incoming_val, (datetime, date)):
return self._normalize_datetime(current_val) == self._normalize_datetime(incoming_val)
# Boolean normalization
if self._looks_bool(current_val) or self._looks_bool(incoming_val):
cur_bool = self._coerce_bool(current_val)
inc_bool = self._coerce_bool(incoming_val)
if cur_bool is not None and inc_bool is not None:
return cur_bool == inc_bool
# Numeric normalization (string vs numeric)
if self._looks_numeric(current_val) or self._looks_numeric(incoming_val):
cur_num = self._coerce_numeric(current_val)
inc_num = self._coerce_numeric(incoming_val)
if cur_num is not None and inc_num is not None:
return cur_num == inc_num
return current_val == incoming_val
def _normalize_empty(self, value: Any) -> Any:
if isinstance(value, str):
stripped = value.strip()
return None if stripped == "" else stripped
return value
def _normalize_datetime(self, value: Any) -> Any:
if value is None:
return None
if isinstance(value, date) and not isinstance(value, datetime):
value = datetime.combine(value, datetime.min.time())
if not isinstance(value, datetime):
return value
if value.tzinfo is None:
return value.replace(tzinfo=self.tz)
return value.astimezone(self.tz)
def _looks_numeric(self, value: Any) -> bool:
if isinstance(value, (int, float, Decimal)) and not isinstance(value, bool):
return True
if isinstance(value, str):
return bool(self._NUMERIC_RE.match(value.strip()))
return False
def _coerce_numeric(self, value: Any) -> Decimal | None:
value = self._normalize_empty(value)
if value is None:
return None
if isinstance(value, bool):
return Decimal(int(value))
if isinstance(value, (int, float, Decimal)):
try:
return Decimal(str(value))
except InvalidOperation:
return None
if isinstance(value, str):
s = value.strip()
if not self._NUMERIC_RE.match(s):
return None
try:
return Decimal(s)
except InvalidOperation:
return None
return None
def _looks_bool(self, value: Any) -> bool:
if isinstance(value, bool):
return True
if isinstance(value, str):
return value.strip().lower() in self._BOOL_STRINGS
return False
def _coerce_bool(self, value: Any) -> bool | None:
value = self._normalize_empty(value)
if value is None:
return None
if isinstance(value, bool):
return value
if isinstance(value, (int, Decimal)) and not isinstance(value, bool):
return bool(int(value))
if isinstance(value, str):
s = value.strip().lower()
if s in {"true", "1", "yes", "y", "t"}:
return True
if s in {"false", "0", "no", "n", "f"}:
return False
return None
def _merge_fact_increment(
self,
cur,
@@ -1052,6 +1201,9 @@ class DwdLoadTask(BaseTask):
mapping: Dict[str, tuple[str, str | None]] = {
dst.lower(): (src, cast_type) for dst, src, cast_type in mapping_entries
}
ods_set = {c.lower() for c in ods_cols}
snapshot_mode = "content_hash" in ods_set
fact_upsert = bool(self.config.get("dwd.fact_upsert", True))
mapping_dest = [dst for dst, _, _ in mapping_entries]
insert_cols: List[str] = list(mapping_dest)
@@ -1064,7 +1216,6 @@ class DwdLoadTask(BaseTask):
insert_cols.append(col)
pk_cols = self._get_primary_keys(cur, dwd_table)
ods_set = {c.lower() for c in ods_cols}
existing_lower = [c.lower() for c in insert_cols]
for pk in pk_cols:
pk_lower = pk.lower()
@@ -1092,7 +1243,11 @@ class DwdLoadTask(BaseTask):
self.logger.warning("跳过 %s:未找到可插入的列", dwd_table)
return 0
order_col = self._pick_order_column(dwd_table, dwd_cols, ods_cols)
order_col = (
self._pick_snapshot_order_column(ods_cols)
if snapshot_mode
else self._pick_order_column(dwd_table, dwd_cols, ods_cols)
)
where_sql = ""
params: List[Any] = []
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
@@ -1122,12 +1277,40 @@ class DwdLoadTask(BaseTask):
select_cols_sql = ", ".join(select_exprs)
insert_cols_sql = ", ".join(f'"{c}"' for c in insert_cols)
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
if snapshot_mode and pk_cols:
key_exprs: list[str] = []
for pk in pk_cols:
pk_lower = pk.lower()
if pk_lower in mapping:
src, cast_type = mapping[pk_lower]
key_exprs.append(self._cast_expr(src, cast_type))
elif pk_lower in ods_set:
key_exprs.append(f'"{pk_lower}"')
elif "id" in ods_set:
key_exprs.append('"id"')
select_sql = self._latest_snapshot_select_sql(
select_cols_sql,
ods_table_sql,
key_exprs,
order_col,
where_sql,
)
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) {select_sql}'
else:
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
pk_cols = self._get_primary_keys(cur, dwd_table)
if pk_cols:
pk_sql = ", ".join(f'"{c}"' for c in pk_cols)
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
pk_lower = {c.lower() for c in pk_cols}
set_exprs = [f'"{c}" = EXCLUDED."{c}"' for c in insert_cols if c.lower() not in pk_lower]
if snapshot_mode or fact_upsert:
if set_exprs:
sql += f" ON CONFLICT ({pk_sql}) DO UPDATE SET {', '.join(set_exprs)}"
else:
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
else:
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
cur.execute(sql, params)
inserted = cur.rowcount