初始提交:飞球 ETL 系统全量代码

This commit is contained in:
Neo
2026-02-13 08:05:34 +08:00
commit 3c51f5485d
441 changed files with 117631 additions and 0 deletions

View File

@@ -0,0 +1,871 @@
# -*- coding: utf-8 -*-
"""ODS 层批量校验器
校验逻辑:对比 API 源数据与 ODS 表数据
- 主键 + content_hash 对比
- 批量 UPSERT 补齐缺失/不一致数据
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from psycopg2.extras import execute_values
from api.local_json_client import LocalJsonClient
from .base_verifier import BaseVerifier, VerificationFetchError
class OdsVerifier(BaseVerifier):
"""ODS 层校验器"""
def __init__(
self,
db_connection: Any,
api_client: Any = None,
logger: Optional[logging.Logger] = None,
fetch_from_api: bool = False,
local_dump_dirs: Optional[Dict[str, str]] = None,
use_local_json: bool = False,
):
"""
初始化 ODS 校验器
Args:
db_connection: 数据库连接
api_client: API 客户端(用于重新获取数据)
logger: 日志器
fetch_from_api: 是否从 API 获取源数据进行校验(默认 False仅校验 ODS 内部一致性)
local_dump_dirs: 本地 JSON dump 目录映射task_code -> 目录)
use_local_json: 是否优先使用本地 JSON 作为源数据
"""
super().__init__(db_connection, logger)
self.api_client = api_client
self.fetch_from_api = fetch_from_api
self.local_dump_dirs = local_dump_dirs or {}
self.use_local_json = bool(use_local_json or self.local_dump_dirs)
# 缓存从 API 获取的数据(避免重复调用)
self._api_data_cache: Dict[str, List[dict]] = {}
self._api_key_cache: Dict[str, Set[Tuple]] = {}
self._api_hash_cache: Dict[str, Dict[Tuple, str]] = {}
self._table_column_cache: Dict[Tuple[str, str], bool] = {}
self._table_pk_cache: Dict[str, List[str]] = {}
self._local_json_clients: Dict[str, LocalJsonClient] = {}
# ODS 表配置:{表名: {pk_columns, time_column, api_endpoint}}
self._table_config = self._load_table_config()
@property
def layer_name(self) -> str:
return "ODS"
def _load_table_config(self) -> Dict[str, dict]:
"""加载 ODS 表配置"""
# 从任务定义中动态获取配置
try:
from tasks.ods.ods_tasks import ODS_TASK_SPECS
config = {}
for spec in ODS_TASK_SPECS:
# time_fields 是一个元组 (start_field, end_field),取第一个作为时间列
# 或者使用 fetched_at 作为默认
time_column = "fetched_at"
# 使用 table_name 属性(不是 table
table_name = spec.table_name
# 提取不带 schema 前缀的表名作为 key
if "." in table_name:
table_key = table_name.split(".")[-1]
else:
table_key = table_name
# 从 sources 中提取 ODS 表的实际主键列名
# sources 格式如 ("settleList.id", "id"),最后一个简单名称是 ODS 列名
pk_columns = []
for col in spec.pk_columns:
ods_col_name = self._extract_ods_column_name(col)
pk_columns.append(ods_col_name)
# 如果 pk_columns 为空,尝试使用 conflict_columns_override 或跳过校验
# 一些特殊表(如 goods_stock_summary, settlement_ticket_details没有标准主键
if not pk_columns:
# 跳过没有明确主键定义的表
self.logger.debug("%s 没有定义主键列,跳过校验配置", table_key)
continue
config[table_key] = {
"full_table_name": table_name,
"pk_columns": pk_columns,
"time_column": time_column,
"api_endpoint": spec.endpoint,
"task_code": spec.code,
}
return config
except ImportError:
self.logger.warning("无法加载 ODS 任务定义,使用默认配置")
return {}
def _extract_ods_column_name(self, col) -> str:
"""
从 ColumnSpec 中提取 ODS 表的实际列名
ODS 表使用原始 JSON 字段名(小写),而 col.column 是 DWD 层的命名。
sources 中的最后一个简单字段名通常就是 ODS 表的列名。
"""
# 如果 sources 为空,使用 column假设 column 就是 ODS 列名)
if not col.sources:
return col.column
# 遍历 sources找到最简单的字段名不含点号的
for source in reversed(col.sources):
if "." not in source:
return source.lower() # ODS 列名通常是小写
# 如果都是复杂路径,取最后一个路径的最后一部分
last_source = col.sources[-1]
if "." in last_source:
return last_source.split(".")[-1].lower()
return last_source.lower()
def get_tables(self) -> List[str]:
"""获取需要校验的 ODS 表列表"""
if self._table_config:
return list(self._table_config.keys())
# 从数据库查询 ODS schema 中的表
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_ods'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql)
return [row[0] for row in cur.fetchall()]
except Exception as e:
self.logger.warning("获取 ODS 表列表失败: %s", e)
try:
self.db.conn.rollback()
except Exception:
pass
return []
def get_primary_keys(self, table: str) -> List[str]:
"""获取表的主键列"""
if table in self._table_config:
return self._table_config[table].get("pk_columns", [])
# 表不在配置中,返回空列表表示无法校验
return []
def get_time_column(self, table: str) -> Optional[str]:
"""获取表的时间列"""
if table in self._table_config:
return self._table_config[table].get("time_column", "fetched_at")
return "fetched_at"
def _get_full_table_name(self, table: str) -> str:
"""获取完整的表名(包含 schema"""
if table in self._table_config:
return self._table_config[table].get("full_table_name", f"billiards_ods.{table}")
# 如果表名已经包含 schema直接返回
if "." in table:
return table
return f"billiards_ods.{table}"
def fetch_source_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""
从源获取主键集合
根据 fetch_from_api 参数决定数据来源:
- fetch_from_api=True: 从 API 获取数据(真正的源到目标校验)
- fetch_from_api=False: 从 ODS 表获取ODS 内部一致性校验)
"""
if self._has_external_source():
return self._fetch_keys_from_api(table, window_start, window_end)
else:
# ODS 内部校验:直接从 ODS 表获取
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 API 获取源数据主键集合"""
# 尝试获取缓存的 API 数据
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_key_cache:
return self._api_key_cache[cache_key]
if cache_key not in self._api_data_cache:
# 调用 API 获取数据
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
self.logger.debug("%s 从 API 未获取到数据", table)
return set()
# 获取主键列
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过 API 校验", table)
return set()
# 提取主键
keys = set()
for record in api_records:
pk_values = []
for col in pk_cols:
# API 返回的字段名可能是原始格式(如 id, Id, ID
# 尝试多种格式
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
keys.add(tuple(pk_values))
self.logger.info("%s 从源数据获取 %d 条记录,%d 个唯一主键", table, len(api_records), len(keys))
self._api_key_cache[cache_key] = keys
return keys
def _call_api_for_table(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> List[dict]:
"""调用源数据获取表对应的数据"""
config = self._table_config.get(table, {})
task_code = config.get("task_code")
endpoint = config.get("api_endpoint")
if not task_code or not endpoint:
self.logger.warning(
"%s 没有完整的任务配置task_code=%s, endpoint=%s),无法获取源数据",
table, task_code, endpoint
)
return []
source_client = self._get_source_client(task_code)
if not source_client:
self.logger.warning("%s 未找到可用源API/本地JSON跳过获取源数据", table)
return []
source_label = "本地 JSON" if self._is_using_local_json(task_code) else "API"
self.logger.info(
"%s 获取数据: 表=%s, 端点=%s, 时间窗口=%s ~ %s",
source_label, table, endpoint, window_start, window_end
)
try:
# 获取 ODS 任务规格以获取正确的参数配置
from tasks.ods.ods_tasks import ODS_TASK_SPECS
# 查找对应的任务规格
spec = None
for s in ODS_TASK_SPECS:
if s.code == task_code:
spec = s
break
if not spec:
self.logger.warning("未找到任务规格: %s", task_code)
return []
# 构建 API 参数
params = {}
if spec.include_site_id:
# 从 API 客户端获取 store_id如果可用
store_id = getattr(self.api_client, 'store_id', None)
if store_id:
params["siteId"] = store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
# 格式化时间戳
params[start_key] = window_start.strftime("%Y-%m-%d %H:%M:%S")
params[end_key] = window_end.strftime("%Y-%m-%d %H:%M:%S")
# 合并额外参数
params.update(spec.extra_params)
# 调用源数据
all_records = []
for _, page_records, _, _ in source_client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=200,
data_path=spec.data_path,
list_key=spec.list_key,
):
all_records.extend(page_records)
self.logger.info("源数据返回 %d 条原始记录", len(all_records))
return all_records
except Exception as e:
self.logger.warning("获取源数据失败: 表=%s, error=%s", table, e)
import traceback
self.logger.debug("调用栈: %s", traceback.format_exc())
raise VerificationFetchError(f"获取源数据失败: {table}") from e
def fetch_target_keys(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从 ODS 表获取目标数据主键集合"""
if self._has_external_source():
cache_key = f"{table}_{window_start}_{window_end}"
api_keys = self._api_key_cache.get(cache_key)
if api_keys is None:
api_keys = self._fetch_keys_from_api(table, window_start, window_end)
return self._fetch_keys_from_db_by_keys(table, api_keys)
return self._fetch_keys_from_db(table, window_start, window_end)
def _fetch_keys_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Set[Tuple]:
"""从数据库获取主键集合"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取主键", table)
return set()
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
return {tuple(row) for row in cur.fetchall()}
except Exception as e:
self.logger.warning("获取 ODS 主键失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS 主键失败: {table}") from e
def _fetch_keys_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Set[Tuple]:
"""按主键集合反查 ODS 表是否存在记录(不依赖时间窗口)"""
if not keys:
return set()
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查", table)
return set()
full_table = self._get_full_table_name(table)
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
except Exception as e:
self.logger.warning("按主键反查 ODS 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS 失败: {table}") from e
return existing
def fetch_source_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取源数据的主键->content_hash 映射"""
if self._has_external_source():
return self._fetch_hashes_from_api(table, window_start, window_end)
else:
# ODS 表自带 content_hash 列
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从 API 数据计算哈希"""
cache_key = f"{table}_{window_start}_{window_end}"
if cache_key in self._api_hash_cache:
return self._api_hash_cache[cache_key]
api_records = self._api_data_cache.get(cache_key, [])
if not api_records:
# 尝试从 API 获取
api_records = self._call_api_for_table(table, window_start, window_end)
self._api_data_cache[cache_key] = api_records
if not api_records:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
return {}
result = {}
for record in api_records:
# 提取主键
pk_values = []
for col in pk_cols:
value = record.get(col)
if value is None:
value = record.get(col.lower())
if value is None:
value = record.get(col.upper())
pk_values.append(value)
if all(v is not None for v in pk_values):
pk = tuple(pk_values)
# 计算内容哈希
content_hash = self._compute_hash(record)
result[pk] = content_hash
self._api_hash_cache[cache_key] = result
if cache_key not in self._api_key_cache:
self._api_key_cache[cache_key] = set(result.keys())
return result
def fetch_target_hashes(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""获取目标数据的主键->content_hash 映射"""
if self.fetch_from_api and self.api_client:
cache_key = f"{table}_{window_start}_{window_end}"
api_hashes = self._api_hash_cache.get(cache_key)
if api_hashes is None:
api_hashes = self._fetch_hashes_from_api(table, window_start, window_end)
api_keys = set(api_hashes.keys())
return self._fetch_hashes_from_db_by_keys(table, api_keys)
return self._fetch_hashes_from_db(table, window_start, window_end)
def _fetch_hashes_from_db(
self,
table: str,
window_start: datetime,
window_end: datetime,
) -> Dict[Tuple, str]:
"""从数据库获取主键->hash 映射"""
pk_cols = self.get_primary_keys(table)
# 如果没有主键列配置,跳过校验
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过获取哈希", table)
return {}
time_col = self.get_time_column(table)
full_table = self._get_full_table_name(table)
pk_select = ", ".join(pk_cols)
sql = f"""
SELECT {pk_select}, content_hash
FROM {full_table}
WHERE {time_col} >= %s AND {time_col} < %s
"""
result = {}
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (window_start, window_end))
for row in cur.fetchall():
pk = tuple(row[:-1])
content_hash = row[-1]
result[pk] = content_hash or ""
except Exception as e:
# 查询失败时回滚事务,避免影响后续查询
self.logger.warning("获取 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"获取 ODS hash 失败: {table}") from e
return result
def _fetch_hashes_from_db_by_keys(self, table: str, keys: Set[Tuple]) -> Dict[Tuple, str]:
"""按主键集合反查 ODS 的对比哈希(不依赖时间窗口)"""
if not keys:
return {}
pk_cols = self.get_primary_keys(table)
if not pk_cols:
self.logger.debug("%s 没有主键配置,跳过按主键反查 hash", table)
return {}
full_table = self._get_full_table_name(table)
has_payload = self._table_has_column(full_table, "payload")
select_tail = 't."payload"' if has_payload else 't."content_hash"'
select_cols = ", ".join([*(f't."{c}"' for c in pk_cols), select_tail])
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {full_table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
result: Dict[Tuple, str] = {}
try:
with self.db.conn.cursor() as cur:
for chunk in self._chunked(list(keys), 500):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
pk = tuple(row[:-1])
tail_value = row[-1]
if has_payload:
compare_hash = self._compute_compare_hash_from_payload(tail_value)
result[pk] = compare_hash or ""
else:
result[pk] = tail_value or ""
except Exception as e:
self.logger.warning("按主键反查 ODS hash 失败: %s, error=%s", table, e)
try:
self.db.conn.rollback()
except Exception:
pass
raise VerificationFetchError(f"按主键反查 ODS hash 失败: {table}") from e
return result
@staticmethod
def _chunked(items: List[Tuple], chunk_size: int) -> List[List[Tuple]]:
"""将列表按固定大小分块"""
if chunk_size <= 0:
return [items]
return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)]
def backfill_missing(
self,
table: str,
missing_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量补齐缺失数据
ODS 层补齐需要重新从 API 获取数据
"""
if not self._has_external_source():
self.logger.warning("未配置 API/本地JSON 源,无法补齐 ODS 缺失数据")
return 0
if not missing_keys:
return 0
# 获取表配置
config = self._table_config.get(table, {})
task_code = config.get("task_code")
if not task_code:
self.logger.warning("未找到表 %s 的任务配置,跳过补齐", table)
return 0
self.logger.info(
"ODS 补齐缺失: 表=%s, 数量=%d, 任务=%s",
table, len(missing_keys), task_code
)
# ODS 层的补齐实际上是重新执行 ODS 任务从 API 获取数据
# 但由于 ODS 任务已经在 "校验前先从 API 获取数据" 步骤执行过了,
# 这里补齐失败是预期的(数据已经在 ODS 表中,只是校验窗口可能不一致)
#
# 实际的 ODS 补齐应该在 verify_only 模式下启用 fetch_before_verify 选项,
# 这会先执行 ODS 任务获取 API 数据,然后再校验。
#
# 如果仍然有缺失,说明:
# 1. API 返回的数据时间窗口与校验窗口不完全匹配
# 2. 或者 ODS 任务的时间参数配置问题
self.logger.info(
"ODS 补齐提示: 表=%s%d 条缺失记录,建议使用 '校验前先从 API 获取数据' 选项获取完整数据",
table, len(missing_keys)
)
return 0
def backfill_mismatch(
self,
table: str,
mismatch_keys: Set[Tuple],
window_start: datetime,
window_end: datetime,
) -> int:
"""
批量更新不一致数据
ODS 层更新也需要重新从 API 获取
"""
# 与 backfill_missing 类似,重新获取数据会自动 UPSERT
return self.backfill_missing(table, mismatch_keys, window_start, window_end)
def _has_external_source(self) -> bool:
return bool(self.fetch_from_api and (self.api_client or self.use_local_json))
def _is_using_local_json(self, task_code: str) -> bool:
return bool(self.use_local_json and task_code in self.local_dump_dirs)
def _get_local_json_client(self, task_code: str) -> Optional[LocalJsonClient]:
if task_code in self._local_json_clients:
return self._local_json_clients[task_code]
dump_dir = self.local_dump_dirs.get(task_code)
if not dump_dir:
return None
try:
client = LocalJsonClient(dump_dir)
except Exception as exc: # noqa: BLE001
self.logger.warning(
"本地 JSON 目录不可用: task=%s, dir=%s, error=%s",
task_code, dump_dir, exc,
)
return None
self._local_json_clients[task_code] = client
return client
def _get_source_client(self, task_code: str):
if self.use_local_json:
return self._get_local_json_client(task_code)
return self.api_client
def verify_against_api(
self,
table: str,
window_start: datetime,
window_end: datetime,
auto_backfill: bool = False,
) -> Dict[str, Any]:
"""
与 API 源数据对比校验
这是更严格的校验,直接调用 API 获取数据进行对比
"""
if not self.api_client:
return {"error": "未配置 API 客户端"}
config = self._table_config.get(table, {})
endpoint = config.get("api_endpoint")
if not endpoint:
return {"error": f"未找到表 {table} 的 API 端点配置"}
self.logger.info("开始与 API 对比校验: %s", table)
# 1. 从 API 获取数据
try:
api_records = self.api_client.fetch_records(
endpoint=endpoint,
start_time=window_start,
end_time=window_end,
)
except Exception as e:
return {"error": f"API 调用失败: {e}"}
# 2. 从 ODS 获取数据
ods_hashes = self.fetch_target_hashes(table, window_start, window_end)
# 3. 计算 API 数据的 hash
pk_cols = self.get_primary_keys(table)
api_hashes = {}
for record in api_records:
pk = tuple(record.get(col) for col in pk_cols)
content_hash = self._compute_hash(record)
api_hashes[pk] = content_hash
# 4. 对比
api_keys = set(api_hashes.keys())
ods_keys = set(ods_hashes.keys())
missing = api_keys - ods_keys
extra = ods_keys - api_keys
mismatch = {
k for k in (api_keys & ods_keys)
if api_hashes[k] != ods_hashes[k]
}
result = {
"table": table,
"api_count": len(api_hashes),
"ods_count": len(ods_hashes),
"missing_count": len(missing),
"extra_count": len(extra),
"mismatch_count": len(mismatch),
"is_consistent": len(missing) == 0 and len(mismatch) == 0,
}
# 5. 自动补齐
if auto_backfill and (missing or mismatch):
# 需要重新获取的主键
keys_to_refetch = missing | mismatch
# 筛选需要重新插入的记录
records_to_upsert = [
r for r in api_records
if tuple(r.get(col) for col in pk_cols) in keys_to_refetch
]
if records_to_upsert:
backfilled = self._batch_upsert(table, records_to_upsert)
result["backfilled_count"] = backfilled
return result
def _table_has_column(self, full_table: str, column: str) -> bool:
"""检查表是否包含指定列(带缓存)"""
cache_key = (full_table, column)
if cache_key in self._table_column_cache:
return self._table_column_cache[cache_key]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT 1
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s AND column_name = %s
LIMIT 1
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table, column))
exists = cur.fetchone() is not None
except Exception:
exists = False
try:
self.db.conn.rollback()
except Exception:
pass
self._table_column_cache[cache_key] = exists
return exists
def _get_db_primary_keys(self, full_table: str) -> List[str]:
"""Read primary key columns from database metadata (ordered)."""
if full_table in self._table_pk_cache:
return self._table_pk_cache[full_table]
schema = "public"
table = full_table
if "." in full_table:
schema, table = full_table.split(".", 1)
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
"""
try:
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, table))
rows = cur.fetchall()
cols = [r[0] if not isinstance(r, dict) else r.get("column_name") for r in rows]
result = [c for c in cols if c]
except Exception:
result = []
try:
self.db.conn.rollback()
except Exception:
pass
self._table_pk_cache[full_table] = result
return result
def _compute_compare_hash_from_payload(self, payload: Any) -> Optional[str]:
"""使用 ODS 任务的算法计算对比哈希"""
try:
from tasks.ods.ods_tasks import BaseOdsTask
return BaseOdsTask._compute_compare_hash_from_payload(payload)
except Exception:
return None
def _compute_hash(self, record: dict) -> str:
"""计算记录的对比哈希(与 ODS 入库一致,不包含 fetched_at"""
compare_hash = self._compute_compare_hash_from_payload(record)
return compare_hash or ""
def _batch_upsert(self, table: str, records: List[dict]) -> int:
"""Batch backfill in snapshot-safe mode (insert-only on PK conflict)."""
if not records:
return 0
full_table = self._get_full_table_name(table)
db_pk_cols = self._get_db_primary_keys(full_table)
if not db_pk_cols:
self.logger.warning("%s 未找到主键,跳过回填", full_table)
return 0
has_content_hash_col = self._table_has_column(full_table, "content_hash")
# 获取所有列(从第一条记录),并在存在 content_hash 列时补齐该列。
all_cols = list(records[0].keys())
if has_content_hash_col and "content_hash" not in all_cols:
all_cols.append("content_hash")
# Snapshot-safe strategy: never update historical rows; only insert new snapshots.
col_list = ", ".join(all_cols)
placeholders = ", ".join(["%s"] * len(all_cols))
pk_list = ", ".join(db_pk_cols)
sql = f"""
INSERT INTO {full_table} ({col_list})
VALUES ({placeholders})
ON CONFLICT ({pk_list}) DO NOTHING
"""
count = 0
with self.db.conn.cursor() as cur:
for record in records:
row = dict(record)
if has_content_hash_col:
row["content_hash"] = self._compute_hash(record)
values = [row.get(col) for col in all_cols]
try:
cur.execute(sql, values)
affected = int(cur.rowcount or 0)
if affected > 0:
count += affected
except Exception as e:
self.logger.warning("UPSERT 失败: %s, error=%s", record, e)
self.db.commit()
return count