在准备环境前提交次全部更改。
This commit is contained in:
@@ -0,0 +1,717 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
补全丢失的 ODS 数据
|
||||
|
||||
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
|
||||
然后重新从 API 获取丢失的数据并插入 ODS。
|
||||
|
||||
用法:
|
||||
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
|
||||
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time as time_mod
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from dateutil import parser as dtparser
|
||||
from psycopg2.extras import Json, execute_values
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from api.recording_client import build_recording_client
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from models.parsers import TypeParser
|
||||
from tasks.ods.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
|
||||
from scripts.check.check_ods_gaps import run_gap_check
|
||||
from utils.logging_utils import build_log_path, configure_logging
|
||||
from utils.ods_record_utils import (
|
||||
get_value_case_insensitive,
|
||||
merge_record_layers,
|
||||
normalize_pk_value,
|
||||
pk_tuple_from_record,
|
||||
)
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
|
||||
raw = (value or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("empty datetime")
|
||||
has_time = any(ch in raw for ch in (":", "T"))
|
||||
dt = dtparser.parse(raw)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=tz)
|
||||
else:
|
||||
dt = dt.astimezone(tz)
|
||||
if not has_time:
|
||||
dt = dt.replace(
|
||||
hour=23 if is_end else 0,
|
||||
minute=59 if is_end else 0,
|
||||
second=59 if is_end else 0,
|
||||
microsecond=0
|
||||
)
|
||||
return dt
|
||||
|
||||
|
||||
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
|
||||
"""根据任务代码获取 ODS 任务规格"""
|
||||
for spec in ODS_TASK_SPECS:
|
||||
if spec.code == code:
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def _merge_record_layers(record: dict) -> dict:
|
||||
"""Flatten nested data layers into a single dict."""
|
||||
return merge_record_layers(record)
|
||||
|
||||
|
||||
def _get_value_case_insensitive(record: dict | None, col: str | None):
|
||||
"""Fetch value without case sensitivity."""
|
||||
return get_value_case_insensitive(record, col)
|
||||
|
||||
|
||||
def _normalize_pk_value(value):
|
||||
"""Normalize PK value."""
|
||||
return normalize_pk_value(value)
|
||||
|
||||
|
||||
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
|
||||
"""Extract PK tuple from record."""
|
||||
return pk_tuple_from_record(record, pk_cols)
|
||||
|
||||
|
||||
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
|
||||
"""获取表的主键列"""
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
if include_content_hash:
|
||||
return cols
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
|
||||
"""获取表的所有列信息"""
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT column_name, data_type, udt_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_existing_pk_set(
|
||||
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
|
||||
) -> Set[Tuple]:
|
||||
"""获取已存在的 PK 集合"""
|
||||
if not pk_values:
|
||||
return set()
|
||||
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
|
||||
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
|
||||
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
|
||||
sql = (
|
||||
f"SELECT {select_cols} FROM {table} t "
|
||||
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
|
||||
)
|
||||
existing: Set[Tuple] = set()
|
||||
with conn.cursor() as cur:
|
||||
for i in range(0, len(pk_values), chunk_size):
|
||||
chunk = pk_values[i:i + chunk_size]
|
||||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||||
for row in cur.fetchall():
|
||||
existing.add(tuple(row))
|
||||
return existing
|
||||
|
||||
|
||||
def _cast_value(value, data_type: str):
|
||||
"""类型转换"""
|
||||
if value is None:
|
||||
return None
|
||||
dt = (data_type or "").lower()
|
||||
if dt in ("integer", "bigint", "smallint"):
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
if dt in ("numeric", "double precision", "real", "decimal"):
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
|
||||
return value if isinstance(value, (str, datetime)) else None
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_scalar(value):
|
||||
"""规范化标量值"""
|
||||
if value == "" or value == "{}" or value == "[]":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class MissingDataBackfiller:
|
||||
"""丢失数据补全器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AppConfig,
|
||||
logger: logging.Logger,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
self.cfg = cfg
|
||||
self.logger = logger
|
||||
self.dry_run = dry_run
|
||||
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
|
||||
self.store_id = int(cfg.get("app.store_id") or 0)
|
||||
|
||||
# API 客户端
|
||||
self.api = build_recording_client(cfg, task_code="BACKFILL_MISSING_DATA")
|
||||
|
||||
# 数据库连接(DatabaseConnection 构造时已设置 autocommit=False)
|
||||
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
||||
def _ensure_db(self):
|
||||
"""确保数据库连接可用"""
|
||||
if self.db and getattr(self.db, "conn", None) is not None:
|
||||
if getattr(self.db.conn, "closed", 0) == 0:
|
||||
return
|
||||
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
|
||||
|
||||
def backfill_from_gap_check(
|
||||
self,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
task_codes: Optional[str] = None,
|
||||
include_mismatch: bool = False,
|
||||
page_size: int = 200,
|
||||
chunk_size: int = 500,
|
||||
content_sample_limit: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行 gap check 并补全丢失数据
|
||||
|
||||
Returns:
|
||||
补全结果统计
|
||||
"""
|
||||
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
|
||||
|
||||
# 计算窗口大小
|
||||
total_seconds = max(0, int((end - start).total_seconds()))
|
||||
if total_seconds >= 86400:
|
||||
window_days = max(1, total_seconds // 86400)
|
||||
window_hours = 0
|
||||
else:
|
||||
window_days = 0
|
||||
window_hours = max(1, total_seconds // 3600 or 1)
|
||||
|
||||
# 运行 gap check
|
||||
self.logger.info("正在执行缺失检查...")
|
||||
gap_result = run_gap_check(
|
||||
cfg=self.cfg,
|
||||
start=start,
|
||||
end=end,
|
||||
window_days=window_days,
|
||||
window_hours=window_hours,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
sample_limit=10000, # 获取所有丢失样本
|
||||
sleep_per_window=0,
|
||||
sleep_per_page=0,
|
||||
task_codes=task_codes or "",
|
||||
from_cutoff=False,
|
||||
cutoff_overlap_hours=24,
|
||||
allow_small_window=True,
|
||||
logger=self.logger,
|
||||
compare_content=include_mismatch,
|
||||
content_sample_limit=content_sample_limit or 10000,
|
||||
)
|
||||
|
||||
total_missing = gap_result.get("total_missing", 0)
|
||||
total_mismatch = gap_result.get("total_mismatch", 0)
|
||||
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
|
||||
self.logger.info("Data complete: no missing/mismatch records")
|
||||
return {"backfilled": 0, "errors": 0, "details": []}
|
||||
|
||||
if include_mismatch:
|
||||
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
|
||||
else:
|
||||
self.logger.info("Missing check done missing=%s", total_missing)
|
||||
|
||||
results = []
|
||||
total_backfilled = 0
|
||||
total_errors = 0
|
||||
|
||||
for task_result in gap_result.get("results", []):
|
||||
task_code = task_result.get("task_code")
|
||||
missing = task_result.get("missing", 0)
|
||||
missing_samples = task_result.get("missing_samples", [])
|
||||
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
|
||||
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
|
||||
target_samples = list(missing_samples) + list(mismatch_samples)
|
||||
|
||||
if missing == 0 and mismatch == 0:
|
||||
continue
|
||||
|
||||
self.logger.info(
|
||||
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
|
||||
task_code, missing, mismatch, len(target_samples)
|
||||
)
|
||||
|
||||
try:
|
||||
backfilled = self._backfill_task(
|
||||
task_code=task_code,
|
||||
table=task_result.get("table"),
|
||||
pk_columns=task_result.get("pk_columns", []),
|
||||
pk_samples=target_samples,
|
||||
start=start,
|
||||
end=end,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
results.append({
|
||||
"task_code": task_code,
|
||||
"missing": missing,
|
||||
"mismatch": mismatch,
|
||||
"backfilled": backfilled,
|
||||
"error": None,
|
||||
})
|
||||
total_backfilled += backfilled
|
||||
except Exception as exc:
|
||||
self.logger.exception("补全失败 任务=%s", task_code)
|
||||
results.append({
|
||||
"task_code": task_code,
|
||||
"missing": missing,
|
||||
"mismatch": mismatch,
|
||||
"backfilled": 0,
|
||||
"error": str(exc),
|
||||
})
|
||||
total_errors += 1
|
||||
|
||||
self.logger.info(
|
||||
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
|
||||
total_missing, total_backfilled, total_errors
|
||||
)
|
||||
|
||||
return {
|
||||
"total_missing": total_missing,
|
||||
"total_mismatch": total_mismatch,
|
||||
"backfilled": total_backfilled,
|
||||
"errors": total_errors,
|
||||
"details": results,
|
||||
}
|
||||
|
||||
def _backfill_task(
|
||||
self,
|
||||
*,
|
||||
task_code: str,
|
||||
table: str,
|
||||
pk_columns: List[str],
|
||||
pk_samples: List[Dict],
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> int:
|
||||
"""补全单个任务的丢失数据"""
|
||||
self._ensure_db()
|
||||
spec = _get_spec(task_code)
|
||||
if not spec:
|
||||
self.logger.warning("未找到任务规格 任务=%s", task_code)
|
||||
return 0
|
||||
|
||||
if not pk_columns:
|
||||
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
|
||||
|
||||
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
|
||||
if not conflict_columns:
|
||||
conflict_columns = pk_columns
|
||||
|
||||
if not pk_columns:
|
||||
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
|
||||
return 0
|
||||
|
||||
# 提取丢失的 PK 值
|
||||
missing_pks: Set[Tuple] = set()
|
||||
for sample in pk_samples:
|
||||
pk_tuple = tuple(sample.get(col) for col in pk_columns)
|
||||
if all(v is not None for v in pk_tuple):
|
||||
missing_pks.add(pk_tuple)
|
||||
|
||||
if not missing_pks:
|
||||
self.logger.info("无缺失主键 任务=%s", task_code)
|
||||
return 0
|
||||
|
||||
self.logger.info(
|
||||
"开始获取数据 任务=%s 缺失主键数=%s",
|
||||
task_code, len(missing_pks)
|
||||
)
|
||||
|
||||
# 从 API 获取数据并过滤出丢失的记录
|
||||
params = self._build_params(spec, start, end)
|
||||
|
||||
backfilled = 0
|
||||
cols_info = _get_table_columns(self.db.conn, table)
|
||||
db_json_cols_lower = {
|
||||
c[0].lower() for c in cols_info
|
||||
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
|
||||
}
|
||||
col_names = [c[0] for c in cols_info]
|
||||
|
||||
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
|
||||
try:
|
||||
self.db.conn.commit()
|
||||
except Exception:
|
||||
self.db.conn.rollback()
|
||||
|
||||
try:
|
||||
for page_no, records, _, response_payload in self.api.iter_paginated(
|
||||
endpoint=spec.endpoint,
|
||||
params=params,
|
||||
page_size=page_size,
|
||||
data_path=spec.data_path,
|
||||
list_key=spec.list_key,
|
||||
):
|
||||
# 过滤出丢失的记录
|
||||
records_to_insert = []
|
||||
for rec in records:
|
||||
if not isinstance(rec, dict):
|
||||
continue
|
||||
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
|
||||
if pk_tuple and pk_tuple in missing_pks:
|
||||
records_to_insert.append(rec)
|
||||
|
||||
if not records_to_insert:
|
||||
continue
|
||||
|
||||
# 插入丢失的记录
|
||||
if self.dry_run:
|
||||
backfilled += len(records_to_insert)
|
||||
self.logger.info(
|
||||
"模拟运行 任务=%s 页=%s 将插入=%s",
|
||||
task_code, page_no, len(records_to_insert)
|
||||
)
|
||||
else:
|
||||
inserted = self._insert_records(
|
||||
table=table,
|
||||
records=records_to_insert,
|
||||
cols_info=cols_info,
|
||||
pk_columns=pk_columns,
|
||||
conflict_columns=conflict_columns,
|
||||
db_json_cols_lower=db_json_cols_lower,
|
||||
)
|
||||
backfilled += inserted
|
||||
# 避免长事务阻塞与 idle_in_tx 超时
|
||||
self.db.conn.commit()
|
||||
self.logger.info(
|
||||
"已插入 任务=%s 页=%s 数量=%s",
|
||||
task_code, page_no, inserted
|
||||
)
|
||||
|
||||
if not self.dry_run:
|
||||
self.db.conn.commit()
|
||||
|
||||
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
|
||||
return backfilled
|
||||
|
||||
except Exception:
|
||||
self.db.conn.rollback()
|
||||
raise
|
||||
|
||||
def _build_params(
|
||||
self,
|
||||
spec: OdsTaskSpec,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> Dict:
|
||||
"""构建 API 请求参数"""
|
||||
base: Dict[str, Any] = {}
|
||||
if spec.include_site_id:
|
||||
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
|
||||
base["siteId"] = [self.store_id]
|
||||
else:
|
||||
base["siteId"] = self.store_id
|
||||
|
||||
if spec.requires_window and spec.time_fields:
|
||||
start_key, end_key = spec.time_fields
|
||||
base[start_key] = TypeParser.format_timestamp(start, self.tz)
|
||||
base[end_key] = TypeParser.format_timestamp(end, self.tz)
|
||||
|
||||
# 合并公共参数
|
||||
common = self.cfg.get("api.params", {}) or {}
|
||||
if isinstance(common, dict):
|
||||
merged = {**common, **base}
|
||||
else:
|
||||
merged = base
|
||||
|
||||
merged.update(spec.extra_params or {})
|
||||
return merged
|
||||
|
||||
def _insert_records(
|
||||
self,
|
||||
*,
|
||||
table: str,
|
||||
records: List[Dict],
|
||||
cols_info: List[Tuple[str, str, str]],
|
||||
pk_columns: List[str],
|
||||
conflict_columns: List[str],
|
||||
db_json_cols_lower: Set[str],
|
||||
) -> int:
|
||||
"""插入记录到数据库"""
|
||||
if not records:
|
||||
return 0
|
||||
|
||||
col_names = [c[0] for c in cols_info]
|
||||
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
|
||||
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
|
||||
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
|
||||
conflict_cols = conflict_columns or pk_columns
|
||||
if conflict_cols:
|
||||
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
|
||||
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
|
||||
|
||||
now = datetime.now(self.tz)
|
||||
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
|
||||
|
||||
params: List[Tuple] = []
|
||||
for rec in records:
|
||||
merged_rec = _merge_record_layers(rec)
|
||||
|
||||
# 检查 PK
|
||||
if pk_columns:
|
||||
missing_pk = False
|
||||
for pk in pk_columns:
|
||||
if str(pk).lower() == "content_hash":
|
||||
continue
|
||||
pk_val = _get_value_case_insensitive(merged_rec, pk)
|
||||
if pk_val is None or pk_val == "":
|
||||
missing_pk = True
|
||||
break
|
||||
if missing_pk:
|
||||
continue
|
||||
|
||||
content_hash = None
|
||||
if needs_content_hash:
|
||||
content_hash = BaseOdsTask._compute_content_hash(
|
||||
merged_rec, include_fetched_at=False
|
||||
)
|
||||
|
||||
row_vals: List[Any] = []
|
||||
for (col_name, data_type, _udt) in cols_info:
|
||||
col_lower = col_name.lower()
|
||||
if col_lower == "payload":
|
||||
row_vals.append(Json(rec, dumps=json_dump))
|
||||
continue
|
||||
if col_lower == "source_file":
|
||||
row_vals.append("backfill")
|
||||
continue
|
||||
if col_lower == "source_endpoint":
|
||||
row_vals.append("backfill")
|
||||
continue
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(now)
|
||||
continue
|
||||
if col_lower == "content_hash":
|
||||
row_vals.append(content_hash)
|
||||
continue
|
||||
|
||||
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
|
||||
if col_lower in db_json_cols_lower:
|
||||
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
|
||||
continue
|
||||
|
||||
row_vals.append(_cast_value(value, data_type))
|
||||
|
||||
params.append(tuple(row_vals))
|
||||
|
||||
if not params:
|
||||
return 0
|
||||
|
||||
inserted = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for i in range(0, len(params), 200):
|
||||
chunk = params[i:i + 200]
|
||||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||||
if cur.rowcount is not None and cur.rowcount > 0:
|
||||
inserted += int(cur.rowcount)
|
||||
|
||||
return inserted
|
||||
|
||||
|
||||
def run_backfill(
|
||||
*,
|
||||
cfg: AppConfig,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
task_codes: Optional[str] = None,
|
||||
include_mismatch: bool = False,
|
||||
dry_run: bool = False,
|
||||
page_size: int = 200,
|
||||
chunk_size: int = 500,
|
||||
content_sample_limit: int | None = None,
|
||||
logger: logging.Logger,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行数据补全
|
||||
|
||||
Args:
|
||||
cfg: 应用配置
|
||||
start: 开始时间
|
||||
end: 结束时间
|
||||
task_codes: 指定任务代码(逗号分隔)
|
||||
dry_run: 是否仅预览
|
||||
page_size: API 分页大小
|
||||
chunk_size: 数据库批量大小
|
||||
logger: 日志记录器
|
||||
|
||||
Returns:
|
||||
补全结果
|
||||
"""
|
||||
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
|
||||
try:
|
||||
return backfiller.backfill_from_gap_check(
|
||||
start=start,
|
||||
end=end,
|
||||
task_codes=task_codes,
|
||||
include_mismatch=include_mismatch,
|
||||
page_size=page_size,
|
||||
chunk_size=chunk_size,
|
||||
content_sample_limit=content_sample_limit,
|
||||
)
|
||||
finally:
|
||||
backfiller.close()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
|
||||
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
|
||||
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
|
||||
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
|
||||
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
|
||||
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
|
||||
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
|
||||
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
|
||||
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
|
||||
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
|
||||
ap.add_argument("--log-file", default="", help="日志文件路径")
|
||||
ap.add_argument("--log-dir", default="", help="日志目录")
|
||||
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
|
||||
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
|
||||
args = ap.parse_args()
|
||||
|
||||
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
|
||||
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
|
||||
log_console = not args.no_log_console
|
||||
|
||||
with configure_logging(
|
||||
"backfill_missing",
|
||||
log_file,
|
||||
level=args.log_level,
|
||||
console=log_console,
|
||||
tee_std=True,
|
||||
) as logger:
|
||||
cfg = AppConfig.load({})
|
||||
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
|
||||
|
||||
start = _parse_dt(args.start, tz)
|
||||
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
|
||||
|
||||
result = run_backfill(
|
||||
cfg=cfg,
|
||||
start=start,
|
||||
end=end,
|
||||
task_codes=args.task_codes or None,
|
||||
include_mismatch=args.include_mismatch,
|
||||
dry_run=args.dry_run,
|
||||
page_size=args.page_size,
|
||||
chunk_size=args.chunk_size,
|
||||
content_sample_limit=args.content_sample_limit,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("补全完成!")
|
||||
logger.info(" 总丢失: %s", result.get("total_missing", 0))
|
||||
if args.include_mismatch:
|
||||
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
|
||||
logger.info(" 已补全: %s", result.get("backfilled", 0))
|
||||
logger.info(" 错误数: %s", result.get("errors", 0))
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 输出详细结果
|
||||
for detail in result.get("details", []):
|
||||
if detail.get("error"):
|
||||
logger.error(
|
||||
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
|
||||
detail.get("task_code"),
|
||||
detail.get("missing"),
|
||||
detail.get("mismatch", 0),
|
||||
detail.get("backfilled"),
|
||||
detail.get("error"),
|
||||
)
|
||||
elif detail.get("backfilled", 0) > 0:
|
||||
logger.info(
|
||||
" %s: 丢失=%s 不一致=%s 补全=%s",
|
||||
detail.get("task_code"),
|
||||
detail.get("missing"),
|
||||
detail.get("mismatch", 0),
|
||||
detail.get("backfilled"),
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,261 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Deduplicate ODS snapshots by (business PK, content_hash).
|
||||
Keep the latest row by fetched_at (tie-breaker: ctid desc).
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots
|
||||
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --schema ods
|
||||
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --tables member_profiles,orders
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import psycopg2
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _quote_ident(name: str) -> str:
|
||||
return '"' + str(name).replace('"', '""') + '"'
|
||||
|
||||
|
||||
def _fetch_tables(conn, schema: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema,))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _build_report_path(out_arg: str | None) -> Path:
|
||||
if out_arg:
|
||||
return Path(out_arg)
|
||||
reports_dir = PROJECT_ROOT / "reports"
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return reports_dir / f"ods_snapshot_dedupe_{ts}.json"
|
||||
|
||||
|
||||
def _print_progress(
|
||||
table_label: str,
|
||||
deleted: int,
|
||||
total: int,
|
||||
errors: int,
|
||||
) -> None:
|
||||
if total:
|
||||
msg = f"[{table_label}] deleted {deleted}/{total} errors={errors}"
|
||||
else:
|
||||
msg = f"[{table_label}] deleted {deleted} errors={errors}"
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def _count_duplicates(conn, schema: str, table: str, key_cols: Sequence[str]) -> int:
|
||||
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
|
||||
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||||
sql = f"""
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM (
|
||||
SELECT ROW_NUMBER() OVER (
|
||||
PARTITION BY {keys_sql}
|
||||
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
|
||||
) AS rn
|
||||
FROM {table_sql}
|
||||
) t
|
||||
WHERE rn > 1
|
||||
) s
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
row = cur.fetchone()
|
||||
return int(row[0] if row else 0)
|
||||
|
||||
|
||||
def _delete_duplicate_batch(
|
||||
conn,
|
||||
schema: str,
|
||||
table: str,
|
||||
key_cols: Sequence[str],
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
|
||||
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
|
||||
sql = f"""
|
||||
WITH dupes AS (
|
||||
SELECT ctid
|
||||
FROM (
|
||||
SELECT ctid,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY {keys_sql}
|
||||
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
|
||||
) AS rn
|
||||
FROM {table_sql}
|
||||
) s
|
||||
WHERE rn > 1
|
||||
LIMIT %s
|
||||
)
|
||||
DELETE FROM {table_sql} t
|
||||
USING dupes d
|
||||
WHERE t.ctid = d.ctid
|
||||
RETURNING 1
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (int(batch_size),))
|
||||
rows = cur.fetchall()
|
||||
return len(rows or [])
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
ap = argparse.ArgumentParser(description="Deduplicate ODS snapshot rows by PK+content_hash")
|
||||
ap.add_argument("--schema", default="ods", help="ODS schema name")
|
||||
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
|
||||
ap.add_argument("--batch-size", type=int, default=1000, help="delete batch size")
|
||||
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N deletions")
|
||||
ap.add_argument("--out", default="", help="output report JSON path")
|
||||
ap.add_argument("--dry-run", action="store_true", help="only compute duplicate counts")
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
try:
|
||||
db.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
db.conn.autocommit = True
|
||||
|
||||
tables = _fetch_tables(db.conn, args.schema)
|
||||
if args.tables.strip():
|
||||
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
|
||||
tables = [t for t in tables if t in whitelist]
|
||||
|
||||
report = {
|
||||
"schema": args.schema,
|
||||
"tables": [],
|
||||
"summary": {
|
||||
"total_tables": len(tables),
|
||||
"checked_tables": 0,
|
||||
"total_duplicates": 0,
|
||||
"deleted_rows": 0,
|
||||
"error_rows": 0,
|
||||
"skipped_tables": 0,
|
||||
},
|
||||
}
|
||||
|
||||
for table in tables:
|
||||
table_label = f"{args.schema}.{table}"
|
||||
cols = _fetch_columns(db.conn, args.schema, table)
|
||||
cols_lower = {c.lower() for c in cols}
|
||||
if "content_hash" not in cols_lower or "fetched_at" not in cols_lower:
|
||||
print(f"[{table_label}] skip: missing content_hash/fetched_at", flush=True)
|
||||
report["summary"]["skipped_tables"] += 1
|
||||
continue
|
||||
|
||||
key_cols = _fetch_pk_columns(db.conn, args.schema, table)
|
||||
if not key_cols:
|
||||
print(f"[{table_label}] skip: missing primary key", flush=True)
|
||||
report["summary"]["skipped_tables"] += 1
|
||||
continue
|
||||
|
||||
total_dupes = _count_duplicates(db.conn, args.schema, table, key_cols)
|
||||
print(f"[{table_label}] duplicates={total_dupes}", flush=True)
|
||||
deleted = 0
|
||||
errors = 0
|
||||
|
||||
if not args.dry_run and total_dupes:
|
||||
while True:
|
||||
try:
|
||||
batch_deleted = _delete_duplicate_batch(
|
||||
db.conn,
|
||||
args.schema,
|
||||
table,
|
||||
key_cols,
|
||||
args.batch_size,
|
||||
)
|
||||
except psycopg2.Error:
|
||||
errors += 1
|
||||
break
|
||||
if batch_deleted <= 0:
|
||||
break
|
||||
deleted += batch_deleted
|
||||
if args.progress_every and deleted % int(args.progress_every) == 0:
|
||||
_print_progress(table_label, deleted, total_dupes, errors)
|
||||
|
||||
if deleted and (not args.progress_every or deleted % int(args.progress_every) != 0):
|
||||
_print_progress(table_label, deleted, total_dupes, errors)
|
||||
|
||||
report["tables"].append(
|
||||
{
|
||||
"table": table_label,
|
||||
"duplicate_rows": total_dupes,
|
||||
"deleted_rows": deleted,
|
||||
"error_rows": errors,
|
||||
}
|
||||
)
|
||||
report["summary"]["checked_tables"] += 1
|
||||
report["summary"]["total_duplicates"] += total_dupes
|
||||
report["summary"]["deleted_rows"] += deleted
|
||||
report["summary"]["error_rows"] += errors
|
||||
|
||||
out_path = _build_report_path(args.out)
|
||||
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"[REPORT] {out_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,86 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""修复 dim_assistant 表中的 user_id 字段"""
|
||||
import sys
|
||||
sys.path.insert(0, '.')
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
|
||||
config = AppConfig.load()
|
||||
db_conn = DatabaseConnection(config.config['db']['dsn'])
|
||||
db = DatabaseOperations(db_conn)
|
||||
|
||||
print("=== 修复 dim_assistant.user_id ===")
|
||||
|
||||
# 方案:从 ODS 表更新 DWD 表的 user_id
|
||||
# 通过 id (ODS) = assistant_id (DWD) 关联
|
||||
|
||||
# 1. 先检查当前状态
|
||||
print("\n修复前:")
|
||||
sql_before = """
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
COUNT(CASE WHEN user_id > 0 THEN 1 END) as has_user_id
|
||||
FROM dwd.dim_assistant
|
||||
WHERE scd2_is_current = 1
|
||||
"""
|
||||
r = dict(db.query(sql_before)[0])
|
||||
print(f" 总记录: {r['total']}, 有user_id: {r['has_user_id']}")
|
||||
|
||||
# 2. 执行更新
|
||||
print("\n执行更新...")
|
||||
update_sql = """
|
||||
UPDATE dwd.dim_assistant d
|
||||
SET user_id = o.user_id
|
||||
FROM (
|
||||
SELECT DISTINCT ON (id) id, user_id
|
||||
FROM ods.assistant_accounts_master
|
||||
WHERE user_id > 0
|
||||
ORDER BY id, fetched_at DESC
|
||||
) o
|
||||
WHERE d.assistant_id = o.id
|
||||
AND (d.user_id IS NULL OR d.user_id = 0)
|
||||
"""
|
||||
with db_conn.conn.cursor() as cur:
|
||||
cur.execute(update_sql)
|
||||
updated = cur.rowcount
|
||||
print(f" 更新了 {updated} 条记录")
|
||||
db_conn.conn.commit()
|
||||
|
||||
# 3. 检查修复后状态
|
||||
print("\n修复后:")
|
||||
r2 = dict(db.query(sql_before)[0])
|
||||
print(f" 总记录: {r2['total']}, 有user_id: {r2['has_user_id']}")
|
||||
|
||||
# 4. 显示样本数据
|
||||
print("\n样本数据:")
|
||||
sql_sample = """
|
||||
SELECT assistant_id, user_id, assistant_no, nickname
|
||||
FROM dwd.dim_assistant
|
||||
WHERE scd2_is_current = 1
|
||||
ORDER BY assistant_no::int
|
||||
LIMIT 10
|
||||
"""
|
||||
for row in db.query(sql_sample):
|
||||
r = dict(row)
|
||||
print(f" assistant_id={r['assistant_id']}, user_id={r['user_id']}, no={r['assistant_no']}, nickname={r['nickname']}")
|
||||
|
||||
# 5. 验证与服务日志的关联
|
||||
print("\n验证与服务日志的关联:")
|
||||
sql_verify = """
|
||||
SELECT
|
||||
COUNT(DISTINCT s.user_id) as service_unique_users,
|
||||
COUNT(DISTINCT CASE WHEN d.assistant_id IS NOT NULL THEN s.user_id END) as matched_users
|
||||
FROM dwd.dwd_assistant_service_log s
|
||||
LEFT JOIN dwd.dim_assistant d
|
||||
ON s.user_id = d.user_id AND d.scd2_is_current = 1
|
||||
WHERE s.is_delete = 0 AND s.user_id > 0
|
||||
"""
|
||||
r3 = dict(db.query(sql_verify)[0])
|
||||
print(f" 服务日志唯一user_id: {r3['service_unique_users']}")
|
||||
print(f" 能匹配到dim_assistant: {r3['matched_users']}")
|
||||
match_rate = r3['matched_users'] / r3['service_unique_users'] * 100 if r3['service_unique_users'] > 0 else 0
|
||||
print(f" 匹配率: {match_rate:.1f}%")
|
||||
|
||||
db_conn.close()
|
||||
print("\n完成!")
|
||||
@@ -0,0 +1,302 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Repair ODS content_hash values by recomputing from payload.
|
||||
|
||||
Usage:
|
||||
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash
|
||||
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema ods
|
||||
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from tasks.ods.ods_tasks import BaseOdsTask
|
||||
|
||||
|
||||
def _reconfigure_stdout_utf8() -> None:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_tables(conn, schema: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema,))
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c]
|
||||
|
||||
|
||||
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, table))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
return [c for c in cols if c.lower() != "content_hash"]
|
||||
|
||||
|
||||
def _fetch_row_count(conn, schema: str, table: str) -> int:
|
||||
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
row = cur.fetchone()
|
||||
return int(row[0] if row else 0)
|
||||
|
||||
|
||||
def _iter_rows(
|
||||
conn,
|
||||
schema: str,
|
||||
table: str,
|
||||
select_cols: Sequence[str],
|
||||
batch_size: int,
|
||||
) -> Iterable[dict]:
|
||||
cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols)
|
||||
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
|
||||
with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur:
|
||||
cur.itersize = max(1, int(batch_size or 500))
|
||||
cur.execute(sql)
|
||||
for row in cur:
|
||||
yield row
|
||||
|
||||
|
||||
def _build_report_path(out_arg: str | None) -> Path:
|
||||
if out_arg:
|
||||
return Path(out_arg)
|
||||
reports_dir = PROJECT_ROOT / "reports"
|
||||
reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return reports_dir / f"ods_content_hash_repair_{ts}.json"
|
||||
|
||||
|
||||
def _print_progress(
|
||||
table_label: str,
|
||||
processed: int,
|
||||
total: int,
|
||||
updated: int,
|
||||
skipped: int,
|
||||
conflicts: int,
|
||||
errors: int,
|
||||
missing_hash: int,
|
||||
invalid_payload: int,
|
||||
) -> None:
|
||||
if total:
|
||||
msg = (
|
||||
f"[{table_label}] checked {processed}/{total} "
|
||||
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
|
||||
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"[{table_label}] checked {processed} "
|
||||
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
|
||||
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
|
||||
)
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
_reconfigure_stdout_utf8()
|
||||
ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload")
|
||||
ap.add_argument("--schema", default="ods", help="ODS schema name")
|
||||
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
|
||||
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
|
||||
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
|
||||
ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table")
|
||||
ap.add_argument("--out", default="", help="output report JSON path")
|
||||
ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update")
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
|
||||
try:
|
||||
db_write.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
db_write.conn.autocommit = True
|
||||
|
||||
tables = _fetch_tables(db_read.conn, args.schema)
|
||||
if args.tables.strip():
|
||||
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
|
||||
tables = [t for t in tables if t in whitelist]
|
||||
|
||||
report = {
|
||||
"schema": args.schema,
|
||||
"tables": [],
|
||||
"summary": {
|
||||
"total_tables": len(tables),
|
||||
"checked_tables": 0,
|
||||
"total_rows": 0,
|
||||
"checked_rows": 0,
|
||||
"updated_rows": 0,
|
||||
"skipped_rows": 0,
|
||||
"conflict_rows": 0,
|
||||
"error_rows": 0,
|
||||
"missing_hash_rows": 0,
|
||||
"invalid_payload_rows": 0,
|
||||
},
|
||||
}
|
||||
|
||||
for table in tables:
|
||||
table_label = f"{args.schema}.{table}"
|
||||
cols = _fetch_columns(db_read.conn, args.schema, table)
|
||||
cols_lower = {c.lower() for c in cols}
|
||||
if "payload" not in cols_lower or "content_hash" not in cols_lower:
|
||||
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
|
||||
continue
|
||||
|
||||
total = _fetch_row_count(db_read.conn, args.schema, table)
|
||||
pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table)
|
||||
select_cols = ["ctid", "content_hash", "payload", *pk_cols]
|
||||
|
||||
processed = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
conflicts = 0
|
||||
errors = 0
|
||||
missing_hash = 0
|
||||
invalid_payload = 0
|
||||
samples: list[dict[str, Any]] = []
|
||||
|
||||
print(f"[{table_label}] start: total_rows={total}", flush=True)
|
||||
|
||||
for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size):
|
||||
processed += 1
|
||||
content_hash = row.get("content_hash")
|
||||
payload = row.get("payload")
|
||||
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
|
||||
row_ctid = row.get("ctid")
|
||||
|
||||
if not content_hash:
|
||||
missing_hash += 1
|
||||
if not recomputed:
|
||||
invalid_payload += 1
|
||||
|
||||
if not recomputed:
|
||||
skipped += 1
|
||||
elif content_hash == recomputed:
|
||||
skipped += 1
|
||||
else:
|
||||
if args.dry_run:
|
||||
updated += 1
|
||||
else:
|
||||
try:
|
||||
with db_write.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s',
|
||||
(recomputed, row_ctid),
|
||||
)
|
||||
updated += 1
|
||||
except psycopg2.errors.UniqueViolation:
|
||||
conflicts += 1
|
||||
if len(samples) < max(0, int(args.sample_limit or 0)):
|
||||
sample = {k: row.get(k) for k in pk_cols}
|
||||
sample["content_hash"] = content_hash
|
||||
sample["recomputed_hash"] = recomputed
|
||||
samples.append(sample)
|
||||
except psycopg2.Error:
|
||||
errors += 1
|
||||
|
||||
if args.progress_every and processed % int(args.progress_every) == 0:
|
||||
_print_progress(
|
||||
table_label,
|
||||
processed,
|
||||
total,
|
||||
updated,
|
||||
skipped,
|
||||
conflicts,
|
||||
errors,
|
||||
missing_hash,
|
||||
invalid_payload,
|
||||
)
|
||||
|
||||
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
|
||||
_print_progress(
|
||||
table_label,
|
||||
processed,
|
||||
total,
|
||||
updated,
|
||||
skipped,
|
||||
conflicts,
|
||||
errors,
|
||||
missing_hash,
|
||||
invalid_payload,
|
||||
)
|
||||
|
||||
report["tables"].append(
|
||||
{
|
||||
"table": table_label,
|
||||
"total_rows": total,
|
||||
"checked_rows": processed,
|
||||
"updated_rows": updated,
|
||||
"skipped_rows": skipped,
|
||||
"conflict_rows": conflicts,
|
||||
"error_rows": errors,
|
||||
"missing_hash_rows": missing_hash,
|
||||
"invalid_payload_rows": invalid_payload,
|
||||
"conflict_samples": samples,
|
||||
}
|
||||
)
|
||||
|
||||
report["summary"]["checked_tables"] += 1
|
||||
report["summary"]["total_rows"] += total
|
||||
report["summary"]["checked_rows"] += processed
|
||||
report["summary"]["updated_rows"] += updated
|
||||
report["summary"]["skipped_rows"] += skipped
|
||||
report["summary"]["conflict_rows"] += conflicts
|
||||
report["summary"]["error_rows"] += errors
|
||||
report["summary"]["missing_hash_rows"] += missing_hash
|
||||
report["summary"]["invalid_payload_rows"] += invalid_payload
|
||||
|
||||
out_path = _build_report_path(args.out)
|
||||
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"[REPORT] {out_path}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Create performance indexes for integrity verification and run ANALYZE.
|
||||
|
||||
Usage:
|
||||
python -m scripts.tune_integrity_indexes
|
||||
python -m scripts.tune_integrity_indexes --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Set, Tuple
|
||||
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
|
||||
from config.settings import AppConfig
|
||||
|
||||
|
||||
TIME_CANDIDATES = (
|
||||
"pay_time",
|
||||
"create_time",
|
||||
"start_use_time",
|
||||
"scd2_start_time",
|
||||
"calc_time",
|
||||
"order_date",
|
||||
"fetched_at",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IndexPlan:
|
||||
schema: str
|
||||
table: str
|
||||
index_name: str
|
||||
columns: Tuple[str, ...]
|
||||
|
||||
|
||||
def _short_index_name(table: str, tag: str, columns: Sequence[str]) -> str:
|
||||
raw = f"idx_{table}_{tag}_{'_'.join(columns)}"
|
||||
if len(raw) <= 63:
|
||||
return raw
|
||||
digest = hashlib.md5(raw.encode("utf-8")).hexdigest()[:8]
|
||||
shortened = f"idx_{table}_{tag}_{digest}"
|
||||
return shortened[:63]
|
||||
|
||||
|
||||
def _load_table_columns(cur, schema: str, table: str) -> Set[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
return {r[0] for r in cur.fetchall()}
|
||||
|
||||
|
||||
def _load_pk_columns(cur, schema: str, table: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
AND tc.table_name = kcu.table_name
|
||||
WHERE tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
ORDER BY kcu.ordinal_position
|
||||
""",
|
||||
(schema, table),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _load_tables(cur, schema: str) -> List[str]:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = %s
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
""",
|
||||
(schema,),
|
||||
)
|
||||
return [r[0] for r in cur.fetchall()]
|
||||
|
||||
|
||||
def _plan_indexes(cur, schema: str, table: str) -> List[IndexPlan]:
|
||||
plans: List[IndexPlan] = []
|
||||
cols = _load_table_columns(cur, schema, table)
|
||||
pk_cols = _load_pk_columns(cur, schema, table)
|
||||
|
||||
if schema == "ods":
|
||||
if "fetched_at" in cols:
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "fetched_at", ("fetched_at",)),
|
||||
columns=("fetched_at",),
|
||||
)
|
||||
)
|
||||
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
|
||||
comp_cols = ("fetched_at", *pk_cols)
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "fetched_pk", comp_cols),
|
||||
columns=comp_cols,
|
||||
)
|
||||
)
|
||||
|
||||
if schema == "dwd":
|
||||
if pk_cols and "scd2_is_current" in cols and len(pk_cols) <= 4:
|
||||
comp_cols = (*pk_cols, "scd2_is_current")
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "pk_current", comp_cols),
|
||||
columns=comp_cols,
|
||||
)
|
||||
)
|
||||
|
||||
for tcol in TIME_CANDIDATES:
|
||||
if tcol in cols:
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "time", (tcol,)),
|
||||
columns=(tcol,),
|
||||
)
|
||||
)
|
||||
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
|
||||
comp_cols = (tcol, *pk_cols)
|
||||
plans.append(
|
||||
IndexPlan(
|
||||
schema=schema,
|
||||
table=table,
|
||||
index_name=_short_index_name(table, "time_pk", comp_cols),
|
||||
columns=comp_cols,
|
||||
)
|
||||
)
|
||||
|
||||
# 按索引名去重
|
||||
dedup: Dict[str, IndexPlan] = {}
|
||||
for p in plans:
|
||||
dedup[p.index_name] = p
|
||||
return list(dedup.values())
|
||||
|
||||
|
||||
def _create_index(cur, plan: IndexPlan) -> None:
|
||||
stmt = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {sch}.{tbl} ({cols})").format(
|
||||
idx=sql.Identifier(plan.index_name),
|
||||
sch=sql.Identifier(plan.schema),
|
||||
tbl=sql.Identifier(plan.table),
|
||||
cols=sql.SQL(", ").join(sql.Identifier(c) for c in plan.columns),
|
||||
)
|
||||
cur.execute(stmt)
|
||||
|
||||
|
||||
def _analyze_table(cur, schema: str, table: str) -> None:
|
||||
stmt = sql.SQL("ANALYZE {sch}.{tbl}").format(
|
||||
sch=sql.Identifier(schema),
|
||||
tbl=sql.Identifier(table),
|
||||
)
|
||||
cur.execute(stmt)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(description="Tune indexes for integrity verification.")
|
||||
ap.add_argument("--dry-run", action="store_true", help="Print planned SQL only.")
|
||||
ap.add_argument(
|
||||
"--skip-analyze",
|
||||
action="store_true",
|
||||
help="Create indexes but skip ANALYZE.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
cfg = AppConfig.load({})
|
||||
dsn = cfg.get("db.dsn")
|
||||
timeout_sec = int(cfg.get("db.connect_timeout_sec", 10) or 10)
|
||||
|
||||
with psycopg2.connect(dsn, connect_timeout=timeout_sec) as conn:
|
||||
conn.autocommit = False
|
||||
with conn.cursor() as cur:
|
||||
all_plans: List[IndexPlan] = []
|
||||
for schema in ("ods", "dwd"):
|
||||
for table in _load_tables(cur, schema):
|
||||
all_plans.extend(_plan_indexes(cur, schema, table))
|
||||
|
||||
touched_tables: Set[Tuple[str, str]] = set()
|
||||
print(f"planned indexes: {len(all_plans)}")
|
||||
for plan in all_plans:
|
||||
cols = ", ".join(plan.columns)
|
||||
print(f"[INDEX] {plan.schema}.{plan.table} ({cols}) -> {plan.index_name}")
|
||||
if not args.dry_run:
|
||||
_create_index(cur, plan)
|
||||
touched_tables.add((plan.schema, plan.table))
|
||||
|
||||
if not args.skip_analyze:
|
||||
if args.dry_run:
|
||||
for schema, table in sorted({(p.schema, p.table) for p in all_plans}):
|
||||
print(f"[ANALYZE] {schema}.{table}")
|
||||
else:
|
||||
for schema, table in sorted(touched_tables):
|
||||
_analyze_table(cur, schema, table)
|
||||
print(f"[ANALYZE] {schema}.{table}")
|
||||
|
||||
if args.dry_run:
|
||||
conn.rollback()
|
||||
print("dry-run complete; transaction rolled back")
|
||||
else:
|
||||
conn.commit()
|
||||
print("index tuning complete")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
||||
Reference in New Issue
Block a user