# -*- coding: utf-8 -*- """ 补全丢失的 ODS 数据 通过运行数据完整性检查,找出 API 与 ODS 之间的差异, 然后重新从 API 获取丢失的数据并插入 ODS。 用法: python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19 python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json """ from __future__ import annotations import argparse import json import logging import sys import time as time_mod from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple from zoneinfo import ZoneInfo from dateutil import parser as dtparser from psycopg2.extras import Json, execute_values PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from api.client import APIClient from config.settings import AppConfig from database.connection import DatabaseConnection from models.parsers import TypeParser from tasks.ods_tasks import ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec from scripts.check_ods_gaps import run_gap_check from utils.logging_utils import build_log_path, configure_logging def _reconfigure_stdout_utf8() -> None: if hasattr(sys.stdout, "reconfigure"): try: sys.stdout.reconfigure(encoding="utf-8") except Exception: pass def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime: raw = (value or "").strip() if not raw: raise ValueError("empty datetime") has_time = any(ch in raw for ch in (":", "T")) dt = dtparser.parse(raw) if dt.tzinfo is None: dt = dt.replace(tzinfo=tz) else: dt = dt.astimezone(tz) if not has_time: dt = dt.replace( hour=23 if is_end else 0, minute=59 if is_end else 0, second=59 if is_end else 0, microsecond=0 ) return dt def _get_spec(code: str) -> Optional[OdsTaskSpec]: """根据任务代码获取 ODS 任务规格""" for spec in ODS_TASK_SPECS: if spec.code == code: return spec return None def _merge_record_layers(record: dict) -> dict: """展开嵌套的 data 层""" merged = record data_part = merged.get("data") while isinstance(data_part, dict): merged = {**data_part, **merged} data_part = data_part.get("data") settle_inner = merged.get("settleList") if isinstance(settle_inner, dict): merged = {**settle_inner, **merged} return merged def _get_value_case_insensitive(record: dict | None, col: str | None): """不区分大小写地获取值""" if record is None or col is None: return None if col in record: return record.get(col) col_lower = col.lower() for k, v in record.items(): if isinstance(k, str) and k.lower() == col_lower: return v return None def _normalize_pk_value(value): """规范化 PK 值""" if value is None: return None if isinstance(value, str) and value.isdigit(): try: return int(value) except Exception: return value return value def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]: """从记录中提取 PK 元组""" merged = _merge_record_layers(record) values = [] for col in pk_cols: val = _normalize_pk_value(_get_value_case_insensitive(merged, col)) if val is None or val == "": return None values.append(val) return tuple(values) def _get_table_pk_columns(conn, table: str) -> List[str]: """获取表的主键列""" if "." in table: schema, name = table.split(".", 1) else: schema, name = "public", table sql = """ SELECT kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = %s AND tc.table_name = %s ORDER BY kcu.ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, name)) return [r[0] for r in cur.fetchall()] def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]: """获取表的所有列信息""" if "." in table: schema, name = table.split(".", 1) else: schema, name = "public", table sql = """ SELECT column_name, data_type, udt_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """ with conn.cursor() as cur: cur.execute(sql, (schema, name)) return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()] def _fetch_existing_pk_set( conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int ) -> Set[Tuple]: """获取已存在的 PK 集合""" if not pk_values: return set() select_cols = ", ".join(f't."{c}"' for c in pk_cols) value_cols = ", ".join(f'"{c}"' for c in pk_cols) join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols) sql = ( f"SELECT {select_cols} FROM {table} t " f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}" ) existing: Set[Tuple] = set() with conn.cursor() as cur: for i in range(0, len(pk_values), chunk_size): chunk = pk_values[i:i + chunk_size] execute_values(cur, sql, chunk, page_size=len(chunk)) for row in cur.fetchall(): existing.add(tuple(row)) return existing def _cast_value(value, data_type: str): """类型转换""" if value is None: return None dt = (data_type or "").lower() if dt in ("integer", "bigint", "smallint"): if isinstance(value, bool): return int(value) try: return int(value) except Exception: return None if dt in ("numeric", "double precision", "real", "decimal"): if isinstance(value, bool): return int(value) try: return float(value) except Exception: return None if dt.startswith("timestamp") or dt in ("date", "time", "interval"): return value if isinstance(value, (str, datetime)) else None return value def _normalize_scalar(value): """规范化标量值""" if value == "" or value == "{}" or value == "[]": return None return value class MissingDataBackfiller: """丢失数据补全器""" def __init__( self, cfg: AppConfig, logger: logging.Logger, dry_run: bool = False, ): self.cfg = cfg self.logger = logger self.dry_run = dry_run self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) self.store_id = int(cfg.get("app.store_id") or 0) # API 客户端 self.api = APIClient( base_url=cfg["api"]["base_url"], token=cfg["api"]["token"], timeout=int(cfg["api"].get("timeout_sec") or 20), retry_max=int(cfg["api"].get("retries", {}).get("max_attempts") or 3), headers_extra=cfg["api"].get("headers_extra") or {}, ) # 数据库连接 self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session")) self.db.conn.autocommit = False def close(self): """关闭连接""" if self.db: self.db.close() def backfill_from_gap_check( self, *, start: datetime, end: datetime, task_codes: Optional[str] = None, page_size: int = 200, chunk_size: int = 500, ) -> Dict[str, Any]: """ 运行 gap check 并补全丢失数据 Returns: 补全结果统计 """ self.logger.info("BACKFILL_START start=%s end=%s", start.isoformat(), end.isoformat()) # 计算窗口大小 total_seconds = max(0, int((end - start).total_seconds())) if total_seconds >= 86400: window_days = max(1, total_seconds // 86400) window_hours = 0 else: window_days = 0 window_hours = max(1, total_seconds // 3600 or 1) # 运行 gap check self.logger.info("RUNNING_GAP_CHECK...") gap_result = run_gap_check( cfg=self.cfg, start=start, end=end, window_days=window_days, window_hours=window_hours, page_size=page_size, chunk_size=chunk_size, sample_limit=10000, # 获取所有丢失样本 sleep_per_window=0, sleep_per_page=0, task_codes=task_codes or "", from_cutoff=False, cutoff_overlap_hours=24, allow_small_window=True, logger=self.logger, ) total_missing = gap_result.get("total_missing", 0) if total_missing == 0: self.logger.info("NO_MISSING_DATA") return {"backfilled": 0, "errors": 0, "details": []} self.logger.info("GAP_CHECK_DONE total_missing=%s", total_missing) # 补全每个任务的丢失数据 results = [] total_backfilled = 0 total_errors = 0 for task_result in gap_result.get("results", []): task_code = task_result.get("task_code") missing = task_result.get("missing", 0) missing_samples = task_result.get("missing_samples", []) if missing == 0: continue self.logger.info( "BACKFILL_TASK task=%s missing=%s samples=%s", task_code, missing, len(missing_samples) ) try: backfilled = self._backfill_task( task_code=task_code, table=task_result.get("table"), pk_columns=task_result.get("pk_columns", []), missing_samples=missing_samples, start=start, end=end, page_size=page_size, chunk_size=chunk_size, ) results.append({ "task_code": task_code, "missing": missing, "backfilled": backfilled, "error": None, }) total_backfilled += backfilled except Exception as exc: self.logger.exception("BACKFILL_ERROR task=%s", task_code) results.append({ "task_code": task_code, "missing": missing, "backfilled": 0, "error": str(exc), }) total_errors += 1 self.logger.info( "BACKFILL_DONE total_missing=%s backfilled=%s errors=%s", total_missing, total_backfilled, total_errors ) return { "total_missing": total_missing, "backfilled": total_backfilled, "errors": total_errors, "details": results, } def _backfill_task( self, *, task_code: str, table: str, pk_columns: List[str], missing_samples: List[Dict], start: datetime, end: datetime, page_size: int, chunk_size: int, ) -> int: """补全单个任务的丢失数据""" spec = _get_spec(task_code) if not spec: self.logger.warning("SPEC_NOT_FOUND task=%s", task_code) return 0 if not pk_columns: pk_columns = _get_table_pk_columns(self.db.conn, table) if not pk_columns: self.logger.warning("NO_PK_COLUMNS task=%s table=%s", task_code, table) return 0 # 提取丢失的 PK 值 missing_pks: Set[Tuple] = set() for sample in missing_samples: pk_tuple = tuple(sample.get(col) for col in pk_columns) if all(v is not None for v in pk_tuple): missing_pks.add(pk_tuple) if not missing_pks: self.logger.info("NO_MISSING_PKS task=%s", task_code) return 0 self.logger.info( "BACKFILL_FETCHING task=%s missing_pks=%s", task_code, len(missing_pks) ) # 从 API 获取数据并过滤出丢失的记录 params = self._build_params(spec, start, end) backfilled = 0 cols_info = _get_table_columns(self.db.conn, table) db_json_cols_lower = { c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb") } col_names = [c[0] for c in cols_info] try: for page_no, records, _, response_payload in self.api.iter_paginated( endpoint=spec.endpoint, params=params, page_size=page_size, data_path=spec.data_path, list_key=spec.list_key, ): # 过滤出丢失的记录 records_to_insert = [] for rec in records: if not isinstance(rec, dict): continue pk_tuple = _pk_tuple_from_record(rec, pk_columns) if pk_tuple and pk_tuple in missing_pks: records_to_insert.append(rec) if not records_to_insert: continue # 插入丢失的记录 if self.dry_run: backfilled += len(records_to_insert) self.logger.info( "DRY_RUN task=%s page=%s would_insert=%s", task_code, page_no, len(records_to_insert) ) else: inserted = self._insert_records( table=table, records=records_to_insert, cols_info=cols_info, pk_columns=pk_columns, db_json_cols_lower=db_json_cols_lower, ) backfilled += inserted self.logger.info( "INSERTED task=%s page=%s count=%s", task_code, page_no, inserted ) if not self.dry_run: self.db.conn.commit() self.logger.info("BACKFILL_TASK_DONE task=%s backfilled=%s", task_code, backfilled) return backfilled except Exception: self.db.conn.rollback() raise def _build_params( self, spec: OdsTaskSpec, start: datetime, end: datetime, ) -> Dict: """构建 API 请求参数""" base: Dict[str, Any] = {} if spec.include_site_id: if spec.endpoint == "/TenantGoods/GetGoodsInventoryList": base["siteId"] = [self.store_id] else: base["siteId"] = self.store_id if spec.requires_window and spec.time_fields: start_key, end_key = spec.time_fields base[start_key] = TypeParser.format_timestamp(start, self.tz) base[end_key] = TypeParser.format_timestamp(end, self.tz) # 合并公共参数 common = self.cfg.get("api.params", {}) or {} if isinstance(common, dict): merged = {**common, **base} else: merged = base merged.update(spec.extra_params or {}) return merged def _insert_records( self, *, table: str, records: List[Dict], cols_info: List[Tuple[str, str, str]], pk_columns: List[str], db_json_cols_lower: Set[str], ) -> int: """插入记录到数据库""" if not records: return 0 col_names = [c[0] for c in cols_info] quoted_cols = ", ".join(f'"{c}"' for c in col_names) sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s" if pk_columns: pk_clause = ", ".join(f'"{c}"' for c in pk_columns) sql += f" ON CONFLICT ({pk_clause}) DO NOTHING" now = datetime.now(self.tz) json_dump = lambda v: json.dumps(v, ensure_ascii=False) params: List[Tuple] = [] for rec in records: merged_rec = _merge_record_layers(rec) # 检查 PK if pk_columns: missing_pk = False for pk in pk_columns: pk_val = _get_value_case_insensitive(merged_rec, pk) if pk_val is None or pk_val == "": missing_pk = True break if missing_pk: continue row_vals: List[Any] = [] for (col_name, data_type, _udt) in cols_info: col_lower = col_name.lower() if col_lower == "payload": row_vals.append(Json(rec, dumps=json_dump)) continue if col_lower == "source_file": row_vals.append("backfill") continue if col_lower == "source_endpoint": row_vals.append("backfill") continue if col_lower == "fetched_at": row_vals.append(now) continue value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name)) if col_lower in db_json_cols_lower: row_vals.append(Json(value, dumps=json_dump) if value is not None else None) continue row_vals.append(_cast_value(value, data_type)) params.append(tuple(row_vals)) if not params: return 0 inserted = 0 with self.db.conn.cursor() as cur: for i in range(0, len(params), 200): chunk = params[i:i + 200] execute_values(cur, sql, chunk, page_size=len(chunk)) if cur.rowcount is not None and cur.rowcount > 0: inserted += int(cur.rowcount) return inserted def run_backfill( *, cfg: AppConfig, start: datetime, end: datetime, task_codes: Optional[str] = None, dry_run: bool = False, page_size: int = 200, chunk_size: int = 500, logger: logging.Logger, ) -> Dict[str, Any]: """ 运行数据补全 Args: cfg: 应用配置 start: 开始时间 end: 结束时间 task_codes: 指定任务代码(逗号分隔) dry_run: 是否仅预览 page_size: API 分页大小 chunk_size: 数据库批量大小 logger: 日志记录器 Returns: 补全结果 """ backfiller = MissingDataBackfiller(cfg, logger, dry_run) try: return backfiller.backfill_from_gap_check( start=start, end=end, task_codes=task_codes, page_size=page_size, chunk_size=chunk_size, ) finally: backfiller.close() def main() -> int: _reconfigure_stdout_utf8() ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据") ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)") ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)") ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)") ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入") ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)") ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)") ap.add_argument("--log-file", default="", help="日志文件路径") ap.add_argument("--log-dir", default="", help="日志目录") ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)") ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志") args = ap.parse_args() log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs") log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing") log_console = not args.no_log_console with configure_logging( "backfill_missing", log_file, level=args.log_level, console=log_console, tee_std=True, ) as logger: cfg = AppConfig.load({}) tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")) start = _parse_dt(args.start, tz) end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz) result = run_backfill( cfg=cfg, start=start, end=end, task_codes=args.task_codes or None, dry_run=args.dry_run, page_size=args.page_size, chunk_size=args.chunk_size, logger=logger, ) logger.info("=" * 60) logger.info("补全完成!") logger.info(" 总丢失: %s", result.get("total_missing", 0)) logger.info(" 已补全: %s", result.get("backfilled", 0)) logger.info(" 错误数: %s", result.get("errors", 0)) logger.info("=" * 60) # 输出详细结果 for detail in result.get("details", []): if detail.get("error"): logger.error( " %s: 丢失=%s 补全=%s 错误=%s", detail.get("task_code"), detail.get("missing"), detail.get("backfilled"), detail.get("error"), ) elif detail.get("backfilled", 0) > 0: logger.info( " %s: 丢失=%s 补全=%s", detail.get("task_code"), detail.get("missing"), detail.get("backfilled"), ) return 0 if __name__ == "__main__": raise SystemExit(main())