872 lines
33 KiB
Python
872 lines
33 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""ODS 层批量校验器
|
||
|
||
校验逻辑:对比 API 源数据与 ODS 表数据
|
||
- 主键 + content_hash 对比
|
||
- 批量 UPSERT 补齐缺失/不一致数据
|
||
"""
|
||
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||
|
||
from psycopg2.extras import execute_values
|
||
|
||
from api.local_json_client import LocalJsonClient
|
||
|
||
from .base_verifier import BaseVerifier, VerificationFetchError
|
||
|
||
|
||
class OdsVerifier(BaseVerifier):
|
||
"""ODS 层校验器"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_connection: Any,
|
||
api_client: Any = None,
|
||
logger: Optional[logging.Logger] = None,
|
||
fetch_from_api: bool = False,
|
||
local_dump_dirs: Optional[Dict[str, str]] = None,
|
||
use_local_json: bool = False,
|
||
):
|
||
"""
|
||
初始化 ODS 校验器
|
||
|
||
Args:
|
||
db_connection: 数据库连接
|
||
api_client: API 客户端(用于重新获取数据)
|
||
logger: 日志器
|
||
fetch_from_api: 是否从 API 获取源数据进行校验(默认 False,仅校验 ODS 内部一致性)
|
||
local_dump_dirs: 本地 JSON dump 目录映射(task_code -> 目录)
|
||
use_local_json: 是否优先使用本地 JSON 作为源数据
|
||
"""
|
||
super().__init__(db_connection, logger)
|
||
self.api_client = api_client
|
||
self.fetch_from_api = fetch_from_api
|
||
self.local_dump_dirs = local_dump_dirs or {}
|
||
self.use_local_json = bool(use_local_json or self.local_dump_dirs)
|
||
|
||
# 缓存从 API 获取的数据(避免重复调用)
|
||
self._api_data_cache: Dict[str, List[dict]] = {}
|
||
self._api_key_cache: Dict[str, Set[Tuple]] = {}
|
||
self._api_hash_cache: Dict[str, Dict[Tuple, str]] = {}
|
||
self._table_column_cache: Dict[Tuple[str, str], bool] = {}
|
||
self._table_pk_cache: Dict[str, List[str]] = {}
|
||
self._local_json_clients: Dict[str, LocalJsonClient] = {}
|
||
|
||
# ODS 表配置:{表名: {pk_columns, time_column, api_endpoint}}
|
||
self._table_config = self._load_table_config()
|
||
|
||
@property
|
||
def layer_name(self) -> str:
|
||
return "ODS"
|
||
|
||
def _load_table_config(self) -> Dict[str, dict]:
|
||
"""加载 ODS 表配置"""
|
||
# 从任务定义中动态获取配置
|
||
try:
|
||
from tasks.ods.ods_tasks import ODS_TASK_SPECS
|
||
config = {}
|
||
for spec in ODS_TASK_SPECS:
|
||
# time_fields 是一个元组 (start_field, end_field),取第一个作为时间列
|
||
# 或者使用 fetched_at 作为默认
|
||
time_column = "fetched_at"
|
||
|
||
# 使用 table_name 属性(不是 table)
|
||
table_name = spec.table_name
|
||
# 提取不带 schema 前缀的表名作为 key
|
||
if "." in table_name:
|
||
table_key = table_name.split(".")[-1]
|
||
else:
|
||
table_key = table_name
|
||
|
||
# 从 sources 中提取 ODS 表的实际主键列名
|
||
# sources 格式如 ("settleList.id", "id"),最后一个简单名称是 ODS 列名
|
||
pk_columns = []
|
||
for col in spec.pk_columns:
|
||
ods_col_name = self._extract_ods_column_name(col)
|
||
pk_columns.append(ods_col_name)
|
||
|
||
# 如果 pk_columns 为空,尝试使用 conflict_columns_override 或跳过校验
|
||
# 一些特殊表(如 goods_stock_summary, settlement_ticket_details)没有标准主键
|
||
if not pk_columns:
|
||
# 跳过没有明确主键定义的表
|
||
self.logger.debug("表 %s 没有定义主键列,跳过校验配置", table_key)
|
||
continue
|
||
|
||
config[table_key] = {
|
||
"full_table_name": table_name,
|
||
"pk_columns": pk_columns,
|
||
"time_column": time_column,
|
||
"api_endpoint": spec.endpoint,
|
||
"task_code": spec.code,
|
||
}
|
||
return config
|
||
except ImportError:
|
||
self.logger.warning("无法加载 ODS 任务定义,使用默认配置")
|
||
return {}
|
||
|
||
def _extract_ods_column_name(self, col) -> str:
|
||
"""
|
||
从 ColumnSpec 中提取 ODS 表的实际列名
|
||
|
||
ODS 表使用原始 JSON 字段名(小写),而 col.column 是 DWD 层的命名。
|
||
sources 中的最后一个简单字段名通常就是 ODS 表的列名。
|
||
"""
|
||
# 如果 sources 为空,使用 column(假设 column 就是 ODS 列名)
|
||
if not col.sources:
|
||
return col.column
|
||
|
||
# 遍历 sources,找到最简单的字段名(不含点号的)
|
||
for source in reversed(col.sources):
|
||
if "." not in source:
|
||
return source.lower() # ODS 列名通常是小写
|
||
|
||
# 如果都是复杂路径,取最后一个路径的最后一部分
|
||
last_source = col.sources[-1]
|
||
if "." in last_source:
|
||
return last_source.split(".")[-1].lower()
|
||
return last_source.lower()
|
||
|
||
def get_tables(self) -> List[str]:
|
||
"""获取需要校验的 ODS 表列表"""
|
||
if self._table_config:
|
||
return list(self._table_config.keys())
|
||
|
||
# 从数据库查询 ODS schema 中的表
|
||
sql = """
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'billiards_ods'
|
||
AND table_type = 'BASE TABLE'
|
||
ORDER BY table_name
|
||
"""
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql)
|
||
return [row[0] for row in cur.fetchall()]
|
||
except Exception as e:
|
||
self.logger.warning("获取 ODS 表列表失败: %s", e)
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
return []
|
||
|
||
def get_primary_keys(self, table: str) -> List[str]:
|
||
"""获取表的主键列"""
|
||
if table in self._table_config:
|
||
return self._table_config[table].get("pk_columns", [])
|
||
# 表不在配置中,返回空列表表示无法校验
|
||
return []
|
||
|
||
def get_time_column(self, table: str) -> Optional[str]:
|
||
"""获取表的时间列"""
|
||
if table in self._table_config:
|
||
return self._table_config[table].get("time_column", "fetched_at")
|
||
return "fetched_at"
|
||
|
||
def _get_full_table_name(self, table: str) -> str:
|
||
"""获取完整的表名(包含 schema)"""
|
||
if table in self._table_config:
|
||
return self._table_config[table].get("full_table_name", f"billiards_ods.{table}")
|
||
# 如果表名已经包含 schema,直接返回
|
||
if "." in table:
|
||
return table
|
||
return f"billiards_ods.{table}"
|
||
|
||
def fetch_source_keys(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Set[Tuple]:
|
||
"""
|
||
从源获取主键集合
|
||
|
||
根据 fetch_from_api 参数决定数据来源:
|
||
- fetch_from_api=True: 从 API 获取数据(真正的源到目标校验)
|
||
- fetch_from_api=False: 从 ODS 表获取(ODS 内部一致性校验)
|
||
"""
|
||
if self._has_external_source():
|
||
return self._fetch_keys_from_api(table, window_start, window_end)
|
||
else:
|
||
# ODS 内部校验:直接从 ODS 表获取
|
||
return self._fetch_keys_from_db(table, window_start, window_end)
|
||
|
||
def _fetch_keys_from_api(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Set[Tuple]:
|
||
"""从 API 获取源数据主键集合"""
|
||
# 尝试获取缓存的 API 数据
|
||
cache_key = f"{table}_{window_start}_{window_end}"
|
||
if cache_key in self._api_key_cache:
|
||
return self._api_key_cache[cache_key]
|
||
if cache_key not in self._api_data_cache:
|
||
# 调用 API 获取数据
|
||
api_records = self._call_api_for_table(table, window_start, window_end)
|
||
self._api_data_cache[cache_key] = api_records
|
||
|
||
api_records = self._api_data_cache.get(cache_key, [])
|
||
|
||
if not api_records:
|
||
self.logger.debug("表 %s 从 API 未获取到数据", table)
|
||
return set()
|
||
|
||
# 获取主键列
|
||
pk_cols = self.get_primary_keys(table)
|
||
if not pk_cols:
|
||
self.logger.debug("表 %s 没有主键配置,跳过 API 校验", table)
|
||
return set()
|
||
|
||
# 提取主键
|
||
keys = set()
|
||
for record in api_records:
|
||
pk_values = []
|
||
for col in pk_cols:
|
||
# API 返回的字段名可能是原始格式(如 id, Id, ID)
|
||
# 尝试多种格式
|
||
value = record.get(col)
|
||
if value is None:
|
||
value = record.get(col.lower())
|
||
if value is None:
|
||
value = record.get(col.upper())
|
||
pk_values.append(value)
|
||
if all(v is not None for v in pk_values):
|
||
keys.add(tuple(pk_values))
|
||
|
||
self.logger.info("表 %s 从源数据获取 %d 条记录,%d 个唯一主键", table, len(api_records), len(keys))
|
||
self._api_key_cache[cache_key] = keys
|
||
return keys
|
||
|
||
def _call_api_for_table(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> List[dict]:
|
||
"""调用源数据获取表对应的数据"""
|
||
config = self._table_config.get(table, {})
|
||
task_code = config.get("task_code")
|
||
endpoint = config.get("api_endpoint")
|
||
|
||
if not task_code or not endpoint:
|
||
self.logger.warning(
|
||
"表 %s 没有完整的任务配置(task_code=%s, endpoint=%s),无法获取源数据",
|
||
table, task_code, endpoint
|
||
)
|
||
return []
|
||
|
||
source_client = self._get_source_client(task_code)
|
||
if not source_client:
|
||
self.logger.warning("表 %s 未找到可用源(API/本地JSON),跳过获取源数据", table)
|
||
return []
|
||
|
||
source_label = "本地 JSON" if self._is_using_local_json(task_code) else "API"
|
||
self.logger.info(
|
||
"从 %s 获取数据: 表=%s, 端点=%s, 时间窗口=%s ~ %s",
|
||
source_label, table, endpoint, window_start, window_end
|
||
)
|
||
|
||
try:
|
||
# 获取 ODS 任务规格以获取正确的参数配置
|
||
from tasks.ods.ods_tasks import ODS_TASK_SPECS
|
||
|
||
# 查找对应的任务规格
|
||
spec = None
|
||
for s in ODS_TASK_SPECS:
|
||
if s.code == task_code:
|
||
spec = s
|
||
break
|
||
|
||
if not spec:
|
||
self.logger.warning("未找到任务规格: %s", task_code)
|
||
return []
|
||
|
||
# 构建 API 参数
|
||
params = {}
|
||
if spec.include_site_id:
|
||
# 从 API 客户端获取 store_id(如果可用)
|
||
store_id = getattr(self.api_client, 'store_id', None)
|
||
if store_id:
|
||
params["siteId"] = store_id
|
||
|
||
if spec.requires_window and spec.time_fields:
|
||
start_key, end_key = spec.time_fields
|
||
# 格式化时间戳
|
||
params[start_key] = window_start.strftime("%Y-%m-%d %H:%M:%S")
|
||
params[end_key] = window_end.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# 合并额外参数
|
||
params.update(spec.extra_params)
|
||
|
||
# 调用源数据
|
||
all_records = []
|
||
for _, page_records, _, _ in source_client.iter_paginated(
|
||
endpoint=spec.endpoint,
|
||
params=params,
|
||
page_size=200,
|
||
data_path=spec.data_path,
|
||
list_key=spec.list_key,
|
||
):
|
||
all_records.extend(page_records)
|
||
|
||
self.logger.info("源数据返回 %d 条原始记录", len(all_records))
|
||
return all_records
|
||
|
||
except Exception as e:
|
||
self.logger.warning("获取源数据失败: 表=%s, error=%s", table, e)
|
||
import traceback
|
||
self.logger.debug("调用栈: %s", traceback.format_exc())
|
||
raise VerificationFetchError(f"获取源数据失败: {table}") from e
|
||
|
||
def fetch_target_keys(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Set[Tuple]:
|
||
"""从 ODS 表获取目标数据主键集合"""
|
||
if self._has_external_source():
|
||
cache_key = f"{table}_{window_start}_{window_end}"
|
||
api_keys = self._api_key_cache.get(cache_key)
|
||
if api_keys is None:
|
||
api_keys = self._fetch_keys_from_api(table, window_start, window_end)
|
||
return self._fetch_keys_from_db_by_keys(table, api_keys)
|
||
return self._fetch_keys_from_db(table, window_start, window_end)
|
||
|
||
def _fetch_keys_from_db(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Set[Tuple]:
|
||
"""从数据库获取主键集合"""
|
||
pk_cols = self.get_primary_keys(table)
|
||
|
||
# 如果没有主键列配置,跳过校验
|
||
if not pk_cols:
|
||
self.logger.debug("表 %s 没有主键配置,跳过获取主键", table)
|
||
return set()
|
||
|
||
time_col = self.get_time_column(table)
|
||
full_table = self._get_full_table_name(table)
|
||
|
||
pk_select = ", ".join(pk_cols)
|
||
sql = f"""
|
||
SELECT {pk_select}
|
||
FROM {full_table}
|
||
WHERE {time_col} >= %s AND {time_col} < %s
|
||
"""
|
||
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql, (window_start, window_end))
|
||
return {tuple(row) for row in cur.fetchall()}
|
||
except Exception as e:
|
||
self.logger.warning("获取 ODS 主键失败: %s, error=%s", table, e)
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
raise VerificationFetchError(f"获取 ODS 主键失败: {table}") from e
|
||
|
||
def _fetch_keys_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Set[Tuple]:
|
||
"""按主键集合反查 ODS 表是否存在记录(不依赖时间窗口)"""
|
||
if not keys:
|
||
return set()
|
||
pk_cols = self.get_primary_keys(table)
|
||
if not pk_cols:
|
||
self.logger.debug("表 %s 没有主键配置,跳过按主键反查", table)
|
||
return set()
|
||
full_table = self._get_full_table_name(table)
|
||
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
|
||
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
|
||
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
|
||
sql = (
|
||
f"SELECT {select_cols} FROM {full_table} t "
|
||
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
|
||
)
|
||
existing: Set[Tuple] = set()
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
for chunk in self._chunked(list(keys), 500):
|
||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||
for row in cur.fetchall():
|
||
existing.add(tuple(row))
|
||
except Exception as e:
|
||
self.logger.warning("按主键反查 ODS 失败: %s, error=%s", table, e)
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
raise VerificationFetchError(f"按主键反查 ODS 失败: {table}") from e
|
||
return existing
|
||
|
||
def fetch_source_hashes(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Dict[Tuple, str]:
|
||
"""获取源数据的主键->content_hash 映射"""
|
||
if self._has_external_source():
|
||
return self._fetch_hashes_from_api(table, window_start, window_end)
|
||
else:
|
||
# ODS 表自带 content_hash 列
|
||
return self._fetch_hashes_from_db(table, window_start, window_end)
|
||
|
||
def _fetch_hashes_from_api(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Dict[Tuple, str]:
|
||
"""从 API 数据计算哈希"""
|
||
cache_key = f"{table}_{window_start}_{window_end}"
|
||
if cache_key in self._api_hash_cache:
|
||
return self._api_hash_cache[cache_key]
|
||
api_records = self._api_data_cache.get(cache_key, [])
|
||
|
||
if not api_records:
|
||
# 尝试从 API 获取
|
||
api_records = self._call_api_for_table(table, window_start, window_end)
|
||
self._api_data_cache[cache_key] = api_records
|
||
|
||
if not api_records:
|
||
return {}
|
||
|
||
pk_cols = self.get_primary_keys(table)
|
||
if not pk_cols:
|
||
return {}
|
||
|
||
result = {}
|
||
for record in api_records:
|
||
# 提取主键
|
||
pk_values = []
|
||
for col in pk_cols:
|
||
value = record.get(col)
|
||
if value is None:
|
||
value = record.get(col.lower())
|
||
if value is None:
|
||
value = record.get(col.upper())
|
||
pk_values.append(value)
|
||
|
||
if all(v is not None for v in pk_values):
|
||
pk = tuple(pk_values)
|
||
# 计算内容哈希
|
||
content_hash = self._compute_hash(record)
|
||
result[pk] = content_hash
|
||
|
||
self._api_hash_cache[cache_key] = result
|
||
if cache_key not in self._api_key_cache:
|
||
self._api_key_cache[cache_key] = set(result.keys())
|
||
return result
|
||
|
||
def fetch_target_hashes(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Dict[Tuple, str]:
|
||
"""获取目标数据的主键->content_hash 映射"""
|
||
if self.fetch_from_api and self.api_client:
|
||
cache_key = f"{table}_{window_start}_{window_end}"
|
||
api_hashes = self._api_hash_cache.get(cache_key)
|
||
if api_hashes is None:
|
||
api_hashes = self._fetch_hashes_from_api(table, window_start, window_end)
|
||
api_keys = set(api_hashes.keys())
|
||
return self._fetch_hashes_from_db_by_keys(table, api_keys)
|
||
return self._fetch_hashes_from_db(table, window_start, window_end)
|
||
|
||
def _fetch_hashes_from_db(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> Dict[Tuple, str]:
|
||
"""从数据库获取主键->hash 映射"""
|
||
pk_cols = self.get_primary_keys(table)
|
||
|
||
# 如果没有主键列配置,跳过校验
|
||
if not pk_cols:
|
||
self.logger.debug("表 %s 没有主键配置,跳过获取哈希", table)
|
||
return {}
|
||
|
||
time_col = self.get_time_column(table)
|
||
full_table = self._get_full_table_name(table)
|
||
|
||
pk_select = ", ".join(pk_cols)
|
||
sql = f"""
|
||
SELECT {pk_select}, content_hash
|
||
FROM {full_table}
|
||
WHERE {time_col} >= %s AND {time_col} < %s
|
||
"""
|
||
|
||
result = {}
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql, (window_start, window_end))
|
||
for row in cur.fetchall():
|
||
pk = tuple(row[:-1])
|
||
content_hash = row[-1]
|
||
result[pk] = content_hash or ""
|
||
except Exception as e:
|
||
# 查询失败时回滚事务,避免影响后续查询
|
||
self.logger.warning("获取 ODS hash 失败: %s, error=%s", table, e)
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
raise VerificationFetchError(f"获取 ODS hash 失败: {table}") from e
|
||
|
||
return result
|
||
|
||
def _fetch_hashes_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Dict[Tuple, str]:
|
||
"""按主键集合反查 ODS 的对比哈希(不依赖时间窗口)"""
|
||
if not keys:
|
||
return {}
|
||
pk_cols = self.get_primary_keys(table)
|
||
if not pk_cols:
|
||
self.logger.debug("表 %s 没有主键配置,跳过按主键反查 hash", table)
|
||
return {}
|
||
full_table = self._get_full_table_name(table)
|
||
has_payload = self._table_has_column(full_table, "payload")
|
||
select_tail = 't."payload"' if has_payload else 't."content_hash"'
|
||
select_cols = ", ".join([*(f't."{c}"' for c in pk_cols), select_tail])
|
||
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
|
||
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
|
||
sql = (
|
||
f"SELECT {select_cols} FROM {full_table} t "
|
||
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
|
||
)
|
||
result: Dict[Tuple, str] = {}
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
for chunk in self._chunked(list(keys), 500):
|
||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||
for row in cur.fetchall():
|
||
pk = tuple(row[:-1])
|
||
tail_value = row[-1]
|
||
if has_payload:
|
||
compare_hash = self._compute_compare_hash_from_payload(tail_value)
|
||
result[pk] = compare_hash or ""
|
||
else:
|
||
result[pk] = tail_value or ""
|
||
except Exception as e:
|
||
self.logger.warning("按主键反查 ODS hash 失败: %s, error=%s", table, e)
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
raise VerificationFetchError(f"按主键反查 ODS hash 失败: {table}") from e
|
||
return result
|
||
|
||
@staticmethod
|
||
def _chunked(items: List[Tuple], chunk_size: int) -> List[List[Tuple]]:
|
||
"""将列表按固定大小分块"""
|
||
if chunk_size <= 0:
|
||
return [items]
|
||
return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
|
||
|
||
def backfill_missing(
|
||
self,
|
||
table: str,
|
||
missing_keys: Set[Tuple],
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> int:
|
||
"""
|
||
批量补齐缺失数据
|
||
|
||
ODS 层补齐需要重新从 API 获取数据
|
||
"""
|
||
if not self._has_external_source():
|
||
self.logger.warning("未配置 API/本地JSON 源,无法补齐 ODS 缺失数据")
|
||
return 0
|
||
|
||
if not missing_keys:
|
||
return 0
|
||
|
||
# 获取表配置
|
||
config = self._table_config.get(table, {})
|
||
task_code = config.get("task_code")
|
||
|
||
if not task_code:
|
||
self.logger.warning("未找到表 %s 的任务配置,跳过补齐", table)
|
||
return 0
|
||
|
||
self.logger.info(
|
||
"ODS 补齐缺失: 表=%s, 数量=%d, 任务=%s",
|
||
table, len(missing_keys), task_code
|
||
)
|
||
|
||
# ODS 层的补齐实际上是重新执行 ODS 任务从 API 获取数据
|
||
# 但由于 ODS 任务已经在 "校验前先从 API 获取数据" 步骤执行过了,
|
||
# 这里补齐失败是预期的(数据已经在 ODS 表中,只是校验窗口可能不一致)
|
||
#
|
||
# 实际的 ODS 补齐应该在 verify_only 模式下启用 fetch_before_verify 选项,
|
||
# 这会先执行 ODS 任务获取 API 数据,然后再校验。
|
||
#
|
||
# 如果仍然有缺失,说明:
|
||
# 1. API 返回的数据时间窗口与校验窗口不完全匹配
|
||
# 2. 或者 ODS 任务的时间参数配置问题
|
||
self.logger.info(
|
||
"ODS 补齐提示: 表=%s 有 %d 条缺失记录,建议使用 '校验前先从 API 获取数据' 选项获取完整数据",
|
||
table, len(missing_keys)
|
||
)
|
||
return 0
|
||
|
||
def backfill_mismatch(
|
||
self,
|
||
table: str,
|
||
mismatch_keys: Set[Tuple],
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
) -> int:
|
||
"""
|
||
批量更新不一致数据
|
||
|
||
ODS 层更新也需要重新从 API 获取
|
||
"""
|
||
# 与 backfill_missing 类似,重新获取数据会自动 UPSERT
|
||
return self.backfill_missing(table, mismatch_keys, window_start, window_end)
|
||
|
||
def _has_external_source(self) -> bool:
|
||
return bool(self.fetch_from_api and (self.api_client or self.use_local_json))
|
||
|
||
def _is_using_local_json(self, task_code: str) -> bool:
|
||
return bool(self.use_local_json and task_code in self.local_dump_dirs)
|
||
|
||
def _get_local_json_client(self, task_code: str) -> Optional[LocalJsonClient]:
|
||
if task_code in self._local_json_clients:
|
||
return self._local_json_clients[task_code]
|
||
dump_dir = self.local_dump_dirs.get(task_code)
|
||
if not dump_dir:
|
||
return None
|
||
try:
|
||
client = LocalJsonClient(dump_dir)
|
||
except Exception as exc: # noqa: BLE001
|
||
self.logger.warning(
|
||
"本地 JSON 目录不可用: task=%s, dir=%s, error=%s",
|
||
task_code, dump_dir, exc,
|
||
)
|
||
return None
|
||
self._local_json_clients[task_code] = client
|
||
return client
|
||
|
||
def _get_source_client(self, task_code: str):
|
||
if self.use_local_json:
|
||
return self._get_local_json_client(task_code)
|
||
return self.api_client
|
||
|
||
def verify_against_api(
|
||
self,
|
||
table: str,
|
||
window_start: datetime,
|
||
window_end: datetime,
|
||
auto_backfill: bool = False,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
与 API 源数据对比校验
|
||
|
||
这是更严格的校验,直接调用 API 获取数据进行对比
|
||
"""
|
||
if not self.api_client:
|
||
return {"error": "未配置 API 客户端"}
|
||
|
||
config = self._table_config.get(table, {})
|
||
endpoint = config.get("api_endpoint")
|
||
|
||
if not endpoint:
|
||
return {"error": f"未找到表 {table} 的 API 端点配置"}
|
||
|
||
self.logger.info("开始与 API 对比校验: %s", table)
|
||
|
||
# 1. 从 API 获取数据
|
||
try:
|
||
api_records = self.api_client.fetch_records(
|
||
endpoint=endpoint,
|
||
start_time=window_start,
|
||
end_time=window_end,
|
||
)
|
||
except Exception as e:
|
||
return {"error": f"API 调用失败: {e}"}
|
||
|
||
# 2. 从 ODS 获取数据
|
||
ods_hashes = self.fetch_target_hashes(table, window_start, window_end)
|
||
|
||
# 3. 计算 API 数据的 hash
|
||
pk_cols = self.get_primary_keys(table)
|
||
api_hashes = {}
|
||
for record in api_records:
|
||
pk = tuple(record.get(col) for col in pk_cols)
|
||
content_hash = self._compute_hash(record)
|
||
api_hashes[pk] = content_hash
|
||
|
||
# 4. 对比
|
||
api_keys = set(api_hashes.keys())
|
||
ods_keys = set(ods_hashes.keys())
|
||
|
||
missing = api_keys - ods_keys
|
||
extra = ods_keys - api_keys
|
||
mismatch = {
|
||
k for k in (api_keys & ods_keys)
|
||
if api_hashes[k] != ods_hashes[k]
|
||
}
|
||
|
||
result = {
|
||
"table": table,
|
||
"api_count": len(api_hashes),
|
||
"ods_count": len(ods_hashes),
|
||
"missing_count": len(missing),
|
||
"extra_count": len(extra),
|
||
"mismatch_count": len(mismatch),
|
||
"is_consistent": len(missing) == 0 and len(mismatch) == 0,
|
||
}
|
||
|
||
# 5. 自动补齐
|
||
if auto_backfill and (missing or mismatch):
|
||
# 需要重新获取的主键
|
||
keys_to_refetch = missing | mismatch
|
||
|
||
# 筛选需要重新插入的记录
|
||
records_to_upsert = [
|
||
r for r in api_records
|
||
if tuple(r.get(col) for col in pk_cols) in keys_to_refetch
|
||
]
|
||
|
||
if records_to_upsert:
|
||
backfilled = self._batch_upsert(table, records_to_upsert)
|
||
result["backfilled_count"] = backfilled
|
||
|
||
return result
|
||
|
||
def _table_has_column(self, full_table: str, column: str) -> bool:
|
||
"""检查表是否包含指定列(带缓存)"""
|
||
cache_key = (full_table, column)
|
||
if cache_key in self._table_column_cache:
|
||
return self._table_column_cache[cache_key]
|
||
schema = "public"
|
||
table = full_table
|
||
if "." in full_table:
|
||
schema, table = full_table.split(".", 1)
|
||
sql = """
|
||
SELECT 1
|
||
FROM information_schema.columns
|
||
WHERE table_schema = %s AND table_name = %s AND column_name = %s
|
||
LIMIT 1
|
||
"""
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql, (schema, table, column))
|
||
exists = cur.fetchone() is not None
|
||
except Exception:
|
||
exists = False
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
self._table_column_cache[cache_key] = exists
|
||
return exists
|
||
|
||
def _get_db_primary_keys(self, full_table: str) -> List[str]:
|
||
"""Read primary key columns from database metadata (ordered)."""
|
||
if full_table in self._table_pk_cache:
|
||
return self._table_pk_cache[full_table]
|
||
|
||
schema = "public"
|
||
table = full_table
|
||
if "." in full_table:
|
||
schema, table = full_table.split(".", 1)
|
||
|
||
sql = """
|
||
SELECT kcu.column_name
|
||
FROM information_schema.table_constraints tc
|
||
JOIN information_schema.key_column_usage kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
AND tc.table_name = kcu.table_name
|
||
WHERE tc.table_schema = %s
|
||
AND tc.table_name = %s
|
||
AND tc.constraint_type = 'PRIMARY KEY'
|
||
ORDER BY kcu.ordinal_position
|
||
"""
|
||
try:
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql, (schema, table))
|
||
rows = cur.fetchall()
|
||
cols = [r[0] if not isinstance(r, dict) else r.get("column_name") for r in rows]
|
||
result = [c for c in cols if c]
|
||
except Exception:
|
||
result = []
|
||
try:
|
||
self.db.conn.rollback()
|
||
except Exception:
|
||
pass
|
||
|
||
self._table_pk_cache[full_table] = result
|
||
return result
|
||
|
||
def _compute_compare_hash_from_payload(self, payload: Any) -> Optional[str]:
|
||
"""使用 ODS 任务的算法计算对比哈希"""
|
||
try:
|
||
from tasks.ods.ods_tasks import BaseOdsTask
|
||
return BaseOdsTask._compute_compare_hash_from_payload(payload)
|
||
except Exception:
|
||
return None
|
||
|
||
def _compute_hash(self, record: dict) -> str:
|
||
"""计算记录的对比哈希(与 ODS 入库一致,不包含 fetched_at)"""
|
||
compare_hash = self._compute_compare_hash_from_payload(record)
|
||
return compare_hash or ""
|
||
|
||
def _batch_upsert(self, table: str, records: List[dict]) -> int:
|
||
"""Batch backfill in snapshot-safe mode (insert-only on PK conflict)."""
|
||
if not records:
|
||
return 0
|
||
|
||
full_table = self._get_full_table_name(table)
|
||
db_pk_cols = self._get_db_primary_keys(full_table)
|
||
if not db_pk_cols:
|
||
self.logger.warning("表 %s 未找到主键,跳过回填", full_table)
|
||
return 0
|
||
has_content_hash_col = self._table_has_column(full_table, "content_hash")
|
||
|
||
# 获取所有列(从第一条记录),并在存在 content_hash 列时补齐该列。
|
||
all_cols = list(records[0].keys())
|
||
if has_content_hash_col and "content_hash" not in all_cols:
|
||
all_cols.append("content_hash")
|
||
|
||
# Snapshot-safe strategy: never update historical rows; only insert new snapshots.
|
||
col_list = ", ".join(all_cols)
|
||
placeholders = ", ".join(["%s"] * len(all_cols))
|
||
pk_list = ", ".join(db_pk_cols)
|
||
|
||
sql = f"""
|
||
INSERT INTO {full_table} ({col_list})
|
||
VALUES ({placeholders})
|
||
ON CONFLICT ({pk_list}) DO NOTHING
|
||
"""
|
||
|
||
count = 0
|
||
with self.db.conn.cursor() as cur:
|
||
for record in records:
|
||
row = dict(record)
|
||
if has_content_hash_col:
|
||
row["content_hash"] = self._compute_hash(record)
|
||
values = [row.get(col) for col in all_cols]
|
||
try:
|
||
cur.execute(sql, values)
|
||
affected = int(cur.rowcount or 0)
|
||
if affected > 0:
|
||
count += affected
|
||
except Exception as e:
|
||
self.logger.warning("UPSERT 失败: %s, error=%s", record, e)
|
||
|
||
self.db.commit()
|
||
return count
|