# -*- coding: utf-8 -*- """ 补全丢失的 ODS 数据 通过运行数据完整性检查,找出 API 与 ODS 之间的差异, 然后重新从 API 获取丢失的数据并插入 ODS。 用法: python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19 python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json """ from __future__ import annotations import argparse import json import logging import sys import time as time_mod from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple from zoneinfo import ZoneInfo from dateutil import parser as dtparser from psycopg2.extras import Json, execute_values PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from api.client import APIClient from config.settings import AppConfig from database.connection import DatabaseConnection from models.parsers import TypeParser from tasks.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec from scripts.check_ods_gaps import run_gap_check from utils.logging_utils import build_log_path, configure_logging from utils.ods_record_utils import ( get_value_case_insensitive, merge_record_layers, normalize_pk_value, pk_tuple_from_record, ) def _reconfigure_stdout_utf8() -> None: if hasattr(sys.stdout, "reconfigure"): try: sys.stdout.reconfigure(encoding="utf-8") except Exception: pass def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime: raw = (value or "").strip() if not raw: raise ValueError("empty datetime") has_time = any(ch in raw for ch in (":", "T")) dt = dtparser.parse(raw) if dt.tzinfo is None: dt = dt.replace(tzinfo=tz) else: dt = dt.astimezone(tz) if not has_time: dt = dt.replace( hour=23 if is_end else 0, minute=59 if is_end else 0, second=59 if is_end else 0, microsecond=0 ) return dt def _get_spec(code: str) -> Optional[OdsTaskSpec]: """根据任务代码获取 ODS 任务规格""" for spec in ODS_TASK_SPECS: if spec.code == code: return spec return None def _merge_record_layers(record: dict) -> dict: """Flatten nested data layers into a single dict.""" return merge_record_layers(record) def _get_value_case_insensitive(record: dict | None, col: str | None): """Fetch value without case sensitivity.""" return get_value_case_insensitive(record, col) def _normalize_pk_value(value): """Normalize PK value.""" return normalize_pk_value(value) def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]: """Extract PK tuple from record.""" return pk_tuple_from_record(record, pk_cols) def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]: """获取表的主键列""" if "." in table: schema, name = table.split(".", 1) else: schema, name = "public", table sql = """ SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s ORDER BY kcu.ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, name)) cols = [r[0] for r in cur.fetchall()] if include_content_hash: return cols return [c for c in cols if c.lower() != "content_hash"] def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]: """获取表的所有列信息""" if "." in table: schema, name = table.split(".", 1) else: schema, name = "public", table sql = """ SELECT column_name, data_type, udt_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, name)) return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()] def _fetch_existing_pk_set( conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int ) -> Set[Tuple]: """获取已存在的 PK 集合""" if not pk_values: return set() select_cols = ", ".join(f't."{c}"' for c in pk_cols) value_cols = ", ".join(f'"{c}"' for c in pk_cols) join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols) sql = ( f"SELECT {select_cols} FROM {table} t " f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}" ) existing: Set[Tuple] = set() with conn.cursor() as cur: for i in range(0, len(pk_values), chunk_size): chunk = pk_values[i:i + chunk_size] execute_values(cur, sql, chunk, page_size=len(chunk)) for row in cur.fetchall(): existing.add(tuple(row)) return existing def _cast_value(value, data_type: str): """类型转换""" if value is None: return None dt = (data_type or "").lower() if dt in ("integer", "bigint", "smallint"): if isinstance(value, bool): return int(value) try: return int(value) except Exception: return None if dt in ("numeric", "double precision", "real", "decimal"): if isinstance(value, bool): return int(value) try: return float(value) except Exception: return None if dt.startswith("timestamp") or dt in ("date", "time", "interval"): return value if isinstance(value, (str, datetime)) else None return value def _normalize_scalar(value): """规范化标量值""" if value == "" or value == "{}" or value == "[]": return None return value class MissingDataBackfiller: """丢失数据补全器""" def __init__( self, cfg: AppConfig, logger: logging.Logger, dry_run: bool = False, ): self.cfg = cfg self.logger = logger self.dry_run = dry_run self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) self.store_id = int(cfg.get("app.store_id") or 0) # API 客户端 self.api = APIClient( base_url=cfg["api"]["base_url"], token=cfg["api"]["token"], timeout=int(cfg["api"].get("timeout_sec") or 20), retry_max=int(cfg["api"].get("retries", {}).get("max_attempts") or 3), headers_extra=cfg["api"].get("headers_extra") or {}, ) # 数据库连接(DatabaseConnection 构造时已设置 autocommit=False) self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session")) def close(self): """关闭连接""" if self.db: self.db.close() def _ensure_db(self): """确保数据库连接可用""" if self.db and getattr(self.db, "conn", None) is not None: if getattr(self.db.conn, "closed", 0) == 0: return self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session")) def backfill_from_gap_check( self, *, start: datetime, end: datetime, task_codes: Optional[str] = None, include_mismatch: bool = False, page_size: int = 200, chunk_size: int = 500, content_sample_limit: int | None = None, ) -> Dict[str, Any]: """ 运行 gap check 并补全丢失数据 Returns: 补全结果统计 """ self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat()) # 计算窗口大小 total_seconds = max(0, int((end - start).total_seconds())) if total_seconds >= 86400: window_days = max(1, total_seconds // 86400) window_hours = 0 else: window_days = 0 window_hours = max(1, total_seconds // 3600 or 1) # 运行 gap check self.logger.info("正在执行缺失检查...") gap_result = run_gap_check( cfg=self.cfg, start=start, end=end, window_days=window_days, window_hours=window_hours, page_size=page_size, chunk_size=chunk_size, sample_limit=10000, # 获取所有丢失样本 sleep_per_window=0, sleep_per_page=0, task_codes=task_codes or "", from_cutoff=False, cutoff_overlap_hours=24, allow_small_window=True, logger=self.logger, compare_content=include_mismatch, content_sample_limit=content_sample_limit or 10000, ) total_missing = gap_result.get("total_missing", 0) total_mismatch = gap_result.get("total_mismatch", 0) if total_missing == 0 and (not include_mismatch or total_mismatch == 0): self.logger.info("Data complete: no missing/mismatch records") return {"backfilled": 0, "errors": 0, "details": []} if include_mismatch: self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch) else: self.logger.info("Missing check done missing=%s", total_missing) results = [] total_backfilled = 0 total_errors = 0 for task_result in gap_result.get("results", []): task_code = task_result.get("task_code") missing = task_result.get("missing", 0) missing_samples = task_result.get("missing_samples", []) mismatch = task_result.get("mismatch", 0) if include_mismatch else 0 mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else [] target_samples = list(missing_samples) + list(mismatch_samples) if missing == 0 and mismatch == 0: continue self.logger.info( "Start backfill task task=%s missing=%s mismatch=%s samples=%s", task_code, missing, mismatch, len(target_samples) ) try: backfilled = self._backfill_task( task_code=task_code, table=task_result.get("table"), pk_columns=task_result.get("pk_columns", []), pk_samples=target_samples, start=start, end=end, page_size=page_size, chunk_size=chunk_size, ) results.append({ "task_code": task_code, "missing": missing, "mismatch": mismatch, "backfilled": backfilled, "error": None, }) total_backfilled += backfilled except Exception as exc: self.logger.exception("补全失败 任务=%s", task_code) results.append({ "task_code": task_code, "missing": missing, "mismatch": mismatch, "backfilled": 0, "error": str(exc), }) total_errors += 1 self.logger.info( "数据补全完成 总缺失=%s 已补全=%s 错误数=%s", total_missing, total_backfilled, total_errors ) return { "total_missing": total_missing, "total_mismatch": total_mismatch, "backfilled": total_backfilled, "errors": total_errors, "details": results, } def _backfill_task( self, *, task_code: str, table: str, pk_columns: List[str], pk_samples: List[Dict], start: datetime, end: datetime, page_size: int, chunk_size: int, ) -> int: """补全单个任务的丢失数据""" self._ensure_db() spec = _get_spec(task_code) if not spec: self.logger.warning("未找到任务规格 任务=%s", task_code) return 0 if not pk_columns: pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False) conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True) if not conflict_columns: conflict_columns = pk_columns if not pk_columns: self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table) return 0 # 提取丢失的 PK 值 missing_pks: Set[Tuple] = set() for sample in pk_samples: pk_tuple = tuple(sample.get(col) for col in pk_columns) if all(v is not None for v in pk_tuple): missing_pks.add(pk_tuple) if not missing_pks: self.logger.info("无缺失主键 任务=%s", task_code) return 0 self.logger.info( "开始获取数据 任务=%s 缺失主键数=%s", task_code, len(missing_pks) ) # 从 API 获取数据并过滤出丢失的记录 params = self._build_params(spec, start, end) backfilled = 0 cols_info = _get_table_columns(self.db.conn, table) db_json_cols_lower = { c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb") } col_names = [c[0] for c in cols_info] # 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时 try: self.db.conn.commit() except Exception: self.db.conn.rollback() try: for page_no, records, _, response_payload in self.api.iter_paginated( endpoint=spec.endpoint, params=params, page_size=page_size, data_path=spec.data_path, list_key=spec.list_key, ): # 过滤出丢失的记录 records_to_insert = [] for rec in records: if not isinstance(rec, dict): continue pk_tuple = _pk_tuple_from_record(rec, pk_columns) if pk_tuple and pk_tuple in missing_pks: records_to_insert.append(rec) if not records_to_insert: continue # 插入丢失的记录 if self.dry_run: backfilled += len(records_to_insert) self.logger.info( "模拟运行 任务=%s 页=%s 将插入=%s", task_code, page_no, len(records_to_insert) ) else: inserted = self._insert_records( table=table, records=records_to_insert, cols_info=cols_info, pk_columns=pk_columns, conflict_columns=conflict_columns, db_json_cols_lower=db_json_cols_lower, ) backfilled += inserted # 避免长事务阻塞与 idle_in_tx 超时 self.db.conn.commit() self.logger.info( "已插入 任务=%s 页=%s 数量=%s", task_code, page_no, inserted ) if not self.dry_run: self.db.conn.commit() self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled) return backfilled except Exception: self.db.conn.rollback() raise def _build_params( self, spec: OdsTaskSpec, start: datetime, end: datetime, ) -> Dict: """构建 API 请求参数""" base: Dict[str, Any] = {} if spec.include_site_id: if spec.endpoint == "/TenantGoods/GetGoodsInventoryList": base["siteId"] = [self.store_id] else: base["siteId"] = self.store_id if spec.requires_window and spec.time_fields: start_key, end_key = spec.time_fields base[start_key] = TypeParser.format_timestamp(start, self.tz) base[end_key] = TypeParser.format_timestamp(end, self.tz) # 合并公共参数 common = self.cfg.get("api.params", {}) or {} if isinstance(common, dict): merged = {**common, **base} else: merged = base merged.update(spec.extra_params or {}) return merged def _insert_records( self, *, table: str, records: List[Dict], cols_info: List[Tuple[str, str, str]], pk_columns: List[str], conflict_columns: List[str], db_json_cols_lower: Set[str], ) -> int: """插入记录到数据库""" if not records: return 0 col_names = [c[0] for c in cols_info] needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info) quoted_cols = ", ".join(f'"{c}"' for c in col_names) sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s" conflict_cols = conflict_columns or pk_columns if conflict_cols: pk_clause = ", ".join(f'"{c}"' for c in conflict_cols) sql += f" ON CONFLICT ({pk_clause}) DO NOTHING" now = datetime.now(self.tz) json_dump = lambda v: json.dumps(v, ensure_ascii=False) params: List[Tuple] = [] for rec in records: merged_rec = _merge_record_layers(rec) # 检查 PK if pk_columns: missing_pk = False for pk in pk_columns: if str(pk).lower() == "content_hash": continue pk_val = _get_value_case_insensitive(merged_rec, pk) if pk_val is None or pk_val == "": missing_pk = True break if missing_pk: continue content_hash = None if needs_content_hash: hash_record = dict(merged_rec) hash_record["fetched_at"] = now content_hash = BaseOdsTask._compute_content_hash(hash_record, include_fetched_at=True) row_vals: List[Any] = [] for (col_name, data_type, _udt) in cols_info: col_lower = col_name.lower() if col_lower == "payload": row_vals.append(Json(rec, dumps=json_dump)) continue if col_lower == "source_file": row_vals.append("backfill") continue if col_lower == "source_endpoint": row_vals.append("backfill") continue if col_lower == "fetched_at": row_vals.append(now) continue if col_lower == "content_hash": row_vals.append(content_hash) continue value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name)) if col_lower in db_json_cols_lower: row_vals.append(Json(value, dumps=json_dump) if value is not None else None) continue row_vals.append(_cast_value(value, data_type)) params.append(tuple(row_vals)) if not params: return 0 inserted = 0 with self.db.conn.cursor() as cur: for i in range(0, len(params), 200): chunk = params[i:i + 200] execute_values(cur, sql, chunk, page_size=len(chunk)) if cur.rowcount is not None and cur.rowcount > 0: inserted += int(cur.rowcount) return inserted def run_backfill( *, cfg: AppConfig, start: datetime, end: datetime, task_codes: Optional[str] = None, include_mismatch: bool = False, dry_run: bool = False, page_size: int = 200, chunk_size: int = 500, content_sample_limit: int | None = None, logger: logging.Logger, ) -> Dict[str, Any]: """ 运行数据补全 Args: cfg: 应用配置 start: 开始时间 end: 结束时间 task_codes: 指定任务代码(逗号分隔) dry_run: 是否仅预览 page_size: API 分页大小 chunk_size: 数据库批量大小 logger: 日志记录器 Returns: 补全结果 """ backfiller = MissingDataBackfiller(cfg, logger, dry_run) try: return backfiller.backfill_from_gap_check( start=start, end=end, task_codes=task_codes, include_mismatch=include_mismatch, page_size=page_size, chunk_size=chunk_size, content_sample_limit=content_sample_limit, ) finally: backfiller.close() def main() -> int: _reconfigure_stdout_utf8() ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据") ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)") ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)") ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)") ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录") ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)") ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入") ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)") ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)") ap.add_argument("--log-file", default="", help="日志文件路径") ap.add_argument("--log-dir", default="", help="日志目录") ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)") ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志") args = ap.parse_args() log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs") log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing") log_console = not args.no_log_console with configure_logging( "backfill_missing", log_file, level=args.log_level, console=log_console, tee_std=True, ) as logger: cfg = AppConfig.load({}) tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) start = _parse_dt(args.start, tz) end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz) result = run_backfill( cfg=cfg, start=start, end=end, task_codes=args.task_codes or None, include_mismatch=args.include_mismatch, dry_run=args.dry_run, page_size=args.page_size, chunk_size=args.chunk_size, content_sample_limit=args.content_sample_limit, logger=logger, ) logger.info("=" * 60) logger.info("补全完成!") logger.info(" 总丢失: %s", result.get("total_missing", 0)) if args.include_mismatch: logger.info(" 总不一致: %s", result.get("total_mismatch", 0)) logger.info(" 已补全: %s", result.get("backfilled", 0)) logger.info(" 错误数: %s", result.get("errors", 0)) logger.info("=" * 60) # 输出详细结果 for detail in result.get("details", []): if detail.get("error"): logger.error( " %s: 丢失=%s 不一致=%s 补全=%s 错误=%s", detail.get("task_code"), detail.get("missing"), detail.get("mismatch", 0), detail.get("backfilled"), detail.get("error"), ) elif detail.get("backfilled", 0) > 0: logger.info( " %s: 丢失=%s 不一致=%s 补全=%s", detail.get("task_code"), detail.get("missing"), detail.get("mismatch", 0), detail.get("backfilled"), ) return 0 if __name__ == "__main__": raise SystemExit(main())