Files
Neo-ZQYY/apps/etl/pipelines/feiqiu/tasks/verification/ods_verifier.py

872 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ODS 层批量校验器
校验逻辑:对比 API 源数据与 ODS 表数据
- 主键 + content_hash 对比
- 批量 UPSERT 补齐缺失/不一致数据
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from psycopg2.extras import execute_values
from api.local_json_client import LocalJsonClient
from .base_verifier import BaseVerifier, VerificationFetchError
class OdsVerifier(BaseVerifier):
"""ODS 层校验器"""
def __init__(
self,
db_connection: Any,
api_client: Any = None,
logger: Optional[logging.Logger] = None,
fetch_from_api: bool = False,
local_dump_dirs: Optional[Dict[str, str]] = None,
use_local_json: bool = False,
):
"""
初始化 ODS 校验器
Args:
db_connection: 数据库连接
api_client: API 客户端(用于重新获取数据)
logger: 日志器
fetch_from_api: 是否从 API 获取源数据进行校验(默认 False仅校验 ODS 内部一致性)
local_dump_dirs: 本地 JSON dump 目录映射task_code -> 目录)
use_local_json: 是否优先使用本地 JSON 作为源数据
"""
super().__init__(db_connection, logger)
self.api_client = api_client
self.fetch_from_api = fetch_from_api
self.local_dump_dirs = local_dump_dirs or {}
self.use_local_json = bool(use_local_json or self.local_dump_dirs)
# 缓存从 API 获取的数据(避免重复调用)
self._api_data_cache: Dict[str, List[dict]] = {}
self._api_key_cache: Dict[str, Set[Tuple]] = {}
self._api_hash_cache: Dict[str, Dict[Tuple, str]] = {}
self._table_column_cache: Dict[Tuple[str, str], bool] = {}
self._table_pk_cache: Dict[str, List[str]] = {}
self._local_json_clients: Dict[str, LocalJsonClient] = {}
# ODS 表配置:{表名: {pk_columns, time_column, api_endpoint}}
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "ODS"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 ODS 表配置"""
# 从任务定义中动态获取配置
try:
from tasks.ods.ods_tasks import ODS_TASK_SPECS
config = {}
for spec in ODS_TASK_SPECS:
# time_fields 是一个元组 (start_field, end_field),取第一个作为时间列
# 或者使用 fetched_at 作为默认
time_column = "fetched_at"
# 使用 table_name 属性(不是 table
table_name = spec.table_name
# 提取不带 schema 前缀的表名作为 key
if "." in table_name:
table_key = table_name.split(".")[-1]
else:
table_key = table_name
# 从 sources 中提取 ODS 表的实际主键列名
# sources 格式如 ("settleList.id", "id"),最后一个简单名称是 ODS 列名
pk_columns = []
for col in spec.pk_columns:
ods_col_name = self._extract_ods_column_name(col)
pk_columns.append(ods_col_name)
# 如果 pk_columns 为空,尝试使用 conflict_columns_override 或跳过校验
# 一些特殊表(如 goods_stock_summary, settlement_ticket_details没有标准主键
if not pk_columns:
# 跳过没有明确主键定义的表
self.logger.debug("%s 没有定义主键列,跳过校验配置", table_key)
continue
config[table_key] = {
"full_table_name": table_name,
"pk_columns": pk_columns,
"time_column": time_column,
"api_endpoint": spec.endpoint,
"task_code": spec.code,
}
return config
except ImportError:
self.logger.warning("无法加载 ODS 任务定义,使用默认配置")
return {}
def _extract_ods_column_name(self, col) -> str:
"""
从 ColumnSpec 中提取 ODS 表的实际列名
ODS 表使用原始 JSON 字段名(小写),而 col.column 是 DWD 层的命名。
sources 中的最后一个简单字段名通常就是 ODS 表的列名。
"""
# 如果 sources 为空,使用 column假设 column 就是 ODS 列名)
if not col.sources:
return col.column
# 遍历 sources找到最简单的字段名不含点号的
for source in reversed(col.sources):
if "." not in source:
return source.lower() # ODS 列名通常是小写
# 如果都是复杂路径,取最后一个路径的最后一部分
last_source = col.sources[-1]
if "." in last_source:
return last_source.split(".")[-1].lower()
return last_source.lower()
def get_tables(self) -> List[str]:
"""获取需要校验的 ODS 表列表"""
if self._table_config:
return list(self._table_config.keys())
# 从数据库查询 ODS schema 中的表
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_ods'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
except Exception as e:
self.logger.warning("获取 ODS 表列表失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return []
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
if table in self._table_config:
return self._table_config[table].get("pk_columns", [])
# 表不在配置中,返回空列表表示无法校验
return []
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列"""
if table in self._table_config:
return self._table_config[table].get("time_column", "fetched_at")
return "fetched_at"
def _get_full_table_name(self, table: str) -> str:
"""获取完整的表名(包含 schema"""
if table in self._table_config:
return self._table_config[table].get("full_table_name", f"billiards_ods.{table}")
# 如果表名已经包含 schema直接返回
if "." in table:
return table
return f"billiards_ods.{table}"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""
从源获取主键集合
根据 fetch_from_api 参数决定数据来源:
- fetch_from_api=True: 从 API 获取数据(真正的源到目标校验)
- fetch_from_api=False: 从 ODS 表获取ODS 内部一致性校验)
"""
if self._has_external_source():
return self._fetch_keys_from_api(table, window_start, window_end)
else:
# ODS 内部校验:直接从 ODS 表获取
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 API 获取源数据主键集合"""
# 尝试获取缓存的 API 数据
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_key_cache:
return self._api_key_cache[cache_key]
if cache_key not in self._api_data_cache:
# 调用 API 获取数据
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
self.logger.debug("%s 从 API 未获取到数据", table)
return set()
# 获取主键列
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过 API 校验", table)
return set()
# 提取主键
keys = set()
for record in api_records:
pk_values = []
for col in pk_cols:
# API 返回的字段名可能是原始格式(如 id, Id, ID
# 尝试多种格式
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
keys.add(tuple(pk_values))
self.logger.info("%s 从源数据获取 %d 条记录,%d 个唯一主键", table, len(api_records), len(keys))
self._api_key_cache[cache_key] = keys
return keys
def _call_api_for_table(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> List[dict]:
"""调用源数据获取表对应的数据"""
config = self._table_config.get(table, {})
task_code = config.get("task_code")
endpoint = config.get("api_endpoint")
if not task_code or not endpoint:
self.logger.warning(
"%s 没有完整的任务配置task_code=%s, endpoint=%s),无法获取源数据",
table, task_code, endpoint
)
return []
source_client = self._get_source_client(task_code)
if not source_client:
self.logger.warning("%s 未找到可用源API/本地JSON跳过获取源数据", table)
return []
source_label = "本地 JSON" if self._is_using_local_json(task_code) else "API"
self.logger.info(
"%s 获取数据: 表=%s, 端点=%s, 时间窗口=%s ~ %s",
source_label, table, endpoint, window_start, window_end
)
try:
# 获取 ODS 任务规格以获取正确的参数配置
from tasks.ods.ods_tasks import ODS_TASK_SPECS
# 查找对应的任务规格
spec = None
for s in ODS_TASK_SPECS:
if s.code == task_code:
spec = s
break
if not spec:
self.logger.warning("未找到任务规格: %s", task_code)
return []
# 构建 API 参数
params = {}
if spec.include_site_id:
# 从 API 客户端获取 store_id如果可用
store_id = getattr(self.api_client, 'store_id', None)
if store_id:
params["siteId"] = store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
# 格式化时间戳
params[start_key] = window_start.strftime("%Y-%m-%d %H:%M:%S")
params[end_key] = window_end.strftime("%Y-%m-%d %H:%M:%S")
# 合并额外参数
params.update(spec.extra_params)
# 调用源数据
all_records = []
for _, page_records, _, _ in source_client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=200,
data_path=spec.data_path,
list_key=spec.list_key,
):
all_records.extend(page_records)
self.logger.info("源数据返回 %d 条原始记录", len(all_records))
return all_records
except Exception as e:
self.logger.warning("获取源数据失败: 表=%s, error=%s", table, e)
import traceback
self.logger.debug("调用栈: %s", traceback.format_exc())
raise VerificationFetchError(f"获取源数据失败: {table}") from e
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 ODS 表获取目标数据主键集合"""
if self._has_external_source():
cache_key = f"{table}_{window_start}_{window_end}"
api_keys = self._api_key_cache.get(cache_key)
if api_keys is None:
api_keys = self._fetch_keys_from_api(table, window_start, window_end)
return self._fetch_keys_from_db_by_keys(table, api_keys)
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从数据库获取主键集合"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取主键", table)
return set()
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
return {tuple(row) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 ODS 主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS 主键失败: {table}") from e
def _fetch_keys_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Set[Tuple]:
"""按主键集合反查 ODS 表是否存在记录(不依赖时间窗口)"""
if not keys:
return set()
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查", table)
return set()
full_table = self._get_full_table_name(table)
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
except Exception as e:
self.logger.warning("按主键反查 ODS 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS 失败: {table}") from e
return existing
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取源数据的主键->content_hash 映射"""
if self._has_external_source():
return self._fetch_hashes_from_api(table, window_start, window_end)
else:
# ODS 表自带 content_hash 列
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 API 数据计算哈希"""
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_hash_cache:
return self._api_hash_cache[cache_key]
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
# 尝试从 API 获取
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
if not api_records:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
return {}
result = {}
for record in api_records:
# 提取主键
pk_values = []
for col in pk_cols:
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
pk = tuple(pk_values)
# 计算内容哈希
content_hash = self._compute_hash(record)
result[pk] = content_hash
self._api_hash_cache[cache_key] = result
if cache_key not in self._api_key_cache:
self._api_key_cache[cache_key] = set(result.keys())
return result
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取目标数据的主键->content_hash 映射"""
if self.fetch_from_api and self.api_client:
cache_key = f"{table}_{window_start}_{window_end}"
api_hashes = self._api_hash_cache.get(cache_key)
if api_hashes is None:
api_hashes = self._fetch_hashes_from_api(table, window_start, window_end)
api_keys = set(api_hashes.keys())
return self._fetch_hashes_from_db_by_keys(table, api_keys)
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从数据库获取主键->hash 映射"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取哈希", table)
return {}
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}, content_hash
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
for row in cur.fetchall():
pk = tuple(row[:-1])
content_hash = row[-1]
result[pk] = content_hash or ""
except Exception as e:
# 查询失败时回滚事务,避免影响后续查询
self.logger.warning("获取 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS hash 失败: {table}") from e
return result
def _fetch_hashes_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Dict[Tuple, str]:
"""按主键集合反查 ODS 的对比哈希(不依赖时间窗口)"""
if not keys:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查 hash", table)
return {}
full_table = self._get_full_table_name(table)
has_payload = self._table_has_column(full_table, "payload")
select_tail = 't."payload"' if has_payload else 't."content_hash"'
select_cols = ", ".join([*(f't."{c}"' for c in pk_cols), select_tail])
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
result: Dict[Tuple, str] = {}
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
pk = tuple(row[:-1])
tail_value = row[-1]
if has_payload:
compare_hash = self._compute_compare_hash_from_payload(tail_value)
result[pk] = compare_hash or ""
else:
result[pk] = tail_value or ""
except Exception as e:
self.logger.warning("按主键反查 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS hash 失败: {table}") from e
return result
@staticmethod
def _chunked(items: List[Tuple], chunk_size: int) -> List[List[Tuple]]:
"""将列表按固定大小分块"""
if chunk_size <= 0:
return [items]
return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量补齐缺失数据
ODS 层补齐需要重新从 API 获取数据
"""
if not self._has_external_source():
self.logger.warning("未配置 API/本地JSON 源,无法补齐 ODS 缺失数据")
return 0
if not missing_keys:
return 0
# 获取表配置
config = self._table_config.get(table, {})
task_code = config.get("task_code")
if not task_code:
self.logger.warning("未找到表 %s 的任务配置,跳过补齐", table)
return 0
self.logger.info(
"ODS 补齐缺失: 表=%s, 数量=%d, 任务=%s",
table, len(missing_keys), task_code
)
# ODS 层的补齐实际上是重新执行 ODS 任务从 API 获取数据
# 但由于 ODS 任务已经在 "校验前先从 API 获取数据" 步骤执行过了,
# 这里补齐失败是预期的(数据已经在 ODS 表中,只是校验窗口可能不一致)
#
# 实际的 ODS 补齐应该在 verify_only 模式下启用 fetch_before_verify 选项,
# 这会先执行 ODS 任务获取 API 数据,然后再校验。
#
# 如果仍然有缺失,说明:
# 1. API 返回的数据时间窗口与校验窗口不完全匹配
# 2. 或者 ODS 任务的时间参数配置问题
self.logger.info(
"ODS 补齐提示: 表=%s%d 条缺失记录,建议使用 '校验前先从 API 获取数据' 选项获取完整数据",
table, len(missing_keys)
)
return 0
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量更新不一致数据
ODS 层更新也需要重新从 API 获取
"""
# 与 backfill_missing 类似,重新获取数据会自动 UPSERT
return self.backfill_missing(table, mismatch_keys, window_start, window_end)
def _has_external_source(self) -> bool:
return bool(self.fetch_from_api and (self.api_client or self.use_local_json))
def _is_using_local_json(self, task_code: str) -> bool:
return bool(self.use_local_json and task_code in self.local_dump_dirs)
def _get_local_json_client(self, task_code: str) -> Optional[LocalJsonClient]:
if task_code in self._local_json_clients:
return self._local_json_clients[task_code]
dump_dir = self.local_dump_dirs.get(task_code)
if not dump_dir:
return None
try:
client = LocalJsonClient(dump_dir)
except Exception as exc: # noqa: BLE001
self.logger.warning(
"本地 JSON 目录不可用: task=%s, dir=%s, error=%s",
task_code, dump_dir, exc,
)
return None
self._local_json_clients[task_code] = client
return client
def _get_source_client(self, task_code: str):
if self.use_local_json:
return self._get_local_json_client(task_code)
return self.api_client
def verify_against_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
auto_backfill: bool = False,
) -> Dict[str, Any]:
"""
与 API 源数据对比校验
这是更严格的校验,直接调用 API 获取数据进行对比
"""
if not self.api_client:
return {"error": "未配置 API 客户端"}
config = self._table_config.get(table, {})
endpoint = config.get("api_endpoint")
if not endpoint:
return {"error": f"未找到表 {table} 的 API 端点配置"}
self.logger.info("开始与 API 对比校验: %s", table)
# 1. 从 API 获取数据
try:
api_records = self.api_client.fetch_records(
endpoint=endpoint,
start_time=window_start,
end_time=window_end,
)
except Exception as e:
return {"error": f"API 调用失败: {e}"}
# 2. 从 ODS 获取数据
ods_hashes = self.fetch_target_hashes(table, window_start, window_end)
# 3. 计算 API 数据的 hash
pk_cols = self.get_primary_keys(table)
api_hashes = {}
for record in api_records:
pk = tuple(record.get(col) for col in pk_cols)
content_hash = self._compute_hash(record)
api_hashes[pk] = content_hash
# 4. 对比
api_keys = set(api_hashes.keys())
ods_keys = set(ods_hashes.keys())
missing = api_keys - ods_keys
extra = ods_keys - api_keys
mismatch = {
k for k in (api_keys & ods_keys)
if api_hashes[k] != ods_hashes[k]
}
result = {
"table": table,
"api_count": len(api_hashes),
"ods_count": len(ods_hashes),
"missing_count": len(missing),
"extra_count": len(extra),
"mismatch_count": len(mismatch),
"is_consistent": len(missing) == 0 and len(mismatch) == 0,
}
# 5. 自动补齐
if auto_backfill and (missing or mismatch):
# 需要重新获取的主键
keys_to_refetch = missing | mismatch
# 筛选需要重新插入的记录
records_to_upsert = [
r for r in api_records
if tuple(r.get(col) for col in pk_cols) in keys_to_refetch
]
if records_to_upsert:
backfilled = self._batch_upsert(table, records_to_upsert)
result["backfilled_count"] = backfilled
return result
def _table_has_column(self, full_table: str, column: str) -> bool:
"""检查表是否包含指定列(带缓存)"""
cache_key = (full_table, column)
if cache_key in self._table_column_cache:
return self._table_column_cache[cache_key]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT 1
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s AND column_name = %s
LIMIT 1
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table, column))
exists = cur.fetchone() is not None
except Exception:
exists = False
try:
self.db.conn.rollback()
except Exception:
pass
self._table_column_cache[cache_key] = exists
return exists
def _get_db_primary_keys(self, full_table: str) -> List[str]:
"""Read primary key columns from database metadata (ordered)."""
if full_table in self._table_pk_cache:
return self._table_pk_cache[full_table]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table))
rows = cur.fetchall()
cols = [r[0] if not isinstance(r, dict) else r.get("column_name") for r in rows]
result = [c for c in cols if c]
except Exception:
result = []
try:
self.db.conn.rollback()
except Exception:
pass
self._table_pk_cache[full_table] = result
return result
def _compute_compare_hash_from_payload(self, payload: Any) -> Optional[str]:
"""使用 ODS 任务的算法计算对比哈希"""
try:
from tasks.ods.ods_tasks import BaseOdsTask
return BaseOdsTask._compute_compare_hash_from_payload(payload)
except Exception:
return None
def _compute_hash(self, record: dict) -> str:
"""计算记录的对比哈希(与 ODS 入库一致,不包含 fetched_at"""
compare_hash = self._compute_compare_hash_from_payload(record)
return compare_hash or ""
def _batch_upsert(self, table: str, records: List[dict]) -> int:
"""Batch backfill in snapshot-safe mode (insert-only on PK conflict)."""
if not records:
return 0
full_table = self._get_full_table_name(table)
db_pk_cols = self._get_db_primary_keys(full_table)
if not db_pk_cols:
self.logger.warning("%s 未找到主键,跳过回填", full_table)
return 0
has_content_hash_col = self._table_has_column(full_table, "content_hash")
# 获取所有列(从第一条记录),并在存在 content_hash 列时补齐该列。
all_cols = list(records[0].keys())
if has_content_hash_col and "content_hash" not in all_cols:
all_cols.append("content_hash")
# Snapshot-safe strategy: never update historical rows; only insert new snapshots.
col_list = ", ".join(all_cols)
placeholders = ", ".join(["%s"] * len(all_cols))
pk_list = ", ".join(db_pk_cols)
sql = f"""
INSERT INTO {full_table} ({col_list})
VALUES ({placeholders})
ON CONFLICT ({pk_list}) DO NOTHING
"""
count = 0
with self.db.conn.cursor() as cur:
for record in records:
row = dict(record)
if has_content_hash_col:
row["content_hash"] = self._compute_hash(record)
values = [row.get(col) for col in all_cols]
try:
cur.execute(sql, values)
affected = int(cur.rowcount or 0)
if affected > 0:
count += affected
except Exception as e:
self.logger.warning("UPSERT 失败: %s, error=%s", record, e)
self.db.commit()
return count