数据库 数据校验写入等逻辑更新。
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
"""DWD 装载任务:从 ODS 增量写入 DWD(维度 SCD2,事实按时间增量)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Any, Dict, Iterable, List, Sequence
|
||||
|
||||
from psycopg2.extras import RealDictCursor, execute_batch, execute_values
|
||||
@@ -77,6 +79,37 @@ class DwdLoadTask(BaseTask):
|
||||
"billiards_dwd.dwd_assistant_service_log",
|
||||
}
|
||||
|
||||
_NUMERIC_RE = re.compile(r"^[+-]?\d+(?:\.\d+)?$")
|
||||
_BOOL_STRINGS = {"true", "false", "1", "0", "yes", "no", "y", "n", "t", "f"}
|
||||
|
||||
def _strip_scd2_keys(self, pk_cols: Sequence[str]) -> list[str]:
|
||||
return [c for c in pk_cols if c.lower() not in self.SCD_COLS]
|
||||
|
||||
@staticmethod
|
||||
def _pick_snapshot_order_column(ods_cols: Sequence[str]) -> str | None:
|
||||
lower_cols = {c.lower() for c in ods_cols}
|
||||
for candidate in ("fetched_at", "update_time", "create_time"):
|
||||
if candidate in lower_cols:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _latest_snapshot_select_sql(
|
||||
select_cols_sql: str,
|
||||
ods_table_sql: str,
|
||||
key_exprs: Sequence[str],
|
||||
order_col: str | None,
|
||||
where_sql: str = "",
|
||||
) -> str:
|
||||
if key_exprs and order_col:
|
||||
distinct_on = ", ".join(key_exprs)
|
||||
order_by = ", ".join([*key_exprs, f'"{order_col}" DESC NULLS LAST'])
|
||||
return (
|
||||
f"SELECT DISTINCT ON ({distinct_on}) {select_cols_sql} "
|
||||
f"FROM {ods_table_sql} {where_sql} ORDER BY {order_by}"
|
||||
)
|
||||
return f"SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}"
|
||||
|
||||
# 特殊列映射:dwd 列名 -> 源列表达式(可选 CAST)
|
||||
FACT_MAPPINGS: dict[str, list[tuple[str, str, str | None]]] = {
|
||||
# 维度表(补齐主键/字段差异)
|
||||
@@ -652,9 +685,8 @@ class DwdLoadTask(BaseTask):
|
||||
if not pk_cols:
|
||||
raise ValueError(f"{dwd_table} 未配置主键,无法执行维表合并")
|
||||
|
||||
pk_has_scd = any(pk.lower() in self.SCD_COLS for pk in pk_cols)
|
||||
scd_cols_present = any(c.lower() in self.SCD_COLS for c in dwd_cols)
|
||||
if scd_cols_present and pk_has_scd:
|
||||
if scd_cols_present:
|
||||
return self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
|
||||
return self._merge_dim_type1_upsert(cur, dwd_table, ods_table, dwd_cols, ods_cols, pk_cols, now)
|
||||
|
||||
@@ -701,12 +733,19 @@ class DwdLoadTask(BaseTask):
|
||||
if not select_exprs:
|
||||
return 0
|
||||
|
||||
# 对于 dim_site 和 dim_site_ex,使用 DISTINCT ON 优化查询
|
||||
# 避免从大表 table_fee_transactions 全表扫描,只获取每个 site_id 的最新记录
|
||||
if dwd_table in ("billiards_dwd.dim_site", "billiards_dwd.dim_site_ex"):
|
||||
sql = f"SELECT DISTINCT ON (site_id) {', '.join(select_exprs)} FROM {ods_table_sql} ORDER BY site_id, fetched_at DESC NULLS LAST"
|
||||
else:
|
||||
sql = f"SELECT {', '.join(select_exprs)} FROM {ods_table_sql}"
|
||||
order_col = self._pick_snapshot_order_column(ods_cols)
|
||||
business_keys = self._strip_scd2_keys(pk_cols)
|
||||
key_exprs: list[str] = []
|
||||
for key in business_keys:
|
||||
lc = key.lower()
|
||||
if lc in mapping:
|
||||
src, cast_type = mapping[lc]
|
||||
key_exprs.append(self._cast_expr(src, cast_type))
|
||||
elif lc in ods_set:
|
||||
key_exprs.append(f'"{lc}"')
|
||||
|
||||
select_cols_sql = ", ".join(select_exprs)
|
||||
sql = self._latest_snapshot_select_sql(select_cols_sql, ods_table_sql, key_exprs, order_col)
|
||||
|
||||
cur.execute(sql)
|
||||
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
|
||||
@@ -784,7 +823,11 @@ class DwdLoadTask(BaseTask):
|
||||
if not pk_cols:
|
||||
raise ValueError(f"{dwd_table} 未配置主键,无法执行 SCD2 合并")
|
||||
|
||||
mapping = self._build_column_mapping(dwd_table, pk_cols, ods_cols)
|
||||
business_keys = self._strip_scd2_keys(pk_cols)
|
||||
if not business_keys:
|
||||
raise ValueError(f"{dwd_table} primary key only contains SCD2 columns; cannot merge")
|
||||
|
||||
mapping = self._build_column_mapping(dwd_table, business_keys, ods_cols)
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
table_sql = self._format_table(ods_table, "billiards_ods")
|
||||
# 构造 SELECT 表达式,支持 JSON/expression 映射
|
||||
@@ -806,7 +849,7 @@ class DwdLoadTask(BaseTask):
|
||||
select_exprs.append('"categoryboxes" AS "categoryboxes"')
|
||||
added.add("categoryboxes")
|
||||
# 主键兜底确保被选出
|
||||
for pk in pk_cols:
|
||||
for pk in business_keys:
|
||||
lc = pk.lower()
|
||||
if lc not in added:
|
||||
if lc in mapping:
|
||||
@@ -819,7 +862,18 @@ class DwdLoadTask(BaseTask):
|
||||
if not select_exprs:
|
||||
return 0
|
||||
|
||||
sql = f"SELECT {', '.join(select_exprs)} FROM {table_sql}"
|
||||
order_col = self._pick_snapshot_order_column(ods_cols)
|
||||
key_exprs: list[str] = []
|
||||
for key in business_keys:
|
||||
lc = key.lower()
|
||||
if lc in mapping:
|
||||
src, cast_type = mapping[lc]
|
||||
key_exprs.append(self._cast_expr(src, cast_type))
|
||||
elif lc in ods_set:
|
||||
key_exprs.append(f'"{lc}"')
|
||||
|
||||
select_cols_sql = ", ".join(select_exprs)
|
||||
sql = self._latest_snapshot_select_sql(select_cols_sql, table_sql, key_exprs, order_col)
|
||||
cur.execute(sql)
|
||||
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
|
||||
|
||||
@@ -842,11 +896,11 @@ class DwdLoadTask(BaseTask):
|
||||
value = row.get(src.lower())
|
||||
mapped_row[lc] = value
|
||||
|
||||
pk_key = tuple(mapped_row.get(pk) for pk in pk_cols)
|
||||
pk_key = tuple(mapped_row.get(pk) for pk in business_keys)
|
||||
if pk_key in seen_pk:
|
||||
continue
|
||||
if any(v is None for v in pk_key):
|
||||
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
|
||||
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(business_keys, pk_key)))
|
||||
continue
|
||||
seen_pk.add(pk_key)
|
||||
src_rows_by_pk[pk_key] = mapped_row
|
||||
@@ -862,7 +916,7 @@ class DwdLoadTask(BaseTask):
|
||||
current_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
|
||||
for r in current_rows:
|
||||
rr = {k.lower(): v for k, v in r.items()}
|
||||
pk_key = tuple(rr.get(pk) for pk in pk_cols)
|
||||
pk_key = tuple(rr.get(pk) for pk in business_keys)
|
||||
current_by_pk[pk_key] = rr
|
||||
|
||||
# 计算需要关闭/插入的主键集合
|
||||
@@ -881,7 +935,7 @@ class DwdLoadTask(BaseTask):
|
||||
|
||||
# 先关闭旧版本(同一批次统一 end_time)
|
||||
if to_close:
|
||||
self._close_current_dim_bulk(cur, dwd_table, pk_cols, to_close, now)
|
||||
self._close_current_dim_bulk(cur, dwd_table, business_keys, to_close, now)
|
||||
|
||||
# 批量插入新版本
|
||||
if to_insert:
|
||||
@@ -1031,10 +1085,105 @@ class DwdLoadTask(BaseTask):
|
||||
lc = col.lower()
|
||||
if lc in self.SCD_COLS:
|
||||
continue
|
||||
if current.get(lc) != incoming.get(lc):
|
||||
if not self._values_equal(current.get(lc), incoming.get(lc)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _values_equal(self, current_val: Any, incoming_val: Any) -> bool:
|
||||
"""Normalize common type mismatches (numeric/text, naive/aware datetime) before compare."""
|
||||
current_val = self._normalize_empty(current_val)
|
||||
incoming_val = self._normalize_empty(incoming_val)
|
||||
if current_val is None and incoming_val is None:
|
||||
return True
|
||||
|
||||
# Datetime normalization (naive vs aware)
|
||||
if isinstance(current_val, (datetime, date)) or isinstance(incoming_val, (datetime, date)):
|
||||
return self._normalize_datetime(current_val) == self._normalize_datetime(incoming_val)
|
||||
|
||||
# Boolean normalization
|
||||
if self._looks_bool(current_val) or self._looks_bool(incoming_val):
|
||||
cur_bool = self._coerce_bool(current_val)
|
||||
inc_bool = self._coerce_bool(incoming_val)
|
||||
if cur_bool is not None and inc_bool is not None:
|
||||
return cur_bool == inc_bool
|
||||
|
||||
# Numeric normalization (string vs numeric)
|
||||
if self._looks_numeric(current_val) or self._looks_numeric(incoming_val):
|
||||
cur_num = self._coerce_numeric(current_val)
|
||||
inc_num = self._coerce_numeric(incoming_val)
|
||||
if cur_num is not None and inc_num is not None:
|
||||
return cur_num == inc_num
|
||||
|
||||
return current_val == incoming_val
|
||||
|
||||
def _normalize_empty(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return None if stripped == "" else stripped
|
||||
return value
|
||||
|
||||
def _normalize_datetime(self, value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, date) and not isinstance(value, datetime):
|
||||
value = datetime.combine(value, datetime.min.time())
|
||||
if not isinstance(value, datetime):
|
||||
return value
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=self.tz)
|
||||
return value.astimezone(self.tz)
|
||||
|
||||
def _looks_numeric(self, value: Any) -> bool:
|
||||
if isinstance(value, (int, float, Decimal)) and not isinstance(value, bool):
|
||||
return True
|
||||
if isinstance(value, str):
|
||||
return bool(self._NUMERIC_RE.match(value.strip()))
|
||||
return False
|
||||
|
||||
def _coerce_numeric(self, value: Any) -> Decimal | None:
|
||||
value = self._normalize_empty(value)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return Decimal(int(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except InvalidOperation:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
s = value.strip()
|
||||
if not self._NUMERIC_RE.match(s):
|
||||
return None
|
||||
try:
|
||||
return Decimal(s)
|
||||
except InvalidOperation:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _looks_bool(self, value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return True
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in self._BOOL_STRINGS
|
||||
return False
|
||||
|
||||
def _coerce_bool(self, value: Any) -> bool | None:
|
||||
value = self._normalize_empty(value)
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, Decimal)) and not isinstance(value, bool):
|
||||
return bool(int(value))
|
||||
if isinstance(value, str):
|
||||
s = value.strip().lower()
|
||||
if s in {"true", "1", "yes", "y", "t"}:
|
||||
return True
|
||||
if s in {"false", "0", "no", "n", "f"}:
|
||||
return False
|
||||
return None
|
||||
|
||||
def _merge_fact_increment(
|
||||
self,
|
||||
cur,
|
||||
@@ -1052,6 +1201,9 @@ class DwdLoadTask(BaseTask):
|
||||
mapping: Dict[str, tuple[str, str | None]] = {
|
||||
dst.lower(): (src, cast_type) for dst, src, cast_type in mapping_entries
|
||||
}
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
snapshot_mode = "content_hash" in ods_set
|
||||
fact_upsert = bool(self.config.get("dwd.fact_upsert", True))
|
||||
|
||||
mapping_dest = [dst for dst, _, _ in mapping_entries]
|
||||
insert_cols: List[str] = list(mapping_dest)
|
||||
@@ -1064,7 +1216,6 @@ class DwdLoadTask(BaseTask):
|
||||
insert_cols.append(col)
|
||||
|
||||
pk_cols = self._get_primary_keys(cur, dwd_table)
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
existing_lower = [c.lower() for c in insert_cols]
|
||||
for pk in pk_cols:
|
||||
pk_lower = pk.lower()
|
||||
@@ -1092,7 +1243,11 @@ class DwdLoadTask(BaseTask):
|
||||
self.logger.warning("跳过 %s:未找到可插入的列", dwd_table)
|
||||
return 0
|
||||
|
||||
order_col = self._pick_order_column(dwd_table, dwd_cols, ods_cols)
|
||||
order_col = (
|
||||
self._pick_snapshot_order_column(ods_cols)
|
||||
if snapshot_mode
|
||||
else self._pick_order_column(dwd_table, dwd_cols, ods_cols)
|
||||
)
|
||||
where_sql = ""
|
||||
params: List[Any] = []
|
||||
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
|
||||
@@ -1122,12 +1277,40 @@ class DwdLoadTask(BaseTask):
|
||||
|
||||
select_cols_sql = ", ".join(select_exprs)
|
||||
insert_cols_sql = ", ".join(f'"{c}"' for c in insert_cols)
|
||||
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
|
||||
if snapshot_mode and pk_cols:
|
||||
key_exprs: list[str] = []
|
||||
for pk in pk_cols:
|
||||
pk_lower = pk.lower()
|
||||
if pk_lower in mapping:
|
||||
src, cast_type = mapping[pk_lower]
|
||||
key_exprs.append(self._cast_expr(src, cast_type))
|
||||
elif pk_lower in ods_set:
|
||||
key_exprs.append(f'"{pk_lower}"')
|
||||
elif "id" in ods_set:
|
||||
key_exprs.append('"id"')
|
||||
select_sql = self._latest_snapshot_select_sql(
|
||||
select_cols_sql,
|
||||
ods_table_sql,
|
||||
key_exprs,
|
||||
order_col,
|
||||
where_sql,
|
||||
)
|
||||
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) {select_sql}'
|
||||
else:
|
||||
sql = f'INSERT INTO {dwd_table_sql} ({insert_cols_sql}) SELECT {select_cols_sql} FROM {ods_table_sql} {where_sql}'
|
||||
|
||||
pk_cols = self._get_primary_keys(cur, dwd_table)
|
||||
if pk_cols:
|
||||
pk_sql = ", ".join(f'"{c}"' for c in pk_cols)
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
|
||||
pk_lower = {c.lower() for c in pk_cols}
|
||||
set_exprs = [f'"{c}" = EXCLUDED."{c}"' for c in insert_cols if c.lower() not in pk_lower]
|
||||
if snapshot_mode or fact_upsert:
|
||||
if set_exprs:
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO UPDATE SET {', '.join(set_exprs)}"
|
||||
else:
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
|
||||
else:
|
||||
sql += f" ON CONFLICT ({pk_sql}) DO NOTHING"
|
||||
|
||||
cur.execute(sql, params)
|
||||
inserted = cur.rowcount
|
||||
|
||||
Reference in New Issue
Block a user