初始提交:飞球 ETL 系统全量代码

This commit is contained in:
Neo
2026-02-13 08:05:34 +08:00
commit 3c51f5485d
441 changed files with 117631 additions and 0 deletions

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""批量后置校验框架
提供各层数据的批量校验和补齐功能:
- ODS 层:主键 + content_hash 对比,批量 UPSERT
- DWD 层:维度 SCD2 / 事实主键对比,批量 UPSERT
- DWS 层:聚合对比,批量重算 UPSERT
- INDEX 层:实体覆盖对比,批量重算 UPSERT
"""
from .models import (
VerificationResult,
VerificationSummary,
VerificationStatus,
WindowSegment,
build_window_segments,
filter_verify_tables,
)
from .base_verifier import BaseVerifier
from .ods_verifier import OdsVerifier
from .dwd_verifier import DwdVerifier
from .dws_verifier import DwsVerifier
from .index_verifier import IndexVerifier
__all__ = [
# 模型
"VerificationResult",
"VerificationSummary",
"VerificationStatus",
"WindowSegment",
"build_window_segments",
"filter_verify_tables",
# 校验器
"BaseVerifier",
"OdsVerifier",
"DwdVerifier",
"DwsVerifier",
"IndexVerifier",
]
def get_verifier_for_layer(layer: str, db_connection, logger=None, **kwargs):
"""
根据层名获取对应的校验器实例
Args:
layer: 层名 ("ODS", "DWD", "DWS", "INDEX")
db_connection: 数据库连接
logger: 日志器
**kwargs: 额外参数
- api_client: API 客户端ODS 层需要)
- fetch_from_api: 是否从 API 获取源数据ODS 层需要)
- local_dump_dirs: 本地 JSON dump 目录映射ODS 层需要)
- use_local_json: 是否优先使用本地 JSONODS 层需要)
Returns:
对应的校验器实例
"""
verifier_map = {
"ODS": OdsVerifier,
"DWD": DwdVerifier,
"DWS": DwsVerifier,
"INDEX": IndexVerifier,
}
verifier_class = verifier_map.get(layer.upper())
if verifier_class is None:
raise ValueError(f"未知的数据层: {layer}")
# ODS 层支持额外参数
if layer.upper() == "ODS":
api_client = kwargs.pop("api_client", None)
fetch_from_api = kwargs.pop("fetch_from_api", False)
local_dump_dirs = kwargs.pop("local_dump_dirs", None)
use_local_json = kwargs.pop("use_local_json", False)
return verifier_class(
db_connection,
api_client=api_client,
logger=logger,
fetch_from_api=fetch_from_api,
local_dump_dirs=local_dump_dirs,
use_local_json=use_local_json,
**kwargs
)
return verifier_class(db_connection, logger=logger, **kwargs)

View File

@@ -0,0 +1,382 @@
# -*- coding: utf-8 -*-
"""批量校验基类"""
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from .models import (
VerificationResult,
VerificationSummary,
VerificationStatus,
WindowSegment,
build_window_segments,
)
class VerificationFetchError(RuntimeError):
"""校验数据获取失败(用于显式标记 ERROR"""
class BaseVerifier(ABC):
"""批量校验基类
提供统一的校验流程:
1. 切分时间窗口
2. 批量读取源数据
3. 批量读取目标数据
4. 内存对比
5. 批量补齐
"""
def __init__(
self,
db_connection: Any,
logger: Optional[logging.Logger] = None,
):
"""
初始化校验器
Args:
db_connection: 数据库连接
logger: 日志器
"""
self.db = db_connection
self.logger = logger or logging.getLogger(self.__class__.__name__)
@property
@abstractmethod
def layer_name(self) -> str:
"""数据层名称"""
pass
@abstractmethod
def get_tables(self) -> List[str]:
"""获取需要校验的表列表"""
pass
@abstractmethod
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
pass
@abstractmethod
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列(用于窗口过滤)"""
pass
@abstractmethod
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""批量获取源数据主键集合"""
pass
@abstractmethod
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""批量获取目标数据主键集合"""
pass
@abstractmethod
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""批量获取源数据主键->内容哈希映射"""
pass
@abstractmethod
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""批量获取目标数据主键->内容哈希映射"""
pass
@abstractmethod
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量补齐缺失数据,返回补齐的记录数"""
pass
@abstractmethod
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量更新不一致数据,返回更新的记录数"""
pass
def verify_table(
self,
table: str,
window_start: datetime,
window_end: datetime,
auto_backfill: bool = False,
compare_content: bool = True,
) -> VerificationResult:
"""
校验单表
Args:
table: 表名
window_start: 窗口开始
window_end: 窗口结束
auto_backfill: 是否自动补齐
compare_content: 是否对比内容True=对比hashFalse=仅对比主键)
Returns:
校验结果
"""
start_time = time.time()
result = VerificationResult(
layer=self.layer_name,
table=table,
window_start=window_start,
window_end=window_end,
)
try:
# 确保连接可用避免“connection already closed”导致误判 OK
self._ensure_connection()
self.logger.info(
"%s 校验开始: %s [%s ~ %s]",
self.layer_name, table,
window_start.strftime("%Y-%m-%d %H:%M"),
window_end.strftime("%Y-%m-%d %H:%M")
)
if compare_content:
# 对比内容哈希
source_hashes = self.fetch_source_hashes(table, window_start, window_end)
target_hashes = self.fetch_target_hashes(table, window_start, window_end)
result.source_count = len(source_hashes)
result.target_count = len(target_hashes)
source_keys = set(source_hashes.keys())
target_keys = set(target_hashes.keys())
# 计算缺失
missing_keys = source_keys - target_keys
result.missing_count = len(missing_keys)
# 计算不一致两边都有但hash不同
common_keys = source_keys & target_keys
mismatch_keys = {
k for k in common_keys
if source_hashes[k] != target_hashes[k]
}
result.mismatch_count = len(mismatch_keys)
else:
# 仅对比主键
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
result.source_count = len(source_keys)
result.target_count = len(target_keys)
missing_keys = source_keys - target_keys
result.missing_count = len(missing_keys)
mismatch_keys = set()
# 判断状态
if result.missing_count > 0:
result.status = VerificationStatus.MISSING
elif result.mismatch_count > 0:
result.status = VerificationStatus.MISMATCH
else:
result.status = VerificationStatus.OK
# 自动补齐
if auto_backfill and (missing_keys or mismatch_keys):
backfill_missing_count = 0
backfill_mismatch_count = 0
if missing_keys:
self.logger.info(
"%s 补齐缺失: %s, 数量=%d",
self.layer_name, table, len(missing_keys)
)
backfill_missing_count += self.backfill_missing(
table, missing_keys, window_start, window_end
)
if mismatch_keys:
self.logger.info(
"%s 更新不一致: %s, 数量=%d",
self.layer_name, table, len(mismatch_keys)
)
backfill_mismatch_count += self.backfill_mismatch(
table, mismatch_keys, window_start, window_end
)
result.backfilled_missing_count = backfill_missing_count
result.backfilled_mismatch_count = backfill_mismatch_count
result.backfilled_count = backfill_missing_count + backfill_mismatch_count
if result.backfilled_count > 0:
result.status = VerificationStatus.BACKFILLED
self.logger.info(
"%s 校验完成: %s, 源=%d, 目标=%d, 缺失=%d, 不一致=%d, 补齐=%d(缺失=%d, 不一致=%d)",
self.layer_name, table,
result.source_count, result.target_count,
result.missing_count, result.mismatch_count, result.backfilled_count,
result.backfilled_missing_count, result.backfilled_mismatch_count
)
except Exception as e:
result.status = VerificationStatus.ERROR
result.error_message = str(e)
if isinstance(e, VerificationFetchError):
# 连接不可用等致命错误,标记后续应中止
result.details["fatal"] = True
self.logger.exception("%s 校验失败: %s, error=%s", self.layer_name, table, e)
# 回滚事务,避免 PostgreSQL "当前事务被终止" 错误影响后续查询
try:
self.db.conn.rollback()
except Exception:
pass # 忽略回滚错误
result.elapsed_seconds = time.time() - start_time
return result
def verify_and_backfill(
self,
window_start: datetime,
window_end: datetime,
split_unit: str = "month",
tables: Optional[List[str]] = None,
auto_backfill: bool = True,
compare_content: bool = True,
) -> VerificationSummary:
"""
按时间窗口切分执行批量校验
Args:
window_start: 开始时间
window_end: 结束时间
split_unit: 切分单位 ("none", "day", "week", "month")
tables: 指定校验的表None 表示全部
auto_backfill: 是否自动补齐
compare_content: 是否对比内容
Returns:
校验汇总结果
"""
summary = VerificationSummary(
layer=self.layer_name,
window_start=window_start,
window_end=window_end,
)
# 获取要校验的表
all_tables = tables or self.get_tables()
# 切分时间窗口
segments = build_window_segments(window_start, window_end, split_unit)
self.logger.info(
"%s 批量校验开始: 表数=%d, 窗口切分=%d",
self.layer_name, len(all_tables), len(segments)
)
fatal_error = False
for segment in segments:
# 每段开始前检查连接状态,异常时立即终止,避免大量空跑
self._ensure_connection()
self.logger.info(
"%s 处理窗口 [%d/%d]: %s",
self.layer_name, segment.index + 1, segment.total, segment.label
)
for table in all_tables:
result = self.verify_table(
table=table,
window_start=segment.start,
window_end=segment.end,
auto_backfill=auto_backfill,
compare_content=compare_content,
)
summary.add_result(result)
if result.details.get("fatal"):
fatal_error = True
break
# 每段完成后提交
try:
self.db.commit()
except Exception as e:
self.logger.warning("提交失败: %s", e)
if fatal_error:
self.logger.warning("%s 校验中止:连接不可用或发生致命错误", self.layer_name)
break
self.logger.info(summary.format_summary())
return summary
def _ensure_connection(self):
"""确保数据库连接可用,必要时尝试重连。"""
if not hasattr(self.db, "conn"):
raise VerificationFetchError("校验器未绑定有效数据库连接")
if getattr(self.db.conn, "closed", 0):
# 优先使用连接对象的重连能力
if hasattr(self.db, "ensure_open"):
if not self.db.ensure_open():
raise VerificationFetchError("数据库连接已关闭,无法继续校验")
else:
raise VerificationFetchError("数据库连接已关闭,无法继续校验")
def quick_check(
self,
window_start: datetime,
window_end: datetime,
tables: Optional[List[str]] = None,
) -> Dict[str, dict]:
"""
快速检查(仅对比数量,不对比内容)
Args:
window_start: 开始时间
window_end: 结束时间
tables: 指定表None 表示全部
Returns:
{表名: {source_count, target_count, diff}}
"""
all_tables = tables or self.get_tables()
results = {}
for table in all_tables:
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
results[table] = {
"source_count": len(source_keys),
"target_count": len(target_keys),
"diff": len(source_keys) - len(target_keys),
}
return results

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,455 @@
# -*- coding: utf-8 -*-
"""DWS 汇总层批量校验器
校验逻辑:对比 DWD 聚合数据与 DWS 表数据
- 按日期/门店聚合对比
- 对比数值一致性
- 批量重算 UPSERT 补齐
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_verifier import BaseVerifier, VerificationFetchError
class DwsVerifier(BaseVerifier):
"""DWS 汇总层校验器"""
def __init__(
self,
db_connection: Any,
logger: Optional[logging.Logger] = None,
):
"""
初始化 DWS 校验器
Args:
db_connection: 数据库连接
logger: 日志器
"""
super().__init__(db_connection, logger)
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "DWS"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 DWS 汇总表配置"""
# DWS 汇总表通常有以下结构:
# - 主键site_id, stat_date 或类似组合
# - 数值列:各种统计值
# - 源表:对应的 DWD 事实表
return {
# 财务日度汇总表 - 包含结算、台费、商品、助教等汇总数据
# 注意:实际 DWS 表使用 gross_amount, table_fee_amount, goods_amount 等列
"dws_finance_daily_summary": {
"pk_columns": ["site_id", "stat_date"],
"time_column": "stat_date",
"source_table": "billiards_dwd.dwd_settlement_head",
"source_time_column": "pay_time",
"agg_sql": """
SELECT
site_id,
tenant_id,
DATE(pay_time) as stat_date,
COALESCE(SUM(pay_amount), 0) as cash_pay_amount,
COALESCE(SUM(table_charge_money), 0) as table_fee_amount,
COALESCE(SUM(goods_money), 0) as goods_amount,
COALESCE(SUM(table_charge_money) + SUM(goods_money) + COALESCE(SUM(assistant_pd_money), 0) + COALESCE(SUM(assistant_cx_money), 0), 0) as gross_amount
FROM billiards_dwd.dwd_settlement_head
WHERE pay_time >= %s AND pay_time < %s
GROUP BY site_id, tenant_id, DATE(pay_time)
""",
"compare_columns": ["cash_pay_amount", "table_fee_amount", "goods_amount", "gross_amount"],
},
# 助教日度明细表 - 按助教+日期汇总服务次数、时长、金额
# 注意DWD 表中使用 site_assistant_idDWS 表中使用 assistant_id
"dws_assistant_daily_detail": {
"pk_columns": ["site_id", "assistant_id", "stat_date"],
"time_column": "stat_date",
"source_table": "billiards_dwd.dwd_assistant_service_log",
"source_time_column": "start_use_time",
"agg_sql": """
SELECT
site_id,
tenant_id,
site_assistant_id as assistant_id,
DATE(start_use_time) as stat_date,
COUNT(*) as total_service_count,
COALESCE(SUM(income_seconds), 0) as total_seconds,
COALESCE(SUM(ledger_amount), 0) as total_ledger_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE start_use_time >= %s AND start_use_time < %s
AND is_delete = 0
GROUP BY site_id, tenant_id, site_assistant_id, DATE(start_use_time)
""",
"compare_columns": ["total_service_count", "total_seconds", "total_ledger_amount"],
},
# 会员来店明细表 - 按会员+订单记录每次来店消费
# 注意DWD 表主键是 order_settle_id不是 id
"dws_member_visit_detail": {
"pk_columns": ["site_id", "member_id", "order_settle_id"],
"time_column": "visit_date",
"source_table": "billiards_dwd.dwd_settlement_head",
"source_time_column": "pay_time",
"agg_sql": """
SELECT
site_id,
tenant_id,
member_id,
order_settle_id,
DATE(pay_time) as visit_date,
COALESCE(table_charge_money, 0) as table_fee,
COALESCE(goods_money, 0) as goods_amount,
COALESCE(pay_amount, 0) as actual_pay
FROM billiards_dwd.dwd_settlement_head
WHERE pay_time >= %s AND pay_time < %s
AND member_id > 0
""",
"compare_columns": ["table_fee", "goods_amount", "actual_pay"],
},
}
def get_tables(self) -> List[str]:
"""获取需要校验的 DWS 汇总表列表"""
if self._table_config:
return list(self._table_config.keys())
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_dws'
AND table_type = 'BASE TABLE'
AND table_name LIKE 'dws_%'
AND table_name NOT LIKE 'cfg_%'
ORDER BY table_name
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
except Exception as e:
self.logger.warning("获取 DWS 表列表失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return []
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
if table in self._table_config:
return self._table_config[table].get("pk_columns", ["site_id", "stat_date"])
return ["site_id", "stat_date"]
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列"""
if table in self._table_config:
return self._table_config[table].get("time_column", "stat_date")
return "stat_date"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 DWD 聚合获取源数据主键集合"""
config = self._table_config.get(table, {})
agg_sql = config.get("agg_sql")
if not agg_sql:
return set()
pk_cols = self.get_primary_keys(table)
try:
with self.db.conn.cursor() as cur:
cur.execute(agg_sql, (window_start, window_end))
columns = [desc[0] for desc in cur.description]
pk_indices = [columns.index(c) for c in pk_cols if c in columns]
return {tuple(row[i] for i in pk_indices) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 DWD 聚合主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWD 聚合主键失败: {table}") from e
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 DWS 表获取目标数据主键集合"""
pk_cols = self.get_primary_keys(table)
time_col = self.get_time_column(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}
FROM billiards_dws.{table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start.date(), window_end.date()))
return {tuple(row) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 DWS 主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWS 主键失败: {table}") from e
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 DWD 聚合获取数据,返回主键->聚合值字符串"""
config = self._table_config.get(table, {})
agg_sql = config.get("agg_sql")
compare_cols = config.get("compare_columns", [])
if not agg_sql:
return {}
pk_cols = self.get_primary_keys(table)
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(agg_sql, (window_start, window_end))
columns = [desc[0] for desc in cur.description]
pk_indices = [columns.index(c) for c in pk_cols if c in columns]
value_indices = [columns.index(c) for c in compare_cols if c in columns]
for row in cur.fetchall():
pk = tuple(row[i] for i in pk_indices)
values = tuple(row[i] for i in value_indices)
result[pk] = str(values)
except Exception as e:
self.logger.warning("获取 DWD 聚合数据失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWD 聚合数据失败: {table}") from e
return result
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 DWS 表获取数据,返回主键->值字符串"""
config = self._table_config.get(table, {})
compare_cols = config.get("compare_columns", [])
pk_cols = self.get_primary_keys(table)
time_col = self.get_time_column(table)
all_cols = pk_cols + compare_cols
col_select = ", ".join(all_cols)
sql = f"""
SELECT {col_select}
FROM billiards_dws.{table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start.date(), window_end.date()))
for row in cur.fetchall():
pk = tuple(row[:len(pk_cols)])
values = tuple(row[len(pk_cols):])
result[pk] = str(values)
except Exception as e:
self.logger.warning("获取 DWS 数据失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 DWS 数据失败: {table}") from e
return result
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量补齐缺失数据(重新计算并插入)"""
if not missing_keys:
return 0
self.logger.info(
"DWS 补齐缺失: 表=%s, 数量=%d",
table, len(missing_keys)
)
# 在执行之前确保事务状态干净
try:
self.db.conn.rollback()
except Exception:
pass
# 重新计算汇总数据
return self._recalculate_and_upsert(table, window_start, window_end, missing_keys)
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""批量更新不一致数据(重新计算并更新)"""
if not mismatch_keys:
return 0
self.logger.info(
"DWS 更新不一致: 表=%s, 数量=%d",
table, len(mismatch_keys)
)
# 在执行之前确保事务状态干净
try:
self.db.conn.rollback()
except Exception:
pass
# 重新计算汇总数据
return self._recalculate_and_upsert(table, window_start, window_end, mismatch_keys)
def _recalculate_and_upsert(
self,
table: str,
window_start: datetime,
window_end: datetime,
target_keys: Optional[Set[Tuple]] = None,
) -> int:
"""重新计算汇总数据并 UPSERT"""
config = self._table_config.get(table, {})
agg_sql = config.get("agg_sql")
if not agg_sql:
return 0
pk_cols = self.get_primary_keys(table)
# 执行聚合查询
try:
with self.db.conn.cursor() as cur:
cur.execute(agg_sql, (window_start, window_end))
columns = [desc[0] for desc in cur.description]
records = [dict(zip(columns, row)) for row in cur.fetchall()]
except Exception as e:
self.logger.error("聚合查询失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return 0
if not records:
return 0
# 如果指定了目标主键,只处理这些记录
if target_keys:
records = [
r for r in records
if tuple(r.get(c) for c in pk_cols) in target_keys
]
if not records:
return 0
# 构建 UPSERT SQL
col_list = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
pk_list = ", ".join(pk_cols)
update_cols = [c for c in columns if c not in pk_cols]
update_set = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols)
upsert_sql = f"""
INSERT INTO billiards_dws.{table} ({col_list})
VALUES ({placeholders})
ON CONFLICT ({pk_list}) DO UPDATE SET {update_set}
"""
count = 0
with self.db.conn.cursor() as cur:
for record in records:
values = [record.get(c) for c in columns]
try:
cur.execute(upsert_sql, values)
count += 1
except Exception as e:
self.logger.warning("UPSERT 失败: %s", e)
self.db.commit()
return count
def verify_aggregation(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[str, Any]:
"""
详细校验聚合数据
返回源和目标的详细对比
"""
config = self._table_config.get(table, {})
compare_cols = config.get("compare_columns", [])
source_hashes = self.fetch_source_hashes(table, window_start, window_end)
target_hashes = self.fetch_target_hashes(table, window_start, window_end)
source_keys = set(source_hashes.keys())
target_keys = set(target_hashes.keys())
missing = source_keys - target_keys
extra = target_keys - source_keys
# 对比数值
mismatch_details = []
for key in source_keys & target_keys:
if source_hashes[key] != target_hashes[key]:
mismatch_details.append({
"key": key,
"source": source_hashes[key],
"target": target_hashes[key],
})
return {
"table": table,
"window": f"{window_start.date()} ~ {window_end.date()}",
"source_count": len(source_hashes),
"target_count": len(target_hashes),
"missing_count": len(missing),
"extra_count": len(extra),
"mismatch_count": len(mismatch_details),
"is_consistent": len(missing) == 0 and len(mismatch_details) == 0,
"missing_keys": list(missing)[:10], # 只返回前10个
"mismatch_details": mismatch_details[:10],
}

View File

@@ -0,0 +1,348 @@
# -*- coding: utf-8 -*-
"""INDEX 层批量校验器。"""
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set, Tuple
from .base_verifier import BaseVerifier, VerificationFetchError
class IndexVerifier(BaseVerifier):
"""INDEX 层校验器(覆盖率校验 + 重算补齐)。"""
def __init__(
self,
db_connection: Any,
logger: Optional[logging.Logger] = None,
lookback_days: int = 60,
config: Any = None,
):
super().__init__(db_connection, logger)
self.lookback_days = lookback_days
self.config = config
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "INDEX"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 INDEX 表配置。"""
return {
"v_member_recall_priority": {
"pk_columns": ["site_id", "member_id"],
"time_column": "calc_time",
"entity_sql": """
WITH params AS (
SELECT %s::timestamp AS start_time, %s::timestamp AS end_time
),
visit_members AS (
SELECT DISTINCT s.site_id, s.member_id
FROM billiards_dwd.dwd_settlement_head s
CROSS JOIN params p
WHERE s.pay_time >= p.start_time
AND s.pay_time < p.end_time
AND s.member_id > 0
AND (
s.settle_type = 1
OR (
s.settle_type = 3
AND EXISTS (
SELECT 1
FROM billiards_dwd.dwd_assistant_service_log asl
JOIN billiards_dws.cfg_skill_type st
ON asl.skill_id = st.skill_id
AND st.course_type_code = 'BONUS'
AND st.is_active = TRUE
WHERE asl.order_settle_id = s.order_settle_id
AND asl.site_id = s.site_id
AND asl.tenant_member_id = s.member_id
AND asl.is_delete = 0
)
)
)
),
recharge_members AS (
SELECT DISTINCT r.site_id, r.member_id
FROM billiards_dwd.dwd_recharge_order r
CROSS JOIN params p
WHERE r.pay_time >= p.start_time
AND r.pay_time < p.end_time
AND r.member_id > 0
AND r.settle_type = 5
)
SELECT site_id, member_id FROM visit_members
UNION
SELECT site_id, member_id FROM recharge_members
""",
# 该视图由 WBI + NCI 共同产出,缺失时需同时触发两类重算
"task_codes": ["DWS_WINBACK_INDEX", "DWS_NEWCONV_INDEX"],
"description": "客户召回/转化优先级视图",
},
"dws_member_assistant_relation_index": {
"pk_columns": ["site_id", "member_id", "assistant_id"],
"time_column": "calc_time",
"entity_sql": """
WITH params AS (
SELECT %s::timestamp AS start_time, %s::timestamp AS end_time
),
service_pairs AS (
SELECT DISTINCT
s.site_id,
s.tenant_member_id AS member_id,
d.assistant_id
FROM billiards_dwd.dwd_assistant_service_log s
JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id
AND d.scd2_is_current = 1
AND COALESCE(d.is_delete, 0) = 0
CROSS JOIN params p
WHERE s.last_use_time >= p.start_time
AND s.last_use_time < p.end_time
AND s.tenant_member_id > 0
AND s.user_id > 0
AND s.is_delete = 0
),
manual_pairs AS (
SELECT DISTINCT
m.site_id,
m.member_id,
m.assistant_id
FROM billiards_dws.dws_ml_manual_order_alloc m
CROSS JOIN params p
WHERE m.pay_time >= p.start_time
AND m.pay_time < p.end_time
AND m.member_id > 0
AND m.assistant_id > 0
)
SELECT site_id, member_id, assistant_id FROM service_pairs
UNION
SELECT site_id, member_id, assistant_id FROM manual_pairs
""",
"task_code": "DWS_RELATION_INDEX",
"description": "客户-助教关系指数",
},
}
def get_tables(self) -> List[str]:
return list(self._table_config.keys())
def get_primary_keys(self, table: str) -> List[str]:
if table in self._table_config:
return self._table_config[table].get("pk_columns", [])
self.logger.warning("%s 未在 INDEX 校验配置中定义,跳过", table)
return []
def get_time_column(self, table: str) -> Optional[str]:
if table in self._table_config:
return self._table_config[table].get("time_column", "calc_time")
return "calc_time"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
config = self._table_config.get(table, {})
entity_sql = config.get("entity_sql")
if not entity_sql:
return set()
actual_start = window_end - timedelta(days=self.lookback_days)
try:
with self.db.conn.cursor() as cur:
cur.execute(entity_sql, (actual_start, window_end))
return {tuple(row) for row in cur.fetchall()}
except Exception as exc:
self.logger.warning("获取源实体失败: table=%s error=%s", table, exc)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取源实体失败: {table}") from exc
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过目标读取", table)
return set()
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT DISTINCT {pk_select}
FROM billiards_dws.{table}
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return {tuple(row) for row in cur.fetchall()}
except Exception as exc:
self.logger.warning("获取目标实体失败: table=%s error=%s", table, exc)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取目标实体失败: {table}") from exc
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
keys = self.fetch_source_keys(table, window_start, window_end)
return {k: "1" for k in keys}
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
keys = self.fetch_target_keys(table, window_start, window_end)
return {k: "1" for k in keys}
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
if not missing_keys:
return 0
config = self._table_config.get(table, {})
task_codes = config.get("task_codes")
if not task_codes:
task_code = config.get("task_code")
task_codes = [task_code] if task_code else []
if not task_codes:
self.logger.warning("未找到补齐任务配置: table=%s", table)
return 0
self.logger.info(
"INDEX 补齐: table=%s missing=%d task_codes=%s",
table,
len(missing_keys),
",".join(task_codes),
)
try:
self.db.conn.rollback()
except Exception:
pass
try:
task_config = self.config
if task_config is None:
from config.settings import AppConfig
task_config = AppConfig.load()
inserted_total = 0
for task_code in task_codes:
if task_code == "DWS_RECALL_INDEX":
from tasks.dws.index.recall_index_task import RecallIndexTask
task = RecallIndexTask(task_config, self.db, None, self.logger)
elif task_code == "DWS_WINBACK_INDEX":
from tasks.dws.index.winback_index_task import WinbackIndexTask
task = WinbackIndexTask(task_config, self.db, None, self.logger)
elif task_code == "DWS_NEWCONV_INDEX":
from tasks.dws.index.newconv_index_task import NewconvIndexTask
task = NewconvIndexTask(task_config, self.db, None, self.logger)
elif task_code == "DWS_INTIMACY_INDEX":
from tasks.dws.index.intimacy_index_task import IntimacyIndexTask
task = IntimacyIndexTask(task_config, self.db, None, self.logger)
elif task_code == "DWS_RELATION_INDEX":
from tasks.dws.index.relation_index_task import RelationIndexTask
task = RelationIndexTask(task_config, self.db, None, self.logger)
else:
self.logger.warning("未知 INDEX 任务代码,跳过: %s", task_code)
continue
self.logger.info("执行 INDEX 补齐任务: %s", task_code)
result = task.execute(None)
inserted_total += result.get("records_inserted", 0) + result.get("records_updated", 0)
return inserted_total
except Exception as exc:
self.logger.error("INDEX 补齐失败: %s", exc)
try:
self.db.conn.rollback()
except Exception:
pass
return 0
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
return 0
def verify_coverage(
self,
table: str,
window_end: Optional[datetime] = None,
) -> Dict[str, Any]:
if window_end is None:
window_end = datetime.now()
window_start = window_end - timedelta(days=self.lookback_days)
config = self._table_config.get(table, {})
description = config.get("description", table)
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
missing = source_keys - target_keys
extra = target_keys - source_keys
coverage_rate = len(target_keys & source_keys) / len(source_keys) * 100 if source_keys else 100.0
return {
"table": table,
"description": description,
"lookback_days": self.lookback_days,
"window": f"{window_start.date()} ~ {window_end.date()}",
"source_entities": len(source_keys),
"indexed_entities": len(target_keys),
"missing_count": len(missing),
"extra_count": len(extra),
"coverage_rate": round(coverage_rate, 2),
"is_complete": len(missing) == 0,
"missing_sample": list(missing)[:10],
}
def verify_all_indices(
self,
window_end: Optional[datetime] = None,
) -> Dict[str, dict]:
results = {}
for table in self.get_tables():
results[table] = self.verify_coverage(table, window_end)
return results
def get_missing_entities(
self,
table: str,
limit: int = 100,
window_end: Optional[datetime] = None,
) -> List[Tuple]:
if window_end is None:
window_end = datetime.now()
window_start = window_end - timedelta(days=self.lookback_days)
source_keys = self.fetch_source_keys(table, window_start, window_end)
target_keys = self.fetch_target_keys(table, window_start, window_end)
missing = source_keys - target_keys
return list(missing)[:limit]

View File

@@ -0,0 +1,283 @@
# -*- coding: utf-8 -*-
"""校验结果数据模型"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import List, Optional, Dict, Any
class VerificationStatus(Enum):
"""校验状态"""
OK = "OK" # 数据一致
MISSING = "MISSING" # 有缺失数据
MISMATCH = "MISMATCH" # 有不一致数据
BACKFILLED = "BACKFILLED" # 已补齐
ERROR = "ERROR" # 校验出错
@dataclass
class VerificationResult:
"""单表校验结果"""
layer: str # 数据层: "ODS" / "DWD" / "DWS" / "INDEX"
table: str # 表名
window_start: datetime # 校验窗口开始
window_end: datetime # 校验窗口结束
source_count: int = 0 # 源数据量
target_count: int = 0 # 目标数据量
missing_count: int = 0 # 缺失记录数
mismatch_count: int = 0 # 不一致记录数
backfilled_count: int = 0 # 已补齐记录数(缺失 + 不一致)
backfilled_missing_count: int = 0 # 缺失补齐数
backfilled_mismatch_count: int = 0 # 不一致补齐数
status: VerificationStatus = VerificationStatus.OK
elapsed_seconds: float = 0.0 # 耗时(秒)
error_message: Optional[str] = None # 错误信息
details: Dict[str, Any] = field(default_factory=dict) # 额外详情
@property
def is_consistent(self) -> bool:
"""数据是否一致"""
return self.status == VerificationStatus.OK
@property
def needs_backfill(self) -> bool:
"""是否需要补齐"""
return self.missing_count > 0 or self.mismatch_count > 0
def to_dict(self) -> dict:
"""转换为字典"""
return {
"layer": self.layer,
"table": self.table,
"window_start": self.window_start.isoformat() if self.window_start else None,
"window_end": self.window_end.isoformat() if self.window_end else None,
"source_count": self.source_count,
"target_count": self.target_count,
"missing_count": self.missing_count,
"mismatch_count": self.mismatch_count,
"backfilled_count": self.backfilled_count,
"backfilled_missing_count": self.backfilled_missing_count,
"backfilled_mismatch_count": self.backfilled_mismatch_count,
"status": self.status.value,
"elapsed_seconds": self.elapsed_seconds,
"error_message": self.error_message,
"details": self.details,
}
def format_summary(self) -> str:
"""格式化摘要"""
lines = [
f"表: {self.table}",
f"层: {self.layer}",
f"窗口: {self.window_start.strftime('%Y-%m-%d %H:%M')} ~ {self.window_end.strftime('%Y-%m-%d %H:%M')}",
f"源数据量: {self.source_count:,}",
f"目标数据量: {self.target_count:,}",
f"缺失: {self.missing_count:,}",
f"不一致: {self.mismatch_count:,}",
f"缺失补齐: {self.backfilled_missing_count:,}",
f"不一致补齐: {self.backfilled_mismatch_count:,}",
f"已补齐: {self.backfilled_count:,}",
f"状态: {self.status.value}",
f"耗时: {self.elapsed_seconds:.2f}s",
]
if self.error_message:
lines.append(f"错误: {self.error_message}")
return "\n".join(lines)
@dataclass
class VerificationSummary:
"""校验汇总结果"""
layer: str # 数据层
window_start: datetime # 校验窗口开始
window_end: datetime # 校验窗口结束
total_tables: int = 0 # 总表数
consistent_tables: int = 0 # 一致的表数
inconsistent_tables: int = 0 # 不一致的表数
total_source_count: int = 0 # 总源数据量
total_target_count: int = 0 # 总目标数据量
total_missing: int = 0 # 总缺失数
total_mismatch: int = 0 # 总不一致数
total_backfilled: int = 0 # 总补齐数
total_backfilled_missing: int = 0 # 总缺失补齐数
total_backfilled_mismatch: int = 0 # 总不一致补齐数
error_tables: int = 0 # 发生错误的表数
elapsed_seconds: float = 0.0 # 总耗时
results: List[VerificationResult] = field(default_factory=list) # 各表结果
status: VerificationStatus = VerificationStatus.OK
def add_result(self, result: VerificationResult):
"""添加单表结果"""
self.results.append(result)
self.total_tables += 1
self.total_source_count += result.source_count
self.total_target_count += result.target_count
self.total_missing += result.missing_count
self.total_mismatch += result.mismatch_count
self.total_backfilled += result.backfilled_count
self.total_backfilled_missing += result.backfilled_missing_count
self.total_backfilled_mismatch += result.backfilled_mismatch_count
self.elapsed_seconds += result.elapsed_seconds
if result.status == VerificationStatus.ERROR:
self.error_tables += 1
self.inconsistent_tables += 1
# 错误优先级最高,直接覆盖汇总状态
self.status = VerificationStatus.ERROR
elif result.is_consistent:
self.consistent_tables += 1
else:
self.inconsistent_tables += 1
if self.status == VerificationStatus.OK:
self.status = result.status
@property
def is_all_consistent(self) -> bool:
"""是否全部一致"""
return self.inconsistent_tables == 0
def to_dict(self) -> dict:
"""转换为字典"""
return {
"layer": self.layer,
"window_start": self.window_start.isoformat() if self.window_start else None,
"window_end": self.window_end.isoformat() if self.window_end else None,
"total_tables": self.total_tables,
"consistent_tables": self.consistent_tables,
"inconsistent_tables": self.inconsistent_tables,
"total_source_count": self.total_source_count,
"total_target_count": self.total_target_count,
"total_missing": self.total_missing,
"total_mismatch": self.total_mismatch,
"total_backfilled": self.total_backfilled,
"total_backfilled_missing": self.total_backfilled_missing,
"total_backfilled_mismatch": self.total_backfilled_mismatch,
"error_tables": self.error_tables,
"elapsed_seconds": self.elapsed_seconds,
"status": self.status.value,
"results": [r.to_dict() for r in self.results],
}
def format_summary(self) -> str:
"""格式化汇总摘要"""
lines = [
f"{'=' * 60}",
f"校验汇总 - {self.layer}",
f"{'=' * 60}",
f"窗口: {self.window_start.strftime('%Y-%m-%d %H:%M')} ~ {self.window_end.strftime('%Y-%m-%d %H:%M')}",
f"表数: {self.total_tables} (一致: {self.consistent_tables}, 不一致: {self.inconsistent_tables})",
f"源数据量: {self.total_source_count:,}",
f"目标数据量: {self.total_target_count:,}",
f"总缺失: {self.total_missing:,}",
f"总不一致: {self.total_mismatch:,}",
f"总补齐: {self.total_backfilled:,} (缺失: {self.total_backfilled_missing:,}, 不一致: {self.total_backfilled_mismatch:,})",
f"错误表数: {self.error_tables}",
f"总耗时: {self.elapsed_seconds:.2f}s",
f"状态: {self.status.value}",
f"{'=' * 60}",
]
return "\n".join(lines)
@dataclass
class WindowSegment:
"""时间窗口片段"""
start: datetime
end: datetime
index: int = 0
total: int = 1
@property
def label(self) -> str:
"""片段标签"""
return f"{self.start.strftime('%Y-%m-%d')} ~ {self.end.strftime('%Y-%m-%d')}"
def build_window_segments(
window_start: datetime,
window_end: datetime,
split_unit: str = "month",
) -> List[WindowSegment]:
"""
按指定单位切分时间窗口
Args:
window_start: 开始时间
window_end: 结束时间
split_unit: 切分单位 ("none", "day", "week", "month")
Returns:
时间窗口片段列表
"""
if split_unit == "none" or not split_unit:
return [WindowSegment(start=window_start, end=window_end, index=0, total=1)]
segments = []
current = window_start
while current < window_end:
if split_unit == "day":
# 按天切分
next_boundary = current.replace(hour=0, minute=0, second=0, microsecond=0)
next_boundary = next_boundary + timedelta(days=1)
elif split_unit == "week":
# 按周切分(周一为起点)
days_until_monday = (7 - current.weekday()) % 7
if days_until_monday == 0:
days_until_monday = 7
next_boundary = current.replace(hour=0, minute=0, second=0, microsecond=0)
next_boundary = next_boundary + timedelta(days=days_until_monday)
elif split_unit == "month":
# 按月切分
if current.month == 12:
next_boundary = current.replace(year=current.year + 1, month=1, day=1,
hour=0, minute=0, second=0, microsecond=0)
else:
next_boundary = current.replace(month=current.month + 1, day=1,
hour=0, minute=0, second=0, microsecond=0)
else:
# 默认不切分
next_boundary = window_end
segment_end = min(next_boundary, window_end)
segments.append(WindowSegment(start=current, end=segment_end))
current = segment_end
# 更新索引
total = len(segments)
for i, seg in enumerate(segments):
seg.index = i
seg.total = total
return segments
def filter_verify_tables(layer: str, tables: list[str] | None) -> list[str] | None:
"""按层过滤校验表名,避免非目标层全量校验。
Args:
layer: 数据层名称("ODS" / "DWD" / "DWS" / "INDEX"
tables: 待过滤的表名列表,为 None 或空时直接返回 None
Returns:
过滤后的表名列表,或 None
"""
if not tables:
return None
layer_upper = layer.upper()
normalized = [t.strip().lower() for t in tables if t and t.strip()]
if layer_upper == "DWD":
return [t for t in normalized if t.startswith(("dwd_", "dim_", "fact_"))]
if layer_upper == "DWS":
return [t for t in normalized if t.startswith("dws_")]
if layer_upper == "INDEX":
return [t for t in normalized if t.startswith("v_") or t.endswith("_index")]
if layer_upper == "ODS":
return [t for t in normalized if t.startswith("ods_")]
return normalized
# 需要导入 timedelta
from datetime import timedelta

View File

@@ -0,0 +1,871 @@
# -*- coding: utf-8 -*-
"""ODS 层批量校验器
校验逻辑:对比 API 源数据与 ODS 表数据
- 主键 + content_hash 对比
- 批量 UPSERT 补齐缺失/不一致数据
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from psycopg2.extras import execute_values
from api.local_json_client import LocalJsonClient
from .base_verifier import BaseVerifier, VerificationFetchError
class OdsVerifier(BaseVerifier):
"""ODS 层校验器"""
def __init__(
self,
db_connection: Any,
api_client: Any = None,
logger: Optional[logging.Logger] = None,
fetch_from_api: bool = False,
local_dump_dirs: Optional[Dict[str, str]] = None,
use_local_json: bool = False,
):
"""
初始化 ODS 校验器
Args:
db_connection: 数据库连接
api_client: API 客户端(用于重新获取数据)
logger: 日志器
fetch_from_api: 是否从 API 获取源数据进行校验(默认 False仅校验 ODS 内部一致性)
local_dump_dirs: 本地 JSON dump 目录映射task_code -> 目录)
use_local_json: 是否优先使用本地 JSON 作为源数据
"""
super().__init__(db_connection, logger)
self.api_client = api_client
self.fetch_from_api = fetch_from_api
self.local_dump_dirs = local_dump_dirs or {}
self.use_local_json = bool(use_local_json or self.local_dump_dirs)
# 缓存从 API 获取的数据(避免重复调用)
self._api_data_cache: Dict[str, List[dict]] = {}
self._api_key_cache: Dict[str, Set[Tuple]] = {}
self._api_hash_cache: Dict[str, Dict[Tuple, str]] = {}
self._table_column_cache: Dict[Tuple[str, str], bool] = {}
self._table_pk_cache: Dict[str, List[str]] = {}
self._local_json_clients: Dict[str, LocalJsonClient] = {}
# ODS 表配置:{表名: {pk_columns, time_column, api_endpoint}}
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "ODS"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 ODS 表配置"""
# 从任务定义中动态获取配置
try:
from tasks.ods.ods_tasks import ODS_TASK_SPECS
config = {}
for spec in ODS_TASK_SPECS:
# time_fields 是一个元组 (start_field, end_field),取第一个作为时间列
# 或者使用 fetched_at 作为默认
time_column = "fetched_at"
# 使用 table_name 属性(不是 table
table_name = spec.table_name
# 提取不带 schema 前缀的表名作为 key
if "." in table_name:
table_key = table_name.split(".")[-1]
else:
table_key = table_name
# 从 sources 中提取 ODS 表的实际主键列名
# sources 格式如 ("settleList.id", "id"),最后一个简单名称是 ODS 列名
pk_columns = []
for col in spec.pk_columns:
ods_col_name = self._extract_ods_column_name(col)
pk_columns.append(ods_col_name)
# 如果 pk_columns 为空,尝试使用 conflict_columns_override 或跳过校验
# 一些特殊表(如 goods_stock_summary, settlement_ticket_details没有标准主键
if not pk_columns:
# 跳过没有明确主键定义的表
self.logger.debug("%s 没有定义主键列,跳过校验配置", table_key)
continue
config[table_key] = {
"full_table_name": table_name,
"pk_columns": pk_columns,
"time_column": time_column,
"api_endpoint": spec.endpoint,
"task_code": spec.code,
}
return config
except ImportError:
self.logger.warning("无法加载 ODS 任务定义,使用默认配置")
return {}
def _extract_ods_column_name(self, col) -> str:
"""
从 ColumnSpec 中提取 ODS 表的实际列名
ODS 表使用原始 JSON 字段名(小写),而 col.column 是 DWD 层的命名。
sources 中的最后一个简单字段名通常就是 ODS 表的列名。
"""
# 如果 sources 为空,使用 column假设 column 就是 ODS 列名)
if not col.sources:
return col.column
# 遍历 sources找到最简单的字段名不含点号的
for source in reversed(col.sources):
if "." not in source:
return source.lower() # ODS 列名通常是小写
# 如果都是复杂路径,取最后一个路径的最后一部分
last_source = col.sources[-1]
if "." in last_source:
return last_source.split(".")[-1].lower()
return last_source.lower()
def get_tables(self) -> List[str]:
"""获取需要校验的 ODS 表列表"""
if self._table_config:
return list(self._table_config.keys())
# 从数据库查询 ODS schema 中的表
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_ods'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
except Exception as e:
self.logger.warning("获取 ODS 表列表失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return []
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
if table in self._table_config:
return self._table_config[table].get("pk_columns", [])
# 表不在配置中,返回空列表表示无法校验
return []
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列"""
if table in self._table_config:
return self._table_config[table].get("time_column", "fetched_at")
return "fetched_at"
def _get_full_table_name(self, table: str) -> str:
"""获取完整的表名(包含 schema"""
if table in self._table_config:
return self._table_config[table].get("full_table_name", f"billiards_ods.{table}")
# 如果表名已经包含 schema直接返回
if "." in table:
return table
return f"billiards_ods.{table}"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""
从源获取主键集合
根据 fetch_from_api 参数决定数据来源:
- fetch_from_api=True: 从 API 获取数据(真正的源到目标校验)
- fetch_from_api=False: 从 ODS 表获取ODS 内部一致性校验)
"""
if self._has_external_source():
return self._fetch_keys_from_api(table, window_start, window_end)
else:
# ODS 内部校验:直接从 ODS 表获取
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 API 获取源数据主键集合"""
# 尝试获取缓存的 API 数据
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_key_cache:
return self._api_key_cache[cache_key]
if cache_key not in self._api_data_cache:
# 调用 API 获取数据
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
self.logger.debug("%s 从 API 未获取到数据", table)
return set()
# 获取主键列
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过 API 校验", table)
return set()
# 提取主键
keys = set()
for record in api_records:
pk_values = []
for col in pk_cols:
# API 返回的字段名可能是原始格式(如 id, Id, ID
# 尝试多种格式
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
keys.add(tuple(pk_values))
self.logger.info("%s 从源数据获取 %d 条记录,%d 个唯一主键", table, len(api_records), len(keys))
self._api_key_cache[cache_key] = keys
return keys
def _call_api_for_table(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> List[dict]:
"""调用源数据获取表对应的数据"""
config = self._table_config.get(table, {})
task_code = config.get("task_code")
endpoint = config.get("api_endpoint")
if not task_code or not endpoint:
self.logger.warning(
"%s 没有完整的任务配置task_code=%s, endpoint=%s),无法获取源数据",
table, task_code, endpoint
)
return []
source_client = self._get_source_client(task_code)
if not source_client:
self.logger.warning("%s 未找到可用源API/本地JSON跳过获取源数据", table)
return []
source_label = "本地 JSON" if self._is_using_local_json(task_code) else "API"
self.logger.info(
"%s 获取数据: 表=%s, 端点=%s, 时间窗口=%s ~ %s",
source_label, table, endpoint, window_start, window_end
)
try:
# 获取 ODS 任务规格以获取正确的参数配置
from tasks.ods.ods_tasks import ODS_TASK_SPECS
# 查找对应的任务规格
spec = None
for s in ODS_TASK_SPECS:
if s.code == task_code:
spec = s
break
if not spec:
self.logger.warning("未找到任务规格: %s", task_code)
return []
# 构建 API 参数
params = {}
if spec.include_site_id:
# 从 API 客户端获取 store_id如果可用
store_id = getattr(self.api_client, 'store_id', None)
if store_id:
params["siteId"] = store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
# 格式化时间戳
params[start_key] = window_start.strftime("%Y-%m-%d %H:%M:%S")
params[end_key] = window_end.strftime("%Y-%m-%d %H:%M:%S")
# 合并额外参数
params.update(spec.extra_params)
# 调用源数据
all_records = []
for _, page_records, _, _ in source_client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=200,
data_path=spec.data_path,
list_key=spec.list_key,
):
all_records.extend(page_records)
self.logger.info("源数据返回 %d 条原始记录", len(all_records))
return all_records
except Exception as e:
self.logger.warning("获取源数据失败: 表=%s, error=%s", table, e)
import traceback
self.logger.debug("调用栈: %s", traceback.format_exc())
raise VerificationFetchError(f"获取源数据失败: {table}") from e
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 ODS 表获取目标数据主键集合"""
if self._has_external_source():
cache_key = f"{table}_{window_start}_{window_end}"
api_keys = self._api_key_cache.get(cache_key)
if api_keys is None:
api_keys = self._fetch_keys_from_api(table, window_start, window_end)
return self._fetch_keys_from_db_by_keys(table, api_keys)
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从数据库获取主键集合"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取主键", table)
return set()
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
return {tuple(row) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 ODS 主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS 主键失败: {table}") from e
def _fetch_keys_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Set[Tuple]:
"""按主键集合反查 ODS 表是否存在记录(不依赖时间窗口)"""
if not keys:
return set()
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查", table)
return set()
full_table = self._get_full_table_name(table)
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
except Exception as e:
self.logger.warning("按主键反查 ODS 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS 失败: {table}") from e
return existing
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取源数据的主键->content_hash 映射"""
if self._has_external_source():
return self._fetch_hashes_from_api(table, window_start, window_end)
else:
# ODS 表自带 content_hash 列
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 API 数据计算哈希"""
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_hash_cache:
return self._api_hash_cache[cache_key]
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
# 尝试从 API 获取
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
if not api_records:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
return {}
result = {}
for record in api_records:
# 提取主键
pk_values = []
for col in pk_cols:
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
pk = tuple(pk_values)
# 计算内容哈希
content_hash = self._compute_hash(record)
result[pk] = content_hash
self._api_hash_cache[cache_key] = result
if cache_key not in self._api_key_cache:
self._api_key_cache[cache_key] = set(result.keys())
return result
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取目标数据的主键->content_hash 映射"""
if self.fetch_from_api and self.api_client:
cache_key = f"{table}_{window_start}_{window_end}"
api_hashes = self._api_hash_cache.get(cache_key)
if api_hashes is None:
api_hashes = self._fetch_hashes_from_api(table, window_start, window_end)
api_keys = set(api_hashes.keys())
return self._fetch_hashes_from_db_by_keys(table, api_keys)
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从数据库获取主键->hash 映射"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取哈希", table)
return {}
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}, content_hash
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
for row in cur.fetchall():
pk = tuple(row[:-1])
content_hash = row[-1]
result[pk] = content_hash or ""
except Exception as e:
# 查询失败时回滚事务,避免影响后续查询
self.logger.warning("获取 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS hash 失败: {table}") from e
return result
def _fetch_hashes_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Dict[Tuple, str]:
"""按主键集合反查 ODS 的对比哈希(不依赖时间窗口)"""
if not keys:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查 hash", table)
return {}
full_table = self._get_full_table_name(table)
has_payload = self._table_has_column(full_table, "payload")
select_tail = 't."payload"' if has_payload else 't."content_hash"'
select_cols = ", ".join([*(f't."{c}"' for c in pk_cols), select_tail])
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
result: Dict[Tuple, str] = {}
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
pk = tuple(row[:-1])
tail_value = row[-1]
if has_payload:
compare_hash = self._compute_compare_hash_from_payload(tail_value)
result[pk] = compare_hash or ""
else:
result[pk] = tail_value or ""
except Exception as e:
self.logger.warning("按主键反查 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS hash 失败: {table}") from e
return result
@staticmethod
def _chunked(items: List[Tuple], chunk_size: int) -> List[List[Tuple]]:
"""将列表按固定大小分块"""
if chunk_size <= 0:
return [items]
return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量补齐缺失数据
ODS 层补齐需要重新从 API 获取数据
"""
if not self._has_external_source():
self.logger.warning("未配置 API/本地JSON 源,无法补齐 ODS 缺失数据")
return 0
if not missing_keys:
return 0
# 获取表配置
config = self._table_config.get(table, {})
task_code = config.get("task_code")
if not task_code:
self.logger.warning("未找到表 %s 的任务配置,跳过补齐", table)
return 0
self.logger.info(
"ODS 补齐缺失: 表=%s, 数量=%d, 任务=%s",
table, len(missing_keys), task_code
)
# ODS 层的补齐实际上是重新执行 ODS 任务从 API 获取数据
# 但由于 ODS 任务已经在 "校验前先从 API 获取数据" 步骤执行过了,
# 这里补齐失败是预期的(数据已经在 ODS 表中,只是校验窗口可能不一致)
#
# 实际的 ODS 补齐应该在 verify_only 模式下启用 fetch_before_verify 选项,
# 这会先执行 ODS 任务获取 API 数据,然后再校验。
#
# 如果仍然有缺失,说明:
# 1. API 返回的数据时间窗口与校验窗口不完全匹配
# 2. 或者 ODS 任务的时间参数配置问题
self.logger.info(
"ODS 补齐提示: 表=%s%d 条缺失记录,建议使用 '校验前先从 API 获取数据' 选项获取完整数据",
table, len(missing_keys)
)
return 0
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量更新不一致数据
ODS 层更新也需要重新从 API 获取
"""
# 与 backfill_missing 类似,重新获取数据会自动 UPSERT
return self.backfill_missing(table, mismatch_keys, window_start, window_end)
def _has_external_source(self) -> bool:
return bool(self.fetch_from_api and (self.api_client or self.use_local_json))
def _is_using_local_json(self, task_code: str) -> bool:
return bool(self.use_local_json and task_code in self.local_dump_dirs)
def _get_local_json_client(self, task_code: str) -> Optional[LocalJsonClient]:
if task_code in self._local_json_clients:
return self._local_json_clients[task_code]
dump_dir = self.local_dump_dirs.get(task_code)
if not dump_dir:
return None
try:
client = LocalJsonClient(dump_dir)
except Exception as exc: # noqa: BLE001
self.logger.warning(
"本地 JSON 目录不可用: task=%s, dir=%s, error=%s",
task_code, dump_dir, exc,
)
return None
self._local_json_clients[task_code] = client
return client
def _get_source_client(self, task_code: str):
if self.use_local_json:
return self._get_local_json_client(task_code)
return self.api_client
def verify_against_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
auto_backfill: bool = False,
) -> Dict[str, Any]:
"""
与 API 源数据对比校验
这是更严格的校验,直接调用 API 获取数据进行对比
"""
if not self.api_client:
return {"error": "未配置 API 客户端"}
config = self._table_config.get(table, {})
endpoint = config.get("api_endpoint")
if not endpoint:
return {"error": f"未找到表 {table} 的 API 端点配置"}
self.logger.info("开始与 API 对比校验: %s", table)
# 1. 从 API 获取数据
try:
api_records = self.api_client.fetch_records(
endpoint=endpoint,
start_time=window_start,
end_time=window_end,
)
except Exception as e:
return {"error": f"API 调用失败: {e}"}
# 2. 从 ODS 获取数据
ods_hashes = self.fetch_target_hashes(table, window_start, window_end)
# 3. 计算 API 数据的 hash
pk_cols = self.get_primary_keys(table)
api_hashes = {}
for record in api_records:
pk = tuple(record.get(col) for col in pk_cols)
content_hash = self._compute_hash(record)
api_hashes[pk] = content_hash
# 4. 对比
api_keys = set(api_hashes.keys())
ods_keys = set(ods_hashes.keys())
missing = api_keys - ods_keys
extra = ods_keys - api_keys
mismatch = {
k for k in (api_keys & ods_keys)
if api_hashes[k] != ods_hashes[k]
}
result = {
"table": table,
"api_count": len(api_hashes),
"ods_count": len(ods_hashes),
"missing_count": len(missing),
"extra_count": len(extra),
"mismatch_count": len(mismatch),
"is_consistent": len(missing) == 0 and len(mismatch) == 0,
}
# 5. 自动补齐
if auto_backfill and (missing or mismatch):
# 需要重新获取的主键
keys_to_refetch = missing | mismatch
# 筛选需要重新插入的记录
records_to_upsert = [
r for r in api_records
if tuple(r.get(col) for col in pk_cols) in keys_to_refetch
]
if records_to_upsert:
backfilled = self._batch_upsert(table, records_to_upsert)
result["backfilled_count"] = backfilled
return result
def _table_has_column(self, full_table: str, column: str) -> bool:
"""检查表是否包含指定列(带缓存)"""
cache_key = (full_table, column)
if cache_key in self._table_column_cache:
return self._table_column_cache[cache_key]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT 1
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s AND column_name = %s
LIMIT 1
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table, column))
exists = cur.fetchone() is not None
except Exception:
exists = False
try:
self.db.conn.rollback()
except Exception:
pass
self._table_column_cache[cache_key] = exists
return exists
def _get_db_primary_keys(self, full_table: str) -> List[str]:
"""Read primary key columns from database metadata (ordered)."""
if full_table in self._table_pk_cache:
return self._table_pk_cache[full_table]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table))
rows = cur.fetchall()
cols = [r[0] if not isinstance(r, dict) else r.get("column_name") for r in rows]
result = [c for c in cols if c]
except Exception:
result = []
try:
self.db.conn.rollback()
except Exception:
pass
self._table_pk_cache[full_table] = result
return result
def _compute_compare_hash_from_payload(self, payload: Any) -> Optional[str]:
"""使用 ODS 任务的算法计算对比哈希"""
try:
from tasks.ods.ods_tasks import BaseOdsTask
return BaseOdsTask._compute_compare_hash_from_payload(payload)
except Exception:
return None
def _compute_hash(self, record: dict) -> str:
"""计算记录的对比哈希(与 ODS 入库一致,不包含 fetched_at"""
compare_hash = self._compute_compare_hash_from_payload(record)
return compare_hash or ""
def _batch_upsert(self, table: str, records: List[dict]) -> int:
"""Batch backfill in snapshot-safe mode (insert-only on PK conflict)."""
if not records:
return 0
full_table = self._get_full_table_name(table)
db_pk_cols = self._get_db_primary_keys(full_table)
if not db_pk_cols:
self.logger.warning("%s 未找到主键,跳过回填", full_table)
return 0
has_content_hash_col = self._table_has_column(full_table, "content_hash")
# 获取所有列(从第一条记录),并在存在 content_hash 列时补齐该列。
all_cols = list(records[0].keys())
if has_content_hash_col and "content_hash" not in all_cols:
all_cols.append("content_hash")
# Snapshot-safe strategy: never update historical rows; only insert new snapshots.
col_list = ", ".join(all_cols)
placeholders = ", ".join(["%s"] * len(all_cols))
pk_list = ", ".join(db_pk_cols)
sql = f"""
INSERT INTO {full_table} ({col_list})
VALUES ({placeholders})
ON CONFLICT ({pk_list}) DO NOTHING
"""
count = 0
with self.db.conn.cursor() as cur:
for record in records:
row = dict(record)
if has_content_hash_col:
row["content_hash"] = self._compute_hash(record)
values = [row.get(col) for col in all_cols]
try:
cur.execute(sql, values)
affected = int(cur.rowcount or 0)
if affected > 0:
count += affected
except Exception as e:
self.logger.warning("UPSERT 失败: %s, error=%s", record, e)
self.db.commit()
return count