ETL 完成

This commit is contained in:
Neo
2026-01-18 22:37:38 +08:00
parent 8da6cb6563
commit 7ca19a4a2c
159 changed files with 31225 additions and 467 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
@dataclass(frozen=True)
@@ -92,6 +93,36 @@ class BaseTask:
"""计算时间窗口"""
now = datetime.now(self.tz)
override_start = self.config.get("run.window_override.start")
override_end = self.config.get("run.window_override.end")
if override_start or override_end:
if not (override_start and override_end):
raise ValueError("run.window_override.start/end 需要同时提供")
window_start = override_start
if isinstance(window_start, str):
window_start = dtparser.parse(window_start)
if isinstance(window_start, datetime) and window_start.tzinfo is None:
window_start = window_start.replace(tzinfo=self.tz)
elif isinstance(window_start, datetime):
window_start = window_start.astimezone(self.tz)
window_end = override_end
if isinstance(window_end, str):
window_end = dtparser.parse(window_end)
if isinstance(window_end, datetime) and window_end.tzinfo is None:
window_end = window_end.replace(tzinfo=self.tz)
elif isinstance(window_end, datetime):
window_end = window_end.astimezone(self.tz)
if not isinstance(window_start, datetime) or not isinstance(window_end, datetime):
raise ValueError("run.window_override.start/end 解析失败")
if window_end <= window_start:
raise ValueError("run.window_override.end 必须大于 start")
window_minutes = max(1, int((window_end - window_start).total_seconds() // 60))
return window_start, window_end, window_minutes
idle_start = self.config.get("run.idle_window.start", "04:00")
idle_end = self.config.get("run.idle_window.end", "16:00")
is_idle = self._is_in_idle_window(now, idle_start, idle_end)

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
"""Task: report last successful cursor cutoff times from etl_admin."""
from __future__ import annotations
from typing import Any
from .base_task import BaseTask
class CheckCutoffTask(BaseTask):
"""Report per-task cursor cutoff times (etl_admin.etl_cursor.last_end)."""
def get_task_code(self) -> str:
return "CHECK_CUTOFF"
def execute(self, cursor_data: dict | None = None) -> dict:
store_id = int(self.config.get("app.store_id"))
filter_codes = self.config.get("run.cutoff_task_codes") or None
if isinstance(filter_codes, str):
filter_codes = [c.strip().upper() for c in filter_codes.split(",") if c.strip()]
sql = """
SELECT
t.task_code,
c.last_start,
c.last_end,
c.last_id,
c.last_run_id,
c.updated_at
FROM etl_admin.etl_task t
LEFT JOIN etl_admin.etl_cursor c
ON c.task_id = t.task_id AND c.store_id = t.store_id
WHERE t.store_id = %s
AND t.enabled = TRUE
ORDER BY t.task_code
"""
rows = self.db.query(sql, (store_id,))
if filter_codes:
wanted = {str(c).upper() for c in filter_codes}
rows = [r for r in rows if str(r.get("task_code", "")).upper() in wanted]
def _ts(v: Any) -> str:
return "-" if not v else str(v)
self.logger.info("CHECK_CUTOFF: store_id=%s enabled_tasks=%s", store_id, len(rows))
for r in rows:
self.logger.info(
"CHECK_CUTOFF: %-24s last_end=%s last_start=%s last_run_id=%s",
str(r.get("task_code") or ""),
_ts(r.get("last_end")),
_ts(r.get("last_start")),
_ts(r.get("last_run_id")),
)
cutoff_candidates = [
r.get("last_end")
for r in rows
if r.get("last_end") is not None and not str(r.get("task_code", "")).upper().startswith("INIT_")
]
cutoff = min(cutoff_candidates) if cutoff_candidates else None
self.logger.info("CHECK_CUTOFF: overall_cutoff(min last_end, excl INIT_*)=%s", _ts(cutoff))
ods_fetched = self._probe_ods_fetched_at(store_id)
if ods_fetched:
non_null = [v["max_fetched_at"] for v in ods_fetched.values() if v.get("max_fetched_at") is not None]
ods_cutoff = min(non_null) if non_null else None
self.logger.info("CHECK_CUTOFF: ODS cutoff(min MAX(fetched_at))=%s", _ts(ods_cutoff))
worst = sorted(
((k, v.get("max_fetched_at")) for k, v in ods_fetched.items()),
key=lambda kv: (kv[1] is None, kv[1]),
)[:8]
for table, mx in worst:
self.logger.info("CHECK_CUTOFF: ODS table=%s max_fetched_at=%s", table, _ts(mx))
dw_checks = self._probe_dw_time_columns()
for name, value in dw_checks.items():
self.logger.info("CHECK_CUTOFF: %s=%s", name, _ts(value))
return {
"status": "SUCCESS",
"counts": {"fetched": len(rows), "inserted": 0, "updated": 0, "skipped": 0, "errors": 0},
"window": None,
"request_params": {"store_id": store_id, "filter_task_codes": filter_codes or []},
"report": {
"rows": rows,
"overall_cutoff": cutoff,
"ods_fetched_at": ods_fetched,
"dw_max_times": dw_checks,
},
}
def _probe_ods_fetched_at(self, store_id: int) -> dict[str, dict[str, Any]]:
try:
from tasks.dwd_load_task import DwdLoadTask # local import to avoid circulars
except Exception:
return {}
ods_tables = sorted({str(t) for t in DwdLoadTask.TABLE_MAP.values() if str(t).startswith("billiards_ods.")})
results: dict[str, dict[str, Any]] = {}
for table in ods_tables:
try:
row = self.db.query(f"SELECT MAX(fetched_at) AS mx, COUNT(*) AS cnt FROM {table}")[0]
results[table] = {"max_fetched_at": row.get("mx"), "count": row.get("cnt")}
except Exception as exc: # noqa: BLE001
results[table] = {"max_fetched_at": None, "count": None, "error": str(exc)}
return results
def _probe_dw_time_columns(self) -> dict[str, Any]:
checks: dict[str, Any] = {}
probes = {
"DWD.max_settlement_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_settlement_head",
"DWD.max_payment_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_payment",
"DWD.max_refund_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_refund",
"DWS.max_order_date": "SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary",
"DWS.max_updated_at": "SELECT MAX(updated_at) AS mx FROM billiards_dws.dws_order_summary",
}
for name, sql2 in probes.items():
try:
row = self.db.query(sql2)[0]
checks[name] = row.get("mx")
except Exception as exc: # noqa: BLE001
checks[name] = f"ERROR: {exc}"
return checks

View File

@@ -2,10 +2,11 @@
"""DWD 装载任务:从 ODS 增量写入 DWD维度 SCD2事实按时间增量"""
from __future__ import annotations
import time
from datetime import datetime
from typing import Any, Dict, Iterable, List, Sequence
from psycopg2.extras import RealDictCursor
from psycopg2.extras import RealDictCursor, execute_batch, execute_values
from .base_task import BaseTask, TaskContext
@@ -61,14 +62,15 @@ class DwdLoadTask(BaseTask):
}
SCD_COLS = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"}
# 增量/窗口过滤优先使用业务时间fetched_at入库时间放最后避免回溯窗口被“当前入库时间”干扰。
FACT_ORDER_CANDIDATES = [
"fetched_at",
"pay_time",
"create_time",
"update_time",
"occur_time",
"settle_time",
"start_use_time",
"fetched_at",
]
# 特殊列映射dwd 列名 -> 源列表达式(可选 CAST
@@ -457,30 +459,69 @@ class DwdLoadTask(BaseTask):
return {"now": datetime.now()}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict[str, Any]:
"""遍历映射关系,维度执行 SCD2 合并,事实表按时间增量插入。"""
"""
遍历映射关系,维度执行 SCD2 合并,事实表按时间增量插入。
说明:
- 为避免长事务导致锁堆积/中断后遗留 idle-in-tx本任务按“每张表一次事务”提交
- 单表失败会回滚该表并继续后续表,最终在结果中汇总错误信息。
"""
now = extracted["now"]
summary: List[Dict[str, Any]] = []
errors: List[Dict[str, Any]] = []
only_tables_cfg = self.config.get("dwd.only_tables") or []
only_tables = {str(t).strip().lower() for t in only_tables_cfg if str(t).strip()} if only_tables_cfg else set()
with self.db.conn.cursor(cursor_factory=RealDictCursor) as cur:
for dwd_table, ods_table in self.TABLE_MAP.items():
dwd_cols = self._get_columns(cur, dwd_table)
ods_cols = self._get_columns(cur, ods_table)
if not dwd_cols:
self.logger.warning("跳过 %s,未能获取 DWD 列信息", dwd_table)
if only_tables and dwd_table.lower() not in only_tables and self._table_base(dwd_table).lower() not in only_tables:
continue
started = time.monotonic()
self.logger.info("DWD 装载开始:%s <= %s", dwd_table, ods_table)
try:
dwd_cols = self._get_columns(cur, dwd_table)
ods_cols = self._get_columns(cur, ods_table)
if not dwd_cols:
self.logger.warning("跳过 %s:未能获取 DWD 列信息", dwd_table)
continue
if self._table_base(dwd_table).startswith("dim_"):
processed = self._merge_dim(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
self.db.conn.commit()
summary.append({"table": dwd_table, "mode": "SCD2", "processed": processed})
else:
dwd_types = self._get_column_types(cur, dwd_table, "billiards_dwd")
ods_types = self._get_column_types(cur, ods_table, "billiards_ods")
use_window = bool(
self.config.get("run.window_override.start")
and self.config.get("run.window_override.end")
)
inserted = self._merge_fact_increment(
cur,
dwd_table,
ods_table,
dwd_cols,
ods_cols,
dwd_types,
ods_types,
window_start=context.window_start if use_window else None,
window_end=context.window_end if use_window else None,
)
self.db.conn.commit()
summary.append({"table": dwd_table, "mode": "INCREMENT", "inserted": inserted})
elapsed = time.monotonic() - started
self.logger.info("DWD 装载完成:%s,用时 %.2fs", dwd_table, elapsed)
except Exception as exc: # noqa: BLE001
try:
self.db.conn.rollback()
except Exception:
pass
elapsed = time.monotonic() - started
self.logger.exception("DWD 装载失败:%s,用时 %.2fserr=%s", dwd_table, elapsed, exc)
errors.append({"table": dwd_table, "error": str(exc)})
continue
if self._table_base(dwd_table).startswith("dim_"):
processed = self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
summary.append({"table": dwd_table, "mode": "SCD2", "processed": processed})
else:
dwd_types = self._get_column_types(cur, dwd_table, "billiards_dwd")
ods_types = self._get_column_types(cur, ods_table, "billiards_ods")
inserted = self._merge_fact_increment(
cur, dwd_table, ods_table, dwd_cols, ods_cols, dwd_types, ods_types
)
summary.append({"table": dwd_table, "mode": "INCREMENT", "inserted": inserted})
self.db.conn.commit()
return {"tables": summary}
return {"tables": summary, "errors": errors}
# ---------------------- helpers ----------------------
def _get_columns(self, cur, table: str) -> List[str]:
@@ -589,6 +630,135 @@ class DwdLoadTask(BaseTask):
expanded.append(child_row)
return expanded
def _merge_dim(
self,
cur,
dwd_table: str,
ods_table: str,
dwd_cols: Sequence[str],
ods_cols: Sequence[str],
now: datetime,
) -> int:
"""
维表合并策略:
- 若主键包含 scd2 列(如 scd2_start_time/scd2_version执行真正的 SCD2关闭旧版+插入新版)。
- 否则(多数现有表主键仅为业务主键),执行 Type1 Upsert避免重复键异常并保证可重复回放。
"""
pk_cols = self._get_primary_keys(cur, dwd_table)
if not pk_cols:
raise ValueError(f"{dwd_table} 未配置主键,无法执行维表合并")
pk_has_scd = any(pk.lower() in self.SCD_COLS for pk in pk_cols)
scd_cols_present = any(c.lower() in self.SCD_COLS for c in dwd_cols)
if scd_cols_present and pk_has_scd:
return self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
return self._merge_dim_type1_upsert(cur, dwd_table, ods_table, dwd_cols, ods_cols, pk_cols, now)
def _merge_dim_type1_upsert(
self,
cur,
dwd_table: str,
ods_table: str,
dwd_cols: Sequence[str],
ods_cols: Sequence[str],
pk_cols: Sequence[str],
now: datetime,
) -> int:
"""维表 Type1 Upsert主键冲突则更新兼容带 scd2 字段但主键不支持多版本的表。"""
mapping = self._build_column_mapping(dwd_table, pk_cols, ods_cols)
ods_set = {c.lower() for c in ods_cols}
ods_table_sql = self._format_table(ods_table, "billiards_ods")
select_exprs: list[str] = []
added: set[str] = set()
for col in dwd_cols:
lc = col.lower()
if lc in self.SCD_COLS:
continue
if lc in mapping:
src, cast_type = mapping[lc]
select_exprs.append(f"{self._cast_expr(src, cast_type)} AS \"{lc}\"")
added.add(lc)
elif lc in ods_set:
select_exprs.append(f'\"{lc}\" AS \"{lc}\"')
added.add(lc)
for pk in pk_cols:
lc = pk.lower()
if lc in added:
continue
if lc in mapping:
src, cast_type = mapping[lc]
select_exprs.append(f"{self._cast_expr(src, cast_type)} AS \"{lc}\"")
elif lc in ods_set:
select_exprs.append(f'\"{lc}\" AS \"{lc}\"')
added.add(lc)
if not select_exprs:
return 0
cur.execute(f"SELECT {', '.join(select_exprs)} FROM {ods_table_sql}")
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
if dwd_table == "billiards_dwd.dim_goods_category":
rows = self._expand_goods_category_rows(rows)
# 按主键去重
seen_pk: set[tuple[Any, ...]] = set()
src_rows: list[Dict[str, Any]] = []
pk_lower = [c.lower() for c in pk_cols]
for row in rows:
pk_key = tuple(row.get(pk) for pk in pk_lower)
if pk_key in seen_pk:
continue
if any(v is None for v in pk_key):
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
continue
seen_pk.add(pk_key)
src_rows.append(row)
if not src_rows:
return 0
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
sorted_cols = [c.lower() for c in sorted(dwd_cols)]
insert_cols_sql = ", ".join(f'\"{c}\"' for c in sorted_cols)
def build_row(src_row: Dict[str, Any]) -> list[Any]:
values: list[Any] = []
for c in sorted_cols:
if c == "scd2_start_time":
values.append(now)
elif c == "scd2_end_time":
values.append(datetime(9999, 12, 31, 0, 0, 0))
elif c == "scd2_is_current":
values.append(1)
elif c == "scd2_version":
values.append(1)
else:
values.append(src_row.get(c))
return values
pk_sql = ", ".join(f'\"{c.lower()}\"' for c in pk_cols)
pk_lower_set = {c.lower() for c in pk_cols}
set_exprs: list[str] = []
for c in sorted_cols:
if c in pk_lower_set:
continue
if c == "scd2_start_time":
set_exprs.append(f'\"{c}\" = COALESCE({dwd_table_sql}.\"{c}\", EXCLUDED.\"{c}\")')
elif c == "scd2_version":
set_exprs.append(f'\"{c}\" = COALESCE({dwd_table_sql}.\"{c}\", EXCLUDED.\"{c}\")')
else:
set_exprs.append(f'\"{c}\" = EXCLUDED.\"{c}\"')
upsert_sql = (
f"INSERT INTO {dwd_table_sql} ({insert_cols_sql}) VALUES %s "
f"ON CONFLICT ({pk_sql}) DO UPDATE SET {', '.join(set_exprs)}"
)
execute_values(cur, upsert_sql, [build_row(r) for r in src_rows], page_size=500)
return len(src_rows)
def _merge_dim_scd2(
self,
cur,
@@ -646,8 +816,9 @@ class DwdLoadTask(BaseTask):
if dwd_table == "billiards_dwd.dim_goods_category":
rows = self._expand_goods_category_rows(rows)
inserted_or_updated = 0
# 归一化源行并按主键去重
seen_pk = set()
src_rows_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
for row in rows:
mapped_row: Dict[str, Any] = {}
for col in dwd_cols:
@@ -663,10 +834,110 @@ class DwdLoadTask(BaseTask):
pk_key = tuple(mapped_row.get(pk) for pk in pk_cols)
if pk_key in seen_pk:
continue
if any(v is None for v in pk_key):
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
continue
seen_pk.add(pk_key)
if self._upsert_scd2_row(cur, dwd_table, dwd_cols, pk_cols, mapped_row, now):
inserted_or_updated += 1
return len(rows)
src_rows_by_pk[pk_key] = mapped_row
if not src_rows_by_pk:
return 0
# 预加载当前版本scd2_is_current=1避免逐行 SELECT 造成大量 round-trip
table_sql_dwd = self._format_table(dwd_table, "billiards_dwd")
where_current = " AND ".join([f"COALESCE(scd2_is_current,1)=1"])
cur.execute(f"SELECT * FROM {table_sql_dwd} WHERE {where_current}")
current_rows = cur.fetchall() or []
current_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
for r in current_rows:
rr = {k.lower(): v for k, v in r.items()}
pk_key = tuple(rr.get(pk) for pk in pk_cols)
current_by_pk[pk_key] = rr
# 计算需要关闭/插入的主键集合
to_close: list[tuple[Any, ...]] = []
to_insert: list[tuple[Dict[str, Any], int]] = []
for pk_key, incoming in src_rows_by_pk.items():
current = current_by_pk.get(pk_key)
if current and not self._is_row_changed(current, incoming, dwd_cols):
continue
if current:
version = (current.get("scd2_version") or 1) + 1
to_close.append(pk_key)
else:
version = 1
to_insert.append((incoming, version))
# 先关闭旧版本(同一批次统一 end_time
if to_close:
self._close_current_dim_bulk(cur, dwd_table, pk_cols, to_close, now)
# 批量插入新版本
if to_insert:
self._insert_dim_rows_bulk(cur, dwd_table, dwd_cols, to_insert, now)
return len(src_rows_by_pk)
def _close_current_dim_bulk(
self,
cur,
table: str,
pk_cols: Sequence[str],
pk_keys: Sequence[tuple[Any, ...]],
now: datetime,
) -> None:
"""批量关闭当前版本scd2_is_current=0 + 填充结束时间)。"""
table_sql = self._format_table(table, "billiards_dwd")
if len(pk_cols) == 1:
pk = pk_cols[0]
ids = [k[0] for k in pk_keys]
cur.execute(
f'UPDATE {table_sql} SET scd2_end_time=%s, scd2_is_current=0 '
f'WHERE COALESCE(scd2_is_current,1)=1 AND "{pk}" = ANY(%s)',
(now, ids),
)
return
# 复合主键:对“发生变更的键”逐条关闭(数量通常远小于全量行数)
where_clause = " AND ".join(f'"{pk}" = %s' for pk in pk_cols)
sql = (
f"UPDATE {table_sql} SET scd2_end_time=%s, scd2_is_current=0 "
f"WHERE COALESCE(scd2_is_current,1)=1 AND {where_clause}"
)
args_list = [(now, *pk_key) for pk_key in pk_keys]
execute_batch(cur, sql, args_list, page_size=500)
def _insert_dim_rows_bulk(
self,
cur,
table: str,
dwd_cols: Sequence[str],
rows_with_version: Sequence[tuple[Dict[str, Any], int]],
now: datetime,
) -> None:
"""批量插入新的 SCD2 版本行。"""
sorted_cols = [c.lower() for c in sorted(dwd_cols)]
insert_cols_sql = ", ".join(f'"{c}"' for c in sorted_cols)
table_sql = self._format_table(table, "billiards_dwd")
def build_row(src_row: Dict[str, Any], version: int) -> list[Any]:
values: list[Any] = []
for c in sorted_cols:
if c == "scd2_start_time":
values.append(now)
elif c == "scd2_end_time":
values.append(datetime(9999, 12, 31, 0, 0, 0))
elif c == "scd2_is_current":
values.append(1)
elif c == "scd2_version":
values.append(version)
else:
values.append(src_row.get(c))
return values
values_rows = [build_row(r, ver) for r, ver in rows_with_version]
insert_sql = f"INSERT INTO {table_sql} ({insert_cols_sql}) VALUES %s"
execute_values(cur, insert_sql, values_rows, page_size=500)
def _upsert_scd2_row(
self,
@@ -762,6 +1033,8 @@ class DwdLoadTask(BaseTask):
ods_cols: Sequence[str],
dwd_types: Dict[str, str],
ods_types: Dict[str, str],
window_start: datetime | None = None,
window_end: datetime | None = None,
) -> int:
"""事实表按时间增量插入,默认按列名交集写入。"""
mapping_entries = self.FACT_MAPPINGS.get(dwd_table) or []
@@ -813,7 +1086,10 @@ class DwdLoadTask(BaseTask):
params: List[Any] = []
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
ods_table_sql = self._format_table(ods_table, "billiards_ods")
if order_col:
if order_col and window_start and window_end:
where_sql = f'WHERE "{order_col}" >= %s AND "{order_col}" < %s'
params.extend([window_start, window_end])
elif order_col:
cur.execute(f'SELECT COALESCE(MAX("{order_col}"), %s) FROM {dwd_table_sql}', ("1970-01-01",))
row = cur.fetchone() or {}
watermark = list(row.values())[0] if row else "1970-01-01"

View File

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
"""Build DWS order summary table from DWD fact tables."""
from __future__ import annotations
from datetime import date
from typing import Any
from .base_task import BaseTask, TaskContext
from scripts.build_dws_order_summary import SQL_BUILD_SUMMARY
class DwsBuildOrderSummaryTask(BaseTask):
"""Recompute/refresh `billiards_dws.dws_order_summary` for a date window."""
def get_task_code(self) -> str:
return "DWS_BUILD_ORDER_SUMMARY"
def execute(self, cursor_data: dict | None = None) -> dict:
context = self._build_context(cursor_data)
task_code = self.get_task_code()
self.logger.info(
"%s: start, window[%s ~ %s]",
task_code,
context.window_start,
context.window_end,
)
try:
extracted = self.extract(context)
transformed = self.transform(extracted, context)
load_result = self.load(transformed, context) or {}
self.db.commit()
except Exception:
self.db.rollback()
self.logger.error("%s: failed", task_code, exc_info=True)
raise
counts = load_result.get("counts") or {}
result = {"status": "SUCCESS", "counts": counts}
result["window"] = {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
}
if "request_params" in load_result:
result["request_params"] = load_result["request_params"]
if "extra" in load_result:
result["extra"] = load_result["extra"]
self.logger.info("%s: done, counts=%s", task_code, counts)
return result
def extract(self, context: TaskContext) -> dict[str, Any]:
store_id = int(self.config.get("app.store_id"))
full_refresh = bool(self.config.get("dws.order_summary.full_refresh", False))
site_id = self.config.get("dws.order_summary.site_id", store_id)
if site_id in ("", None, "null", "NULL"):
site_id = None
start_date = self.config.get("dws.order_summary.start_date")
end_date = self.config.get("dws.order_summary.end_date")
if not full_refresh:
if not start_date:
start_date = context.window_start.date()
if not end_date:
end_date = context.window_end.date()
else:
start_date = None
end_date = None
delete_before_insert = bool(self.config.get("dws.order_summary.delete_before_insert", True))
return {
"site_id": site_id,
"start_date": start_date,
"end_date": end_date,
"full_refresh": full_refresh,
"delete_before_insert": delete_before_insert,
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
sql_params = {
"site_id": extracted["site_id"],
"start_date": extracted["start_date"],
"end_date": extracted["end_date"],
}
request_params = {
"site_id": extracted["site_id"],
"start_date": _jsonable_date(extracted["start_date"]),
"end_date": _jsonable_date(extracted["end_date"]),
}
with self.db.conn.cursor() as cur:
cur.execute("SELECT to_regclass('billiards_dws.dws_order_summary') AS reg;")
row = cur.fetchone()
reg = row[0] if row else None
if not reg:
raise RuntimeError("DWS 表不存在:请先运行任务 INIT_DWS_SCHEMA")
deleted = 0
if extracted["delete_before_insert"]:
if extracted["full_refresh"] and extracted["site_id"] is None:
cur.execute("TRUNCATE TABLE billiards_dws.dws_order_summary;")
self.logger.info("DWS_BUILD_ORDER_SUMMARY: truncated billiards_dws.dws_order_summary")
else:
delete_sql = "DELETE FROM billiards_dws.dws_order_summary WHERE 1=1"
delete_args: list[Any] = []
if extracted["site_id"] is not None:
delete_sql += " AND site_id = %s"
delete_args.append(extracted["site_id"])
if extracted["start_date"] is not None:
delete_sql += " AND order_date >= %s"
delete_args.append(_as_date(extracted["start_date"]))
if extracted["end_date"] is not None:
delete_sql += " AND order_date <= %s"
delete_args.append(_as_date(extracted["end_date"]))
cur.execute(delete_sql, delete_args)
deleted = cur.rowcount
self.logger.info("DWS_BUILD_ORDER_SUMMARY: deleted=%s sql=%s", deleted, delete_sql)
cur.execute(SQL_BUILD_SUMMARY, sql_params)
affected = cur.rowcount
return {
"counts": {"fetched": 0, "inserted": affected, "updated": 0, "skipped": 0, "errors": 0},
"request_params": request_params,
"extra": {"deleted": deleted},
}
def _as_date(v: Any) -> date:
if isinstance(v, date):
return v
return date.fromisoformat(str(v))
def _jsonable_date(v: Any):
if v is None:
return None
if isinstance(v, date):
return v.isoformat()
return str(v)

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""Initialize DWS schema (billiards_dws)."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from .base_task import BaseTask, TaskContext
class InitDwsSchemaTask(BaseTask):
"""Apply DWS schema SQL."""
def get_task_code(self) -> str:
return "INIT_DWS_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
base_dir = Path(__file__).resolve().parents[1] / "database"
dws_path = Path(self.config.get("schema.dws_file", base_dir / "schema_dws.sql"))
if not dws_path.exists():
raise FileNotFoundError(f"未找到 DWS schema 文件: {dws_path}")
drop_first = bool(self.config.get("dws.drop_schema_first", False))
return {"dws_sql": dws_path.read_text(encoding="utf-8"), "dws_file": str(dws_path), "drop_first": drop_first}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
with self.db.conn.cursor() as cur:
if extracted["drop_first"]:
cur.execute("DROP SCHEMA IF EXISTS billiards_dws CASCADE;")
self.logger.info("已执行 DROP SCHEMA billiards_dws CASCADE")
self.logger.info("执行 DWS schema 文件: %s", extracted["dws_file"])
cur.execute(extracted["dws_sql"])
return {"executed": 1, "files": [extracted["dws_file"]]}

View File

@@ -7,7 +7,7 @@ import os
from datetime import datetime
from typing import Any, Iterable
from psycopg2.extras import Json
from psycopg2.extras import Json, execute_values
from .base_task import BaseTask
@@ -75,7 +75,7 @@ class ManualIngestTask(BaseTask):
return "MANUAL_INGEST"
def execute(self, cursor_data: dict | None = None) -> dict:
"""从目录读取 JSON按表定义批量入库。"""
"""从目录读取 JSON按表定义批量入库(按文件提交事务,避免长事务导致连接不稳定)"""
data_dir = (
self.config.get("manual.data_dir")
or self.config.get("pipeline.ingest_source_dir")
@@ -87,9 +87,15 @@ class ManualIngestTask(BaseTask):
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
include_files_cfg = self.config.get("manual.include_files") or []
include_files = {str(x).strip().lower() for x in include_files_cfg if str(x).strip()} if include_files_cfg else set()
for filename in sorted(os.listdir(data_dir)):
if not filename.endswith(".json"):
continue
stem = os.path.splitext(filename)[0].lower()
if include_files and stem not in include_files:
continue
filepath = os.path.join(data_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as fh:
@@ -113,22 +119,25 @@ class ManualIngestTask(BaseTask):
self.logger.info("Ingesting %s into %s", filename, target_table)
try:
inserted, updated = self._ingest_table(target_table, records, filename)
inserted, updated, row_errors = self._ingest_table(target_table, records, filename)
counts["inserted"] += inserted
counts["updated"] += updated
counts["fetched"] += len(records)
counts["errors"] += row_errors
# 每个文件一次提交:降低单次事务体积,避免长事务/连接异常导致整体回滚失败。
self.db.commit()
except Exception:
counts["errors"] += 1
self.logger.exception("Error processing %s", filename)
self.db.rollback()
try:
self.db.rollback()
except Exception:
pass
# 若连接已断开,后续文件无法继续,直接抛出让上层处理(重连/重跑)。
if getattr(self.db.conn, "closed", 0):
raise
continue
try:
self.db.commit()
except Exception:
self.db.rollback()
raise
return {"status": "SUCCESS", "counts": counts}
def _match_by_filename(self, filename: str) -> str | None:
@@ -211,8 +220,15 @@ class ManualIngestTask(BaseTask):
self._table_columns_cache = cache
return cols
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int]:
"""构建 INSERT/ON CONFLICT 语句并批量执行。"""
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int, int]:
"""
构建 INSERT/ON CONFLICT 语句并批量执行(优先向量化,小批次提交)。
设计目标:
- 控制单条 SQL 体积(避免一次性 VALUES 过大导致服务端 backend 被 OOM/异常终止);
- 发生异常时,可降级逐行并用 SAVEPOINT 跳过异常行;
- 统计口径偏“尽量可跑通”,插入/更新计数为近似值(不强依赖 RETURNING
"""
spec = self.TABLE_SPECS.get(table)
if not spec:
raise ValueError(f"No table spec for {table}")
@@ -229,15 +245,19 @@ class ManualIngestTask(BaseTask):
pk_col_db = None
if pk_col:
pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col)
pk_index = None
if pk_col_db:
try:
pk_index = next(i for i, c in enumerate(columns_info) if c[0] == pk_col_db)
except Exception:
pk_index = None
placeholders = ", ".join(["%s"] * len(columns))
col_list = ", ".join(f'"{c}"' for c in columns)
sql = f'INSERT INTO {table} ({col_list}) VALUES ({placeholders})'
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
if pk_col_db:
update_cols = [c for c in columns if c != pk_col_db]
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
sql += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
sql += " RETURNING (xmax = 0) AS inserted"
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
params = []
now = datetime.now()
@@ -288,19 +308,55 @@ class ManualIngestTask(BaseTask):
params.append(tuple(row_vals))
if not params:
return 0, 0
return 0, 0, 0
# 先尝试向量化执行(速度快);若失败,再降级逐行并用 SAVEPOINT 跳过异常行。
try:
with self.db.conn.cursor() as cur:
# 分批提交:降低单次事务/单次 SQL 压力,避免服务端异常中断连接。
affected = 0
chunk_size = int(self.config.get("manual.execute_values_page_size", 50) or 50)
chunk_size = max(1, min(chunk_size, 500))
for i in range(0, len(params), chunk_size):
chunk = params[i : i + chunk_size]
execute_values(cur, sql_prefix, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
affected += int(cur.rowcount)
# 这里无法精确拆分 inserted/updated除非 RETURNING按“受影响行数≈插入”近似返回。
return int(affected), 0, 0
except Exception as exc:
self.logger.warning("批量入库失败准备降级逐行处理table=%s, err=%s", table, exc)
try:
self.db.rollback()
except Exception:
pass
inserted = 0
updated = 0
errors = 0
with self.db.conn.cursor() as cur:
for row in params:
cur.execute(sql, row)
flag = cur.fetchone()[0]
if flag:
cur.execute("SAVEPOINT sp_manual_ingest_row")
try:
cur.execute(sql_prefix.replace(" VALUES %s", f" VALUES ({', '.join(['%s'] * len(row))})"), row)
inserted += 1
else:
updated += 1
return inserted, updated
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception as exc: # noqa: BLE001
errors += 1
try:
cur.execute("ROLLBACK TO SAVEPOINT sp_manual_ingest_row")
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception:
pass
pk_val = None
if pk_index is not None:
try:
pk_val = row[pk_index]
except Exception:
pk_val = None
self.logger.warning("跳过异常行table=%s pk=%s err=%s", table, pk_val, exc)
return inserted, updated, errors
@staticmethod
def _get_value_case_insensitive(record: dict, col: str | None):

View File

@@ -0,0 +1,260 @@
# -*- coding: utf-8 -*-
"""在线抓取 ODS 相关接口并落盘为 JSON用于后续离线回放/入库)。"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from api.client import APIClient
from models.parsers import TypeParser
from utils.json_store import dump_json, endpoint_to_filename
from .base_task import BaseTask, TaskContext
@dataclass(frozen=True)
class EndpointSpec:
endpoint: str
window_style: str # site | start_end | range | pay | none
data_path: tuple[str, ...] = ("data",)
list_key: str | None = None
class OdsJsonArchiveTask(BaseTask):
"""
抓取一组 ODS 所需接口并落盘为“简化 JSON”
{"code": 0, "data": [...records...]}
说明:
- 该输出格式与 tasks/manual_ingest_task.py 的解析逻辑兼容;
- 默认每页一个文件,避免单文件过大;
- 结算小票(/Order/GetOrderSettleTicketNew按 orderSettleId 分文件写入。
"""
ENDPOINTS: tuple[EndpointSpec, ...] = (
EndpointSpec("/MemberProfile/GetTenantMemberList", "site", list_key="tenantMemberInfos"),
EndpointSpec("/MemberProfile/GetTenantMemberCardList", "site", list_key="tenantMemberCards"),
EndpointSpec("/MemberProfile/GetMemberCardBalanceChange", "start_end"),
EndpointSpec("/PersonnelManagement/SearchAssistantInfo", "site", list_key="assistantInfos"),
EndpointSpec(
"/AssistantPerformance/GetOrderAssistantDetails",
"start_end",
list_key="orderAssistantDetails",
),
EndpointSpec(
"/AssistantPerformance/GetAbolitionAssistant",
"start_end",
list_key="abolitionAssistants",
),
EndpointSpec("/Table/GetSiteTables", "site", list_key="siteTables"),
EndpointSpec(
"/TenantGoodsCategory/QueryPrimarySecondaryCategory",
"site",
list_key="goodsCategoryList",
),
EndpointSpec("/TenantGoods/QueryTenantGoods", "site", list_key="tenantGoodsList"),
EndpointSpec("/TenantGoods/GetGoodsInventoryList", "site", list_key="orderGoodsList"),
EndpointSpec("/TenantGoods/GetGoodsStockReport", "site"),
EndpointSpec("/TenantGoods/GetGoodsSalesList", "start_end", list_key="orderGoodsLedgers"),
EndpointSpec(
"/PackageCoupon/QueryPackageCouponList",
"site",
list_key="packageCouponList",
),
EndpointSpec("/Site/GetSiteTableUseDetails", "start_end", list_key="siteTableUseDetailsList"),
EndpointSpec("/Site/GetSiteTableOrderDetails", "start_end", list_key="siteTableUseDetailsList"),
EndpointSpec("/Site/GetTaiFeeAdjustList", "start_end", list_key="taiFeeAdjustInfos"),
EndpointSpec(
"/GoodsStockManage/QueryGoodsOutboundReceipt",
"start_end",
list_key="queryDeliveryRecordsList",
),
EndpointSpec("/Promotion/GetOfflineCouponConsumePageList", "start_end"),
EndpointSpec("/Order/GetRefundPayLogList", "start_end"),
EndpointSpec("/Site/GetAllOrderSettleList", "range", list_key="settleList"),
EndpointSpec("/Site/GetRechargeSettleList", "range", list_key="settleList"),
EndpointSpec("/PayLog/GetPayLogListPage", "pay"),
)
TICKET_ENDPOINT = "/Order/GetOrderSettleTicketNew"
def get_task_code(self) -> str:
return "ODS_JSON_ARCHIVE"
def extract(self, context: TaskContext) -> dict:
base_client = getattr(self.api, "base", None) or self.api
if not isinstance(base_client, APIClient):
raise TypeError("ODS_JSON_ARCHIVE 需要 APIClient在线抓取")
output_dir = getattr(self.api, "output_dir", None)
if output_dir:
out = Path(output_dir)
else:
out = Path(self.config.get("pipeline.fetch_root") or self.config["pipeline"]["fetch_root"])
out.mkdir(parents=True, exist_ok=True)
write_pretty = bool(self.config.get("io.write_pretty_json", False))
page_size = int(self.config.get("api.page_size", 200) or 200)
store_id = int(context.store_id)
total_records = 0
ticket_ids: set[int] = set()
per_endpoint: list[dict] = []
self.logger.info(
"ODS_JSON_ARCHIVE: 开始抓取,窗口[%s ~ %s] 输出目录=%s",
context.window_start,
context.window_end,
out,
)
for spec in self.ENDPOINTS:
self.logger.info("ODS_JSON_ARCHIVE: 抓取 endpoint=%s", spec.endpoint)
built_params = self._build_params(
spec.window_style, store_id, context.window_start, context.window_end
)
# /TenantGoods/GetGoodsInventoryList 要求 siteId 为数组(标量会触发服务端异常,返回畸形状态行 HTTP/1.1 1400
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
built_params["siteId"] = [store_id]
params = self._merge_common_params(built_params)
base_filename = endpoint_to_filename(spec.endpoint)
stem = Path(base_filename).stem
suffix = Path(base_filename).suffix or ".json"
endpoint_records = 0
endpoint_pages = 0
endpoint_error: str | None = None
try:
for page_no, records, _, _ in base_client.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
endpoint_pages += 1
total_records += len(records)
endpoint_records += len(records)
if spec.endpoint == "/PayLog/GetPayLogListPage":
for rec in records or []:
relate_id = TypeParser.parse_int(
(rec or {}).get("relateId")
or (rec or {}).get("orderSettleId")
or (rec or {}).get("order_settle_id")
)
if relate_id:
ticket_ids.add(relate_id)
out_path = out / f"{stem}__p{int(page_no):04d}{suffix}"
dump_json(out_path, {"code": 0, "data": records}, pretty=write_pretty)
except Exception as exc: # noqa: BLE001
endpoint_error = f"{type(exc).__name__}: {exc}"
self.logger.error("ODS_JSON_ARCHIVE: 接口抓取失败 endpoint=%s err=%s", spec.endpoint, endpoint_error)
per_endpoint.append(
{
"endpoint": spec.endpoint,
"file_stem": stem,
"pages": endpoint_pages,
"records": endpoint_records,
"error": endpoint_error,
}
)
if endpoint_error:
self.logger.warning(
"ODS_JSON_ARCHIVE: endpoint=%s 完成失败pages=%s records=%s err=%s",
spec.endpoint,
endpoint_pages,
endpoint_records,
endpoint_error,
)
else:
self.logger.info(
"ODS_JSON_ARCHIVE: endpoint=%s 完成 pages=%s records=%s",
spec.endpoint,
endpoint_pages,
endpoint_records,
)
# Ticket details: per orderSettleId
ticket_ids_sorted = sorted(ticket_ids)
self.logger.info("ODS_JSON_ARCHIVE: 小票候选数=%s", len(ticket_ids_sorted))
ticket_file_stem = Path(endpoint_to_filename(self.TICKET_ENDPOINT)).stem
ticket_file_suffix = Path(endpoint_to_filename(self.TICKET_ENDPOINT)).suffix or ".json"
ticket_records = 0
for order_settle_id in ticket_ids_sorted:
params = self._merge_common_params({"orderSettleId": int(order_settle_id)})
try:
records, _ = base_client.get_paginated(
endpoint=self.TICKET_ENDPOINT,
params=params,
page_size=None,
data_path=("data",),
list_key=None,
)
if not records:
continue
ticket_records += len(records)
out_path = out / f"{ticket_file_stem}__{int(order_settle_id)}{ticket_file_suffix}"
dump_json(out_path, {"code": 0, "data": records}, pretty=write_pretty)
except Exception as exc: # noqa: BLE001
self.logger.error(
"ODS_JSON_ARCHIVE: 小票抓取失败 orderSettleId=%s err=%s",
order_settle_id,
exc,
)
continue
total_records += ticket_records
manifest = {
"task": self.get_task_code(),
"store_id": store_id,
"window_start": context.window_start.isoformat(),
"window_end": context.window_end.isoformat(),
"page_size": page_size,
"total_records": total_records,
"ticket_ids": len(ticket_ids_sorted),
"ticket_records": ticket_records,
"endpoints": per_endpoint,
}
manifest_path = out / "manifest.json"
dump_json(manifest_path, manifest, pretty=True)
if hasattr(self.api, "last_dump"):
try:
self.api.last_dump = {"file": str(manifest_path), "records": total_records, "pages": None}
except Exception:
pass
self.logger.info("ODS_JSON_ARCHIVE: 抓取完成,总记录数=%s(含小票=%s", total_records, ticket_records)
return {"fetched": total_records, "ticket_ids": len(ticket_ids_sorted)}
def _build_params(self, window_style: str, store_id: int, window_start, window_end) -> dict:
if window_style == "none":
return {}
if window_style == "site":
return {"siteId": store_id}
if window_style == "range":
return {
"siteId": store_id,
"rangeStartTime": TypeParser.format_timestamp(window_start, self.tz),
"rangeEndTime": TypeParser.format_timestamp(window_end, self.tz),
}
if window_style == "pay":
return {
"siteId": store_id,
"StartPayTime": TypeParser.format_timestamp(window_start, self.tz),
"EndPayTime": TypeParser.format_timestamp(window_end, self.tz),
}
# default: startTime/endTime
return {
"siteId": store_id,
"startTime": TypeParser.format_timestamp(window_start, self.tz),
"endTime": TypeParser.format_timestamp(window_end, self.tz),
}

View File

@@ -2,11 +2,13 @@
"""ODS ingestion tasks."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Type
from loaders.ods import GenericODSLoader
from psycopg2.extras import Json, execute_values
from models.parsers import TypeParser
from .base_task import BaseTask
@@ -60,70 +62,61 @@ class BaseOdsTask(BaseTask):
def get_task_code(self) -> str:
return self.SPEC.code
def execute(self) -> dict:
def execute(self, cursor_data: dict | None = None) -> dict:
spec = self.SPEC
self.logger.info("寮€濮嬫墽琛?%s (ODS)", spec.code)
window_start, window_end, window_minutes = self._resolve_window(cursor_data)
store_id = TypeParser.parse_int(self.config.get("app.store_id"))
if not store_id:
raise ValueError("app.store_id 鏈厤缃紝鏃犳硶鎵ц ODS 浠诲姟")
page_size = self.config.get("api.page_size", 200)
params = self._build_params(spec, store_id)
columns = self._resolve_columns(spec)
if spec.conflict_columns_override:
conflict_columns = list(spec.conflict_columns_override)
else:
conflict_columns = []
if spec.include_site_column:
conflict_columns.append("site_id")
conflict_columns += [col.column for col in spec.pk_columns]
loader = GenericODSLoader(
self.db,
spec.table_name,
columns,
conflict_columns,
params = self._build_params(
spec,
store_id,
window_start=window_start,
window_end=window_end,
)
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
source_file = self._resolve_source_file_hint(spec)
try:
global_index = 0
for page_no, page_records, _, _ in self.api.iter_paginated(
for _, page_records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
rows: List[dict] = []
for raw in page_records:
row = self._build_row(
spec=spec,
store_id=store_id,
record=raw,
page_no=page_no if spec.include_page_no else None,
page_size_value=len(page_records)
if spec.include_page_size
else None,
source_file=source_file,
record_index=global_index if spec.include_record_index else None,
)
if row is None:
counts["skipped"] += 1
continue
rows.append(row)
global_index += 1
inserted, updated, _ = loader.upsert_rows(rows)
counts["inserted"] += inserted
counts["updated"] += updated
inserted, skipped = self._insert_records_schema_aware(
table=spec.table_name,
records=page_records,
response_payload=response_payload,
source_file=source_file,
source_endpoint=spec.endpoint if spec.include_source_endpoint else None,
)
counts["fetched"] += len(page_records)
counts["inserted"] += inserted
counts["skipped"] += skipped
self.db.commit()
self.logger.info("%s ODS 浠诲姟瀹屾垚: %s", spec.code, counts)
return self._build_result("SUCCESS", counts)
allow_empty_advance = bool(self.config.get("run.allow_empty_result_advance", False))
status = "SUCCESS"
if counts["fetched"] == 0 and not allow_empty_advance:
status = "PARTIAL"
result = self._build_result(status, counts)
result["window"] = {
"start": window_start,
"end": window_end,
"minutes": window_minutes,
}
result["request_params"] = params
return result
except Exception:
self.db.rollback()
@@ -131,12 +124,70 @@ class BaseOdsTask(BaseTask):
self.logger.error("%s ODS 浠诲姟澶辫触", spec.code, exc_info=True)
raise
def _build_params(self, spec: OdsTaskSpec, store_id: int) -> dict:
def _resolve_window(self, cursor_data: dict | None) -> tuple[datetime, datetime, int]:
base_start, base_end, base_minutes = self._get_time_window(cursor_data)
if self.config.get("run.force_window_override"):
override_start = self.config.get("run.window_override.start")
override_end = self.config.get("run.window_override.end")
if override_start and override_end:
return base_start, base_end, base_minutes
# 以 ODS 表 MAX(fetched_at) 兜底:避免“窗口游标推进但未实际入库”导致漏数。
last_fetched = self._get_max_fetched_at(self.SPEC.table_name)
if last_fetched:
overlap_seconds = int(self.config.get("run.overlap_seconds", 120) or 120)
cursor_end = cursor_data.get("last_end") if isinstance(cursor_data, dict) else None
anchor = cursor_end or last_fetched
# 如果 cursor_end 比真实入库时间(last_fetched)更靠后,说明游标被推进但表未跟上:改用 last_fetched 作为起点
if isinstance(cursor_end, datetime) and cursor_end.tzinfo is None:
cursor_end = cursor_end.replace(tzinfo=self.tz)
if isinstance(cursor_end, datetime) and cursor_end > last_fetched:
anchor = last_fetched
start = anchor - timedelta(seconds=max(0, overlap_seconds))
if start.tzinfo is None:
start = start.replace(tzinfo=self.tz)
else:
start = start.astimezone(self.tz)
end = datetime.now(self.tz)
minutes = max(1, int((end - start).total_seconds() // 60))
return start, end, minutes
return base_start, base_end, base_minutes
def _get_max_fetched_at(self, table_name: str) -> datetime | None:
try:
rows = self.db.query(f"SELECT MAX(fetched_at) AS mx FROM {table_name}")
except Exception:
return None
if not rows or not rows[0].get("mx"):
return None
mx = rows[0]["mx"]
if not isinstance(mx, datetime):
return None
if mx.tzinfo is None:
return mx.replace(tzinfo=self.tz)
return mx.astimezone(self.tz)
def _build_params(
self,
spec: OdsTaskSpec,
store_id: int,
*,
window_start: datetime,
window_end: datetime,
) -> dict:
base: dict[str, Any] = {}
if spec.include_site_id:
base["siteId"] = store_id
# /TenantGoods/GetGoodsInventoryList 要求 siteId 为数组(标量会触发服务端异常,返回畸形状态行 HTTP/1.1 1400
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [store_id]
else:
base["siteId"] = store_id
if spec.requires_window and spec.time_fields:
window_start, window_end, _ = self._get_time_window()
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(window_start, self.tz)
base[end_key] = TypeParser.format_timestamp(window_end, self.tz)
@@ -145,109 +196,226 @@ class BaseOdsTask(BaseTask):
params.update(spec.extra_params)
return params
def _resolve_columns(self, spec: OdsTaskSpec) -> List[str]:
columns: List[str] = []
if spec.include_site_column:
columns.append("site_id")
seen = set(columns)
for col_spec in list(spec.pk_columns) + list(spec.extra_columns):
if col_spec.column not in seen:
columns.append(col_spec.column)
seen.add(col_spec.column)
# ------------------------------------------------------------------ schema-aware ingest (ODS doc schema)
def _get_table_columns(self, table: str) -> list[tuple[str, str, str]]:
cache = getattr(self, "_table_columns_cache", {})
if table in cache:
return cache[table]
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
cache[table] = cols
self._table_columns_cache = cache
return cols
if spec.include_record_index and "record_index" not in seen:
columns.append("record_index")
seen.add("record_index")
def _get_table_pk_columns(self, table: str) -> list[str]:
cache = getattr(self, "_table_pk_cache", {})
if table in cache:
return cache[table]
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [r[0] for r in cur.fetchall()]
cache[table] = cols
self._table_pk_cache = cache
return cols
if spec.include_page_no and "page_no" not in seen:
columns.append("page_no")
seen.add("page_no")
if spec.include_page_size and "page_size" not in seen:
columns.append("page_size")
seen.add("page_size")
if spec.include_source_file and "source_file" not in seen:
columns.append("source_file")
seen.add("source_file")
if spec.include_source_endpoint and "source_endpoint" not in seen:
columns.append("source_endpoint")
seen.add("source_endpoint")
if spec.include_fetched_at and "fetched_at" not in seen:
columns.append("fetched_at")
seen.add("fetched_at")
if "payload" not in seen:
columns.append("payload")
return columns
def _build_row(
def _insert_records_schema_aware(
self,
spec: OdsTaskSpec,
store_id: int,
record: dict,
page_no: int | None,
page_size_value: int | None,
*,
table: str,
records: list,
response_payload: dict | list | None,
source_file: str | None,
record_index: int | None = None,
) -> dict | None:
row: dict[str, Any] = {}
if spec.include_site_column:
row["site_id"] = store_id
source_endpoint: str | None,
) -> tuple[int, int]:
"""
按 DB 表结构动态写入 ODS只插新数据ON CONFLICT DO NOTHING
返回 (inserted, skipped)。
"""
if not records:
return 0, 0
for col_spec in spec.pk_columns + spec.extra_columns:
value = self._extract_value(record, col_spec)
if value is None and col_spec.required:
self.logger.warning(
"%s 缂哄皯蹇呭~瀛楁 %s锛屽師濮嬭褰? %s",
spec.code,
col_spec.column,
record,
)
return None
row[col_spec.column] = value
cols_info = self._get_table_columns(table)
if not cols_info:
raise ValueError(f"Cannot resolve columns for table={table}")
if spec.include_page_no:
row["page_no"] = page_no
if spec.include_page_size:
row["page_size"] = page_size_value
if spec.include_record_index:
row["record_index"] = record_index
if spec.include_source_file:
row["source_file"] = source_file
if spec.include_source_endpoint:
row["source_endpoint"] = spec.endpoint
pk_cols = self._get_table_pk_columns(table)
db_json_cols_lower = {
c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
if spec.include_fetched_at:
row["fetched_at"] = datetime.now(self.tz)
row["payload"] = record
return row
col_names = [c[0] for c in cols_info]
quoted_cols = ", ".join(f'\"{c}\"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
if pk_cols:
pk_clause = ", ".join(f'\"{c}\"' for c in pk_cols)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
def _extract_value(self, record: dict, spec: ColumnSpec):
value = None
for key in spec.sources:
value = self._dig(record, key)
if value is not None:
break
if value is None and spec.default is not None:
value = spec.default
if value is not None and spec.transform:
value = spec.transform(value)
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False) # noqa: E731
params: list[tuple] = []
skipped = 0
root_site_profile = None
if isinstance(response_payload, dict):
data_part = response_payload.get("data")
if isinstance(data_part, dict):
sp = data_part.get("siteProfile") or data_part.get("site_profile")
if isinstance(sp, dict):
root_site_profile = sp
for rec in records:
if not isinstance(rec, dict):
skipped += 1
continue
merged_rec = self._merge_record_layers(rec)
if table in {"billiards_ods.recharge_settlements", "billiards_ods.settlement_records"}:
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") or root_site_profile
if isinstance(site_profile, dict):
# 避免写入 None 覆盖原本存在的 camelCase 字段(例如 tenantId/siteId/siteName
def _fill_missing(target_col: str, candidates: list[Any]):
existing = self._get_value_case_insensitive(merged_rec, target_col)
if existing not in (None, ""):
return
for cand in candidates:
if cand in (None, "", 0):
continue
merged_rec[target_col] = cand
return
_fill_missing("tenantid", [site_profile.get("tenant_id"), site_profile.get("tenantId")])
_fill_missing("siteid", [site_profile.get("siteId"), site_profile.get("id")])
_fill_missing("sitename", [site_profile.get("shop_name"), site_profile.get("siteName")])
if pk_cols:
missing_pk = False
for pk in pk_cols:
pk_val = self._get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
skipped += 1
continue
row_vals: list[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append(source_file)
continue
if col_lower == "source_endpoint":
row_vals.append(source_endpoint)
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(self._cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0, skipped
inserted = 0
chunk_size = int(self.config.get("run.ods_execute_values_page_size", 200) or 200)
chunk_size = max(1, min(chunk_size, 2000))
with self.db.conn.cursor() as cur:
for i in range(0, len(params), chunk_size):
chunk = params[i : i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted, skipped
@staticmethod
def _merge_record_layers(record: dict) -> dict:
merged = record
data_part = merged.get("data")
while isinstance(data_part, dict):
merged = {**data_part, **merged}
data_part = data_part.get("data")
settle_inner = merged.get("settleList")
if isinstance(settle_inner, dict):
merged = {**settle_inner, **merged}
return merged
@staticmethod
def _get_value_case_insensitive(record: dict | None, col: str | None):
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
@staticmethod
def _normalize_scalar(value):
if value == "" or value == "{}" or value == "[]":
return None
return value
@staticmethod
def _dig(record: Any, path: str | None):
if not path:
def _cast_value(value, data_type: str):
if value is None:
return None
current = record
for part in path.split("."):
if isinstance(current, dict):
current = current.get(part)
else:
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
return current
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _resolve_source_file_hint(self, spec: OdsTaskSpec) -> str | None:
resolver = getattr(self.api, "get_source_hint", None)
@@ -319,15 +487,16 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
endpoint="/Site/GetAllOrderSettleList",
data_path=("data",),
list_key="settleList",
time_fields=("rangeStartTime", "rangeEndTime"),
pk_columns=(),
include_site_column=False,
include_source_endpoint=False,
include_source_endpoint=True,
include_page_no=False,
include_page_size=False,
include_fetched_at=False,
include_record_index=True,
conflict_columns_override=("source_file", "record_index"),
requires_window=False,
requires_window=True,
description="缁撹处璁板綍 ODS锛欸etAllOrderSettleList -> settleList 鍘熷 JSON",
),
OdsTaskSpec(
@@ -512,6 +681,7 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
endpoint="/Site/GetRechargeSettleList",
data_path=("data",),
list_key="settleList",
time_fields=("rangeStartTime", "rangeEndTime"),
pk_columns=(_int_col("recharge_order_id", "settleList.id", "id", required=True),),
extra_columns=(
_int_col("tenant_id", "settleList.tenantId", "tenantId"),
@@ -583,7 +753,7 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
include_fetched_at=True,
include_record_index=False,
conflict_columns_override=None,
requires_window=False,
requires_window=True,
description="?????? ODS?GetRechargeSettleList -> data.settleList ????",
),
@@ -800,12 +970,6 @@ class OdsSettlementTicketTask(BaseOdsTask):
store_id = TypeParser.parse_int(self.config.get("app.store_id")) or 0
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
loader = GenericODSLoader(
self.db,
spec.table_name,
self._resolve_columns(spec),
list(spec.conflict_columns_override or ("source_file", "record_index")),
)
source_file = self._resolve_source_file_hint(spec)
try:
@@ -823,39 +987,43 @@ class OdsSettlementTicketTask(BaseOdsTask):
context.window_start,
context.window_end,
)
return self._build_result("SUCCESS", counts)
result = self._build_result("SUCCESS", counts)
result["window"] = {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
}
result["request_params"] = {"candidates": 0}
return result
payloads, skipped = self._fetch_ticket_payloads(candidates)
counts["skipped"] += skipped
rows: list[dict] = []
for idx, payload in enumerate(payloads):
row = self._build_row(
spec=spec,
store_id=store_id,
record=payload,
page_no=None,
page_size_value=None,
source_file=source_file,
record_index=idx if spec.include_record_index else None,
)
if row is None:
counts["skipped"] += 1
continue
rows.append(row)
inserted, updated, _ = loader.upsert_rows(rows)
inserted, skipped2 = self._insert_records_schema_aware(
table=spec.table_name,
records=payloads,
response_payload=None,
source_file=source_file,
source_endpoint=spec.endpoint,
)
counts["inserted"] += inserted
counts["updated"] += updated
counts["skipped"] += skipped2
self.db.commit()
self.logger.info(
"%s: 灏忕エ鎶撳彇瀹屾垚锛屽€欓€?%s 鎻掑叆=%s 鏇存柊=%s 璺宠繃=%s",
spec.code,
len(candidates),
inserted,
updated,
0,
counts["skipped"],
)
return self._build_result("SUCCESS", counts)
result = self._build_result("SUCCESS", counts)
result["window"] = {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
}
result["request_params"] = {"candidates": len(candidates)}
return result
except Exception:
counts["errors"] += 1
@@ -1026,4 +1194,3 @@ ODS_TASK_CLASSES: Dict[str, Type[BaseOdsTask]] = {
ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"] = OdsSettlementTicketTask
__all__ = ["ODS_TASK_CLASSES", "ODS_TASK_SPECS", "BaseOdsTask", "ENABLED_ODS_CODES"]