Files
feiqiu-ETL/etl_billiards/scripts/backfill_missing_data.py

724 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
补全丢失的 ODS 数据
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
然后重新从 API 获取丢失的数据并插入 ODS。
用法:
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time as time_mod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2.extras import Json, execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.client import APIClient
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
from scripts.check_ods_gaps import run_gap_check
from utils.logging_utils import build_log_path, configure_logging
from utils.ods_record_utils import (
get_value_case_insensitive,
merge_record_layers,
normalize_pk_value,
pk_tuple_from_record,
)
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(
hour=23 if is_end else 0,
minute=59 if is_end else 0,
second=59 if is_end else 0,
microsecond=0
)
return dt
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
"""根据任务代码获取 ODS 任务规格"""
for spec in ODS_TASK_SPECS:
if spec.code == code:
return spec
return None
def _merge_record_layers(record: dict) -> dict:
"""Flatten nested data layers into a single dict."""
return merge_record_layers(record)
def _get_value_case_insensitive(record: dict | None, col: str | None):
"""Fetch value without case sensitivity."""
return get_value_case_insensitive(record, col)
def _normalize_pk_value(value):
"""Normalize PK value."""
return normalize_pk_value(value)
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
"""Extract PK tuple from record."""
return pk_tuple_from_record(record, pk_cols)
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
"""获取表的主键列"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [r[0] for r in cur.fetchall()]
if include_content_hash:
return cols
return [c for c in cols if c.lower() != "content_hash"]
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
"""获取表的所有列信息"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
def _fetch_existing_pk_set(
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
) -> Set[Tuple]:
"""获取已存在的 PK 集合"""
if not pk_values:
return set()
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
with conn.cursor() as cur:
for i in range(0, len(pk_values), chunk_size):
chunk = pk_values[i:i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _cast_value(value, data_type: str):
"""类型转换"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _normalize_scalar(value):
"""规范化标量值"""
if value == "" or value == "{}" or value == "[]":
return None
return value
class MissingDataBackfiller:
"""丢失数据补全器"""
def __init__(
self,
cfg: AppConfig,
logger: logging.Logger,
dry_run: bool = False,
):
self.cfg = cfg
self.logger = logger
self.dry_run = dry_run
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
self.store_id = int(cfg.get("app.store_id") or 0)
# API 客户端
self.api = APIClient(
base_url=cfg["api"]["base_url"],
token=cfg["api"]["token"],
timeout=int(cfg["api"].get("timeout_sec") or 20),
retry_max=int(cfg["api"].get("retries", {}).get("max_attempts") or 3),
headers_extra=cfg["api"].get("headers_extra") or {},
)
# 数据库连接DatabaseConnection 构造时已设置 autocommit=False
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
def close(self):
"""关闭连接"""
if self.db:
self.db.close()
def _ensure_db(self):
"""确保数据库连接可用"""
if self.db and getattr(self.db, "conn", None) is not None:
if getattr(self.db.conn, "closed", 0) == 0:
return
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
def backfill_from_gap_check(
self,
*,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
) -> Dict[str, Any]:
"""
运行 gap check 并补全丢失数据
Returns:
补全结果统计
"""
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
# 计算窗口大小
total_seconds = max(0, int((end - start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
# 运行 gap check
self.logger.info("正在执行缺失检查...")
gap_result = run_gap_check(
cfg=self.cfg,
start=start,
end=end,
window_days=window_days,
window_hours=window_hours,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=10000, # 获取所有丢失样本
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes or "",
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=self.logger,
compare_content=include_mismatch,
content_sample_limit=content_sample_limit or 10000,
)
total_missing = gap_result.get("total_missing", 0)
total_mismatch = gap_result.get("total_mismatch", 0)
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
self.logger.info("Data complete: no missing/mismatch records")
return {"backfilled": 0, "errors": 0, "details": []}
if include_mismatch:
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
else:
self.logger.info("Missing check done missing=%s", total_missing)
results = []
total_backfilled = 0
total_errors = 0
for task_result in gap_result.get("results", []):
task_code = task_result.get("task_code")
missing = task_result.get("missing", 0)
missing_samples = task_result.get("missing_samples", [])
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
target_samples = list(missing_samples) + list(mismatch_samples)
if missing == 0 and mismatch == 0:
continue
self.logger.info(
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
task_code, missing, mismatch, len(target_samples)
)
try:
backfilled = self._backfill_task(
task_code=task_code,
table=task_result.get("table"),
pk_columns=task_result.get("pk_columns", []),
pk_samples=target_samples,
start=start,
end=end,
page_size=page_size,
chunk_size=chunk_size,
)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": backfilled,
"error": None,
})
total_backfilled += backfilled
except Exception as exc:
self.logger.exception("补全失败 任务=%s", task_code)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": 0,
"error": str(exc),
})
total_errors += 1
self.logger.info(
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
total_missing, total_backfilled, total_errors
)
return {
"total_missing": total_missing,
"total_mismatch": total_mismatch,
"backfilled": total_backfilled,
"errors": total_errors,
"details": results,
}
def _backfill_task(
self,
*,
task_code: str,
table: str,
pk_columns: List[str],
pk_samples: List[Dict],
start: datetime,
end: datetime,
page_size: int,
chunk_size: int,
) -> int:
"""补全单个任务的丢失数据"""
self._ensure_db()
spec = _get_spec(task_code)
if not spec:
self.logger.warning("未找到任务规格 任务=%s", task_code)
return 0
if not pk_columns:
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
if not conflict_columns:
conflict_columns = pk_columns
if not pk_columns:
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
return 0
# 提取丢失的 PK 值
missing_pks: Set[Tuple] = set()
for sample in pk_samples:
pk_tuple = tuple(sample.get(col) for col in pk_columns)
if all(v is not None for v in pk_tuple):
missing_pks.add(pk_tuple)
if not missing_pks:
self.logger.info("无缺失主键 任务=%s", task_code)
return 0
self.logger.info(
"开始获取数据 任务=%s 缺失主键数=%s",
task_code, len(missing_pks)
)
# 从 API 获取数据并过滤出丢失的记录
params = self._build_params(spec, start, end)
backfilled = 0
cols_info = _get_table_columns(self.db.conn, table)
db_json_cols_lower = {
c[0].lower() for c in cols_info
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
col_names = [c[0] for c in cols_info]
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
try:
self.db.conn.commit()
except Exception:
self.db.conn.rollback()
try:
for page_no, records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
# 过滤出丢失的记录
records_to_insert = []
for rec in records:
if not isinstance(rec, dict):
continue
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
if pk_tuple and pk_tuple in missing_pks:
records_to_insert.append(rec)
if not records_to_insert:
continue
# 插入丢失的记录
if self.dry_run:
backfilled += len(records_to_insert)
self.logger.info(
"模拟运行 任务=%s 页=%s 将插入=%s",
task_code, page_no, len(records_to_insert)
)
else:
inserted = self._insert_records(
table=table,
records=records_to_insert,
cols_info=cols_info,
pk_columns=pk_columns,
conflict_columns=conflict_columns,
db_json_cols_lower=db_json_cols_lower,
)
backfilled += inserted
# 避免长事务阻塞与 idle_in_tx 超时
self.db.conn.commit()
self.logger.info(
"已插入 任务=%s 页=%s 数量=%s",
task_code, page_no, inserted
)
if not self.dry_run:
self.db.conn.commit()
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
return backfilled
except Exception:
self.db.conn.rollback()
raise
def _build_params(
self,
spec: OdsTaskSpec,
start: datetime,
end: datetime,
) -> Dict:
"""构建 API 请求参数"""
base: Dict[str, Any] = {}
if spec.include_site_id:
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [self.store_id]
else:
base["siteId"] = self.store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(start, self.tz)
base[end_key] = TypeParser.format_timestamp(end, self.tz)
# 合并公共参数
common = self.cfg.get("api.params", {}) or {}
if isinstance(common, dict):
merged = {**common, **base}
else:
merged = base
merged.update(spec.extra_params or {})
return merged
def _insert_records(
self,
*,
table: str,
records: List[Dict],
cols_info: List[Tuple[str, str, str]],
pk_columns: List[str],
conflict_columns: List[str],
db_json_cols_lower: Set[str],
) -> int:
"""插入记录到数据库"""
if not records:
return 0
col_names = [c[0] for c in cols_info]
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
conflict_cols = conflict_columns or pk_columns
if conflict_cols:
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
params: List[Tuple] = []
for rec in records:
merged_rec = _merge_record_layers(rec)
# 检查 PK
if pk_columns:
missing_pk = False
for pk in pk_columns:
if str(pk).lower() == "content_hash":
continue
pk_val = _get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
continue
content_hash = None
if needs_content_hash:
hash_record = dict(merged_rec)
hash_record["fetched_at"] = now
content_hash = BaseOdsTask._compute_content_hash(hash_record, include_fetched_at=True)
row_vals: List[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append("backfill")
continue
if col_lower == "source_endpoint":
row_vals.append("backfill")
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(_cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0
inserted = 0
with self.db.conn.cursor() as cur:
for i in range(0, len(params), 200):
chunk = params[i:i + 200]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted
def run_backfill(
*,
cfg: AppConfig,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
dry_run: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
logger: logging.Logger,
) -> Dict[str, Any]:
"""
运行数据补全
Args:
cfg: 应用配置
start: 开始时间
end: 结束时间
task_codes: 指定任务代码(逗号分隔)
dry_run: 是否仅预览
page_size: API 分页大小
chunk_size: 数据库批量大小
logger: 日志记录器
Returns:
补全结果
"""
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
try:
return backfiller.backfill_from_gap_check(
start=start,
end=end,
task_codes=task_codes,
include_mismatch=include_mismatch,
page_size=page_size,
chunk_size=chunk_size,
content_sample_limit=content_sample_limit,
)
finally:
backfiller.close()
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
ap.add_argument("--log-file", default="", help="日志文件路径")
ap.add_argument("--log-dir", default="", help="日志目录")
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
log_console = not args.no_log_console
with configure_logging(
"backfill_missing",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
start = _parse_dt(args.start, tz)
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
result = run_backfill(
cfg=cfg,
start=start,
end=end,
task_codes=args.task_codes or None,
include_mismatch=args.include_mismatch,
dry_run=args.dry_run,
page_size=args.page_size,
chunk_size=args.chunk_size,
content_sample_limit=args.content_sample_limit,
logger=logger,
)
logger.info("=" * 60)
logger.info("补全完成!")
logger.info(" 总丢失: %s", result.get("total_missing", 0))
if args.include_mismatch:
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
logger.info(" 已补全: %s", result.get("backfilled", 0))
logger.info(" 错误数: %s", result.get("errors", 0))
logger.info("=" * 60)
# 输出详细结果
for detail in result.get("details", []):
if detail.get("error"):
logger.error(
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
detail.get("error"),
)
elif detail.get("backfilled", 0) > 0:
logger.info(
" %s: 丢失=%s 不一致=%s 补全=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())