392 lines
13 KiB
Python
392 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""批量校验基类"""
|
||
|
||
import logging
|
||
import time
|
||
from abc import ABC, abstractmethod
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||
|
||
from .models import (
|
||
VerificationResult,
|
||
VerificationSummary,
|
||
VerificationStatus,
|
||
WindowSegment,
|
||
build_window_segments,
|
||
)
|
||
|
||
|
||
class VerificationFetchError(RuntimeError):
|
||
"""校验数据获取失败(用于显式标记 ERROR)。"""
|
||
|
||
|
||
class BaseVerifier(ABC):
|
||
"""批量校验基类
|
||
|
||
提供统一的校验流程:
|
||
1. 切分时间窗口
|
||
2. 批量读取源数据
|
||
3. 批量读取目标数据
|
||
4. 内存对比
|
||
5. 批量补齐
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_connection: Any,
|
||
logger: Optional[logging.Logger] = None,
|
||
):
|
||
"""
|
||
初始化校验器
|
||
|
||
Args:
|
||
db_connection: 数据库连接
|
||
logger: 日志器
|
||
"""
|
||
self.db = db_connection
|
||
self.logger = logger or logging.getLogger(self.__class__.__name__)
|
||
|
||
@property
|
||
@abstractmethod
|
||
def layer_name(self) -> str:
|
||
"""数据层名称"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_tables(self) -> List[str]:
|
||
"""获取需要校验的表列表"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_primary_keys(self, table: str) -> List[str]:
|
||
"""获取表的主键列"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_time_column(self, table: str) -> Optional[str]:
|
||
"""获取表的时间列(用于窗口过滤)"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def fetch_source_keys(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Set[Tuple]:
|
||
"""批量获取源数据主键集合"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def fetch_target_keys(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Set[Tuple]:
|
||
"""批量获取目标数据主键集合"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def fetch_source_hashes(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Dict[Tuple, str]:
|
||
"""批量获取源数据主键->内容哈希映射"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def fetch_target_hashes(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Dict[Tuple, str]:
|
||
"""批量获取目标数据主键->内容哈希映射"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def backfill_missing(
|
||
self,
|
||
table: str,
|
||
missing_keys: Set[Tuple],
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> int:
|
||
"""批量补齐缺失数据,返回补齐的记录数"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def backfill_mismatch(
|
||
self,
|
||
table: str,
|
||
mismatch_keys: Set[Tuple],
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> int:
|
||
"""批量更新不一致数据,返回更新的记录数"""
|
||
pass
|
||
|
||
def verify_table(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
auto_backfill: bool = False,
|
||
compare_content: bool = True,
|
||
) -> VerificationResult:
|
||
"""
|
||
校验单表
|
||
|
||
Args:
|
||
table: 表名
|
||
window_start: 窗口开始
|
||
window_end: 窗口结束
|
||
auto_backfill: 是否自动补齐
|
||
compare_content: 是否对比内容(True=对比hash,False=仅对比主键)
|
||
|
||
Returns:
|
||
校验结果
|
||
"""
|
||
start_time = time.time()
|
||
result = VerificationResult(
|
||
layer=self.layer_name,
|
||
table=table,
|
||
window_start=window_start,
|
||
window_end=window_end,
|
||
)
|
||
|
||
try:
|
||
# 确保连接可用,避免“connection already closed”导致误判 OK
|
||
self._ensure_connection()
|
||
self.logger.info(
|
||
"%s 校验开始: %s [%s ~ %s]",
|
||
self.layer_name, table,
|
||
window_start.strftime("%Y-%m-%d %H:%M"),
|
||
window_end.strftime("%Y-%m-%d %H:%M")
|
||
)
|
||
|
||
if compare_content:
|
||
# 对比内容哈希
|
||
source_hashes = self.fetch_source_hashes(table, window_start, window_end)
|
||
target_hashes = self.fetch_target_hashes(table, window_start, window_end)
|
||
|
||
result.source_count = len(source_hashes)
|
||
result.target_count = len(target_hashes)
|
||
|
||
source_keys = set(source_hashes.keys())
|
||
target_keys = set(target_hashes.keys())
|
||
|
||
# 计算缺失
|
||
missing_keys = source_keys - target_keys
|
||
result.missing_count = len(missing_keys)
|
||
|
||
# 计算不一致(两边都有但hash不同)
|
||
common_keys = source_keys & target_keys
|
||
mismatch_keys = {
|
||
k for k in common_keys
|
||
if source_hashes[k] != target_hashes[k]
|
||
}
|
||
result.mismatch_count = len(mismatch_keys)
|
||
else:
|
||
# 仅对比主键
|
||
source_keys = self.fetch_source_keys(table, window_start, window_end)
|
||
target_keys = self.fetch_target_keys(table, window_start, window_end)
|
||
|
||
result.source_count = len(source_keys)
|
||
result.target_count = len(target_keys)
|
||
|
||
missing_keys = source_keys - target_keys
|
||
result.missing_count = len(missing_keys)
|
||
mismatch_keys = set()
|
||
|
||
# 判断状态
|
||
if result.missing_count > 0:
|
||
result.status = VerificationStatus.MISSING
|
||
elif result.mismatch_count > 0:
|
||
result.status = VerificationStatus.MISMATCH
|
||
else:
|
||
result.status = VerificationStatus.OK
|
||
|
||
# 自动补齐
|
||
if auto_backfill and (missing_keys or mismatch_keys):
|
||
backfill_missing_count = 0
|
||
backfill_mismatch_count = 0
|
||
|
||
if missing_keys:
|
||
self.logger.info(
|
||
"%s 补齐缺失: %s, 数量=%d",
|
||
self.layer_name, table, len(missing_keys)
|
||
)
|
||
backfill_missing_count += self.backfill_missing(
|
||
table, missing_keys, window_start, window_end
|
||
)
|
||
|
||
if mismatch_keys:
|
||
self.logger.info(
|
||
"%s 更新不一致: %s, 数量=%d",
|
||
self.layer_name, table, len(mismatch_keys)
|
||
)
|
||
backfill_mismatch_count += self.backfill_mismatch(
|
||
table, mismatch_keys, window_start, window_end
|
||
)
|
||
|
||
result.backfilled_missing_count = backfill_missing_count
|
||
result.backfilled_mismatch_count = backfill_mismatch_count
|
||
result.backfilled_count = backfill_missing_count + backfill_mismatch_count
|
||
if result.backfilled_count > 0:
|
||
result.status = VerificationStatus.BACKFILLED
|
||
|
||
self.logger.info(
|
||
"%s 校验完成: %s, 源=%d, 目标=%d, 缺失=%d, 不一致=%d, 补齐=%d(缺失=%d, 不一致=%d)",
|
||
self.layer_name, table,
|
||
result.source_count, result.target_count,
|
||
result.missing_count, result.mismatch_count, result.backfilled_count,
|
||
result.backfilled_missing_count, result.backfilled_mismatch_count
|
||
)
|
||
|
||
except Exception as e:
|
||
result.status = VerificationStatus.ERROR
|
||
result.error_message = str(e)
|
||
if isinstance(e, VerificationFetchError):
|
||
# CHANGE [2026-02-19] intent: 区分连接类致命错误与数据质量错误,后者不中止校验
|
||
# assumptions: ValueError/TypeError 通常是脏数据(如非法日期 year=-1),不应中止全部表
|
||
root_cause = e.__cause__
|
||
is_data_quality_error = isinstance(root_cause, (ValueError, TypeError, OverflowError))
|
||
if is_data_quality_error:
|
||
self.logger.warning(
|
||
"%s 数据质量问题(非致命): %s, error=%s", self.layer_name, table, root_cause
|
||
)
|
||
else:
|
||
# 连接不可用等致命错误,标记后续应中止
|
||
result.details["fatal"] = True
|
||
self.logger.exception("%s 校验失败: %s, error=%s", self.layer_name, table, e)
|
||
# 回滚事务,避免 PostgreSQL "当前事务被终止" 错误影响后续查询
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass # 忽略回滚错误
|
||
|
||
result.elapsed_seconds = time.time() - start_time
|
||
return result
|
||
|
||
def verify_and_backfill(
|
||
self,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
split_unit: str = "month",
|
||
tables: Optional[List[str]] = None,
|
||
auto_backfill: bool = True,
|
||
compare_content: bool = True,
|
||
) -> VerificationSummary:
|
||
"""
|
||
按时间窗口切分执行批量校验
|
||
|
||
Args:
|
||
window_start: 开始时间
|
||
window_end: 结束时间
|
||
split_unit: 切分单位 ("none", "day", "week", "month")
|
||
tables: 指定校验的表,None 表示全部
|
||
auto_backfill: 是否自动补齐
|
||
compare_content: 是否对比内容
|
||
|
||
Returns:
|
||
校验汇总结果
|
||
"""
|
||
summary = VerificationSummary(
|
||
layer=self.layer_name,
|
||
window_start=window_start,
|
||
window_end=window_end,
|
||
)
|
||
|
||
# 获取要校验的表
|
||
all_tables = tables or self.get_tables()
|
||
|
||
# 切分时间窗口
|
||
segments = build_window_segments(window_start, window_end, split_unit)
|
||
|
||
self.logger.info(
|
||
"%s 批量校验开始: 表数=%d, 窗口切分=%d段",
|
||
self.layer_name, len(all_tables), len(segments)
|
||
)
|
||
|
||
fatal_error = False
|
||
for segment in segments:
|
||
# 每段开始前检查连接状态,异常时立即终止,避免大量空跑
|
||
self._ensure_connection()
|
||
self.logger.info(
|
||
"%s 处理窗口 [%d/%d]: %s",
|
||
self.layer_name, segment.index + 1, segment.total, segment.label
|
||
)
|
||
|
||
for table in all_tables:
|
||
result = self.verify_table(
|
||
table=table,
|
||
window_start=segment.start,
|
||
window_end=segment.end,
|
||
auto_backfill=auto_backfill,
|
||
compare_content=compare_content,
|
||
)
|
||
summary.add_result(result)
|
||
if result.details.get("fatal"):
|
||
fatal_error = True
|
||
break
|
||
|
||
# 每段完成后提交
|
||
try:
|
||
self.db.commit()
|
||
except Exception as e:
|
||
self.logger.warning("提交失败: %s", e)
|
||
if fatal_error:
|
||
self.logger.warning("%s 校验中止:连接不可用或发生致命错误", self.layer_name)
|
||
break
|
||
|
||
self.logger.info(summary.format_summary())
|
||
return summary
|
||
|
||
def _ensure_connection(self):
|
||
"""确保数据库连接可用,必要时尝试重连。"""
|
||
if not hasattr(self.db, "conn"):
|
||
raise VerificationFetchError("校验器未绑定有效数据库连接")
|
||
if getattr(self.db.conn, "closed", 0):
|
||
# 优先使用连接对象的重连能力
|
||
if hasattr(self.db, "ensure_open"):
|
||
if not self.db.ensure_open():
|
||
raise VerificationFetchError("数据库连接已关闭,无法继续校验")
|
||
else:
|
||
raise VerificationFetchError("数据库连接已关闭,无法继续校验")
|
||
|
||
def quick_check(
|
||
self,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
tables: Optional[List[str]] = None,
|
||
) -> Dict[str, dict]:
|
||
"""
|
||
快速检查(仅对比数量,不对比内容)
|
||
|
||
Args:
|
||
window_start: 开始时间
|
||
window_end: 结束时间
|
||
tables: 指定表,None 表示全部
|
||
|
||
Returns:
|
||
{表名: {source_count, target_count, diff}}
|
||
"""
|
||
all_tables = tables or self.get_tables()
|
||
results = {}
|
||
|
||
for table in all_tables:
|
||
source_keys = self.fetch_source_keys(table, window_start, window_end)
|
||
target_keys = self.fetch_target_keys(table, window_start, window_end)
|
||
|
||
results[table] = {
|
||
"source_count": len(source_keys),
|
||
"target_count": len(target_keys),
|
||
"diff": len(source_keys) - len(target_keys),
|
||
}
|
||
|
||
return results
|