This commit is contained in:
Neo
2026-01-27 22:45:50 +08:00
parent a6ad343092
commit 4c192e921c
476 changed files with 381543 additions and 5819 deletions

View File

@@ -0,0 +1 @@
# Script helpers package marker.

View File

@@ -0,0 +1,684 @@
# -*- coding: utf-8 -*-
"""
补全丢失的 ODS 数据
通过运行数据完整性检查,找出 API 与 ODS 之间的差异,
然后重新从 API 获取丢失的数据并插入 ODS。
用法:
python -m scripts.backfill_missing_data --start 2025-07-01 --end 2026-01-19
python -m scripts.backfill_missing_data --from-report reports/ods_gap_check_xxx.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
import time as time_mod
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2.extras import Json, execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.client import APIClient
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods_tasks import ENABLED_ODS_CODES, ODS_TASK_SPECS, OdsTaskSpec
from scripts.check_ods_gaps import run_gap_check
from utils.logging_utils import build_log_path, configure_logging
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool = False) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(
hour=23 if is_end else 0,
minute=59 if is_end else 0,
second=59 if is_end else 0,
microsecond=0
)
return dt
def _get_spec(code: str) -> Optional[OdsTaskSpec]:
"""根据任务代码获取 ODS 任务规格"""
for spec in ODS_TASK_SPECS:
if spec.code == code:
return spec
return None
def _merge_record_layers(record: dict) -> dict:
"""展开嵌套的 data 层"""
merged = record
data_part = merged.get("data")
while isinstance(data_part, dict):
merged = {**data_part, **merged}
data_part = data_part.get("data")
settle_inner = merged.get("settleList")
if isinstance(settle_inner, dict):
merged = {**settle_inner, **merged}
return merged
def _get_value_case_insensitive(record: dict | None, col: str | None):
"""不区分大小写地获取值"""
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
def _normalize_pk_value(value):
"""规范化 PK 值"""
if value is None:
return None
if isinstance(value, str) and value.isdigit():
try:
return int(value)
except Exception:
return value
return value
def _pk_tuple_from_record(record: dict, pk_cols: List[str]) -> Optional[Tuple]:
"""从记录中提取 PK 元组"""
merged = _merge_record_layers(record)
values = []
for col in pk_cols:
val = _normalize_pk_value(_get_value_case_insensitive(merged, col))
if val is None or val == "":
return None
values.append(val)
return tuple(values)
def _get_table_pk_columns(conn, table: str) -> List[str]:
"""获取表的主键列"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [r[0] for r in cur.fetchall()]
def _get_table_columns(conn, table: str) -> List[Tuple[str, str, str]]:
"""获取表的所有列信息"""
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
def _fetch_existing_pk_set(
conn, table: str, pk_cols: List[str], pk_values: List[Tuple], chunk_size: int
) -> Set[Tuple]:
"""获取已存在的 PK 集合"""
if not pk_values:
return set()
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: Set[Tuple] = set()
with conn.cursor() as cur:
for i in range(0, len(pk_values), chunk_size):
chunk = pk_values[i:i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _cast_value(value, data_type: str):
"""类型转换"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _normalize_scalar(value):
"""规范化标量值"""
if value == "" or value == "{}" or value == "[]":
return None
return value
class MissingDataBackfiller:
"""丢失数据补全器"""
def __init__(
self,
cfg: AppConfig,
logger: logging.Logger,
dry_run: bool = False,
):
self.cfg = cfg
self.logger = logger
self.dry_run = dry_run
self.tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
self.store_id = int(cfg.get("app.store_id") or 0)
# API 客户端
self.api = APIClient(
base_url=cfg["api"]["base_url"],
token=cfg["api"]["token"],
timeout=int(cfg["api"].get("timeout_sec") or 20),
retry_max=int(cfg["api"].get("retries", {}).get("max_attempts") or 3),
headers_extra=cfg["api"].get("headers_extra") or {},
)
# 数据库连接DatabaseConnection 构造时已设置 autocommit=False
self.db = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
def close(self):
"""关闭连接"""
if self.db:
self.db.close()
def backfill_from_gap_check(
self,
*,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
page_size: int = 200,
chunk_size: int = 500,
) -> Dict[str, Any]:
"""
运行 gap check 并补全丢失数据
Returns:
补全结果统计
"""
self.logger.info("数据补全开始 起始=%s 结束=%s", start.isoformat(), end.isoformat())
# 计算窗口大小
total_seconds = max(0, int((end - start).total_seconds()))
if total_seconds >= 86400:
window_days = max(1, total_seconds // 86400)
window_hours = 0
else:
window_days = 0
window_hours = max(1, total_seconds // 3600 or 1)
# 运行 gap check
self.logger.info("正在执行缺失检查...")
gap_result = run_gap_check(
cfg=self.cfg,
start=start,
end=end,
window_days=window_days,
window_hours=window_hours,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=10000, # 获取所有丢失样本
sleep_per_window=0,
sleep_per_page=0,
task_codes=task_codes or "",
from_cutoff=False,
cutoff_overlap_hours=24,
allow_small_window=True,
logger=self.logger,
)
total_missing = gap_result.get("total_missing", 0)
if total_missing == 0:
self.logger.info("数据完整,无缺失记录")
return {"backfilled": 0, "errors": 0, "details": []}
self.logger.info("缺失检查完成 总缺失=%s", total_missing)
# 补全每个任务的丢失数据
results = []
total_backfilled = 0
total_errors = 0
for task_result in gap_result.get("results", []):
task_code = task_result.get("task_code")
missing = task_result.get("missing", 0)
missing_samples = task_result.get("missing_samples", [])
if missing == 0:
continue
self.logger.info(
"开始补全任务 任务=%s 缺失=%s 样本数=%s",
task_code, missing, len(missing_samples)
)
try:
backfilled = self._backfill_task(
task_code=task_code,
table=task_result.get("table"),
pk_columns=task_result.get("pk_columns", []),
missing_samples=missing_samples,
start=start,
end=end,
page_size=page_size,
chunk_size=chunk_size,
)
results.append({
"task_code": task_code,
"missing": missing,
"backfilled": backfilled,
"error": None,
})
total_backfilled += backfilled
except Exception as exc:
self.logger.exception("补全失败 任务=%s", task_code)
results.append({
"task_code": task_code,
"missing": missing,
"backfilled": 0,
"error": str(exc),
})
total_errors += 1
self.logger.info(
"数据补全完成 总缺失=%s 已补全=%s 错误数=%s",
total_missing, total_backfilled, total_errors
)
return {
"total_missing": total_missing,
"backfilled": total_backfilled,
"errors": total_errors,
"details": results,
}
def _backfill_task(
self,
*,
task_code: str,
table: str,
pk_columns: List[str],
missing_samples: List[Dict],
start: datetime,
end: datetime,
page_size: int,
chunk_size: int,
) -> int:
"""补全单个任务的丢失数据"""
spec = _get_spec(task_code)
if not spec:
self.logger.warning("未找到任务规格 任务=%s", task_code)
return 0
if not pk_columns:
pk_columns = _get_table_pk_columns(self.db.conn, table)
if not pk_columns:
self.logger.warning("未找到主键列 任务=%s 表=%s", task_code, table)
return 0
# 提取丢失的 PK 值
missing_pks: Set[Tuple] = set()
for sample in missing_samples:
pk_tuple = tuple(sample.get(col) for col in pk_columns)
if all(v is not None for v in pk_tuple):
missing_pks.add(pk_tuple)
if not missing_pks:
self.logger.info("无缺失主键 任务=%s", task_code)
return 0
self.logger.info(
"开始获取数据 任务=%s 缺失主键数=%s",
task_code, len(missing_pks)
)
# 从 API 获取数据并过滤出丢失的记录
params = self._build_params(spec, start, end)
backfilled = 0
cols_info = _get_table_columns(self.db.conn, table)
db_json_cols_lower = {
c[0].lower() for c in cols_info
if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
col_names = [c[0] for c in cols_info]
try:
for page_no, records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
# 过滤出丢失的记录
records_to_insert = []
for rec in records:
if not isinstance(rec, dict):
continue
pk_tuple = _pk_tuple_from_record(rec, pk_columns)
if pk_tuple and pk_tuple in missing_pks:
records_to_insert.append(rec)
if not records_to_insert:
continue
# 插入丢失的记录
if self.dry_run:
backfilled += len(records_to_insert)
self.logger.info(
"模拟运行 任务=%s 页=%s 将插入=%s",
task_code, page_no, len(records_to_insert)
)
else:
inserted = self._insert_records(
table=table,
records=records_to_insert,
cols_info=cols_info,
pk_columns=pk_columns,
db_json_cols_lower=db_json_cols_lower,
)
backfilled += inserted
self.logger.info(
"已插入 任务=%s 页=%s 数量=%s",
task_code, page_no, inserted
)
if not self.dry_run:
self.db.conn.commit()
self.logger.info("任务补全完成 任务=%s 已补全=%s", task_code, backfilled)
return backfilled
except Exception:
self.db.conn.rollback()
raise
def _build_params(
self,
spec: OdsTaskSpec,
start: datetime,
end: datetime,
) -> Dict:
"""构建 API 请求参数"""
base: Dict[str, Any] = {}
if spec.include_site_id:
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [self.store_id]
else:
base["siteId"] = self.store_id
if spec.requires_window and spec.time_fields:
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(start, self.tz)
base[end_key] = TypeParser.format_timestamp(end, self.tz)
# 合并公共参数
common = self.cfg.get("api.params", {}) or {}
if isinstance(common, dict):
merged = {**common, **base}
else:
merged = base
merged.update(spec.extra_params or {})
return merged
def _insert_records(
self,
*,
table: str,
records: List[Dict],
cols_info: List[Tuple[str, str, str]],
pk_columns: List[str],
db_json_cols_lower: Set[str],
) -> int:
"""插入记录到数据库"""
if not records:
return 0
col_names = [c[0] for c in cols_info]
quoted_cols = ", ".join(f'"{c}"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
if pk_columns:
pk_clause = ", ".join(f'"{c}"' for c in pk_columns)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False)
params: List[Tuple] = []
for rec in records:
merged_rec = _merge_record_layers(rec)
# 检查 PK
if pk_columns:
missing_pk = False
for pk in pk_columns:
pk_val = _get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
continue
row_vals: List[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append("backfill")
continue
if col_lower == "source_endpoint":
row_vals.append("backfill")
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
value = _normalize_scalar(_get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(_cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0
inserted = 0
with self.db.conn.cursor() as cur:
for i in range(0, len(params), 200):
chunk = params[i:i + 200]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted
def run_backfill(
*,
cfg: AppConfig,
start: datetime,
end: datetime,
task_codes: Optional[str] = None,
dry_run: bool = False,
page_size: int = 200,
chunk_size: int = 500,
logger: logging.Logger,
) -> Dict[str, Any]:
"""
运行数据补全
Args:
cfg: 应用配置
start: 开始时间
end: 结束时间
task_codes: 指定任务代码(逗号分隔)
dry_run: 是否仅预览
page_size: API 分页大小
chunk_size: 数据库批量大小
logger: 日志记录器
Returns:
补全结果
"""
backfiller = MissingDataBackfiller(cfg, logger, dry_run)
try:
return backfiller.backfill_from_gap_check(
start=start,
end=end,
task_codes=task_codes,
page_size=page_size,
chunk_size=chunk_size,
)
finally:
backfiller.close()
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="补全丢失的 ODS 数据")
ap.add_argument("--start", default="2025-07-01", help="开始日期 (默认: 2025-07-01)")
ap.add_argument("--end", default="", help="结束日期 (默认: 当前时间)")
ap.add_argument("--task-codes", default="", help="指定任务代码(逗号分隔,留空=全部)")
ap.add_argument("--dry-run", action="store_true", help="仅预览,不实际写入")
ap.add_argument("--page-size", type=int, default=200, help="API 分页大小 (默认: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="数据库批量大小 (默认: 500)")
ap.add_argument("--log-file", default="", help="日志文件路径")
ap.add_argument("--log-dir", default="", help="日志目录")
ap.add_argument("--log-level", default="INFO", help="日志级别 (默认: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="禁用控制台日志")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "backfill_missing")
log_console = not args.no_log_console
with configure_logging(
"backfill_missing",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
start = _parse_dt(args.start, tz)
end = _parse_dt(args.end, tz, is_end=True) if args.end else datetime.now(tz)
result = run_backfill(
cfg=cfg,
start=start,
end=end,
task_codes=args.task_codes or None,
dry_run=args.dry_run,
page_size=args.page_size,
chunk_size=args.chunk_size,
logger=logger,
)
logger.info("=" * 60)
logger.info("补全完成!")
logger.info(" 总丢失: %s", result.get("total_missing", 0))
logger.info(" 已补全: %s", result.get("backfilled", 0))
logger.info(" 错误数: %s", result.get("errors", 0))
logger.info("=" * 60)
# 输出详细结果
for detail in result.get("details", []):
if detail.get("error"):
logger.error(
" %s: 丢失=%s 补全=%s 错误=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("backfilled"),
detail.get("error"),
)
elif detail.get("backfilled", 0) > 0:
logger.info(
" %s: 丢失=%s 补全=%s",
detail.get("task_code"),
detail.get("missing"),
detail.get("backfilled"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""Populate PRD DWD tables from ODS payload snapshots."""
from __future__ import annotations
@@ -16,9 +16,9 @@ SQL_STEPS: list[tuple[str, str]] = [
INSERT INTO billiards_dwd.dim_tenant (tenant_id, tenant_name, status)
SELECT DISTINCT tenant_id, 'default' AS tenant_name, 'active' AS status
FROM (
SELECT tenant_id FROM billiards_ods.ods_order_settle
SELECT tenant_id FROM billiards_ods.settlement_records
UNION SELECT tenant_id FROM billiards_ods.ods_order_receipt_detail
UNION SELECT tenant_id FROM billiards_ods.ods_member_profile
UNION SELECT tenant_id FROM billiards_ods.member_profiles
) s
WHERE tenant_id IS NOT NULL
ON CONFLICT (tenant_id) DO UPDATE SET updated_at = now();
@@ -30,7 +30,7 @@ SQL_STEPS: list[tuple[str, str]] = [
INSERT INTO billiards_dwd.dim_site (site_id, tenant_id, site_name, status)
SELECT DISTINCT site_id, MAX(tenant_id) AS tenant_id, 'default' AS site_name, 'active' AS status
FROM (
SELECT site_id, tenant_id FROM billiards_ods.ods_order_settle
SELECT site_id, tenant_id FROM billiards_ods.settlement_records
UNION SELECT site_id, tenant_id FROM billiards_ods.ods_order_receipt_detail
UNION SELECT site_id, tenant_id FROM billiards_ods.ods_table_info
) s
@@ -84,7 +84,7 @@ SQL_STEPS: list[tuple[str, str]] = [
"""
INSERT INTO billiards_dwd.dim_member_card_type (card_type_id, card_type_name, discount_rate)
SELECT DISTINCT card_type_id, card_type_name, discount_rate
FROM billiards_ods.ods_member_card
FROM billiards_ods.member_stored_value_cards
WHERE card_type_id IS NOT NULL
ON CONFLICT (card_type_id) DO UPDATE SET
card_type_name = EXCLUDED.card_type_name,
@@ -119,10 +119,10 @@ SQL_STEPS: list[tuple[str, str]] = [
prof.wechat_id,
prof.alipay_id,
prof.remarks
FROM billiards_ods.ods_member_profile prof
FROM billiards_ods.member_profiles prof
LEFT JOIN (
SELECT DISTINCT site_id, member_id, card_type_id AS member_type_id, card_type_name AS member_type_name
FROM billiards_ods.ods_member_card
FROM billiards_ods.member_stored_value_cards
) card
ON prof.site_id = card.site_id AND prof.member_id = card.member_id
WHERE prof.member_id IS NOT NULL
@@ -167,7 +167,7 @@ SQL_STEPS: list[tuple[str, str]] = [
"""
INSERT INTO billiards_dwd.dim_assistant (assistant_id, assistant_name, mobile, status)
SELECT DISTINCT assistant_id, assistant_name, mobile, status
FROM billiards_ods.ods_assistant_account
FROM billiards_ods.assistant_accounts_master
WHERE assistant_id IS NOT NULL
ON CONFLICT (assistant_id) DO UPDATE SET
assistant_name = EXCLUDED.assistant_name,
@@ -181,7 +181,7 @@ SQL_STEPS: list[tuple[str, str]] = [
"""
INSERT INTO billiards_dwd.dim_pay_method (pay_method_code, pay_method_name, is_stored_value, status)
SELECT DISTINCT pay_method_code, pay_method_name, FALSE AS is_stored_value, 'active' AS status
FROM billiards_ods.ods_payment_record
FROM billiards_ods.payment_transactions
WHERE pay_method_code IS NOT NULL
ON CONFLICT (pay_method_code) DO UPDATE SET
pay_method_name = EXCLUDED.pay_method_name,
@@ -250,7 +250,7 @@ SQL_STEPS: list[tuple[str, str]] = [
final_table_fee,
FALSE AS is_canceled,
NULL::TIMESTAMPTZ AS cancel_time
FROM billiards_ods.ods_table_use_log
FROM billiards_ods.table_fee_transactions_log
ON CONFLICT (site_id, ledger_id) DO NOTHING;
""",
),
@@ -325,7 +325,7 @@ SQL_STEPS: list[tuple[str, str]] = [
pay_time,
relate_type,
relate_id
FROM billiards_ods.ods_payment_record
FROM billiards_ods.payment_transactions
ON CONFLICT (site_id, pay_id) DO NOTHING;
""",
),
@@ -346,7 +346,7 @@ SQL_STEPS: list[tuple[str, str]] = [
refund_amount,
refund_time,
status
FROM billiards_ods.ods_refund_record
FROM billiards_ods.refund_transactions
ON CONFLICT (site_id, refund_id) DO NOTHING;
""",
),
@@ -369,7 +369,7 @@ SQL_STEPS: list[tuple[str, str]] = [
balance_before,
balance_after,
change_time
FROM billiards_ods.ods_balance_change
FROM billiards_ods.member_balance_changes
ON CONFLICT (site_id, change_id) DO NOTHING;
""",
),
@@ -423,3 +423,4 @@ def main() -> int:
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""Recompute billiards_dws.dws_order_summary from DWD fact tables."""
"""Recompute billiards_dws.dws_order_summary from DWD tables (dwd_*)."""
from __future__ import annotations
import argparse
@@ -15,119 +15,90 @@ from database.connection import DatabaseConnection # noqa: E402
SQL_BUILD_SUMMARY = r"""
WITH table_fee AS (
WITH base AS (
SELECT
sh.site_id,
sh.order_settle_id,
sh.order_trade_no,
COALESCE(sh.pay_time, sh.create_time)::date AS order_date,
sh.tenant_id,
sh.member_id,
COALESCE(sh.is_bind_member, FALSE) AS member_flag,
(COALESCE(sh.consume_money, 0) = 0 AND COALESCE(sh.pay_amount, 0) > 0) AS recharge_order_flag,
COALESCE(sh.member_discount_amount, 0) AS member_discount_amount,
COALESCE(sh.adjust_amount, 0) AS manual_discount_amount,
COALESCE(sh.pay_amount, 0) AS total_paid_amount,
COALESCE(sh.balance_amount, 0) + COALESCE(sh.recharge_card_amount, 0) + COALESCE(sh.gift_card_amount, 0) AS stored_card_deduct,
COALESCE(sh.coupon_amount, 0) AS total_coupon_deduction,
COALESCE(sh.table_charge_money, 0) AS settle_table_fee_amount,
COALESCE(sh.assistant_pd_money, 0) + COALESCE(sh.assistant_cx_money, 0) AS settle_assistant_service_amount,
COALESCE(sh.real_goods_money, 0) AS settle_goods_amount
FROM billiards_dwd.dwd_settlement_head sh
WHERE (%(site_id)s IS NULL OR sh.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR COALESCE(sh.pay_time, sh.create_time)::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR COALESCE(sh.pay_time, sh.create_time)::date <= %(end_date)s)
),
table_fee AS (
SELECT
site_id,
order_settle_id,
order_trade_no,
MIN(member_id) AS member_id,
SUM(COALESCE(final_table_fee, 0)) AS table_fee_amount,
SUM(COALESCE(member_discount_amount, 0)) AS member_discount_amount,
SUM(COALESCE(manual_discount_amount, 0)) AS manual_discount_amount,
SUM(COALESCE(original_table_fee, 0)) AS original_table_fee,
MIN(start_time) AS first_time
FROM billiards_dwd.fact_table_usage
WHERE (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_time::date <= %(end_date)s)
AND COALESCE(is_canceled, FALSE) = FALSE
GROUP BY site_id, order_settle_id, order_trade_no
SUM(COALESCE(real_table_charge_money, 0)) AS table_fee_amount
FROM billiards_dwd.dwd_table_fee_log
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_use_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_use_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
assistant_fee AS (
SELECT
site_id,
order_settle_id,
order_trade_no,
MIN(member_id) AS member_id,
SUM(COALESCE(final_fee, 0)) AS assistant_service_amount,
SUM(COALESCE(member_discount_amount, 0)) AS member_discount_amount,
SUM(COALESCE(manual_discount_amount, 0)) AS manual_discount_amount,
SUM(COALESCE(original_fee, 0)) AS original_fee,
MIN(start_time) AS first_time
FROM billiards_dwd.fact_assistant_service
WHERE (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_time::date <= %(end_date)s)
AND COALESCE(is_canceled, FALSE) = FALSE
GROUP BY site_id, order_settle_id, order_trade_no
SUM(COALESCE(ledger_amount, 0)) AS assistant_service_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_use_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_use_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
goods_fee AS (
SELECT
site_id,
order_settle_id,
order_trade_no,
MIN(member_id) AS member_id,
SUM(COALESCE(final_amount, 0)) FILTER (WHERE COALESCE(is_gift, FALSE) = FALSE) AS goods_amount,
SUM(COALESCE(discount_amount, 0)) FILTER (WHERE COALESCE(is_gift, FALSE) = FALSE) AS goods_discount_amount,
SUM(COALESCE(original_amount, 0)) FILTER (WHERE COALESCE(is_gift, FALSE) = FALSE) AS goods_original_amount,
COUNT(*) FILTER (WHERE COALESCE(is_gift, FALSE) = FALSE) AS item_count,
SUM(COALESCE(quantity, 0)) FILTER (WHERE COALESCE(is_gift, FALSE) = FALSE) AS total_item_quantity,
MIN(sale_time) AS first_time
FROM billiards_dwd.fact_sale_item
WHERE (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR sale_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR sale_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id, order_trade_no
COUNT(*) AS item_count,
SUM(COALESCE(ledger_count, 0)) AS total_item_quantity,
SUM(COALESCE(real_goods_money, 0)) AS goods_amount
FROM billiards_dwd.dwd_store_goods_sale
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR create_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR create_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
coupon_usage AS (
group_fee AS (
SELECT
site_id,
order_settle_id,
order_trade_no,
MIN(member_id) AS member_id,
SUM(COALESCE(deduct_amount, 0)) AS coupon_deduction,
SUM(COALESCE(settle_price, 0)) AS settle_price,
MIN(used_time) AS first_time
FROM billiards_dwd.fact_coupon_usage
WHERE (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR used_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR used_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id, order_trade_no
),
payments AS (
SELECT
fp.site_id,
fp.order_settle_id,
fp.order_trade_no,
MIN(fp.member_id) AS member_id,
SUM(COALESCE(fp.pay_amount, 0)) AS total_paid_amount,
SUM(COALESCE(fp.pay_amount, 0)) FILTER (WHERE COALESCE(pm.is_stored_value, FALSE)) AS stored_card_deduct,
SUM(COALESCE(fp.pay_amount, 0)) FILTER (WHERE NOT COALESCE(pm.is_stored_value, FALSE)) AS external_paid_amount,
MIN(fp.pay_time) AS first_time
FROM billiards_dwd.fact_payment fp
LEFT JOIN billiards_dwd.dim_pay_method pm ON fp.pay_method_code = pm.pay_method_code
WHERE (%(site_id)s IS NULL OR fp.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR fp.pay_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR fp.pay_time::date <= %(end_date)s)
GROUP BY fp.site_id, fp.order_settle_id, fp.order_trade_no
SUM(COALESCE(ledger_amount, 0)) AS group_amount
FROM billiards_dwd.dwd_groupbuy_redemption
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR create_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR create_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
refunds AS (
SELECT
site_id,
order_settle_id,
order_trade_no,
SUM(COALESCE(refund_amount, 0)) AS refund_amount
FROM billiards_dwd.fact_refund
WHERE (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR refund_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR refund_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id, order_trade_no
),
combined_ids AS (
SELECT site_id, order_settle_id, order_trade_no FROM table_fee
UNION
SELECT site_id, order_settle_id, order_trade_no FROM assistant_fee
UNION
SELECT site_id, order_settle_id, order_trade_no FROM goods_fee
UNION
SELECT site_id, order_settle_id, order_trade_no FROM coupon_usage
UNION
SELECT site_id, order_settle_id, order_trade_no FROM payments
UNION
SELECT site_id, order_settle_id, order_trade_no FROM refunds
),
site_dim AS (
SELECT site_id, tenant_id FROM billiards_dwd.dim_site
r.site_id,
r.relate_id AS order_settle_id,
SUM(COALESCE(rx.refund_amount, 0)) AS refund_amount
FROM billiards_dwd.dwd_refund r
LEFT JOIN billiards_dwd.dwd_refund_ex rx ON r.refund_id = rx.refund_id
WHERE (%(site_id)s IS NULL OR r.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR r.pay_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR r.pay_time::date <= %(end_date)s)
GROUP BY r.site_id, r.relate_id
)
INSERT INTO billiards_dws.dws_order_summary (
site_id,
@@ -166,58 +137,50 @@ INSERT INTO billiards_dws.dws_order_summary (
updated_at
)
SELECT
c.site_id,
c.order_settle_id,
c.order_trade_no,
COALESCE(tf.first_time, af.first_time, gf.first_time, pay.first_time, cu.first_time)::date AS order_date,
sd.tenant_id,
COALESCE(tf.member_id, af.member_id, gf.member_id, cu.member_id, pay.member_id) AS member_id,
COALESCE(tf.member_id, af.member_id, gf.member_id, cu.member_id, pay.member_id) IS NOT NULL AS member_flag,
-- recharge flag: no consumption side but has payments
(COALESCE(tf.table_fee_amount, 0) + COALESCE(af.assistant_service_amount, 0) + COALESCE(gf.goods_amount, 0) + COALESCE(cu.settle_price, 0) = 0)
AND COALESCE(pay.total_paid_amount, 0) > 0 AS recharge_order_flag,
b.site_id,
b.order_settle_id,
b.order_trade_no::text AS order_trade_no,
b.order_date,
b.tenant_id,
b.member_id,
b.member_flag,
b.recharge_order_flag,
COALESCE(gf.item_count, 0) AS item_count,
COALESCE(gf.total_item_quantity, 0) AS total_item_quantity,
COALESCE(tf.table_fee_amount, 0) AS table_fee_amount,
COALESCE(af.assistant_service_amount, 0) AS assistant_service_amount,
COALESCE(gf.goods_amount, 0) AS goods_amount,
COALESCE(cu.settle_price, 0) AS group_amount,
COALESCE(cu.coupon_deduction, 0) AS total_coupon_deduction,
COALESCE(tf.member_discount_amount, 0) + COALESCE(af.member_discount_amount, 0) + COALESCE(gf.goods_discount_amount, 0) AS member_discount_amount,
COALESCE(tf.manual_discount_amount, 0) + COALESCE(af.manual_discount_amount, 0) AS manual_discount_amount,
COALESCE(tf.original_table_fee, 0) + COALESCE(af.original_fee, 0) + COALESCE(gf.goods_original_amount, 0) AS order_original_amount,
COALESCE(tf.table_fee_amount, 0) + COALESCE(af.assistant_service_amount, 0) + COALESCE(gf.goods_amount, 0) + COALESCE(cu.settle_price, 0) - COALESCE(cu.coupon_deduction, 0) AS order_final_amount,
COALESCE(pay.stored_card_deduct, 0) AS stored_card_deduct,
COALESCE(pay.external_paid_amount, 0) AS external_paid_amount,
COALESCE(pay.total_paid_amount, 0) AS total_paid_amount,
COALESCE(tf.table_fee_amount, 0) AS book_table_flow,
COALESCE(af.assistant_service_amount, 0) AS book_assistant_flow,
COALESCE(gf.goods_amount, 0) AS book_goods_flow,
COALESCE(cu.settle_price, 0) AS book_group_flow,
COALESCE(tf.table_fee_amount, 0) + COALESCE(af.assistant_service_amount, 0) + COALESCE(gf.goods_amount, 0) + COALESCE(cu.settle_price, 0) AS book_order_flow,
CASE
WHEN (COALESCE(tf.table_fee_amount, 0) + COALESCE(af.assistant_service_amount, 0) + COALESCE(gf.goods_amount, 0) + COALESCE(cu.settle_price, 0) = 0)
THEN 0
ELSE COALESCE(pay.external_paid_amount, 0)
END AS order_effective_consume_cash,
CASE
WHEN (COALESCE(tf.table_fee_amount, 0) + COALESCE(af.assistant_service_amount, 0) + COALESCE(gf.goods_amount, 0) + COALESCE(cu.settle_price, 0) = 0)
THEN COALESCE(pay.external_paid_amount, 0)
ELSE 0
END AS order_effective_recharge_cash,
COALESCE(pay.external_paid_amount, 0) + COALESCE(cu.settle_price, 0) AS order_effective_flow,
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount) AS table_fee_amount,
COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount) AS assistant_service_amount,
COALESCE(gf.goods_amount, b.settle_goods_amount) AS goods_amount,
COALESCE(gr.group_amount, 0) AS group_amount,
b.total_coupon_deduction AS total_coupon_deduction,
b.member_discount_amount AS member_discount_amount,
b.manual_discount_amount AS manual_discount_amount,
-- approximate original amount: final + discounts/coupon
(b.total_paid_amount + b.total_coupon_deduction + b.member_discount_amount + b.manual_discount_amount) AS order_original_amount,
b.total_paid_amount AS order_final_amount,
b.stored_card_deduct AS stored_card_deduct,
GREATEST(b.total_paid_amount - b.stored_card_deduct, 0) AS external_paid_amount,
b.total_paid_amount AS total_paid_amount,
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount) AS book_table_flow,
COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount) AS book_assistant_flow,
COALESCE(gf.goods_amount, b.settle_goods_amount) AS book_goods_flow,
COALESCE(gr.group_amount, 0) AS book_group_flow,
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount)
+ COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount)
+ COALESCE(gf.goods_amount, b.settle_goods_amount)
+ COALESCE(gr.group_amount, 0) AS book_order_flow,
GREATEST(b.total_paid_amount - b.stored_card_deduct, 0) AS order_effective_consume_cash,
0 AS order_effective_recharge_cash,
b.total_paid_amount AS order_effective_flow,
COALESCE(rf.refund_amount, 0) AS refund_amount,
(COALESCE(pay.external_paid_amount, 0) + COALESCE(cu.settle_price, 0)) - COALESCE(rf.refund_amount, 0) AS net_income,
b.total_paid_amount - COALESCE(rf.refund_amount, 0) AS net_income,
now() AS created_at,
now() AS updated_at
FROM combined_ids c
LEFT JOIN table_fee tf ON c.site_id = tf.site_id AND c.order_settle_id = tf.order_settle_id
LEFT JOIN assistant_fee af ON c.site_id = af.site_id AND c.order_settle_id = af.order_settle_id
LEFT JOIN goods_fee gf ON c.site_id = gf.site_id AND c.order_settle_id = gf.order_settle_id
LEFT JOIN coupon_usage cu ON c.site_id = cu.site_id AND c.order_settle_id = cu.order_settle_id
LEFT JOIN payments pay ON c.site_id = pay.site_id AND c.order_settle_id = pay.order_settle_id
LEFT JOIN refunds rf ON c.site_id = rf.site_id AND c.order_settle_id = rf.order_settle_id
LEFT JOIN site_dim sd ON c.site_id = sd.site_id
FROM base b
LEFT JOIN table_fee tf ON b.site_id = tf.site_id AND b.order_settle_id = tf.order_settle_id
LEFT JOIN assistant_fee af ON b.site_id = af.site_id AND b.order_settle_id = af.order_settle_id
LEFT JOIN goods_fee gf ON b.site_id = gf.site_id AND b.order_settle_id = gf.order_settle_id
LEFT JOIN group_fee gr ON b.site_id = gr.site_id AND b.order_settle_id = gr.order_settle_id
LEFT JOIN refunds rf ON b.site_id = rf.site_id AND b.order_settle_id = rf.order_settle_id
ON CONFLICT (site_id, order_settle_id) DO UPDATE SET
order_trade_no = EXCLUDED.order_trade_no,
order_date = EXCLUDED.order_date,

View File

@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""Run data integrity checks across API -> ODS -> DWD."""
from __future__ import annotations
import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from config.settings import AppConfig
from quality.integrity_checker import (
IntegrityWindow,
compute_last_etl_end,
run_integrity_history,
run_integrity_window,
)
from utils.logging_utils import build_log_path, configure_logging
from utils.windowing import split_window
def _parse_dt(value: str, tz: ZoneInfo) -> datetime:
dt = dtparser.parse(value)
if dt.tzinfo is None:
return dt.replace(tzinfo=tz)
return dt.astimezone(tz)
def main() -> int:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
ap = argparse.ArgumentParser(description="Data integrity checks (API -> ODS -> DWD)")
ap.add_argument("--mode", choices=["history", "window"], default="history")
ap.add_argument("--start", default="2025-07-01", help="history start date (default: 2025-07-01)")
ap.add_argument("--end", default="", help="history end datetime (default: last ETL end)")
ap.add_argument("--window-start", default="", help="window start datetime (mode=window)")
ap.add_argument("--window-end", default="", help="window end datetime (mode=window)")
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
ap.add_argument("--include-dimensions", action="store_true", help="include dimension tables in ODS->DWD checks")
ap.add_argument("--ods-task-codes", default="", help="comma-separated ODS task codes for API checks")
ap.add_argument("--out", default="", help="output JSON path")
ap.add_argument("--log-file", default="", help="log file path")
ap.add_argument("--log-dir", default="", help="log directory")
ap.add_argument("--log-level", default="INFO", help="log level")
ap.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (Path(__file__).resolve().parent / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "data_integrity")
log_console = not args.no_log_console
with configure_logging(
"data_integrity",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
report_path = Path(args.out) if args.out else None
if args.mode == "window":
if not args.window_start or not args.window_end:
raise SystemExit("window-start and window-end are required for mode=window")
start_dt = _parse_dt(args.window_start, tz)
end_dt = _parse_dt(args.window_end, tz)
split_unit = (args.window_split_unit or cfg.get("run.window_split.unit", "month") or "month").strip()
comp_hours = args.window_compensation_hours
if comp_hours is None:
comp_hours = cfg.get("run.window_split.compensation_hours", 0)
windows = split_window(
start_dt,
end_dt,
tz=tz,
split_unit=split_unit,
compensation_hours=comp_hours,
)
if not windows:
windows = [(start_dt, end_dt)]
window_reports = []
total_missing = 0
total_errors = 0
for idx, (seg_start, seg_end) in enumerate(windows, start=1):
window = IntegrityWindow(
start=seg_start,
end=seg_end,
label=f"segment_{idx}",
granularity="window",
)
payload = run_integrity_window(
cfg=cfg,
window=window,
include_dimensions=args.include_dimensions,
task_codes=args.ods_task_codes,
logger=logger,
write_report=False,
report_path=None,
window_split_unit="none",
window_compensation_hours=0,
)
window_reports.append(payload)
total_missing += int(payload.get("api_to_ods", {}).get("total_missing") or 0)
total_errors += int(payload.get("api_to_ods", {}).get("total_errors") or 0)
overall_start = windows[0][0]
overall_end = windows[-1][1]
report = {
"mode": "window",
"window": {
"start": overall_start.isoformat(),
"end": overall_end.isoformat(),
"segments": len(windows),
},
"windows": window_reports,
"api_to_ods": {
"total_missing": total_missing,
"total_errors": total_errors,
},
"total_missing": total_missing,
"total_errors": total_errors,
"generated_at": datetime.now(tz).isoformat(),
}
if report_path is None:
root = Path(__file__).resolve().parents[1]
stamp = datetime.now(tz).strftime("%Y%m%d_%H%M%S")
report_path = root / "reports" / f"data_integrity_window_{stamp}.json"
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
report["report_path"] = str(report_path)
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
else:
start_dt = _parse_dt(args.start, tz)
if args.end:
end_dt = _parse_dt(args.end, tz)
else:
end_dt = compute_last_etl_end(cfg) or datetime.now(tz)
report = run_integrity_history(
cfg=cfg,
start_dt=start_dt,
end_dt=end_dt,
include_dimensions=args.include_dimensions,
task_codes=args.ods_task_codes,
logger=logger,
write_report=True,
report_path=report_path,
)
logger.info("REPORT_WRITTEN path=%s", report.get("report_path"))
logger.info(
"SUMMARY missing=%s errors=%s",
report.get("total_missing"),
report.get("total_errors"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,874 @@
# -*- coding: utf-8 -*-
"""
Check missing ODS records by comparing API primary keys vs ODS table primary keys.
Default range:
start = 2025-07-01 00:00:00
end = now
For update runs, use --from-cutoff to derive the start time from ODS max(fetched_at),
then backtrack by --cutoff-overlap-hours.
"""
from __future__ import annotations
import argparse
import json
import logging
import time as time_mod
import sys
from datetime import datetime, time, timedelta
from pathlib import Path
from typing import Iterable, Sequence
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from psycopg2.extras import execute_values
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.client import APIClient
from config.settings import AppConfig
from database.connection import DatabaseConnection
from models.parsers import TypeParser
from tasks.ods_tasks import ENABLED_ODS_CODES, ODS_TASK_SPECS
from utils.logging_utils import build_log_path, configure_logging
from utils.windowing import split_window
DEFAULT_START = "2025-07-01"
MIN_COMPLETENESS_WINDOW_DAYS = 30
def _reconfigure_stdout_utf8() -> None:
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(hour=23 if is_end else 0, minute=59 if is_end else 0, second=59 if is_end else 0, microsecond=0)
return dt
def _iter_windows(start: datetime, end: datetime, window_size: timedelta) -> Iterable[tuple[datetime, datetime]]:
if window_size.total_seconds() <= 0:
raise ValueError("window_size must be > 0")
cur = start
while cur < end:
nxt = min(cur + window_size, end)
yield cur, nxt
cur = nxt
def _merge_record_layers(record: dict) -> dict:
merged = record
data_part = merged.get("data")
while isinstance(data_part, dict):
merged = {**data_part, **merged}
data_part = data_part.get("data")
settle_inner = merged.get("settleList")
if isinstance(settle_inner, dict):
merged = {**settle_inner, **merged}
return merged
def _get_value_case_insensitive(record: dict | None, col: str | None):
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
def _normalize_pk_value(value):
if value is None:
return None
if isinstance(value, str) and value.isdigit():
try:
return int(value)
except Exception:
return value
return value
def _chunked(seq: Sequence, size: int) -> Iterable[Sequence]:
if size <= 0:
size = 500
for i in range(0, len(seq), size):
yield seq[i : i + size]
def _get_table_pk_columns(conn, table: str) -> list[str]:
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with conn.cursor() as cur:
cur.execute(sql, (schema, name))
return [r[0] for r in cur.fetchall()]
def _fetch_existing_pk_set(conn, table: str, pk_cols: Sequence[str], pk_values: list[tuple], chunk_size: int) -> set[tuple]:
if not pk_values:
return set()
select_cols = ", ".join(f't."{c}"' for c in pk_cols)
value_cols = ", ".join(f'"{c}"' for c in pk_cols)
join_cond = " AND ".join(f't."{c}" = v."{c}"' for c in pk_cols)
sql = (
f"SELECT {select_cols} FROM {table} t "
f"JOIN (VALUES %s) AS v({value_cols}) ON {join_cond}"
)
existing: set[tuple] = set()
with conn.cursor() as cur:
for chunk in _chunked(pk_values, chunk_size):
execute_values(cur, sql, chunk, page_size=len(chunk))
for row in cur.fetchall():
existing.add(tuple(row))
return existing
def _merge_common_params(cfg: AppConfig, task_code: str, base: dict) -> dict:
merged: dict = {}
common = cfg.get("api.params", {}) or {}
if isinstance(common, dict):
merged.update(common)
scoped = cfg.get(f"api.params.{task_code.lower()}", {}) or {}
if isinstance(scoped, dict):
merged.update(scoped)
merged.update(base)
return merged
def _build_params(cfg: AppConfig, spec, store_id: int, window_start: datetime | None, window_end: datetime | None) -> dict:
base: dict = {}
if spec.include_site_id:
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [store_id]
else:
base["siteId"] = store_id
if spec.requires_window and spec.time_fields and window_start and window_end:
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(window_start, ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")))
base[end_key] = TypeParser.format_timestamp(window_end, ZoneInfo(cfg.get("app.timezone", "Asia/Taipei")))
base.update(spec.extra_params or {})
return _merge_common_params(cfg, spec.code, base)
def _pk_tuple_from_record(record: dict, pk_cols: Sequence[str]) -> tuple | None:
merged = _merge_record_layers(record)
values = []
for col in pk_cols:
val = _normalize_pk_value(_get_value_case_insensitive(merged, col))
if val is None or val == "":
return None
values.append(val)
return tuple(values)
def _pk_tuple_from_ticket_candidate(value) -> tuple | None:
val = _normalize_pk_value(value)
if val is None or val == "":
return None
return (val,)
def _format_missing_sample(pk_cols: Sequence[str], pk_tuple: tuple) -> dict:
return {col: pk_tuple[idx] for idx, col in enumerate(pk_cols)}
def _check_spec(
*,
client: APIClient,
db_conn,
cfg: AppConfig,
tz: ZoneInfo,
logger: logging.Logger,
spec,
store_id: int,
start: datetime | None,
end: datetime | None,
windows: list[tuple[datetime, datetime]] | None,
page_size: int,
chunk_size: int,
sample_limit: int,
sleep_per_window: float,
sleep_per_page: float,
) -> dict:
result = {
"task_code": spec.code,
"table": spec.table_name,
"endpoint": spec.endpoint,
"pk_columns": [],
"records": 0,
"records_with_pk": 0,
"missing": 0,
"missing_samples": [],
"pages": 0,
"skipped_missing_pk": 0,
"errors": 0,
"error_detail": None,
}
pk_cols = _get_table_pk_columns(db_conn, spec.table_name)
result["pk_columns"] = pk_cols
if not pk_cols:
result["errors"] = 1
result["error_detail"] = "no primary key columns found"
return result
if spec.requires_window and spec.time_fields:
if not start or not end:
result["errors"] = 1
result["error_detail"] = "missing start/end for windowed endpoint"
return result
windows = list(windows or [(start, end)])
else:
windows = [(None, None)]
logger.info(
"CHECK_START task=%s table=%s windows=%s start=%s end=%s",
spec.code,
spec.table_name,
len(windows),
start.isoformat() if start else None,
end.isoformat() if end else None,
)
missing_seen: set[tuple] = set()
for window_idx, (window_start, window_end) in enumerate(windows, start=1):
window_label = (
f"{window_start.isoformat()}~{window_end.isoformat()}"
if window_start and window_end
else "FULL"
)
logger.info(
"WINDOW_START task=%s idx=%s window=%s",
spec.code,
window_idx,
window_label,
)
window_pages = 0
window_records = 0
window_missing = 0
window_skipped = 0
params = _build_params(cfg, spec, store_id, window_start, window_end)
try:
for page_no, records, _, _ in client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
window_pages += 1
window_records += len(records)
result["pages"] += 1
result["records"] += len(records)
pk_tuples: list[tuple] = []
for rec in records:
if not isinstance(rec, dict):
result["skipped_missing_pk"] += 1
window_skipped += 1
continue
pk_tuple = _pk_tuple_from_record(rec, pk_cols)
if not pk_tuple:
result["skipped_missing_pk"] += 1
window_skipped += 1
continue
pk_tuples.append(pk_tuple)
if not pk_tuples:
continue
result["records_with_pk"] += len(pk_tuples)
pk_unique = list(dict.fromkeys(pk_tuples))
existing = _fetch_existing_pk_set(db_conn, spec.table_name, pk_cols, pk_unique, chunk_size)
for pk_tuple in pk_unique:
if pk_tuple in existing:
continue
if pk_tuple in missing_seen:
continue
missing_seen.add(pk_tuple)
result["missing"] += 1
window_missing += 1
if len(result["missing_samples"]) < sample_limit:
result["missing_samples"].append(_format_missing_sample(pk_cols, pk_tuple))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"PAGE task=%s idx=%s page=%s records=%s missing=%s skipped=%s",
spec.code,
window_idx,
page_no,
len(records),
window_missing,
window_skipped,
)
if sleep_per_page > 0:
time_mod.sleep(sleep_per_page)
except Exception as exc:
result["errors"] += 1
result["error_detail"] = f"{type(exc).__name__}: {exc}"
logger.exception(
"WINDOW_ERROR task=%s idx=%s window=%s error=%s",
spec.code,
window_idx,
window_label,
result["error_detail"],
)
break
logger.info(
"WINDOW_DONE task=%s idx=%s window=%s pages=%s records=%s missing=%s skipped=%s",
spec.code,
window_idx,
window_label,
window_pages,
window_records,
window_missing,
window_skipped,
)
if sleep_per_window > 0:
logger.debug(
"SLEEP_WINDOW task=%s idx=%s seconds=%.2f",
spec.code,
window_idx,
sleep_per_window,
)
time_mod.sleep(sleep_per_window)
return result
def _check_settlement_tickets(
*,
client: APIClient,
db_conn,
cfg: AppConfig,
tz: ZoneInfo,
logger: logging.Logger,
store_id: int,
start: datetime | None,
end: datetime | None,
windows: list[tuple[datetime, datetime]] | None,
page_size: int,
chunk_size: int,
sample_limit: int,
sleep_per_window: float,
sleep_per_page: float,
) -> dict:
table_name = "billiards_ods.settlement_ticket_details"
pk_cols = _get_table_pk_columns(db_conn, table_name)
result = {
"task_code": "ODS_SETTLEMENT_TICKET",
"table": table_name,
"endpoint": "/Order/GetOrderSettleTicketNew",
"pk_columns": pk_cols,
"records": 0,
"records_with_pk": 0,
"missing": 0,
"missing_samples": [],
"pages": 0,
"skipped_missing_pk": 0,
"errors": 0,
"error_detail": None,
"source_endpoint": "/PayLog/GetPayLogListPage",
}
if not pk_cols:
result["errors"] = 1
result["error_detail"] = "no primary key columns found"
return result
if not start or not end:
result["errors"] = 1
result["error_detail"] = "missing start/end for ticket check"
return result
missing_seen: set[tuple] = set()
pay_endpoint = "/PayLog/GetPayLogListPage"
windows = list(windows or [(start, end)])
logger.info(
"CHECK_START task=%s table=%s windows=%s start=%s end=%s",
result["task_code"],
table_name,
len(windows),
start.isoformat() if start else None,
end.isoformat() if end else None,
)
for window_idx, (window_start, window_end) in enumerate(windows, start=1):
window_label = f"{window_start.isoformat()}~{window_end.isoformat()}"
logger.info(
"WINDOW_START task=%s idx=%s window=%s",
result["task_code"],
window_idx,
window_label,
)
window_pages = 0
window_records = 0
window_missing = 0
window_skipped = 0
base = {
"siteId": store_id,
"StartPayTime": TypeParser.format_timestamp(window_start, tz),
"EndPayTime": TypeParser.format_timestamp(window_end, tz),
}
params = _merge_common_params(cfg, "ODS_PAYMENT", base)
try:
for page_no, records, _, _ in client.iter_paginated(
endpoint=pay_endpoint,
params=params,
page_size=page_size,
data_path=("data",),
list_key=None,
):
window_pages += 1
window_records += len(records)
result["pages"] += 1
result["records"] += len(records)
pk_tuples: list[tuple] = []
for rec in records:
if not isinstance(rec, dict):
result["skipped_missing_pk"] += 1
window_skipped += 1
continue
relate_id = TypeParser.parse_int(
(rec or {}).get("relateId")
or (rec or {}).get("orderSettleId")
or (rec or {}).get("order_settle_id")
)
pk_tuple = _pk_tuple_from_ticket_candidate(relate_id)
if not pk_tuple:
result["skipped_missing_pk"] += 1
window_skipped += 1
continue
pk_tuples.append(pk_tuple)
if not pk_tuples:
continue
result["records_with_pk"] += len(pk_tuples)
pk_unique = list(dict.fromkeys(pk_tuples))
existing = _fetch_existing_pk_set(db_conn, table_name, pk_cols, pk_unique, chunk_size)
for pk_tuple in pk_unique:
if pk_tuple in existing:
continue
if pk_tuple in missing_seen:
continue
missing_seen.add(pk_tuple)
result["missing"] += 1
window_missing += 1
if len(result["missing_samples"]) < sample_limit:
result["missing_samples"].append(_format_missing_sample(pk_cols, pk_tuple))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"PAGE task=%s idx=%s page=%s records=%s missing=%s skipped=%s",
result["task_code"],
window_idx,
page_no,
len(records),
window_missing,
window_skipped,
)
if sleep_per_page > 0:
time_mod.sleep(sleep_per_page)
except Exception as exc:
result["errors"] += 1
result["error_detail"] = f"{type(exc).__name__}: {exc}"
logger.exception(
"WINDOW_ERROR task=%s idx=%s window=%s error=%s",
result["task_code"],
window_idx,
window_label,
result["error_detail"],
)
break
logger.info(
"WINDOW_DONE task=%s idx=%s window=%s pages=%s records=%s missing=%s skipped=%s",
result["task_code"],
window_idx,
window_label,
window_pages,
window_records,
window_missing,
window_skipped,
)
if sleep_per_window > 0:
logger.debug(
"SLEEP_WINDOW task=%s idx=%s seconds=%.2f",
result["task_code"],
window_idx,
sleep_per_window,
)
time_mod.sleep(sleep_per_window)
return result
def _compute_ods_cutoff(conn, ods_tables: Sequence[str]) -> datetime | None:
values: list[datetime] = []
with conn.cursor() as cur:
for table in ods_tables:
try:
cur.execute(f"SELECT MAX(fetched_at) FROM {table}")
row = cur.fetchone()
if row and row[0]:
values.append(row[0])
except Exception:
continue
if not values:
return None
return min(values)
def _resolve_window_from_cutoff(
*,
conn,
ods_tables: Sequence[str],
tz: ZoneInfo,
overlap_hours: int,
) -> tuple[datetime, datetime, datetime | None]:
cutoff = _compute_ods_cutoff(conn, ods_tables)
now = datetime.now(tz)
if cutoff is None:
start = now - timedelta(hours=max(1, overlap_hours))
return start, now, None
if cutoff.tzinfo is None:
cutoff = cutoff.replace(tzinfo=tz)
else:
cutoff = cutoff.astimezone(tz)
start = cutoff - timedelta(hours=max(0, overlap_hours))
return start, now, cutoff
def run_gap_check(
*,
cfg: AppConfig | None,
start: datetime | str | None,
end: datetime | str | None,
window_days: int,
window_hours: int,
page_size: int,
chunk_size: int,
sample_limit: int,
sleep_per_window: float,
sleep_per_page: float,
task_codes: str,
from_cutoff: bool,
cutoff_overlap_hours: int,
allow_small_window: bool,
logger: logging.Logger,
window_split_unit: str | None = None,
window_compensation_hours: int | None = None,
) -> dict:
cfg = cfg or AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
store_id = int(cfg.get("app.store_id") or 0)
if not cfg.get("api.token"):
raise ValueError("missing api.token; please set API_TOKEN in .env")
window_days = int(window_days)
window_hours = int(window_hours)
split_unit = (window_split_unit or cfg.get("run.window_split.unit", "month") or "month").strip()
comp_hours = window_compensation_hours
if comp_hours is None:
comp_hours = cfg.get("run.window_split.compensation_hours", 0)
use_split = split_unit.lower() not in ("", "none", "off", "false", "0")
if not use_split and not from_cutoff and not allow_small_window:
min_hours = MIN_COMPLETENESS_WINDOW_DAYS * 24
if window_hours > 0:
if window_hours < min_hours:
logger.warning(
"window_hours=%s too small for completeness check; adjust to %s",
window_hours,
min_hours,
)
window_hours = min_hours
elif window_days < MIN_COMPLETENESS_WINDOW_DAYS:
logger.warning(
"window_days=%s too small for completeness check; adjust to %s",
window_days,
MIN_COMPLETENESS_WINDOW_DAYS,
)
window_days = MIN_COMPLETENESS_WINDOW_DAYS
cutoff = None
if from_cutoff:
db_tmp = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
ods_tables = [s.table_name for s in ODS_TASK_SPECS if s.code in ENABLED_ODS_CODES]
start, end, cutoff = _resolve_window_from_cutoff(
conn=db_tmp.conn,
ods_tables=ods_tables,
tz=tz,
overlap_hours=cutoff_overlap_hours,
)
db_tmp.close()
else:
if not start:
start = DEFAULT_START
if not end:
end = datetime.now(tz)
if isinstance(start, str):
start = _parse_dt(start, tz, is_end=False)
if isinstance(end, str):
end = _parse_dt(end, tz, is_end=True)
windows = None
if use_split:
windows = split_window(
start,
end,
tz=tz,
split_unit=split_unit,
compensation_hours=comp_hours,
)
else:
adjusted = split_window(
start,
end,
tz=tz,
split_unit="none",
compensation_hours=comp_hours,
)
if adjusted:
start, end = adjusted[0]
window_size = timedelta(hours=window_hours) if window_hours > 0 else timedelta(days=window_days)
windows = list(_iter_windows(start, end, window_size))
if windows:
start, end = windows[0][0], windows[-1][1]
logger.info(
"START range=%s~%s window_days=%s window_hours=%s split_unit=%s comp_hours=%s page_size=%s chunk_size=%s",
start.isoformat() if isinstance(start, datetime) else None,
end.isoformat() if isinstance(end, datetime) else None,
window_days,
window_hours,
split_unit,
comp_hours,
page_size,
chunk_size,
)
if cutoff:
logger.info("CUTOFF=%s overlap_hours=%s", cutoff.isoformat(), cutoff_overlap_hours)
client = APIClient(
base_url=cfg["api"]["base_url"],
token=cfg["api"]["token"],
timeout=int(cfg["api"].get("timeout_sec") or 20),
retry_max=int(cfg["api"].get("retries", {}).get("max_attempts") or 3),
headers_extra=cfg["api"].get("headers_extra") or {},
)
db_conn = DatabaseConnection(dsn=cfg["db"]["dsn"], session=cfg["db"].get("session"))
try:
db_conn.conn.rollback()
except Exception:
pass
db_conn.conn.autocommit = True
try:
task_filter = {t.strip().upper() for t in (task_codes or "").split(",") if t.strip()}
specs = [s for s in ODS_TASK_SPECS if s.code in ENABLED_ODS_CODES]
if task_filter:
specs = [s for s in specs if s.code in task_filter]
results: list[dict] = []
for spec in specs:
if spec.code == "ODS_SETTLEMENT_TICKET":
continue
result = _check_spec(
client=client,
db_conn=db_conn.conn,
cfg=cfg,
tz=tz,
logger=logger,
spec=spec,
store_id=store_id,
start=start,
end=end,
windows=windows,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=sample_limit,
sleep_per_window=sleep_per_window,
sleep_per_page=sleep_per_page,
)
results.append(result)
logger.info(
"CHECK_DONE task=%s missing=%s records=%s errors=%s",
result.get("task_code"),
result.get("missing"),
result.get("records"),
result.get("errors"),
)
if (not task_filter) or ("ODS_SETTLEMENT_TICKET" in task_filter):
ticket_result = _check_settlement_tickets(
client=client,
db_conn=db_conn.conn,
cfg=cfg,
tz=tz,
logger=logger,
store_id=store_id,
start=start,
end=end,
windows=windows,
page_size=page_size,
chunk_size=chunk_size,
sample_limit=sample_limit,
sleep_per_window=sleep_per_window,
sleep_per_page=sleep_per_page,
)
results.append(ticket_result)
logger.info(
"CHECK_DONE task=%s missing=%s records=%s errors=%s",
ticket_result.get("task_code"),
ticket_result.get("missing"),
ticket_result.get("records"),
ticket_result.get("errors"),
)
total_missing = sum(int(r.get("missing") or 0) for r in results)
total_errors = sum(int(r.get("errors") or 0) for r in results)
payload = {
"window_split_unit": split_unit,
"window_compensation_hours": comp_hours,
"start": start.isoformat() if isinstance(start, datetime) else None,
"end": end.isoformat() if isinstance(end, datetime) else None,
"cutoff": cutoff.isoformat() if cutoff else None,
"window_days": window_days,
"window_hours": window_hours,
"page_size": page_size,
"chunk_size": chunk_size,
"sample_limit": sample_limit,
"store_id": store_id,
"base_url": cfg.get("api.base_url"),
"results": results,
"total_missing": total_missing,
"total_errors": total_errors,
"generated_at": datetime.now(tz).isoformat(),
}
return payload
finally:
db_conn.close()
def main() -> int:
_reconfigure_stdout_utf8()
ap = argparse.ArgumentParser(description="Check missing ODS records by comparing API vs ODS PKs.")
ap.add_argument("--start", default=DEFAULT_START, help="start datetime (default: 2025-07-01)")
ap.add_argument("--end", default="", help="end datetime (default: now)")
ap.add_argument("--window-days", type=int, default=1, help="days per API window (default: 1)")
ap.add_argument("--window-hours", type=int, default=0, help="hours per API window (default: 0)")
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
ap.add_argument("--page-size", type=int, default=200, help="API page size (default: 200)")
ap.add_argument("--chunk-size", type=int, default=500, help="DB query chunk size (default: 500)")
ap.add_argument("--sample-limit", type=int, default=50, help="max missing PK samples per table")
ap.add_argument("--sleep-per-window-seconds", type=float, default=0, help="sleep seconds after each window")
ap.add_argument("--sleep-per-page-seconds", type=float, default=0, help="sleep seconds after each page")
ap.add_argument("--task-codes", default="", help="comma-separated task codes to check (optional)")
ap.add_argument("--out", default="", help="output JSON path (optional)")
ap.add_argument("--tag", default="", help="tag suffix for output filename")
ap.add_argument("--from-cutoff", action="store_true", help="derive start from ODS cutoff")
ap.add_argument(
"--cutoff-overlap-hours",
type=int,
default=24,
help="overlap hours when using --from-cutoff (default: 24)",
)
ap.add_argument(
"--allow-small-window",
action="store_true",
help="allow windows smaller than default completeness guard",
)
ap.add_argument("--log-file", default="", help="log file path (default: logs/check_ods_gaps_YYYYMMDD_HHMMSS.log)")
ap.add_argument("--log-dir", default="", help="log directory (default: logs)")
ap.add_argument("--log-level", default="INFO", help="log level (default: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "check_ods_gaps", args.tag)
log_console = not args.no_log_console
with configure_logging(
"ods_gap_check",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
payload = run_gap_check(
cfg=cfg,
start=args.start,
end=args.end,
window_days=args.window_days,
window_hours=args.window_hours,
page_size=args.page_size,
chunk_size=args.chunk_size,
sample_limit=args.sample_limit,
sleep_per_window=args.sleep_per_window_seconds,
sleep_per_page=args.sleep_per_page_seconds,
task_codes=args.task_codes,
from_cutoff=args.from_cutoff,
cutoff_overlap_hours=args.cutoff_overlap_hours,
allow_small_window=args.allow_small_window,
logger=logger,
window_split_unit=args.window_split_unit or None,
window_compensation_hours=args.window_compensation_hours,
)
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
if args.out:
out_path = Path(args.out)
else:
tag = f"_{args.tag}" if args.tag else ""
stamp = datetime.now(tz).strftime("%Y%m%d_%H%M%S")
out_path = PROJECT_ROOT / "reports" / f"ods_gap_check{tag}_{stamp}.json"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
logger.info("REPORT_WRITTEN path=%s", out_path)
logger.info(
"SUMMARY missing=%s errors=%s",
payload.get("total_missing"),
payload.get("total_errors"),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
ODS JSON 字段核对脚本:对照当前数据库中的 ODS 表字段,检查示例 JSON默认目录 export/test-json-doc
是否包含同名键,并输出每表未命中的字段,便于补充映射或确认确实无源字段。
使用方法:
set PG_DSN=postgresql://... # 如 .env 中配置
python -m etl_billiards.scripts.check_ods_json_vs_table
"""
from __future__ import annotations
import json
import os
import pathlib
from typing import Dict, Iterable, Set, Tuple
import psycopg2
from etl_billiards.tasks.manual_ingest_task import ManualIngestTask
def _flatten_keys(obj, prefix: str = "") -> Set[str]:
"""递归展开 JSON 所有键路径,返回形如 data.assistantInfos.id 的集合。列表不保留索引,仅继续向下展开。"""
keys: Set[str] = set()
if isinstance(obj, dict):
for k, v in obj.items():
new_prefix = f"{prefix}.{k}" if prefix else k
keys.add(new_prefix)
keys |= _flatten_keys(v, new_prefix)
elif isinstance(obj, list):
for item in obj:
keys |= _flatten_keys(item, prefix)
return keys
def _load_json_keys(path: pathlib.Path) -> Tuple[Set[str], dict[str, Set[str]]]:
"""读取单个 JSON 文件并返回展开后的键集合以及末段->路径列表映射,若文件不存在或无法解析则返回空集合。"""
if not path.exists():
return set(), {}
data = json.loads(path.read_text(encoding="utf-8"))
paths = _flatten_keys(data)
last_map: dict[str, Set[str]] = {}
for p in paths:
last = p.split(".")[-1].lower()
last_map.setdefault(last, set()).add(p)
return paths, last_map
def _load_ods_columns(dsn: str) -> Dict[str, Set[str]]:
"""从数据库读取 billiards_ods.* 的列名集合,按表返回。"""
conn = psycopg2.connect(dsn)
cur = conn.cursor()
cur.execute(
"""
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema='billiards_ods'
ORDER BY table_name, ordinal_position
"""
)
result: Dict[str, Set[str]] = {}
for table, col in cur.fetchall():
result.setdefault(table, set()).add(col.lower())
cur.close()
conn.close()
return result
def main() -> None:
"""主流程:遍历 FILE_MAPPING 中的 ODS 表,检查 JSON 键覆盖情况并打印报告。"""
dsn = os.environ.get("PG_DSN")
json_dir = pathlib.Path(os.environ.get("JSON_DOC_DIR", "export/test-json-doc"))
ods_cols_map = _load_ods_columns(dsn)
print(f"使用 JSON 目录: {json_dir}")
print(f"连接 DSN: {dsn}")
print("=" * 80)
for keywords, ods_table in ManualIngestTask.FILE_MAPPING:
table = ods_table.split(".")[-1]
cols = ods_cols_map.get(table, set())
file_name = f"{keywords[0]}.json"
file_path = json_dir / file_name
keys_full, path_map = _load_json_keys(file_path)
key_last_parts = set(path_map.keys())
missing: Set[str] = set()
extra_keys: Set[str] = set()
present: Set[str] = set()
for col in sorted(cols):
if col in key_last_parts:
present.add(col)
else:
missing.add(col)
for k in key_last_parts:
if k not in cols:
extra_keys.add(k)
print(f"[{table}] 文件={file_name} 列数={len(cols)} JSON键(末段)覆盖={len(present)}/{len(cols)}")
if missing:
print(" 未命中列:", ", ".join(sorted(missing)))
else:
print(" 未命中列: 无")
if extra_keys:
extras = []
for k in sorted(extra_keys):
paths = ", ".join(sorted(path_map.get(k, [])))
extras.append(f"{k} ({paths})")
print(" JSON 仅有(表无此列):", "; ".join(extras))
else:
print(" JSON 仅有(表无此列): 无")
print("-" * 80)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""
一键重建 ETL 相关 Schema并执行 ODS → DWD。
本脚本面向“离线示例 JSON 回放”的开发/运维场景,使用当前项目内的任务实现:
1) 可选DROP 并重建 schema`etl_admin` / `billiards_ods` / `billiards_dwd`
2) 执行 `INIT_ODS_SCHEMA`:创建 `etl_admin` 元数据表 + 执行 `schema_ODS_doc.sql`(内部会做轻量清洗)
3) 执行 `INIT_DWD_SCHEMA`:执行 `schema_dwd_doc.sql`
4) 执行 `MANUAL_INGEST`:从本地 JSON 目录灌入 ODS
5) 执行 `DWD_LOAD_FROM_ODS`:从 ODS 装载到 DWD
用法(推荐):
python -m etl_billiards.scripts.rebuild_db_and_run_ods_to_dwd ^
--dsn "postgresql://user:pwd@host:5432/db" ^
--store-id 1 ^
--json-dir "export/test-json-doc" ^
--drop-schemas
环境变量(可选):
PG_DSN、STORE_ID、INGEST_SOURCE_DIR
日志:
默认同时输出到控制台与文件;文件路径为 `io.log_root/rebuild_db_<时间戳>.log`。
"""
from __future__ import annotations
import argparse
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import psycopg2
from etl_billiards.config.settings import AppConfig
from etl_billiards.database.connection import DatabaseConnection
from etl_billiards.database.operations import DatabaseOperations
from etl_billiards.tasks.dwd_load_task import DwdLoadTask
from etl_billiards.tasks.init_dwd_schema_task import InitDwdSchemaTask
from etl_billiards.tasks.init_schema_task import InitOdsSchemaTask
from etl_billiards.tasks.manual_ingest_task import ManualIngestTask
DEFAULT_JSON_DIR = "export/test-json-doc"
@dataclass(frozen=True)
class RunArgs:
"""脚本参数对象(用于减少散落的参数传递)。"""
dsn: str
store_id: int
json_dir: str
drop_schemas: bool
terminate_own_sessions: bool
demo: bool
only_files: list[str]
only_dwd_tables: list[str]
stop_after: str | None
def _attach_file_logger(log_root: str | Path, filename: str, logger: logging.Logger) -> logging.Handler | None:
"""
给 root logger 附加文件日志处理器UTF-8
说明:
- 使用 root logger 是为了覆盖项目中不同命名的 logger包含第三方/子模块)。
- 若创建失败仅记录 warning不中断主流程。
返回值:
创建成功返回 handler调用方负责 removeHandler/close失败返回 None。
"""
log_dir = Path(log_root)
try:
log_dir.mkdir(parents=True, exist_ok=True)
except Exception as exc: # noqa: BLE001
logger.warning("创建日志目录失败:%s%s", log_dir, exc)
return None
log_path = log_dir / filename
try:
handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8")
except Exception as exc: # noqa: BLE001
logger.warning("创建文件日志失败:%s%s", log_path, exc)
return None
handler.setLevel(logging.INFO)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logging.getLogger().addHandler(handler)
logger.info("文件日志已启用:%s", log_path)
return handler
def _parse_args() -> RunArgs:
"""解析命令行/环境变量参数。"""
parser = argparse.ArgumentParser(description="重建 Schema 并执行 ODS→DWD离线 JSON 回放)")
parser.add_argument("--dsn", default=os.environ.get("PG_DSN"), help="PostgreSQL DSN默认读取 PG_DSN")
parser.add_argument(
"--store-id",
type=int,
default=int(os.environ.get("STORE_ID") or 1),
help="门店/租户 store_id默认读取 STORE_ID否则为 1",
)
parser.add_argument(
"--json-dir",
default=os.environ.get("INGEST_SOURCE_DIR") or DEFAULT_JSON_DIR,
help=f"示例 JSON 目录(默认 {DEFAULT_JSON_DIR},也可读 INGEST_SOURCE_DIR",
)
parser.add_argument(
"--drop-schemas",
action=argparse.BooleanOptionalAction,
default=True,
help="是否先 DROP 并重建 etl_admin/billiards_ods/billiards_dwd默认",
)
parser.add_argument(
"--terminate-own-sessions",
action=argparse.BooleanOptionalAction,
default=True,
help="执行 DROP 前是否终止当前用户的 idle-in-transaction 会话(默认:是)",
)
parser.add_argument(
"--demo",
action=argparse.BooleanOptionalAction,
default=False,
help="运行最小 Demo仅导入 member_profiles 并生成 dim_member/dim_member_ex",
)
parser.add_argument(
"--only-files",
default="",
help="仅处理指定 JSON 文件(逗号分隔,不含 .json例如member_profiles,settlement_records",
)
parser.add_argument(
"--only-dwd-tables",
default="",
help="仅处理指定 DWD 表逗号分隔支持完整名或表名例如billiards_dwd.dim_member,dim_member_ex",
)
parser.add_argument(
"--stop-after",
default="",
help="在指定阶段后停止可选DROP_SCHEMAS/INIT_ODS_SCHEMA/INIT_DWD_SCHEMA/MANUAL_INGEST/DWD_LOAD_FROM_ODS/BASIC_VALIDATE",
)
args = parser.parse_args()
if not args.dsn:
raise SystemExit("缺少 DSN请传入 --dsn 或设置环境变量 PG_DSN")
only_files = [x.strip().lower() for x in str(args.only_files or "").split(",") if x.strip()]
only_dwd_tables = [x.strip().lower() for x in str(args.only_dwd_tables or "").split(",") if x.strip()]
stop_after = str(args.stop_after or "").strip().upper() or None
return RunArgs(
dsn=args.dsn,
store_id=args.store_id,
json_dir=str(args.json_dir),
drop_schemas=bool(args.drop_schemas),
terminate_own_sessions=bool(args.terminate_own_sessions),
demo=bool(args.demo),
only_files=only_files,
only_dwd_tables=only_dwd_tables,
stop_after=stop_after,
)
def _build_config(args: RunArgs) -> AppConfig:
"""构建本次执行所需的最小配置覆盖。"""
manual_cfg: dict[str, Any] = {}
dwd_cfg: dict[str, Any] = {}
if args.demo:
manual_cfg["include_files"] = ["member_profiles"]
dwd_cfg["only_tables"] = ["billiards_dwd.dim_member", "billiards_dwd.dim_member_ex"]
if args.only_files:
manual_cfg["include_files"] = args.only_files
if args.only_dwd_tables:
dwd_cfg["only_tables"] = args.only_dwd_tables
overrides: dict[str, Any] = {
"app": {"store_id": args.store_id},
"pipeline": {"flow": "INGEST_ONLY", "ingest_source_dir": args.json_dir},
"manual": manual_cfg,
"dwd": dwd_cfg,
# 离线回放/建仓可能耗时较长,关闭 statement_timeout避免被默认 30s 中断。
# 同时关闭 lock_timeout避免 DROP/DDL 因锁等待稍久就直接失败。
"db": {"dsn": args.dsn, "session": {"statement_timeout_ms": 0, "lock_timeout_ms": 0}},
}
return AppConfig.load(overrides)
def _drop_schemas(db: DatabaseOperations, logger: logging.Logger) -> None:
"""删除并重建 ETL 相关 schema具备破坏性请谨慎"""
with db.conn.cursor() as cur:
# 避免因为其他会话持锁而无限等待;若确实被占用,提示用户先释放/终止阻塞会话。
cur.execute("SET lock_timeout TO '5s'")
for schema in ("billiards_dwd", "billiards_ods", "etl_admin"):
logger.info("DROP SCHEMA IF EXISTS %s CASCADE ...", schema)
cur.execute(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE;')
def _terminate_own_idle_in_tx(db: DatabaseOperations, logger: logging.Logger) -> int:
"""终止当前用户在本库中处于 idle-in-transaction 的会话,避免阻塞 DROP/DDL。"""
with db.conn.cursor() as cur:
cur.execute(
"""
SELECT pid
FROM pg_stat_activity
WHERE datname = current_database()
AND usename = current_user
AND pid <> pg_backend_pid()
AND state = 'idle in transaction'
"""
)
pids = [r[0] for r in cur.fetchall()]
killed = 0
for pid in pids:
cur.execute("SELECT pg_terminate_backend(%s)", (pid,))
ok = bool(cur.fetchone()[0])
logger.info("终止会话 pid=%s ok=%s", pid, ok)
killed += 1 if ok else 0
return killed
def _run_task(task, logger: logging.Logger) -> dict:
"""统一运行任务并打印关键结果。"""
result = task.execute(None)
logger.info("%s: status=%s counts=%s", task.get_task_code(), result.get("status"), result.get("counts"))
return result
def _basic_validate(db: DatabaseOperations, logger: logging.Logger) -> None:
"""做最基础的可用性校验schema 存在、关键表行数可查询。"""
checks = [
("billiards_ods", "member_profiles"),
("billiards_ods", "settlement_records"),
("billiards_dwd", "dim_member"),
("billiards_dwd", "dwd_settlement_head"),
]
for schema, table in checks:
try:
rows = db.query(f'SELECT COUNT(1) AS cnt FROM "{schema}"."{table}"')
logger.info("校验行数:%s.%s = %s", schema, table, (rows[0] or {}).get("cnt") if rows else None)
except Exception as exc: # noqa: BLE001
logger.warning("校验失败:%s.%s%s", schema, table, exc)
def _connect_db_with_retry(cfg: AppConfig, logger: logging.Logger) -> DatabaseConnection:
"""创建数据库连接(带重试),避免短暂网络抖动导致脚本直接失败。"""
dsn = cfg["db"]["dsn"]
session = cfg["db"].get("session")
connect_timeout = cfg["db"].get("connect_timeout_sec")
backoffs = [1, 2, 4, 8, 16]
last_exc: Exception | None = None
for attempt, wait_sec in enumerate([0] + backoffs, start=1):
if wait_sec:
time.sleep(wait_sec)
try:
return DatabaseConnection(dsn=dsn, session=session, connect_timeout=connect_timeout)
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.warning("数据库连接失败(第 %s 次):%s", attempt, exc)
raise last_exc or RuntimeError("数据库连接失败")
def _is_connection_error(exc: Exception) -> bool:
"""判断是否为连接断开/服务端异常导致的可重试错误。"""
return isinstance(exc, (psycopg2.OperationalError, psycopg2.InterfaceError))
def _run_stage_with_reconnect(
cfg: AppConfig,
logger: logging.Logger,
stage_name: str,
fn,
max_attempts: int = 3,
) -> dict | None:
"""
运行单个阶段:失败(尤其是连接断开)时自动重连并重试。
fn: (db_ops) -> dict | None
"""
last_exc: Exception | None = None
for attempt in range(1, max_attempts + 1):
db_conn = _connect_db_with_retry(cfg, logger)
db_ops = DatabaseOperations(db_conn)
try:
logger.info("阶段开始:%s(第 %s/%s 次)", stage_name, attempt, max_attempts)
result = fn(db_ops)
logger.info("阶段完成:%s", stage_name)
return result
except Exception as exc: # noqa: BLE001
last_exc = exc
logger.exception("阶段失败:%s(第 %s/%s 次):%s", stage_name, attempt, max_attempts, exc)
# 连接类错误允许重试;非连接错误直接抛出,避免掩盖逻辑问题。
if not _is_connection_error(exc):
raise
time.sleep(min(2**attempt, 10))
finally:
try:
db_ops.close() # type: ignore[attr-defined]
except Exception:
pass
try:
db_conn.close()
except Exception:
pass
raise last_exc or RuntimeError(f"阶段失败:{stage_name}")
def main() -> int:
"""脚本主入口:按顺序重建并跑通 ODS→DWD。"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("etl_billiards.rebuild_db")
args = _parse_args()
cfg = _build_config(args)
# 默认启用文件日志,便于事后追溯(即便运行失败也应尽早落盘)。
file_handler = _attach_file_logger(
log_root=cfg["io"]["log_root"],
filename=time.strftime("rebuild_db_%Y%m%d-%H%M%S.log"),
logger=logger,
)
try:
json_dir = Path(args.json_dir)
if not json_dir.exists():
logger.error("示例 JSON 目录不存在:%s", json_dir)
return 2
def stage_drop(db_ops: DatabaseOperations):
if not args.drop_schemas:
return None
if args.terminate_own_sessions:
killed = _terminate_own_idle_in_tx(db_ops, logger)
if killed:
db_ops.commit()
_drop_schemas(db_ops, logger)
db_ops.commit()
return None
def stage_init_ods(db_ops: DatabaseOperations):
return _run_task(InitOdsSchemaTask(cfg, db_ops, None, logger), logger)
def stage_init_dwd(db_ops: DatabaseOperations):
return _run_task(InitDwdSchemaTask(cfg, db_ops, None, logger), logger)
def stage_manual_ingest(db_ops: DatabaseOperations):
logger.info("开始执行MANUAL_INGESTjson_dir=%s", json_dir)
return _run_task(ManualIngestTask(cfg, db_ops, None, logger), logger)
def stage_dwd_load(db_ops: DatabaseOperations):
logger.info("开始执行DWD_LOAD_FROM_ODS")
return _run_task(DwdLoadTask(cfg, db_ops, None, logger), logger)
_run_stage_with_reconnect(cfg, logger, "DROP_SCHEMAS", stage_drop, max_attempts=3)
if args.stop_after == "DROP_SCHEMAS":
return 0
_run_stage_with_reconnect(cfg, logger, "INIT_ODS_SCHEMA", stage_init_ods, max_attempts=3)
if args.stop_after == "INIT_ODS_SCHEMA":
return 0
_run_stage_with_reconnect(cfg, logger, "INIT_DWD_SCHEMA", stage_init_dwd, max_attempts=3)
if args.stop_after == "INIT_DWD_SCHEMA":
return 0
_run_stage_with_reconnect(cfg, logger, "MANUAL_INGEST", stage_manual_ingest, max_attempts=5)
if args.stop_after == "MANUAL_INGEST":
return 0
_run_stage_with_reconnect(cfg, logger, "DWD_LOAD_FROM_ODS", stage_dwd_load, max_attempts=5)
if args.stop_after == "DWD_LOAD_FROM_ODS":
return 0
# 校验阶段复用一条新连接即可
_run_stage_with_reconnect(
cfg,
logger,
"BASIC_VALIDATE",
lambda db_ops: _basic_validate(db_ops, logger),
max_attempts=3,
)
if args.stop_after == "BASIC_VALIDATE":
return 0
return 0
finally:
if file_handler is not None:
try:
logging.getLogger().removeHandler(file_handler)
except Exception:
pass
try:
file_handler.close()
except Exception:
pass
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -7,7 +7,7 @@
依赖环境变量:
PG_DSN PostgreSQL 连接串(必填)
PG_CONNECT_TIMEOUT 可选,秒,默认 10
JSON_DOC_DIR 可选JSON 目录,默认 C:\\dev\\LLTQ\\export\\test-json-doc
JSON_DOC_DIR 可选JSON 目录,默认 export/test-json-doc
ODS_INCLUDE_FILES 可选,逗号分隔文件名(不含 .json
ODS_DROP_SCHEMA_FIRST 可选true/false默认 true
"""
@@ -26,7 +26,7 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values
DEFAULT_JSON_DIR = r"C:\dev\LLTQ\export\test-json-doc"
DEFAULT_JSON_DIR = "export/test-json-doc"
SPECIAL_LIST_PATHS: dict[str, tuple[str, ...]] = {
"assistant_accounts_master": ("data", "assistantInfos"),
"assistant_cancellation_records": ("data", "abolitionAssistants"),

View File

@@ -0,0 +1,224 @@
# -*- coding: utf-8 -*-
"""
Reload ODS tasks by fixed time windows with optional sleep between windows.
"""
from __future__ import annotations
import argparse
import logging
import subprocess
import sys
import time as time_mod
from datetime import datetime, timedelta
from pathlib import Path
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from config.settings import AppConfig
from utils.windowing import split_window
from utils.logging_utils import build_log_path, configure_logging
MIN_RELOAD_WINDOW_DAYS = 30
def _parse_dt(value: str, tz: ZoneInfo, *, is_end: bool) -> datetime:
raw = (value or "").strip()
if not raw:
raise ValueError("empty datetime")
has_time = any(ch in raw for ch in (":", "T"))
dt = dtparser.parse(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=tz)
else:
dt = dt.astimezone(tz)
if not has_time:
dt = dt.replace(hour=23 if is_end else 0, minute=59 if is_end else 0, second=59 if is_end else 0, microsecond=0)
return dt
def _iter_windows(start: datetime, end: datetime, window_size: timedelta):
if window_size.total_seconds() <= 0:
raise ValueError("window_size must be > 0")
cur = start
while cur < end:
nxt = min(cur + window_size, end)
yield cur, nxt
cur = nxt
def _run_task_window(
task_code: str,
window_start: datetime,
window_end: datetime,
api_page_size: int,
api_timeout: int,
logger: logging.Logger,
window_split_unit: str | None = "none",
window_compensation_hours: int | None = 0,
) -> None:
cmd = [
sys.executable,
"-m",
"cli.main",
"--pipeline-flow",
"FULL",
"--tasks",
task_code,
"--window-start",
window_start.strftime("%Y-%m-%d %H:%M:%S"),
"--window-end",
window_end.strftime("%Y-%m-%d %H:%M:%S"),
"--force-window-override",
"--window-split-unit",
str(window_split_unit or "none"),
"--window-compensation-hours",
str(int(window_compensation_hours or 0)),
]
if api_page_size > 0:
cmd += ["--api-page-size", str(api_page_size)]
if api_timeout > 0:
cmd += ["--api-timeout", str(api_timeout)]
logger.info(
"RUN_TASK task=%s window_start=%s window_end=%s",
task_code,
window_start.isoformat(),
window_end.isoformat(),
)
logger.debug("CMD %s", " ".join(cmd))
subprocess.run(cmd, check=True, cwd=str(PROJECT_ROOT))
def main() -> int:
ap = argparse.ArgumentParser(description="Reload ODS tasks by window slices.")
ap.add_argument("--tasks", required=True, help="comma-separated ODS task codes")
ap.add_argument("--start", required=True, help="start datetime, e.g. 2025-07-01")
ap.add_argument("--end", default="", help="end datetime (default: now)")
ap.add_argument("--window-days", type=int, default=1, help="days per window (default: 1)")
ap.add_argument("--window-hours", type=int, default=0, help="hours per window (default: 0)")
ap.add_argument("--window-split-unit", default="", help="split unit (month/none), default from config")
ap.add_argument("--window-compensation-hours", type=int, default=None, help="window compensation hours, default from config")
ap.add_argument("--sleep-seconds", type=float, default=0, help="sleep seconds after each window")
ap.add_argument("--api-page-size", type=int, default=200, help="API page size override")
ap.add_argument("--api-timeout", type=int, default=20, help="API timeout seconds override")
ap.add_argument("--log-file", default="", help="log file path (default: logs/reload_ods_windowed_YYYYMMDD_HHMMSS.log)")
ap.add_argument("--log-dir", default="", help="log directory (default: logs)")
ap.add_argument("--log-level", default="INFO", help="log level (default: INFO)")
ap.add_argument("--no-log-console", action="store_true", help="disable console logging")
args = ap.parse_args()
log_dir = Path(args.log_dir) if args.log_dir else (PROJECT_ROOT / "logs")
log_file = Path(args.log_file) if args.log_file else build_log_path(log_dir, "reload_ods_windowed")
log_console = not args.no_log_console
with configure_logging(
"reload_ods_windowed",
log_file,
level=args.log_level,
console=log_console,
tee_std=True,
) as logger:
cfg = AppConfig.load({})
tz = ZoneInfo(cfg.get("app.timezone", "Asia/Taipei"))
start = _parse_dt(args.start, tz, is_end=False)
end = datetime.now(tz) if not args.end else _parse_dt(args.end, tz, is_end=True)
window_days = int(args.window_days)
window_hours = int(args.window_hours)
split_unit = (args.window_split_unit or cfg.get("run.window_split.unit", "month") or "month").strip()
comp_hours = args.window_compensation_hours
if comp_hours is None:
comp_hours = cfg.get("run.window_split.compensation_hours", 0)
use_split = split_unit.lower() not in ("", "none", "off", "false", "0")
if use_split:
windows = split_window(
start,
end,
tz=tz,
split_unit=split_unit,
compensation_hours=comp_hours,
)
else:
min_hours = MIN_RELOAD_WINDOW_DAYS * 24
if window_hours > 0:
if window_hours < min_hours:
logger.warning(
"window_hours=%s too small; adjust to %s",
window_hours,
min_hours,
)
window_hours = min_hours
elif window_days < MIN_RELOAD_WINDOW_DAYS:
logger.warning(
"window_days=%s too small; adjust to %s",
window_days,
MIN_RELOAD_WINDOW_DAYS,
)
window_days = MIN_RELOAD_WINDOW_DAYS
adjusted = split_window(
start,
end,
tz=tz,
split_unit="none",
compensation_hours=comp_hours,
)
if adjusted:
start, end = adjusted[0]
window_size = timedelta(hours=window_hours) if window_hours > 0 else timedelta(days=window_days)
windows = list(_iter_windows(start, end, window_size))
if windows:
start, end = windows[0][0], windows[-1][1]
task_codes = [t.strip().upper() for t in args.tasks.split(",") if t.strip()]
if not task_codes:
raise SystemExit("no tasks specified")
logger.info(
"START range=%s~%s window_days=%s window_hours=%s split_unit=%s comp_hours=%s sleep=%.2f",
start.isoformat(),
end.isoformat(),
window_days,
window_hours,
split_unit,
comp_hours,
args.sleep_seconds,
)
for task_code in task_codes:
logger.info("TASK_START task=%s", task_code)
for window_start, window_end in windows:
start_ts = time_mod.monotonic()
_run_task_window(
task_code=task_code,
window_start=window_start,
window_end=window_end,
api_page_size=args.api_page_size,
api_timeout=args.api_timeout,
logger=logger,
window_split_unit="none",
window_compensation_hours=0,
)
elapsed = time_mod.monotonic() - start_ts
logger.info(
"WINDOW_DONE task=%s window_start=%s window_end=%s elapsed=%.2fs",
task_code,
window_start.isoformat(),
window_end.isoformat(),
elapsed,
)
if args.sleep_seconds > 0:
logger.debug("SLEEP seconds=%.2f", args.sleep_seconds)
time_mod.sleep(args.sleep_seconds)
logger.info("TASK_DONE task=%s", task_code)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,267 @@
# -*- coding: utf-8 -*-
"""PostgreSQL connection performance test (ASCII-only output)."""
from __future__ import annotations
import argparse
import math
import os
import statistics
import sys
import time
from typing import Dict, Iterable, List
from psycopg2.extensions import make_dsn, parse_dsn
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from database.connection import DatabaseConnection
def _load_env() -> Dict[str, str]:
env: Dict[str, str] = {}
try:
from config.env_parser import _load_dotenv_values
except Exception:
_load_dotenv_values = None
if _load_dotenv_values:
try:
env.update(_load_dotenv_values())
except Exception:
pass
env.update(os.environ)
return env
def _apply_dsn_overrides(dsn: str, host: str | None, port: int | None) -> str:
overrides = {}
if host:
overrides["host"] = host
if port:
overrides["port"] = str(port)
if not overrides:
return dsn
return make_dsn(dsn, **overrides)
def _build_dsn_from_env(
host: str,
port: int,
user: str | None,
password: str | None,
dbname: str | None,
) -> str | None:
if not user or not dbname:
return None
params = {
"host": host,
"port": str(port),
"user": user,
"dbname": dbname,
}
if password:
params["password"] = password
return make_dsn("", **params)
def _safe_dsn_summary(dsn: str, host: str | None, port: int | None) -> str:
try:
info = parse_dsn(dsn)
except Exception:
info = {}
if host:
info["host"] = host
if port:
info["port"] = str(port)
info.pop("password", None)
if not info:
return "dsn=(hidden)"
items = " ".join(f"{k}={info[k]}" for k in sorted(info.keys()))
return items
def _percentile(values: List[float], pct: float) -> float:
if not values:
return 0.0
ordered = sorted(values)
if len(ordered) == 1:
return ordered[0]
rank = (len(ordered) - 1) * (pct / 100.0)
low = int(math.floor(rank))
high = int(math.ceil(rank))
if low == high:
return ordered[low]
return ordered[low] + (ordered[high] - ordered[low]) * (rank - low)
def _format_stats(label: str, values: Iterable[float]) -> str:
data = list(values)
if not data:
return f"{label}: no samples"
avg = statistics.mean(data)
stdev = statistics.stdev(data) if len(data) > 1 else 0.0
return (
f"{label}: count={len(data)} "
f"min={min(data):.2f}ms avg={avg:.2f}ms "
f"p50={_percentile(data, 50):.2f}ms "
f"p95={_percentile(data, 95):.2f}ms "
f"max={max(data):.2f}ms stdev={stdev:.2f}ms"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="PostgreSQL connection performance test")
parser.add_argument("--dsn", help="Override PG_DSN/TEST_DB_DSN/.env value")
parser.add_argument(
"--host",
default="100.64.0.4",
help="Override host in DSN (default: 100.64.0.4)",
)
parser.add_argument("--port", type=int, help="Override port in DSN")
parser.add_argument("--user", help="User when building DSN from PG_* env")
parser.add_argument("--password", help="Password when building DSN from PG_* env")
parser.add_argument("--dbname", help="Database name when building DSN from PG_* env")
parser.add_argument("--rounds", type=int, default=20, help="Measured connection rounds")
parser.add_argument("--warmup", type=int, default=2, help="Warmup rounds (not recorded)")
parser.add_argument("--query", default="SELECT 1", help="SQL to run after connect")
parser.add_argument(
"--query-repeat",
type=int,
default=1,
help="Query repetitions per connection (0 to skip)",
)
parser.add_argument(
"--connect-timeout",
type=int,
default=10,
help="connect_timeout seconds (capped at 20, default: 10)",
)
parser.add_argument(
"--statement-timeout-ms",
type=int,
help="Optional statement_timeout applied per connection",
)
parser.add_argument(
"--sleep-ms",
type=int,
default=0,
help="Sleep between rounds in milliseconds",
)
parser.add_argument(
"--continue-on-error",
action="store_true",
help="Continue even if a round fails",
)
parser.add_argument("--verbose", action="store_true", help="Print per-round timings")
return parser.parse_args()
def _run_round(
dsn: str,
timeout: int,
query: str,
query_repeat: int,
session: Dict[str, int] | None,
) -> tuple[float, List[float]]:
start = time.perf_counter()
conn = DatabaseConnection(dsn, connect_timeout=timeout, session=session)
connect_ms = (time.perf_counter() - start) * 1000.0
query_times: List[float] = []
try:
for _ in range(query_repeat):
q_start = time.perf_counter()
conn.query(query)
query_times.append((time.perf_counter() - q_start) * 1000.0)
return connect_ms, query_times
finally:
try:
conn.rollback()
except Exception:
pass
conn.close()
def main() -> int:
args = parse_args()
if args.rounds < 0 or args.warmup < 0 or args.query_repeat < 0:
print("rounds/warmup/query-repeat must be >= 0", file=sys.stderr)
return 2
env = _load_env()
dsn = args.dsn or env.get("PG_DSN") or env.get("TEST_DB_DSN")
host = args.host
port = args.port
if not dsn:
user = args.user or env.get("PG_USER")
password = args.password if args.password is not None else env.get("PG_PASSWORD")
dbname = args.dbname or env.get("PG_NAME")
try:
resolved_port = port or int(env.get("PG_PORT", "5432"))
except ValueError:
resolved_port = port or 5432
dsn = _build_dsn_from_env(host, resolved_port, user, password, dbname)
if not dsn:
print(
"Missing DSN. Provide --dsn or set PG_DSN/TEST_DB_DSN, or PG_USER + PG_NAME.",
file=sys.stderr,
)
return 2
dsn = _apply_dsn_overrides(dsn, host, port)
timeout = max(1, min(int(args.connect_timeout), 20))
session = None
if args.statement_timeout_ms is not None:
session = {"statement_timeout_ms": int(args.statement_timeout_ms)}
print("Target:", _safe_dsn_summary(dsn, host, port))
print(
f"Rounds: {args.rounds} (warmup {args.warmup}), "
f"query_repeat={args.query_repeat}, timeout={timeout}s"
)
if args.query_repeat > 0:
print("Query:", args.query)
connect_times: List[float] = []
query_times: List[float] = []
failures: List[str] = []
total = args.warmup + args.rounds
for idx in range(total):
is_warmup = idx < args.warmup
try:
c_ms, q_times = _run_round(
dsn, timeout, args.query, args.query_repeat, session
)
if not is_warmup:
connect_times.append(c_ms)
query_times.extend(q_times)
if args.verbose:
tag = "warmup" if is_warmup else "sample"
q_msg = ""
if args.query_repeat > 0:
q_avg = statistics.mean(q_times) if q_times else 0.0
q_msg = f", query_avg={q_avg:.2f}ms"
print(f"[{tag} {idx + 1}/{total}] connect={c_ms:.2f}ms{q_msg}")
except Exception as exc:
msg = f"round {idx + 1}: {exc}"
failures.append(msg)
print("Failure:", msg, file=sys.stderr)
if not args.continue_on_error:
break
if args.sleep_ms > 0:
time.sleep(args.sleep_ms / 1000.0)
if connect_times:
print(_format_stats("Connect", connect_times))
if args.query_repeat > 0:
print(_format_stats("Query", query_times))
if failures:
print(f"Failures: {len(failures)}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())