# -*- coding: utf-8 -*- """DWD 层批量校验器 校验逻辑:对比 ODS 源数据与 DWD 表数据 - 维度表:SCD2 模式,对比当前版本 - 事实表:主键对比,批量 UPSERT 补齐 """ import hashlib import json import logging import time from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set, Tuple from psycopg2.extras import Json, execute_values from .base_verifier import BaseVerifier, VerificationFetchError from tasks.dwd.dwd_load_task import DwdLoadTask class DwdVerifier(BaseVerifier): """DWD 层校验器""" def __init__( self, db_connection: Any, logger: Optional[logging.Logger] = None, config: Any = None, ): """ 初始化 DWD 校验器 Args: db_connection: 数据库连接 logger: 日志器 """ super().__init__(db_connection, logger) self._table_config = self._load_table_config() self.config = config @property def layer_name(self) -> str: return "DWD" def _load_table_config(self) -> Dict[str, dict]: """加载 DWD 表配置""" # ODS 表主键列名映射(ODS 列名通常都是 id,特殊情况单独配置) # 格式:ods_table -> ods_pk_column ODS_PK_MAP = { "table_fee_transactions": "id", "site_tables_master": "id", "assistant_accounts_master": "id", "member_profiles": "id", "member_stored_value_cards": "id", "tenant_goods_master": "id", "store_goods_master": "id", "stock_goods_category_tree": "id", "group_buy_packages": "id", "settlement_records": "id", "table_fee_discount_records": "id", "store_goods_sales_records": "id", "assistant_service_records": "id", "assistant_cancellation_records": "id", "member_balance_changes": "id", "group_buy_redemption_records": "id", "platform_coupon_redemption_records": "id", "recharge_settlements": "id", # 注意:这里 ODS 列是 id,但 DWD 列是 recharge_order_id "payment_transactions": "id", "refund_transactions": "id", "goods_stock_summary": "sitegoodsid", # 特殊:主键不是 id "settlement_ticket_details": "ordersettleid", # 特殊:主键不是 id } # ODS 主键特殊覆盖(按 DWD 表名) # 格式:dwd_table -> ods_pk_columns ODS_PK_OVERRIDE = { "dim_site": ["site_id"], "dim_site_ex": ["site_id"], } # ODS 到 DWD 主键列名映射(ODS 的 id 对应 DWD 的语义化列名) # 格式:dwd_table -> {ods_column: dwd_column} ODS_TO_DWD_PK_MAP = { # 维度表(复杂映射的表设为空字典,跳过 backfill) "dim_site": {"site_id": "site_id"}, "dim_site_ex": {"site_id": "site_id"}, "dim_table": {"id": "table_id"}, "dim_table_ex": {"id": "table_id"}, "dim_assistant": {"id": "assistant_id"}, "dim_assistant_ex": {"id": "assistant_id"}, "dim_member": {"id": "member_id"}, "dim_member_ex": {"id": "member_id"}, "dim_member_card_account": {"id": "member_card_id"}, "dim_member_card_account_ex": {"id": "member_card_id"}, "dim_tenant_goods": {"id": "tenant_goods_id"}, "dim_tenant_goods_ex": {"id": "tenant_goods_id"}, "dim_store_goods": {"id": "site_goods_id"}, "dim_store_goods_ex": {"id": "site_goods_id"}, "dim_goods_category": {"id": "category_id"}, "dim_groupbuy_package": {"id": "groupbuy_package_id"}, "dim_groupbuy_package_ex": {"id": "groupbuy_package_id"}, # 事实表 "dwd_settlement_head": {"id": "order_settle_id"}, "dwd_settlement_head_ex": {"id": "order_settle_id"}, "dwd_table_fee_log": {"id": "table_fee_log_id"}, "dwd_table_fee_log_ex": {"id": "table_fee_log_id"}, "dwd_table_fee_adjust": {"id": "table_fee_adjust_id"}, "dwd_table_fee_adjust_ex": {"id": "table_fee_adjust_id"}, "dwd_store_goods_sale": {"id": "store_goods_sale_id"}, "dwd_store_goods_sale_ex": {"id": "store_goods_sale_id"}, "dwd_assistant_service_log": {"id": "assistant_service_id"}, "dwd_assistant_service_log_ex": {"id": "assistant_service_id"}, "dwd_assistant_trash_event": {"id": "assistant_trash_event_id"}, "dwd_assistant_trash_event_ex": {"id": "assistant_trash_event_id"}, "dwd_member_balance_change": {"id": "balance_change_id"}, "dwd_member_balance_change_ex": {"id": "balance_change_id"}, "dwd_groupbuy_redemption": {"id": "redemption_id"}, "dwd_groupbuy_redemption_ex": {"id": "redemption_id"}, "dwd_platform_coupon_redemption": {"id": "platform_coupon_redemption_id"}, "dwd_platform_coupon_redemption_ex": {"id": "platform_coupon_redemption_id"}, "dwd_recharge_order": {"id": "recharge_order_id"}, "dwd_recharge_order_ex": {"id": "recharge_order_id"}, "dwd_payment": {"id": "payment_id"}, "dwd_refund": {"id": "refund_id"}, "dwd_refund_ex": {"id": "refund_id"}, } # DWD 事实表的业务时间列映射(用于时间窗口过滤) DWD_TIME_COL_MAP = { "dwd_settlement_head": "pay_time", "dwd_settlement_head_ex": "pay_time", "dwd_table_fee_log": "start_use_time", "dwd_table_fee_log_ex": "start_use_time", "dwd_table_fee_adjust": "create_time", "dwd_table_fee_adjust_ex": "create_time", "dwd_store_goods_sale": "create_time", "dwd_store_goods_sale_ex": "create_time", "dwd_assistant_service_log": "start_use_time", "dwd_assistant_service_log_ex": "start_use_time", "dwd_assistant_trash_event": "create_time", "dwd_assistant_trash_event_ex": "create_time", "dwd_member_balance_change": "create_time", "dwd_member_balance_change_ex": "create_time", "dwd_groupbuy_redemption": "create_time", "dwd_groupbuy_redemption_ex": "create_time", "dwd_platform_coupon_redemption": "create_time", "dwd_platform_coupon_redemption_ex": "create_time", "dwd_recharge_order": "pay_time", "dwd_recharge_order_ex": "pay_time", "dwd_payment": "pay_time", "dwd_refund": "create_time", "dwd_refund_ex": "create_time", } scd2_cols = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"} try: # 尝试多种导入路径以兼容不同运行环境 from tasks.dwd.dwd_load_task import DwdLoadTask config = {} for full_dwd_table, full_ods_table in DwdLoadTask.TABLE_MAP.items(): # 提取不带 schema 前缀的表名 if "." in full_dwd_table: dwd_table = full_dwd_table.split(".")[-1] else: dwd_table = full_dwd_table if "." in full_ods_table: ods_table = full_ods_table.split(".")[-1] else: ods_table = full_ods_table is_dimension = dwd_table.startswith("dim_") # 获取 ODS 表的主键列名(用于查询 ODS) ods_pk_column = ODS_PK_MAP.get(ods_table, "id") ods_pk_columns = ODS_PK_OVERRIDE.get(dwd_table) if not ods_pk_columns: ods_pk_columns = [ods_pk_column] # 获取 DWD 表的时间列(用于时间窗口过滤) time_column = DWD_TIME_COL_MAP.get(dwd_table, "fetched_at") # 维度表使用 scd2_start_time if is_dimension: time_column = "scd2_start_time" # 若未配置主键映射,且业务主键与 ODS 主键同名,则自动推断映射 pk_columns = self._get_pk_from_db(dwd_table) business_pk_cols = [c for c in pk_columns if c.lower() not in scd2_cols] ods_to_dwd_map = ODS_TO_DWD_PK_MAP.get(dwd_table, {}) if not ods_to_dwd_map and business_pk_cols: if all(pk in ods_pk_columns for pk in business_pk_cols): ods_to_dwd_map = {pk: pk for pk in business_pk_cols} config[dwd_table] = { "full_dwd_table": full_dwd_table, "ods_table": ods_table, "full_ods_table": full_ods_table, "is_dimension": is_dimension, "pk_columns": pk_columns, # DWD 表的主键 "ods_pk_columns": ods_pk_columns, # ODS 表的主键(用于查询 ODS) "ods_to_dwd_pk_map": ods_to_dwd_map, # ODS 到 DWD 主键映射 "time_column": time_column, # DWD 时间列 "ods_time_column": "fetched_at", # ODS 时间列 } return config except (ImportError, AttributeError) as e: self.logger.warning("无法加载 DWD 表映射,使用数据库查询: %s", e) return {} def _get_pk_from_db(self, table: str) -> List[str]: """从数据库获取表的主键""" sql = """ SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = 'dwd' AND tc.table_name = %s ORDER BY kcu.ordinal_position """ try: with self.db.conn.cursor() as cur: cur.execute(sql, (table,)) result = [row[0] for row in cur.fetchall()] return result if result else ["id"] except Exception as e: self.logger.warning("获取 DWD 主键失败: %s, error=%s", table, e) try: self.db.conn.rollback() except Exception: pass return ["id"] def get_tables(self) -> List[str]: """获取需要校验的 DWD 表列表""" if self._table_config: return list(self._table_config.keys()) sql = """ SELECT table_name FROM information_schema.tables WHERE table_schema = 'dwd' AND table_type = 'BASE TABLE' ORDER BY table_name """ try: with self.db.conn.cursor() as cur: cur.execute(sql) return [row[0] for row in cur.fetchall()] except Exception as e: self.logger.warning("获取 DWD 表列表失败: %s", e) try: self.db.conn.rollback() except Exception: pass return [] def get_dimension_tables(self) -> List[str]: """获取维度表列表""" return [t for t in self.get_tables() if t.startswith("dim_")] def get_fact_tables(self) -> List[str]: """获取事实表列表""" return [t for t in self.get_tables() if t.startswith("dwd_") or t.startswith("fact_")] def get_primary_keys(self, table: str) -> List[str]: """获取表的主键列""" if table in self._table_config: pk_cols = self._table_config[table].get("pk_columns", []) if pk_cols: return pk_cols # 尝试从数据库获取,如果配置中没有或为空 return self._get_pk_from_db(table) def get_time_column(self, table: str) -> Optional[str]: """获取表的时间列""" if table in self._table_config: return self._table_config[table].get("time_column", "create_time") # 尝试从表结构中查找常见的时间列 common_time_cols = ["create_time", "pay_time", "start_time", "modify_time", "fetched_at"] try: sql = """ SELECT column_name FROM information_schema.columns WHERE table_schema = 'dwd' AND table_name = %s AND column_name = ANY(%s) """ with self.db.conn.cursor() as cur: cur.execute(sql, (table, common_time_cols)) rows = cur.fetchall() if rows: return rows[0][0] except Exception: pass return "create_time" def get_ods_table(self, dwd_table: str) -> Optional[str]: """获取 DWD 表对应的 ODS 源表""" if dwd_table in self._table_config: return self._table_config[dwd_table].get("ods_table") # 推断 ODS 表名 if dwd_table.startswith("dim_"): ods_name = dwd_table.replace("dim_", "ods_") elif dwd_table.startswith("dwd_"): ods_name = dwd_table.replace("dwd_", "ods_") else: ods_name = f"ods_{dwd_table}" return ods_name def is_dimension_table(self, table: str) -> bool: """判断是否为维度表""" if table in self._table_config: return self._table_config[table].get("is_dimension", False) return table.startswith("dim_") def get_ods_pk_columns(self, table: str) -> List[str]: """获取 ODS 表的主键列名(用于查询 ODS)""" if table in self._table_config: return self._table_config[table].get("ods_pk_columns", ["id"]) return ["id"] def get_ods_time_column(self, table: str) -> str: """获取 ODS 表的时间列名""" if table in self._table_config: return self._table_config[table].get("ods_time_column", "fetched_at") return "fetched_at" def get_ods_to_dwd_pk_map(self, table: str) -> Dict[str, str]: """获取 ODS 到 DWD 主键列名映射 返回 {ods_column: dwd_column} 映射字典 """ if table in self._table_config: mapping = self._table_config[table].get("ods_to_dwd_pk_map", {}) if mapping: return mapping # 若未显式配置映射,尝试用同名业务主键兜底 scd2_cols = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"} pk_cols = self.get_primary_keys(table) business_pk_cols = [c for c in pk_cols if c.lower() not in scd2_cols] ods_pk_cols = self.get_ods_pk_columns(table) if business_pk_cols and all(pk in ods_pk_cols for pk in business_pk_cols): return {pk: pk for pk in business_pk_cols} return {} return {} def fetch_source_keys( self, table: str, window_start: datetime, window_end: datetime, ) -> Set[Tuple]: """从 ODS 源表获取主键集合 注意:使用 fetched_at 过滤 ODS 数据。这意味着只检查最近获取的 ODS 记录 是否正确同步到 DWD 表。历史数据不在校验范围内。 """ ods_table = self.get_ods_table(table) if not ods_table: return set() # 使用 ODS 表的主键列名(不是 DWD 的) ods_pk_cols = self.get_ods_pk_columns(table) # 如果没有主键定义,跳过查询 if not ods_pk_cols: self.logger.debug("表 %s 没有 ODS 主键配置,跳过获取源主键", table) return set() # 使用 ODS 的时间列 ods_time_col = self.get_ods_time_column(table) pk_select = ", ".join(ods_pk_cols) sql = f""" SELECT DISTINCT {pk_select} FROM ods.{ods_table} WHERE {ods_time_col} >= %s AND {ods_time_col} < %s """ try: with self.db.conn.cursor() as cur: cur.execute(sql, (window_start, window_end)) return {tuple(row) for row in cur.fetchall()} except Exception as e: self.logger.warning("获取 ODS 主键失败: %s, error=%s", ods_table, e) try: self.db.conn.rollback() except Exception: pass raise VerificationFetchError(f"获取 ODS 主键失败: {ods_table}") from e def fetch_target_keys( self, table: str, window_start: datetime, window_end: datetime, ) -> Set[Tuple]: """从 DWD 表获取主键集合 注意:为了与 fetch_source_keys 返回的 ODS 主键进行比较, 这里返回的是业务主键(映射后的 DWD 列,与 ODS 主键数量相同)。 对于维度表,不包含 scd2_start_time。 """ # 获取 ODS 到 DWD 的主键映射 ods_to_dwd_map = self.get_ods_to_dwd_pk_map(table) # 确定要查询的主键列 if ods_to_dwd_map: # 使用映射的 DWD 业务主键列(与 ODS 主键数量相同) dwd_pk_cols = list(ods_to_dwd_map.values()) else: # 没有映射,使用原始主键(可能无法与 ODS 正确比较) dwd_pk_cols = self.get_primary_keys(table) if not dwd_pk_cols: self.logger.debug("表 %s 没有主键配置,跳过获取目标主键", table) return set() pk_select = ", ".join(dwd_pk_cols) # 构建查询 if self.is_dimension_table(table): # 维度表:查询当前版本 sql = f""" SELECT DISTINCT {pk_select} FROM dwd.{table} WHERE scd2_is_current = 1 """ params = () else: # 事实表:使用时间窗口过滤 time_col = self.get_time_column(table) # 检查时间列是否存在 time_col_exists = False try: check_sql = """ SELECT 1 FROM information_schema.columns WHERE table_schema = 'dwd' AND table_name = %s AND column_name = %s """ with self.db.conn.cursor() as cur: cur.execute(check_sql, (table, time_col)) if cur.fetchone(): time_col_exists = True else: # 尝试其他时间列 fallback_cols = ["create_time", "pay_time", "start_use_time"] for fc in fallback_cols: cur.execute(check_sql, (table, fc)) if cur.fetchone(): time_col = fc time_col_exists = True break except Exception: pass if time_col_exists: sql = f""" SELECT DISTINCT {pk_select} FROM dwd.{table} WHERE {time_col} >= %s AND {time_col} < %s """ params = (window_start, window_end) else: # 没有时间列,获取全部数据 sql = f""" SELECT DISTINCT {pk_select} FROM dwd.{table} """ params = () try: with self.db.conn.cursor() as cur: cur.execute(sql, params) return {tuple(row) for row in cur.fetchall()} except Exception as e: self.logger.warning("获取 DWD 主键失败: %s, error=%s", table, e) try: self.db.conn.rollback() except Exception: pass raise VerificationFetchError(f"获取 DWD 主键失败: {table}") from e def fetch_source_hashes( self, table: str, window_start: datetime, window_end: datetime, ) -> Dict[Tuple, str]: """从 ODS 源表获取主键->content_hash 映射""" ods_table = self.get_ods_table(table) if not ods_table: return {} # 使用 ODS 表的主键列名(不是 DWD 的) ods_pk_cols = self.get_ods_pk_columns(table) # 如果没有主键定义,跳过查询 if not ods_pk_cols: self.logger.debug("表 %s 没有 ODS 主键配置,跳过获取源哈希", table) return {} # 使用 ODS 的时间列 ods_time_col = self.get_ods_time_column(table) pk_select = ", ".join(ods_pk_cols) sql = f""" SELECT {pk_select}, content_hash FROM ods.{ods_table} WHERE {ods_time_col} >= %s AND {ods_time_col} < %s """ result = {} try: with self.db.conn.cursor() as cur: cur.execute(sql, (window_start, window_end)) for row in cur.fetchall(): pk = tuple(row[:-1]) content_hash = row[-1] result[pk] = content_hash or "" except Exception as e: self.logger.warning("获取 ODS hash 失败: %s, error=%s", ods_table, e) try: self.db.conn.rollback() except Exception: pass raise VerificationFetchError(f"获取 ODS hash 失败: {ods_table}") from e return result def fetch_target_hashes( self, table: str, window_start: datetime, window_end: datetime, ) -> Dict[Tuple, str]: """从 DWD 表获取主键->计算的哈希 映射""" pk_cols = self.get_primary_keys(table) # 如果没有主键定义,跳过查询 if not pk_cols: self.logger.debug("表 %s 没有主键配置,跳过获取目标哈希", table) return {} # DWD 表可能没有 content_hash,需要计算 # 获取所有非系统列 exclude_cols = { "scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version", "dwd_insert_time", "dwd_update_time" } sql = f""" SELECT column_name FROM information_schema.columns WHERE table_schema = 'dwd' AND table_name = %s ORDER BY ordinal_position """ try: with self.db.conn.cursor() as cur: cur.execute(sql, (table,)) all_cols = [row[0] for row in cur.fetchall()] except Exception as e: self.logger.warning("获取 DWD 表列信息失败: %s, error=%s", table, e) try: self.db.conn.rollback() except Exception: pass all_cols = pk_cols data_cols = [c for c in all_cols if c not in exclude_cols] col_select = ", ".join(data_cols) pk_indices = [data_cols.index(c) for c in pk_cols if c in data_cols] if self.is_dimension_table(table): sql = f""" SELECT {col_select} FROM dwd.{table} WHERE scd2_is_current = 1 """ params = () else: # 事实表使用 DWD 的业务时间列 time_col = self.get_time_column(table) # 检查时间列是否在数据列中 if time_col not in data_cols: # 时间列不存在,使用备选方案 fallback_cols = ["create_time", "pay_time", "start_use_time"] time_col = None for fc in fallback_cols: if fc in data_cols: time_col = fc break if not time_col: # 没有找到时间列,查询全部数据 sql = f""" SELECT {col_select} FROM dwd.{table} """ params = () else: sql = f""" SELECT {col_select} FROM dwd.{table} WHERE {time_col} >= %s AND {time_col} < %s """ params = (window_start, window_end) else: sql = f""" SELECT {col_select} FROM dwd.{table} WHERE {time_col} >= %s AND {time_col} < %s """ params = (window_start, window_end) result = {} try: with self.db.conn.cursor() as cur: cur.execute(sql, params) for row in cur.fetchall(): pk = tuple(row[i] for i in pk_indices) # 计算整行数据的哈希 row_dict = dict(zip(data_cols, row)) content_str = json.dumps(row_dict, sort_keys=True, default=str) content_hash = hashlib.md5(content_str.encode()).hexdigest() result[pk] = content_hash except Exception as e: self.logger.warning("获取 DWD hash 失败: %s, error=%s", table, e) try: self.db.conn.rollback() except Exception: pass raise VerificationFetchError(f"获取 DWD hash 失败: {table}") from e return result def backfill_missing( self, table: str, missing_keys: Set[Tuple], window_start: datetime, window_end: datetime, ) -> int: """批量补齐缺失数据""" if not missing_keys: return 0 ods_table = self.get_ods_table(table) if not ods_table: return 0 # 检查是否有主键映射(用于判断是否可以 backfill) ods_to_dwd_map = self.get_ods_to_dwd_pk_map(table) if not ods_to_dwd_map and self.is_dimension_table(table): # 维度表没有主键映射,可能是复杂映射(如从嵌套 JSON 提取) # 无法自动 backfill,跳过 self.logger.warning( "DWD 表 %s 没有主键映射配置,跳过 backfill(需要完整 ETL 同步)", table ) return 0 pk_cols = self.get_primary_keys(table) # DWD 主键列名 ods_pk_cols = self.get_ods_pk_columns(table) # ODS 主键列名(通常是 id) ods_time_col = self.get_ods_time_column(table) self.logger.info( "DWD 补齐缺失: 表=%s, 数量=%d", table, len(missing_keys) ) # 在执行之前确保事务状态干净 try: self.db.conn.rollback() except Exception: pass # 过滤主键列数不匹配的数据 valid_keys = [pk for pk in missing_keys if len(pk) == len(ods_pk_cols)] if not valid_keys: return 0 # 分批通过 VALUES + JOIN 回查 ODS,避免超长 OR 条件导致 SQL 解析/执行变慢 batch_size = 1000 records: List[dict] = [] key_cols_sql = ", ".join(ods_pk_cols) join_sql = " AND ".join(f"o.{col} = k.{col}" for col in ods_pk_cols) try: with self.db.conn.cursor() as cur: for i in range(0, len(valid_keys), batch_size): batch_keys = valid_keys[i:i + batch_size] row_placeholder = "(" + ", ".join(["%s"] * len(ods_pk_cols)) + ")" values_sql = ", ".join([row_placeholder] * len(batch_keys)) params = [v for pk in batch_keys for v in pk] sql = f""" WITH k ({key_cols_sql}) AS ( VALUES {values_sql} ) SELECT o.* FROM ods.{ods_table} o JOIN k ON {join_sql} WHERE o.{ods_time_col} >= %s AND o.{ods_time_col} < %s """ cur.execute(sql, params + [window_start, window_end]) columns = [desc[0] for desc in cur.description] records.extend(dict(zip(columns, row)) for row in cur.fetchall()) except Exception as e: self.logger.error("获取 ODS 记录失败: %s", e) try: self.db.conn.rollback() except Exception: pass return 0 if not records: return 0 # 执行 DWD 装载 return self._load_to_dwd(table, records, pk_cols) def backfill_mismatch( self, table: str, mismatch_keys: Set[Tuple], window_start: datetime, window_end: datetime, ) -> int: """批量更新不一致数据""" # 对于维度表,使用 SCD2 逻辑 # 对于事实表,直接 UPSERT return self.backfill_missing(table, mismatch_keys, window_start, window_end) def _get_fact_column_map(self, table: str) -> Dict[str, Tuple[str, str | None]]: """获取事实表 DWD->ODS 列映射(用于 backfill)。""" mapping_entries = DwdLoadTask.FACT_MAPPINGS.get(f"dwd.{table}") or [] result: Dict[str, Tuple[str, str | None]] = {} for dwd_col, src, cast_type in mapping_entries: if isinstance(src, str) and src.isidentifier(): result[dwd_col.lower()] = (src.lower(), cast_type) return result @staticmethod def _coerce_bool(value: Any) -> bool | None: if value is None: return None if isinstance(value, bool): return value if isinstance(value, (int, float)): return bool(value) if isinstance(value, str): lowered = value.strip().lower() if lowered in {"true", "1", "yes", "y", "t"}: return True if lowered in {"false", "0", "no", "n", "f"}: return False return bool(value) @classmethod def _adapt_fact_value(cls, value: Any, cast_type: str | None = None) -> Any: """适配事实表 UPSERT 值,处理 JSON 字段。""" if cast_type == "boolean": return cls._coerce_bool(value) if isinstance(value, (dict, list)): return Json(value, dumps=lambda v: json.dumps(v, ensure_ascii=False, default=str)) return value def _load_to_dwd(self, table: str, records: List[dict], pk_cols: List[str]) -> int: """装载记录到 DWD 表""" if not records: return 0 is_dim = self.is_dimension_table(table) if is_dim: # 获取 ODS 主键列名和 ODS 到 DWD 的映射 ods_pk_cols = self.get_ods_pk_columns(table) ods_to_dwd_map = self.get_ods_to_dwd_pk_map(table) # 过滤掉 SCD2 列,只保留业务主键 # 因为 ODS 记录中没有 scd2_start_time 等字段 scd2_cols = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"} business_pk_cols = [c for c in pk_cols if c not in scd2_cols] # DEBUG: 记录主键过滤情况 self.logger.debug( "维度表 %s: 原始 pk_cols=%s, 过滤后 business_pk_cols=%s, ods_pk_cols=%s", table, pk_cols, business_pk_cols, ods_pk_cols ) if not business_pk_cols: self.logger.warning( "维度表 %s: 过滤 SCD2 列后业务主键为空,原始 pk_cols=%s", table, pk_cols ) return 0 return self._merge_dimension(table, records, business_pk_cols, ods_pk_cols, ods_to_dwd_map) else: return self._merge_fact(table, records, pk_cols) def _merge_dimension( self, table: str, records: List[dict], pk_cols: List[str], ods_pk_cols: List[str], ods_to_dwd_map: Dict[str, str] ) -> int: """合并维度表(SCD2) Args: table: DWD 表名 records: ODS 记录列表 pk_cols: DWD 主键列名(排除 scd2_start_time) ods_pk_cols: ODS 主键列名 ods_to_dwd_map: ODS 到 DWD 列名映射 {ods_col: dwd_col} """ # 获取 DWD 表列 sql = """ SELECT column_name FROM information_schema.columns WHERE table_schema = 'dwd' AND table_name = %s ORDER BY ordinal_position """ try: with self.db.conn.cursor() as cur: cur.execute(sql, (table,)) dwd_cols = [row[0] for row in cur.fetchall()] except Exception as e: self.logger.error("获取 DWD 表列失败: %s", e) try: self.db.conn.rollback() except Exception: pass return 0 # 过滤出可映射的列 scd2_cols = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"} data_cols = [c for c in dwd_cols if c not in scd2_cols] # 构建 ODS 到 DWD 列名映射(包含主键映射和其他同名列) # 反向映射:dwd_col -> ods_col dwd_to_ods_map = {v: k for k, v in ods_to_dwd_map.items()} # 按业务主键去重,只保留最后一条记录 # 这避免了 ODS 中同一业务实体多次出现导致 SCD2 主键冲突 unique_records = {} for record in records: # 提取业务主键值 pk_values = [] skip = False for dwd_pk_col in pk_cols: ods_col = dwd_to_ods_map.get(dwd_pk_col, dwd_pk_col) value = record.get(ods_col) if value is None: value = record.get(dwd_pk_col) if value is None: skip = True break pk_values.append(value) if not skip: pk_key = tuple(pk_values) unique_records[pk_key] = record # 后面的覆盖前面的 self.logger.debug( "维度表 %s: 原始记录数=%d, 去重后=%d", table, len(records), len(unique_records) ) count = 0 for pk_key, record in unique_records.items(): # pk_key 已经是去重时提取的主键元组 pk_values = pk_key record_time = datetime.now(timezone.utc).replace(tzinfo=None) # 1. 关闭旧版本 pk_where = " AND ".join(f"{c} = %s" for c in pk_cols) update_sql = f""" UPDATE dwd.{table} SET scd2_is_current = 0, scd2_end_time = %s WHERE {pk_where} AND scd2_is_current = 1 """ try: with self.db.conn.cursor() as cur: cur.execute(update_sql, (record_time,) + pk_values) except Exception as e: self.logger.warning("关闭旧版本失败: %s", e) try: self.db.conn.rollback() except Exception: pass continue # 2. 准备插入数据(考虑列名映射) insert_cols = [] values = [] for dwd_col in data_cols: # 获取对应的 ODS 列名 ods_col = dwd_to_ods_map.get(dwd_col, dwd_col) # 优先从 ODS 列名获取值,然后尝试 DWD 列名 if ods_col in record: insert_cols.append(dwd_col) values.append(record[ods_col]) elif dwd_col in record: insert_cols.append(dwd_col) values.append(record[dwd_col]) # 添加 SCD2 列 insert_cols.extend(["scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"]) values.extend([record_time, None, 1, 1]) col_list = ", ".join(insert_cols) placeholders = ", ".join(["%s"] * len(values)) insert_sql = f""" INSERT INTO dwd.{table} ({col_list}) VALUES ({placeholders}) """ try: with self.db.conn.cursor() as cur: cur.execute(insert_sql, values) count += 1 except Exception as e: self.logger.warning("插入新版本失败: %s, error=%s", table, e) try: self.db.conn.rollback() except Exception: pass try: self.db.commit() except Exception as e: self.logger.error("提交事务失败: %s", e) try: self.db.conn.rollback() except Exception: pass return count def _merge_fact(self, table: str, records: List[dict], pk_cols: List[str]) -> int: """合并事实表(UPSERT) 注意:事实表的 backfill 有限制: - ODS 记录列名与 DWD 列名可能不同 - 当前实现只处理主键映射,其他列需要名称相同 - 如果列名完全不匹配,会跳过 backfill """ if not records: return 0 # 获取 ODS 到 DWD 主键映射 ods_to_dwd_map = self.get_ods_to_dwd_pk_map(table) dwd_to_ods_pk_map = {v.lower(): k.lower() for k, v in ods_to_dwd_map.items()} fact_col_map = self._get_fact_column_map(table) # 获取 DWD 表列 sql = """ SELECT column_name FROM information_schema.columns WHERE table_schema = 'dwd' AND table_name = %s ORDER BY ordinal_position """ try: with self.db.conn.cursor() as cur: cur.execute(sql, (table,)) dwd_cols = [row[0] for row in cur.fetchall()] except Exception as e: self.logger.error("获取 DWD 表列失败: %s", e) try: self.db.conn.rollback() except Exception: pass return 0 if not records: return 0 # 统一字段名为小写,避免大小写影响匹配 records_lower = [{k.lower(): v for k, v in record.items()} for record in records] sample_record = records_lower[0] # 找出可映射的列(考虑列名映射) mappable_cols = [] col_source_map = {} # dwd_col -> (source_key, cast_type) for dwd_col in dwd_cols: dwd_key = dwd_col.lower() ods_col = fact_col_map.get(dwd_key) if ods_col and ods_col[0] in sample_record: # 优先使用事实表映射 mappable_cols.append(dwd_col) col_source_map[dwd_col] = ods_col continue ods_col = dwd_to_ods_pk_map.get(dwd_key) if ods_col and ods_col in sample_record: # 有映射且 ODS 记录中有该列 mappable_cols.append(dwd_col) col_source_map[dwd_col] = (ods_col, None) elif dwd_key in sample_record: # ODS 记录中有同名列 mappable_cols.append(dwd_col) col_source_map[dwd_col] = (dwd_key, None) if not mappable_cols: self.logger.warning( "事实表 %s: 无可映射列,跳过 backfill。ODS 列=%s, DWD 列=%s", table, list(sample_record.keys())[:10], dwd_cols[:10] ) return 0 # 确保主键列在可映射列中 for pk_col in pk_cols: if pk_col not in mappable_cols: pk_key = pk_col.lower() ods_pk = fact_col_map.get(pk_key) or dwd_to_ods_pk_map.get(pk_key) if ods_pk: src_key = ods_pk[0] if isinstance(ods_pk, tuple) else ods_pk else: src_key = None if src_key and src_key in sample_record: mappable_cols.append(pk_col) col_source_map[pk_col] = ods_pk if isinstance(ods_pk, tuple) else (src_key, None) else: self.logger.warning( "事实表 %s: 主键列 %s 无法映射,跳过 backfill", table, pk_col ) return 0 # 按业务主键去重,避免批量 UPSERT 出现同主键重复 unique_records = {} for record in records_lower: pk_values = [] missing_pk = False for pk_col in pk_cols: src_key, _ = col_source_map[pk_col] value = record.get(src_key) if value is None: missing_pk = True break pk_values.append(value) if missing_pk: continue unique_records[tuple(pk_values)] = record if len(unique_records) != len(records_lower): self.logger.info( "事实表 %s: 去重记录 %d -> %d", table, len(records_lower), len(unique_records), ) records_lower = list(unique_records.values()) col_list = ", ".join(mappable_cols) pk_list = ", ".join(pk_cols) update_cols = [c for c in mappable_cols if c not in pk_cols] if update_cols: update_set = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols) update_where = " OR ".join( f"dwd.{table}.{c} IS DISTINCT FROM EXCLUDED.{c}" for c in update_cols ) upsert_sql = ( f"INSERT INTO dwd.{table} ({col_list}) " f"VALUES ({', '.join(['%s'] * len(mappable_cols))}) " f"ON CONFLICT ({pk_list}) DO UPDATE SET {update_set} " f"WHERE {update_where}" ) upsert_values_sql = ( f"INSERT INTO dwd.{table} ({col_list}) " f"VALUES %s " f"ON CONFLICT ({pk_list}) DO UPDATE SET {update_set} " f"WHERE {update_where}" ) else: # 只有主键列,使用 DO NOTHING upsert_sql = ( f"INSERT INTO dwd.{table} ({col_list}) " f"VALUES ({', '.join(['%s'] * len(mappable_cols))}) " f"ON CONFLICT ({pk_list}) DO NOTHING" ) upsert_values_sql = ( f"INSERT INTO dwd.{table} ({col_list}) " f"VALUES %s " f"ON CONFLICT ({pk_list}) DO NOTHING" ) all_values: List[List[Any]] = [] for record in records_lower: row_values = [] for col in mappable_cols: src_key, cast_type = col_source_map[col] row_values.append(self._adapt_fact_value(record.get(src_key), cast_type)) all_values.append(row_values) count = 0 # 可配置批量参数,降低锁等待与回退成本 batch_size = self._get_fact_upsert_batch_size() min_batch_size = self._get_fact_upsert_min_batch_size() if min_batch_size > batch_size: min_batch_size = batch_size max_retries = self._get_fact_upsert_max_retries() backoff_sec = self._get_fact_upsert_backoff() lock_timeout_ms = self._get_fact_upsert_lock_timeout_ms() def _sleep_with_backoff(attempt: int): if not backoff_sec: return idx = min(attempt, len(backoff_sec) - 1) wait_sec = backoff_sec[idx] if wait_sec > 0: time.sleep(wait_sec) def _iter_batches(items: List[List[Any]], size: int): for idx in range(0, len(items), size): yield items[idx:idx + size] def _commit_batch(): """批次级提交,缩短锁持有时间。""" try: self.db.commit() except Exception as commit_error: self.logger.error("提交事务失败: %s", commit_error) try: self.db.conn.rollback() except Exception: pass raise def _execute_batch(cur, batch_values: List[List[Any]]): cur.execute("SAVEPOINT dwd_fact_batch_sp") try: execute_values( cur, upsert_values_sql, batch_values, page_size=len(batch_values), ) cur.execute("RELEASE SAVEPOINT dwd_fact_batch_sp") affected = int(cur.rowcount or 0) if affected < 0: affected = 0 return affected, None except Exception as batch_error: cur.execute("ROLLBACK TO SAVEPOINT dwd_fact_batch_sp") cur.execute("RELEASE SAVEPOINT dwd_fact_batch_sp") return 0, batch_error def _fallback_rows(cur, batch_values: List[List[Any]]): affected_total = 0 # 批量失败时退化到逐行,尽量跳过坏数据并继续处理 for values in batch_values: cur.execute("SAVEPOINT dwd_fact_row_sp") try: cur.execute(upsert_sql, values) cur.execute("RELEASE SAVEPOINT dwd_fact_row_sp") affected = int(cur.rowcount or 0) if affected < 0: affected = 0 affected_total += affected except Exception as row_error: cur.execute("ROLLBACK TO SAVEPOINT dwd_fact_row_sp") cur.execute("RELEASE SAVEPOINT dwd_fact_row_sp") self.logger.warning( "UPSERT 失败: %s, error=%s", table, row_error, ) return affected_total def _process_batch(cur, batch_values: List[List[Any]], current_size: int) -> int: if not batch_values: return 0 if len(batch_values) > current_size: # 继续拆分为当前批次大小 total = 0 for sub_batch in _iter_batches(batch_values, current_size): total += _process_batch(cur, sub_batch, current_size) return total for attempt in range(max_retries + 1): affected, batch_error = _execute_batch(cur, batch_values) if batch_error is None: _commit_batch() return affected if self._is_lock_timeout_error(batch_error): if current_size > min_batch_size: new_size = max(min_batch_size, current_size // 2) self.logger.warning( "批量 UPSERT 锁超时,缩小批次: table=%s, %d -> %d", table, current_size, new_size, ) total = 0 for sub_batch in _iter_batches(batch_values, new_size): total += _process_batch(cur, sub_batch, new_size) return total if attempt < max_retries: self.logger.warning( "批量 UPSERT 锁超时,重试: table=%s, attempt=%d/%d", table, attempt + 1, max_retries, ) _sleep_with_backoff(attempt) continue # 非锁超时或重试耗尽:回退逐行 self.logger.warning( "批量 UPSERT 失败,回退逐行: table=%s, batch_size=%d, error=%s", table, len(batch_values), batch_error, ) affected_rows = _fallback_rows(cur, batch_values) _commit_batch() return affected_rows return 0 try: with self.db.conn.cursor() as cur: if lock_timeout_ms is not None: # 设置当前事务的锁等待上限,避免长时间阻塞 cur.execute("SET LOCAL lock_timeout = %s", (int(lock_timeout_ms),)) for batch_values in _iter_batches(all_values, batch_size): count += _process_batch(cur, batch_values, batch_size) except Exception as e: self.logger.error("事实表 backfill 失败: %s", e) try: self.db.conn.rollback() except Exception: pass return count def _get_fact_upsert_batch_size(self) -> int: """读取事实表 UPSERT 批次大小(可配置)。""" return self._get_int_config("dwd.fact_upsert_batch_size", 1000, 10, 5000) def _get_fact_upsert_min_batch_size(self) -> int: """读取事实表 UPSERT 最小批次大小(可配置)。""" return self._get_int_config("dwd.fact_upsert_min_batch_size", 100, 1, 2000) def _get_fact_upsert_max_retries(self) -> int: """读取事实表 UPSERT 最大重试次数(可配置)。""" return self._get_int_config("dwd.fact_upsert_max_retries", 2, 0, 10) def _get_fact_upsert_lock_timeout_ms(self) -> Optional[int]: """读取事实表 UPSERT 锁等待超时(毫秒,可为空)。""" if not self.config: return None value = self.config.get("dwd.fact_upsert_lock_timeout_ms") try: return int(value) if value is not None else None except Exception: return None def _get_fact_upsert_backoff(self) -> List[int]: """读取事实表 UPSERT 重试退避(秒)。""" if not self.config: return [1, 2, 4] value = self.config.get("dwd.fact_upsert_retry_backoff_sec", [1, 2, 4]) if not isinstance(value, list): return [1, 2, 4] return [int(v) for v in value if isinstance(v, (int, float)) and v >= 0] def _get_int_config(self, key: str, default: int, min_value: int, max_value: int) -> int: """读取整数配置并裁剪到合理范围。""" value = default if self.config: value = self.config.get(key, default) try: value = int(value) except Exception: value = default value = max(min_value, min(value, max_value)) return value @staticmethod def _is_lock_timeout_error(error: Exception) -> bool: """判断是否为锁超时/锁冲突错误。""" pgcode = getattr(error, "pgcode", None) if pgcode in ("55P03", "57014"): return True message = str(error).lower() return "lock timeout" in message or "锁超时" in message or "canceling statement due to lock timeout" in message