init: 项目初始提交 - NeoZQYY Monorepo 完整代码

This commit is contained in:
Neo
2026-02-15 14:58:14 +08:00
commit ded6dfb9d8
769 changed files with 182616 additions and 0 deletions

View File

@@ -0,0 +1,717 @@
# -*- coding: utf-8 -*-
"""
补全丢失的 ODS 数据
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
然后重新从 API 获取丢失的数据并插入 ODS。
用法:
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time as time_mod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2.extras import Json, execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.recording_client import build_recording_client
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods.ods_tasks import BaseOdsTask, ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
from scripts.check.check_ods_gaps import run_gap_check
from utils.logging_utils import build_log_path, configure_logging
from utils.ods_record_utils import (
get_value_case_insensitive,
merge_record_layers,
normalize_pk_value,
pk_tuple_from_record,
)
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(
hour=23 if is_end else 0,
minute=59 if is_end else 0,
second=59 if is_end else 0,
microsecond=0
)
return dt
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
"""根据任务代码获取 ODS 任务规格"""
for spec in ODS_TASK_SPECS:
if spec.code == code:
return spec
return None
def _merge_record_layers(record: dict) -> dict:
"""Flatten nested data layers into a single dict."""
return merge_record_layers(record)
def _get_value_case_insensitive(record: dict | None, col: str | None):
"""Fetch value without case sensitivity."""
return get_value_case_insensitive(record, col)
def _normalize_pk_value(value):
"""Normalize PK value."""
return normalize_pk_value(value)
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
"""Extract PK tuple from record."""
return pk_tuple_from_record(record, pk_cols)
def _get_table_pk_columns(conn, table: str, *, include_content_hash: bool = False) -> List[str]:
"""获取表的主键列"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [r[0] for r in cur.fetchall()]
if include_content_hash:
return cols
return [c for c in cols if c.lower() != "content_hash"]
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
"""获取表的所有列信息"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
def _fetch_existing_pk_set(
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
) -> Set[Tuple]:
"""获取已存在的 PK 集合"""
if not pk_values:
return set()
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
with conn.cursor() as cur:
for i in range(0, len(pk_values), chunk_size):
chunk = pk_values[i:i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _cast_value(value, data_type: str):
"""类型转换"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _normalize_scalar(value):
"""规范化标量值"""
if value == "" or value == "{}" or value == "[]":
return None
return value
class MissingDataBackfiller:
"""丢失数据补全器"""
def __init__(
self,
cfg: AppConfig,
logger: logging.Logger,
dry_run: bool = False,
):
self.cfg = cfg
self.logger = logger
self.dry_run = dry_run
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
self.store_id = int(cfg.get("app.store_id") or 0)
# API 客户端
self.api = build_recording_client(cfg, task_code="BACKFILL_MISSING_DATA")
# 数据库连接DatabaseConnection 构造时已设置 autocommit=False
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
def close(self):
"""关闭连接"""
if self.db:
self.db.close()
def _ensure_db(self):
"""确保数据库连接可用"""
if self.db and getattr(self.db, "conn", None) is not None:
if getattr(self.db.conn, "closed", 0) == 0:
return
self.db = DatabaseConnection(dsn=self.cfg["db"]["dsn"], session=self.cfg["db"].get("session"))
def backfill_from_gap_check(
self,
*,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
) -> Dict[str, Any]:
"""
运行 gap check 并补全丢失数据
Returns:
补全结果统计
"""
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
# 计算窗口大小
total_seconds = max(0, int((end - start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
# 运行 gap check
self.logger.info("正在执行缺失检查...")
gap_result = run_gap_check(
cfg=self.cfg,
start=start,
end=end,
window_days=window_days,
window_hours=window_hours,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=10000, # 获取所有丢失样本
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes or "",
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=self.logger,
compare_content=include_mismatch,
content_sample_limit=content_sample_limit or 10000,
)
total_missing = gap_result.get("total_missing", 0)
total_mismatch = gap_result.get("total_mismatch", 0)
if total_missing == 0 and (not include_mismatch or total_mismatch == 0):
self.logger.info("Data complete: no missing/mismatch records")
return {"backfilled": 0, "errors": 0, "details": []}
if include_mismatch:
self.logger.info("Missing/mismatch check done missing=%s mismatch=%s", total_missing, total_mismatch)
else:
self.logger.info("Missing check done missing=%s", total_missing)
results = []
total_backfilled = 0
total_errors = 0
for task_result in gap_result.get("results", []):
task_code = task_result.get("task_code")
missing = task_result.get("missing", 0)
missing_samples = task_result.get("missing_samples", [])
mismatch = task_result.get("mismatch", 0) if include_mismatch else 0
mismatch_samples = task_result.get("mismatch_samples", []) if include_mismatch else []
target_samples = list(missing_samples) + list(mismatch_samples)
if missing == 0 and mismatch == 0:
continue
self.logger.info(
"Start backfill task task=%s missing=%s mismatch=%s samples=%s",
task_code, missing, mismatch, len(target_samples)
)
try:
backfilled = self._backfill_task(
task_code=task_code,
table=task_result.get("table"),
pk_columns=task_result.get("pk_columns", []),
pk_samples=target_samples,
start=start,
end=end,
page_size=page_size,
chunk_size=chunk_size,
)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": backfilled,
"error": None,
})
total_backfilled += backfilled
except Exception as exc:
self.logger.exception("补全失败 任务=%s", task_code)
results.append({
"task_code": task_code,
"missing": missing,
"mismatch": mismatch,
"backfilled": 0,
"error": str(exc),
})
total_errors += 1
self.logger.info(
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
total_missing, total_backfilled, total_errors
)
return {
"total_missing": total_missing,
"total_mismatch": total_mismatch,
"backfilled": total_backfilled,
"errors": total_errors,
"details": results,
}
def _backfill_task(
self,
*,
task_code: str,
table: str,
pk_columns: List[str],
pk_samples: List[Dict],
start: datetime,
end: datetime,
page_size: int,
chunk_size: int,
) -> int:
"""补全单个任务的丢失数据"""
self._ensure_db()
spec = _get_spec(task_code)
if not spec:
self.logger.warning("未找到任务规格 任务=%s", task_code)
return 0
if not pk_columns:
pk_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=False)
conflict_columns = _get_table_pk_columns(self.db.conn, table, include_content_hash=True)
if not conflict_columns:
conflict_columns = pk_columns
if not pk_columns:
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
return 0
# 提取丢失的 PK 值
missing_pks: Set[Tuple] = set()
for sample in pk_samples:
pk_tuple = tuple(sample.get(col) for col in pk_columns)
if all(v is not None for v in pk_tuple):
missing_pks.add(pk_tuple)
if not missing_pks:
self.logger.info("无缺失主键 任务=%s", task_code)
return 0
self.logger.info(
"开始获取数据 任务=%s 缺失主键数=%s",
task_code, len(missing_pks)
)
# 从 API 获取数据并过滤出丢失的记录
params = self._build_params(spec, start, end)
backfilled = 0
cols_info = _get_table_columns(self.db.conn, table)
db_json_cols_lower = {
c[0].lower() for c in cols_info
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
col_names = [c[0] for c in cols_info]
# 结束只读事务,避免长时间 API 拉取导致 idle_in_tx 超时
try:
self.db.conn.commit()
except Exception:
self.db.conn.rollback()
try:
for page_no, records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
# 过滤出丢失的记录
records_to_insert = []
for rec in records:
if not isinstance(rec, dict):
continue
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
if pk_tuple and pk_tuple in missing_pks:
records_to_insert.append(rec)
if not records_to_insert:
continue
# 插入丢失的记录
if self.dry_run:
backfilled += len(records_to_insert)
self.logger.info(
"模拟运行 任务=%s 页=%s 将插入=%s",
task_code, page_no, len(records_to_insert)
)
else:
inserted = self._insert_records(
table=table,
records=records_to_insert,
cols_info=cols_info,
pk_columns=pk_columns,
conflict_columns=conflict_columns,
db_json_cols_lower=db_json_cols_lower,
)
backfilled += inserted
# 避免长事务阻塞与 idle_in_tx 超时
self.db.conn.commit()
self.logger.info(
"已插入 任务=%s 页=%s 数量=%s",
task_code, page_no, inserted
)
if not self.dry_run:
self.db.conn.commit()
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
return backfilled
except Exception:
self.db.conn.rollback()
raise
def _build_params(
self,
spec: OdsTaskSpec,
start: datetime,
end: datetime,
) -> Dict:
"""构建 API 请求参数"""
base: Dict[str, Any] = {}
if spec.include_site_id:
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [self.store_id]
else:
base["siteId"] = self.store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(start, self.tz)
base[end_key] = TypeParser.format_timestamp(end, self.tz)
# 合并公共参数
common = self.cfg.get("api.params", {}) or {}
if isinstance(common, dict):
merged = {**common, **base}
else:
merged = base
merged.update(spec.extra_params or {})
return merged
def _insert_records(
self,
*,
table: str,
records: List[Dict],
cols_info: List[Tuple[str, str, str]],
pk_columns: List[str],
conflict_columns: List[str],
db_json_cols_lower: Set[str],
) -> int:
"""插入记录到数据库"""
if not records:
return 0
col_names = [c[0] for c in cols_info]
needs_content_hash = any(c[0].lower() == "content_hash" for c in cols_info)
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
conflict_cols = conflict_columns or pk_columns
if conflict_cols:
pk_clause = ", ".join(f'"{c}"' for c in conflict_cols)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
params: List[Tuple] = []
for rec in records:
merged_rec = _merge_record_layers(rec)
# 检查 PK
if pk_columns:
missing_pk = False
for pk in pk_columns:
if str(pk).lower() == "content_hash":
continue
pk_val = _get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
continue
content_hash = None
if needs_content_hash:
content_hash = BaseOdsTask._compute_content_hash(
merged_rec, include_fetched_at=False
)
row_vals: List[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append("backfill")
continue
if col_lower == "source_endpoint":
row_vals.append("backfill")
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(_cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0
inserted = 0
with self.db.conn.cursor() as cur:
for i in range(0, len(params), 200):
chunk = params[i:i + 200]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted
def run_backfill(
*,
cfg: AppConfig,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
include_mismatch: bool = False,
dry_run: bool = False,
page_size: int = 200,
chunk_size: int = 500,
content_sample_limit: int | None = None,
logger: logging.Logger,
) -> Dict[str, Any]:
"""
运行数据补全
Args:
cfg: 应用配置
start: 开始时间
end: 结束时间
task_codes: 指定任务代码(逗号分隔)
dry_run: 是否仅预览
page_size: API 分页大小
chunk_size: 数据库批量大小
logger: 日志记录器
Returns:
补全结果
"""
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
try:
return backfiller.backfill_from_gap_check(
start=start,
end=end,
task_codes=task_codes,
include_mismatch=include_mismatch,
page_size=page_size,
chunk_size=chunk_size,
content_sample_limit=content_sample_limit,
)
finally:
backfiller.close()
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
ap.add_argument("--include-mismatch", action="store_true", help="同时补全内容不一致的记录")
ap.add_argument("--content-sample-limit", type=int, default=None, help="不一致样本上限 (默认: 10000)")
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
ap.add_argument("--log-file", default="", help="日志文件路径")
ap.add_argument("--log-dir", default="", help="日志目录")
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
log_console = not args.no_log_console
with configure_logging(
"backfill_missing",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Shanghai"))
start = _parse_dt(args.start, tz)
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
result = run_backfill(
cfg=cfg,
start=start,
end=end,
task_codes=args.task_codes or None,
include_mismatch=args.include_mismatch,
dry_run=args.dry_run,
page_size=args.page_size,
chunk_size=args.chunk_size,
content_sample_limit=args.content_sample_limit,
logger=logger,
)
logger.info("=" * 60)
logger.info("补全完成!")
logger.info(" 总丢失: %s", result.get("total_missing", 0))
if args.include_mismatch:
logger.info(" 总不一致: %s", result.get("total_mismatch", 0))
logger.info(" 已补全: %s", result.get("backfilled", 0))
logger.info(" 错误数: %s", result.get("errors", 0))
logger.info("=" * 60)
# 输出详细结果
for detail in result.get("details", []):
if detail.get("error"):
logger.error(
" %s: 丢失=%s 不一致=%s 补全=%s 错误=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
detail.get("error"),
)
elif detail.get("backfilled", 0) > 0:
logger.info(
" %s: 丢失=%s 不一致=%s 补全=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("mismatch", 0),
detail.get("backfilled"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,261 @@
# -*- coding: utf-8 -*-
"""
Deduplicate ODS snapshots by (business PK, content_hash).
Keep the latest row by fetched_at (tie-breaker: ctid desc).
Usage:
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --schema billiards_ods
PYTHONPATH=. python -m scripts.repair.dedupe_ods_snapshots --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Iterable, Sequence
import psycopg2
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _quote_ident(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
return [r[0] for r in cur.fetchall()]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_snapshot_dedupe_{ts}.json"
def _print_progress(
table_label: str,
deleted: int,
total: int,
errors: int,
) -> None:
if total:
msg = f"[{table_label}] deleted {deleted}/{total} errors={errors}"
else:
msg = f"[{table_label}] deleted {deleted} errors={errors}"
print(msg, flush=True)
def _count_duplicates(conn, schema: str, table: str, key_cols: Sequence[str]) -> int:
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
sql = f"""
SELECT COUNT(*) FROM (
SELECT 1
FROM (
SELECT ROW_NUMBER() OVER (
PARTITION BY {keys_sql}
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
) AS rn
FROM {table_sql}
) t
WHERE rn > 1
) s
"""
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _delete_duplicate_batch(
conn,
schema: str,
table: str,
key_cols: Sequence[str],
batch_size: int,
) -> int:
keys_sql = ", ".join(_quote_ident(c) for c in [*key_cols, "content_hash"])
table_sql = f"{_quote_ident(schema)}.{_quote_ident(table)}"
sql = f"""
WITH dupes AS (
SELECT ctid
FROM (
SELECT ctid,
ROW_NUMBER() OVER (
PARTITION BY {keys_sql}
ORDER BY fetched_at DESC NULLS LAST, ctid DESC
) AS rn
FROM {table_sql}
) s
WHERE rn > 1
LIMIT %s
)
DELETE FROM {table_sql} t
USING dupes d
WHERE t.ctid = d.ctid
RETURNING 1
"""
with conn.cursor() as cur:
cur.execute(sql, (int(batch_size),))
rows = cur.fetchall()
return len(rows or [])
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Deduplicate ODS snapshot rows by PK+content_hash")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=1000, help="delete batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N deletions")
ap.add_argument("--out", default="", help="output report JSON path")
ap.add_argument("--dry-run", action="store_true", help="only compute duplicate counts")
args = ap.parse_args()
cfg = AppConfig.load({})
db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db.conn.rollback()
except Exception:
pass
db.conn.autocommit = True
tables = _fetch_tables(db.conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": len(tables),
"checked_tables": 0,
"total_duplicates": 0,
"deleted_rows": 0,
"error_rows": 0,
"skipped_tables": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(db.conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "content_hash" not in cols_lower or "fetched_at" not in cols_lower:
print(f"[{table_label}] skip: missing content_hash/fetched_at", flush=True)
report["summary"]["skipped_tables"] += 1
continue
key_cols = _fetch_pk_columns(db.conn, args.schema, table)
if not key_cols:
print(f"[{table_label}] skip: missing primary key", flush=True)
report["summary"]["skipped_tables"] += 1
continue
total_dupes = _count_duplicates(db.conn, args.schema, table, key_cols)
print(f"[{table_label}] duplicates={total_dupes}", flush=True)
deleted = 0
errors = 0
if not args.dry_run and total_dupes:
while True:
try:
batch_deleted = _delete_duplicate_batch(
db.conn,
args.schema,
table,
key_cols,
args.batch_size,
)
except psycopg2.Error:
errors += 1
break
if batch_deleted <= 0:
break
deleted += batch_deleted
if args.progress_every and deleted % int(args.progress_every) == 0:
_print_progress(table_label, deleted, total_dupes, errors)
if deleted and (not args.progress_every or deleted % int(args.progress_every) != 0):
_print_progress(table_label, deleted, total_dupes, errors)
report["tables"].append(
{
"table": table_label,
"duplicate_rows": total_dupes,
"deleted_rows": deleted,
"error_rows": errors,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_duplicates"] += total_dupes
report["summary"]["deleted_rows"] += deleted
report["summary"]["error_rows"] += errors
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""修复 dim_assistant 表中的 user_id 字段"""
import sys
sys.path.insert(0, '.')
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
config = AppConfig.load()
db_conn = DatabaseConnection(config.config['db']['dsn'])
db = DatabaseOperations(db_conn)
print("=== 修复 dim_assistant.user_id ===")
# 方案:从 ODS 表更新 DWD 表的 user_id
# 通过 id (ODS) = assistant_id (DWD) 关联
# 1. 先检查当前状态
print("\n修复前:")
sql_before = """
SELECT
COUNT(*) as total,
COUNT(CASE WHEN user_id > 0 THEN 1 END) as has_user_id
FROM billiards_dwd.dim_assistant
WHERE scd2_is_current = 1
"""
r = dict(db.query(sql_before)[0])
print(f" 总记录: {r['total']}, 有user_id: {r['has_user_id']}")
# 2. 执行更新
print("\n执行更新...")
update_sql = """
UPDATE billiards_dwd.dim_assistant d
SET user_id = o.user_id
FROM (
SELECT DISTINCT ON (id) id, user_id
FROM billiards_ods.assistant_accounts_master
WHERE user_id > 0
ORDER BY id, fetched_at DESC
) o
WHERE d.assistant_id = o.id
AND (d.user_id IS NULL OR d.user_id = 0)
"""
with db_conn.conn.cursor() as cur:
cur.execute(update_sql)
updated = cur.rowcount
print(f" 更新了 {updated} 条记录")
db_conn.conn.commit()
# 3. 检查修复后状态
print("\n修复后:")
r2 = dict(db.query(sql_before)[0])
print(f" 总记录: {r2['total']}, 有user_id: {r2['has_user_id']}")
# 4. 显示样本数据
print("\n样本数据:")
sql_sample = """
SELECT assistant_id, user_id, assistant_no, nickname
FROM billiards_dwd.dim_assistant
WHERE scd2_is_current = 1
ORDER BY assistant_no::int
LIMIT 10
"""
for row in db.query(sql_sample):
r = dict(row)
print(f" assistant_id={r['assistant_id']}, user_id={r['user_id']}, no={r['assistant_no']}, nickname={r['nickname']}")
# 5. 验证与服务日志的关联
print("\n验证与服务日志的关联:")
sql_verify = """
SELECT
COUNT(DISTINCT s.user_id) as service_unique_users,
COUNT(DISTINCT CASE WHEN d.assistant_id IS NOT NULL THEN s.user_id END) as matched_users
FROM billiards_dwd.dwd_assistant_service_log s
LEFT JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id AND d.scd2_is_current = 1
WHERE s.is_delete = 0 AND s.user_id > 0
"""
r3 = dict(db.query(sql_verify)[0])
print(f" 服务日志唯一user_id: {r3['service_unique_users']}")
print(f" 能匹配到dim_assistant: {r3['matched_users']}")
match_rate = r3['matched_users'] / r3['service_unique_users'] * 100 if r3['service_unique_users'] > 0 else 0
print(f" 匹配率: {match_rate:.1f}%")
db_conn.close()
print("\n完成!")

View File

@@ -0,0 +1,302 @@
# -*- coding: utf-8 -*-
"""
Repair ODS content_hash values by recomputing from payload.
Usage:
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --schema billiards_ods
PYTHONPATH=. python -m scripts.repair.repair_ods_content_hash --tables member_profiles,orders
"""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Sequence
import psycopg2
from psycopg2.extras import RealDictCursor
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from database.connection import DatabaseConnection
from tasks.ods.ods_tasks import BaseOdsTask
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _fetch_tables(conn, schema: str) -> list[str]:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
with conn.cursor() as cur:
cur.execute(sql, (schema,))
return [r[0] for r in cur.fetchall()]
def _fetch_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c]
def _fetch_pk_columns(conn, schema: str, table: str) -> list[str]:
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, table))
cols = [r[0] for r in cur.fetchall()]
return [c for c in cols if c.lower() != "content_hash"]
def _fetch_row_count(conn, schema: str, table: str) -> int:
sql = f'SELECT COUNT(*) FROM "{schema}"."{table}"'
with conn.cursor() as cur:
cur.execute(sql)
row = cur.fetchone()
return int(row[0] if row else 0)
def _iter_rows(
conn,
schema: str,
table: str,
select_cols: Sequence[str],
batch_size: int,
) -> Iterable[dict]:
cols_sql = ", ".join("ctid" if c == "ctid" else f'"{c}"' for c in select_cols)
sql = f'SELECT {cols_sql} FROM "{schema}"."{table}"'
with conn.cursor(name=f"ods_hash_fix_{table}", cursor_factory=RealDictCursor) as cur:
cur.itersize = max(1, int(batch_size or 500))
cur.execute(sql)
for row in cur:
yield row
def _build_report_path(out_arg: str | None) -> Path:
if out_arg:
return Path(out_arg)
reports_dir = PROJECT_ROOT / "reports"
reports_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return reports_dir / f"ods_content_hash_repair_{ts}.json"
def _print_progress(
table_label: str,
processed: int,
total: int,
updated: int,
skipped: int,
conflicts: int,
errors: int,
missing_hash: int,
invalid_payload: int,
) -> None:
if total:
msg = (
f"[{table_label}] checked {processed}/{total} "
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
else:
msg = (
f"[{table_label}] checked {processed} "
f"updated={updated} skipped={skipped} conflicts={conflicts} errors={errors} "
f"missing_hash={missing_hash} invalid_payload={invalid_payload}"
)
print(msg, flush=True)
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Repair ODS content_hash using payload")
ap.add_argument("--schema", default="billiards_ods", help="ODS schema name")
ap.add_argument("--tables", default="", help="comma-separated table names (optional)")
ap.add_argument("--batch-size", type=int, default=500, help="DB fetch batch size")
ap.add_argument("--progress-every", type=int, default=100, help="print progress every N rows")
ap.add_argument("--sample-limit", type=int, default=10, help="sample conflicts per table")
ap.add_argument("--out", default="", help="output report JSON path")
ap.add_argument("--dry-run", action="store_true", help="only compute stats, do not update")
args = ap.parse_args()
cfg = AppConfig.load({})
db_read = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
db_write = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db_write.conn.rollback()
except Exception:
pass
db_write.conn.autocommit = True
tables = _fetch_tables(db_read.conn, args.schema)
if args.tables.strip():
whitelist = {t.strip() for t in args.tables.split(",") if t.strip()}
tables = [t for t in tables if t in whitelist]
report = {
"schema": args.schema,
"tables": [],
"summary": {
"total_tables": len(tables),
"checked_tables": 0,
"total_rows": 0,
"checked_rows": 0,
"updated_rows": 0,
"skipped_rows": 0,
"conflict_rows": 0,
"error_rows": 0,
"missing_hash_rows": 0,
"invalid_payload_rows": 0,
},
}
for table in tables:
table_label = f"{args.schema}.{table}"
cols = _fetch_columns(db_read.conn, args.schema, table)
cols_lower = {c.lower() for c in cols}
if "payload" not in cols_lower or "content_hash" not in cols_lower:
print(f"[{table_label}] skip: missing payload/content_hash", flush=True)
continue
total = _fetch_row_count(db_read.conn, args.schema, table)
pk_cols = _fetch_pk_columns(db_read.conn, args.schema, table)
select_cols = ["ctid", "content_hash", "payload", *pk_cols]
processed = 0
updated = 0
skipped = 0
conflicts = 0
errors = 0
missing_hash = 0
invalid_payload = 0
samples: list[dict[str, Any]] = []
print(f"[{table_label}] start: total_rows={total}", flush=True)
for row in _iter_rows(db_read.conn, args.schema, table, select_cols, args.batch_size):
processed += 1
content_hash = row.get("content_hash")
payload = row.get("payload")
recomputed = BaseOdsTask._compute_compare_hash_from_payload(payload)
row_ctid = row.get("ctid")
if not content_hash:
missing_hash += 1
if not recomputed:
invalid_payload += 1
if not recomputed:
skipped += 1
elif content_hash == recomputed:
skipped += 1
else:
if args.dry_run:
updated += 1
else:
try:
with db_write.conn.cursor() as cur:
cur.execute(
f'UPDATE "{args.schema}"."{table}" SET content_hash = %s WHERE ctid = %s',
(recomputed, row_ctid),
)
updated += 1
except psycopg2.errors.UniqueViolation:
conflicts += 1
if len(samples) < max(0, int(args.sample_limit or 0)):
sample = {k: row.get(k) for k in pk_cols}
sample["content_hash"] = content_hash
sample["recomputed_hash"] = recomputed
samples.append(sample)
except psycopg2.Error:
errors += 1
if args.progress_every and processed % int(args.progress_every) == 0:
_print_progress(
table_label,
processed,
total,
updated,
skipped,
conflicts,
errors,
missing_hash,
invalid_payload,
)
if processed and (not args.progress_every or processed % int(args.progress_every) != 0):
_print_progress(
table_label,
processed,
total,
updated,
skipped,
conflicts,
errors,
missing_hash,
invalid_payload,
)
report["tables"].append(
{
"table": table_label,
"total_rows": total,
"checked_rows": processed,
"updated_rows": updated,
"skipped_rows": skipped,
"conflict_rows": conflicts,
"error_rows": errors,
"missing_hash_rows": missing_hash,
"invalid_payload_rows": invalid_payload,
"conflict_samples": samples,
}
)
report["summary"]["checked_tables"] += 1
report["summary"]["total_rows"] += total
report["summary"]["checked_rows"] += processed
report["summary"]["updated_rows"] += updated
report["summary"]["skipped_rows"] += skipped
report["summary"]["conflict_rows"] += conflicts
report["summary"]["error_rows"] += errors
report["summary"]["missing_hash_rows"] += missing_hash
report["summary"]["invalid_payload_rows"] += invalid_payload
out_path = _build_report_path(args.out)
out_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[REPORT] {out_path}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""Create performance indexes for integrity verification and run ANALYZE.
Usage:
python -m scripts.tune_integrity_indexes
python -m scripts.tune_integrity_indexes --dry-run
"""
from __future__ import annotations
import argparse
import hashlib
from dataclasses import dataclass
from typing import Dict, List, Sequence, Set, Tuple
import psycopg2
from psycopg2 import sql
from config.settings import AppConfig
TIME_CANDIDATES = (
"pay_time",
"create_time",
"start_use_time",
"scd2_start_time",
"calc_time",
"order_date",
"fetched_at",
)
@dataclass(frozen=True)
class IndexPlan:
schema: str
table: str
index_name: str
columns: Tuple[str, ...]
def _short_index_name(table: str, tag: str, columns: Sequence[str]) -> str:
raw = f"idx_{table}_{tag}_{'_'.join(columns)}"
if len(raw) <= 63:
return raw
digest = hashlib.md5(raw.encode("utf-8")).hexdigest()[:8]
shortened = f"idx_{table}_{tag}_{digest}"
return shortened[:63]
def _load_table_columns(cur, schema: str, table: str) -> Set[str]:
cur.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
""",
(schema, table),
)
return {r[0] for r in cur.fetchall()}
def _load_pk_columns(cur, schema: str, table: str) -> List[str]:
cur.execute(
"""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
AND tc.table_name = kcu.table_name
WHERE tc.table_schema = %s
AND tc.table_name = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
""",
(schema, table),
)
return [r[0] for r in cur.fetchall()]
def _load_tables(cur, schema: str) -> List[str]:
cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
ORDER BY table_name
""",
(schema,),
)
return [r[0] for r in cur.fetchall()]
def _plan_indexes(cur, schema: str, table: str) -> List[IndexPlan]:
plans: List[IndexPlan] = []
cols = _load_table_columns(cur, schema, table)
pk_cols = _load_pk_columns(cur, schema, table)
if schema == "billiards_ods":
if "fetched_at" in cols:
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "fetched_at", ("fetched_at",)),
columns=("fetched_at",),
)
)
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
comp_cols = ("fetched_at", *pk_cols)
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "fetched_pk", comp_cols),
columns=comp_cols,
)
)
if schema == "billiards_dwd":
if pk_cols and "scd2_is_current" in cols and len(pk_cols) <= 4:
comp_cols = (*pk_cols, "scd2_is_current")
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "pk_current", comp_cols),
columns=comp_cols,
)
)
for tcol in TIME_CANDIDATES:
if tcol in cols:
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "time", (tcol,)),
columns=(tcol,),
)
)
if pk_cols and len(pk_cols) <= 3 and all(c in cols for c in pk_cols):
comp_cols = (tcol, *pk_cols)
plans.append(
IndexPlan(
schema=schema,
table=table,
index_name=_short_index_name(table, "time_pk", comp_cols),
columns=comp_cols,
)
)
# 按索引名去重
dedup: Dict[str, IndexPlan] = {}
for p in plans:
dedup[p.index_name] = p
return list(dedup.values())
def _create_index(cur, plan: IndexPlan) -> None:
stmt = sql.SQL("CREATE INDEX IF NOT EXISTS {idx} ON {sch}.{tbl} ({cols})").format(
idx=sql.Identifier(plan.index_name),
sch=sql.Identifier(plan.schema),
tbl=sql.Identifier(plan.table),
cols=sql.SQL(", ").join(sql.Identifier(c) for c in plan.columns),
)
cur.execute(stmt)
def _analyze_table(cur, schema: str, table: str) -> None:
stmt = sql.SQL("ANALYZE {sch}.{tbl}").format(
sch=sql.Identifier(schema),
tbl=sql.Identifier(table),
)
cur.execute(stmt)
def main() -> int:
ap = argparse.ArgumentParser(description="Tune indexes for integrity verification.")
ap.add_argument("--dry-run", action="store_true", help="Print planned SQL only.")
ap.add_argument(
"--skip-analyze",
action="store_true",
help="Create indexes but skip ANALYZE.",
)
args = ap.parse_args()
cfg = AppConfig.load({})
dsn = cfg.get("db.dsn")
timeout_sec = int(cfg.get("db.connect_timeout_sec", 10) or 10)
with psycopg2.connect(dsn, connect_timeout=timeout_sec) as conn:
conn.autocommit = False
with conn.cursor() as cur:
all_plans: List[IndexPlan] = []
for schema in ("billiards_ods", "billiards_dwd"):
for table in _load_tables(cur, schema):
all_plans.extend(_plan_indexes(cur, schema, table))
touched_tables: Set[Tuple[str, str]] = set()
print(f"planned indexes: {len(all_plans)}")
for plan in all_plans:
cols = ", ".join(plan.columns)
print(f"[INDEX] {plan.schema}.{plan.table} ({cols}) -> {plan.index_name}")
if not args.dry_run:
_create_index(cur, plan)
touched_tables.add((plan.schema, plan.table))
if not args.skip_analyze:
if args.dry_run:
for schema, table in sorted({(p.schema, p.table) for p in all_plans}):
print(f"[ANALYZE] {schema}.{table}")
else:
for schema, table in sorted(touched_tables):
_analyze_table(cur, schema, table)
print(f"[ANALYZE] {schema}.{table}")
if args.dry_run:
conn.rollback()
print("dry-run complete; transaction rolled back")
else:
conn.commit()
print("index tuning complete")
return 0
if __name__ == "__main__":
raise SystemExit(main())