ETL 完成
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
from dateutil import parser as dtparser
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -92,6 +93,36 @@ class BaseTask:
|
||||
"""计算时间窗口"""
|
||||
now = datetime.now(self.tz)
|
||||
|
||||
override_start = self.config.get("run.window_override.start")
|
||||
override_end = self.config.get("run.window_override.end")
|
||||
if override_start or override_end:
|
||||
if not (override_start and override_end):
|
||||
raise ValueError("run.window_override.start/end 需要同时提供")
|
||||
|
||||
window_start = override_start
|
||||
if isinstance(window_start, str):
|
||||
window_start = dtparser.parse(window_start)
|
||||
if isinstance(window_start, datetime) and window_start.tzinfo is None:
|
||||
window_start = window_start.replace(tzinfo=self.tz)
|
||||
elif isinstance(window_start, datetime):
|
||||
window_start = window_start.astimezone(self.tz)
|
||||
|
||||
window_end = override_end
|
||||
if isinstance(window_end, str):
|
||||
window_end = dtparser.parse(window_end)
|
||||
if isinstance(window_end, datetime) and window_end.tzinfo is None:
|
||||
window_end = window_end.replace(tzinfo=self.tz)
|
||||
elif isinstance(window_end, datetime):
|
||||
window_end = window_end.astimezone(self.tz)
|
||||
|
||||
if not isinstance(window_start, datetime) or not isinstance(window_end, datetime):
|
||||
raise ValueError("run.window_override.start/end 解析失败")
|
||||
if window_end <= window_start:
|
||||
raise ValueError("run.window_override.end 必须大于 start")
|
||||
|
||||
window_minutes = max(1, int((window_end - window_start).total_seconds() // 60))
|
||||
return window_start, window_end, window_minutes
|
||||
|
||||
idle_start = self.config.get("run.idle_window.start", "04:00")
|
||||
idle_end = self.config.get("run.idle_window.end", "16:00")
|
||||
is_idle = self._is_in_idle_window(now, idle_start, idle_end)
|
||||
|
||||
125
etl_billiards/tasks/check_cutoff_task.py
Normal file
125
etl_billiards/tasks/check_cutoff_task.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task: report last successful cursor cutoff times from etl_admin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base_task import BaseTask
|
||||
|
||||
|
||||
class CheckCutoffTask(BaseTask):
|
||||
"""Report per-task cursor cutoff times (etl_admin.etl_cursor.last_end)."""
|
||||
|
||||
def get_task_code(self) -> str:
|
||||
return "CHECK_CUTOFF"
|
||||
|
||||
def execute(self, cursor_data: dict | None = None) -> dict:
|
||||
store_id = int(self.config.get("app.store_id"))
|
||||
filter_codes = self.config.get("run.cutoff_task_codes") or None
|
||||
if isinstance(filter_codes, str):
|
||||
filter_codes = [c.strip().upper() for c in filter_codes.split(",") if c.strip()]
|
||||
|
||||
sql = """
|
||||
SELECT
|
||||
t.task_code,
|
||||
c.last_start,
|
||||
c.last_end,
|
||||
c.last_id,
|
||||
c.last_run_id,
|
||||
c.updated_at
|
||||
FROM etl_admin.etl_task t
|
||||
LEFT JOIN etl_admin.etl_cursor c
|
||||
ON c.task_id = t.task_id AND c.store_id = t.store_id
|
||||
WHERE t.store_id = %s
|
||||
AND t.enabled = TRUE
|
||||
ORDER BY t.task_code
|
||||
"""
|
||||
rows = self.db.query(sql, (store_id,))
|
||||
|
||||
if filter_codes:
|
||||
wanted = {str(c).upper() for c in filter_codes}
|
||||
rows = [r for r in rows if str(r.get("task_code", "")).upper() in wanted]
|
||||
|
||||
def _ts(v: Any) -> str:
|
||||
return "-" if not v else str(v)
|
||||
|
||||
self.logger.info("CHECK_CUTOFF: store_id=%s enabled_tasks=%s", store_id, len(rows))
|
||||
for r in rows:
|
||||
self.logger.info(
|
||||
"CHECK_CUTOFF: %-24s last_end=%s last_start=%s last_run_id=%s",
|
||||
str(r.get("task_code") or ""),
|
||||
_ts(r.get("last_end")),
|
||||
_ts(r.get("last_start")),
|
||||
_ts(r.get("last_run_id")),
|
||||
)
|
||||
|
||||
cutoff_candidates = [
|
||||
r.get("last_end")
|
||||
for r in rows
|
||||
if r.get("last_end") is not None and not str(r.get("task_code", "")).upper().startswith("INIT_")
|
||||
]
|
||||
cutoff = min(cutoff_candidates) if cutoff_candidates else None
|
||||
self.logger.info("CHECK_CUTOFF: overall_cutoff(min last_end, excl INIT_*)=%s", _ts(cutoff))
|
||||
|
||||
ods_fetched = self._probe_ods_fetched_at(store_id)
|
||||
if ods_fetched:
|
||||
non_null = [v["max_fetched_at"] for v in ods_fetched.values() if v.get("max_fetched_at") is not None]
|
||||
ods_cutoff = min(non_null) if non_null else None
|
||||
self.logger.info("CHECK_CUTOFF: ODS cutoff(min MAX(fetched_at))=%s", _ts(ods_cutoff))
|
||||
worst = sorted(
|
||||
((k, v.get("max_fetched_at")) for k, v in ods_fetched.items()),
|
||||
key=lambda kv: (kv[1] is None, kv[1]),
|
||||
)[:8]
|
||||
for table, mx in worst:
|
||||
self.logger.info("CHECK_CUTOFF: ODS table=%s max_fetched_at=%s", table, _ts(mx))
|
||||
|
||||
dw_checks = self._probe_dw_time_columns()
|
||||
for name, value in dw_checks.items():
|
||||
self.logger.info("CHECK_CUTOFF: %s=%s", name, _ts(value))
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"counts": {"fetched": len(rows), "inserted": 0, "updated": 0, "skipped": 0, "errors": 0},
|
||||
"window": None,
|
||||
"request_params": {"store_id": store_id, "filter_task_codes": filter_codes or []},
|
||||
"report": {
|
||||
"rows": rows,
|
||||
"overall_cutoff": cutoff,
|
||||
"ods_fetched_at": ods_fetched,
|
||||
"dw_max_times": dw_checks,
|
||||
},
|
||||
}
|
||||
|
||||
def _probe_ods_fetched_at(self, store_id: int) -> dict[str, dict[str, Any]]:
|
||||
try:
|
||||
from tasks.dwd_load_task import DwdLoadTask # local import to avoid circulars
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
ods_tables = sorted({str(t) for t in DwdLoadTask.TABLE_MAP.values() if str(t).startswith("billiards_ods.")})
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for table in ods_tables:
|
||||
try:
|
||||
row = self.db.query(f"SELECT MAX(fetched_at) AS mx, COUNT(*) AS cnt FROM {table}")[0]
|
||||
results[table] = {"max_fetched_at": row.get("mx"), "count": row.get("cnt")}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
results[table] = {"max_fetched_at": None, "count": None, "error": str(exc)}
|
||||
return results
|
||||
|
||||
def _probe_dw_time_columns(self) -> dict[str, Any]:
|
||||
checks: dict[str, Any] = {}
|
||||
probes = {
|
||||
"DWD.max_settlement_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_settlement_head",
|
||||
"DWD.max_payment_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_payment",
|
||||
"DWD.max_refund_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_refund",
|
||||
"DWS.max_order_date": "SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary",
|
||||
"DWS.max_updated_at": "SELECT MAX(updated_at) AS mx FROM billiards_dws.dws_order_summary",
|
||||
}
|
||||
for name, sql2 in probes.items():
|
||||
try:
|
||||
row = self.db.query(sql2)[0]
|
||||
checks[name] = row.get("mx")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
checks[name] = f"ERROR: {exc}"
|
||||
return checks
|
||||
@@ -2,10 +2,11 @@
|
||||
"""DWD 装载任务:从 ODS 增量写入 DWD(维度 SCD2,事实按时间增量)。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, List, Sequence
|
||||
|
||||
from psycopg2.extras import RealDictCursor
|
||||
from psycopg2.extras import RealDictCursor, execute_batch, execute_values
|
||||
|
||||
from .base_task import BaseTask, TaskContext
|
||||
|
||||
@@ -61,14 +62,15 @@ class DwdLoadTask(BaseTask):
|
||||
}
|
||||
|
||||
SCD_COLS = {"scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"}
|
||||
# 增量/窗口过滤优先使用业务时间;fetched_at(入库时间)放最后,避免回溯窗口被“当前入库时间”干扰。
|
||||
FACT_ORDER_CANDIDATES = [
|
||||
"fetched_at",
|
||||
"pay_time",
|
||||
"create_time",
|
||||
"update_time",
|
||||
"occur_time",
|
||||
"settle_time",
|
||||
"start_use_time",
|
||||
"fetched_at",
|
||||
]
|
||||
|
||||
# 特殊列映射:dwd 列名 -> 源列表达式(可选 CAST)
|
||||
@@ -457,30 +459,69 @@ class DwdLoadTask(BaseTask):
|
||||
return {"now": datetime.now()}
|
||||
|
||||
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict[str, Any]:
|
||||
"""遍历映射关系,维度执行 SCD2 合并,事实表按时间增量插入。"""
|
||||
"""
|
||||
遍历映射关系,维度执行 SCD2 合并,事实表按时间增量插入。
|
||||
|
||||
说明:
|
||||
- 为避免长事务导致锁堆积/中断后遗留 idle-in-tx,本任务按“每张表一次事务”提交;
|
||||
- 单表失败会回滚该表并继续后续表,最终在结果中汇总错误信息。
|
||||
"""
|
||||
now = extracted["now"]
|
||||
summary: List[Dict[str, Any]] = []
|
||||
errors: List[Dict[str, Any]] = []
|
||||
only_tables_cfg = self.config.get("dwd.only_tables") or []
|
||||
only_tables = {str(t).strip().lower() for t in only_tables_cfg if str(t).strip()} if only_tables_cfg else set()
|
||||
with self.db.conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
for dwd_table, ods_table in self.TABLE_MAP.items():
|
||||
dwd_cols = self._get_columns(cur, dwd_table)
|
||||
ods_cols = self._get_columns(cur, ods_table)
|
||||
if not dwd_cols:
|
||||
self.logger.warning("跳过 %s,未能获取 DWD 列信息", dwd_table)
|
||||
if only_tables and dwd_table.lower() not in only_tables and self._table_base(dwd_table).lower() not in only_tables:
|
||||
continue
|
||||
started = time.monotonic()
|
||||
self.logger.info("DWD 装载开始:%s <= %s", dwd_table, ods_table)
|
||||
try:
|
||||
dwd_cols = self._get_columns(cur, dwd_table)
|
||||
ods_cols = self._get_columns(cur, ods_table)
|
||||
if not dwd_cols:
|
||||
self.logger.warning("跳过 %s:未能获取 DWD 列信息", dwd_table)
|
||||
continue
|
||||
|
||||
if self._table_base(dwd_table).startswith("dim_"):
|
||||
processed = self._merge_dim(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
|
||||
self.db.conn.commit()
|
||||
summary.append({"table": dwd_table, "mode": "SCD2", "processed": processed})
|
||||
else:
|
||||
dwd_types = self._get_column_types(cur, dwd_table, "billiards_dwd")
|
||||
ods_types = self._get_column_types(cur, ods_table, "billiards_ods")
|
||||
use_window = bool(
|
||||
self.config.get("run.window_override.start")
|
||||
and self.config.get("run.window_override.end")
|
||||
)
|
||||
inserted = self._merge_fact_increment(
|
||||
cur,
|
||||
dwd_table,
|
||||
ods_table,
|
||||
dwd_cols,
|
||||
ods_cols,
|
||||
dwd_types,
|
||||
ods_types,
|
||||
window_start=context.window_start if use_window else None,
|
||||
window_end=context.window_end if use_window else None,
|
||||
)
|
||||
self.db.conn.commit()
|
||||
summary.append({"table": dwd_table, "mode": "INCREMENT", "inserted": inserted})
|
||||
|
||||
elapsed = time.monotonic() - started
|
||||
self.logger.info("DWD 装载完成:%s,用时 %.2fs", dwd_table, elapsed)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
try:
|
||||
self.db.conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
elapsed = time.monotonic() - started
|
||||
self.logger.exception("DWD 装载失败:%s,用时 %.2fs,err=%s", dwd_table, elapsed, exc)
|
||||
errors.append({"table": dwd_table, "error": str(exc)})
|
||||
continue
|
||||
|
||||
if self._table_base(dwd_table).startswith("dim_"):
|
||||
processed = self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
|
||||
summary.append({"table": dwd_table, "mode": "SCD2", "processed": processed})
|
||||
else:
|
||||
dwd_types = self._get_column_types(cur, dwd_table, "billiards_dwd")
|
||||
ods_types = self._get_column_types(cur, ods_table, "billiards_ods")
|
||||
inserted = self._merge_fact_increment(
|
||||
cur, dwd_table, ods_table, dwd_cols, ods_cols, dwd_types, ods_types
|
||||
)
|
||||
summary.append({"table": dwd_table, "mode": "INCREMENT", "inserted": inserted})
|
||||
|
||||
self.db.conn.commit()
|
||||
return {"tables": summary}
|
||||
return {"tables": summary, "errors": errors}
|
||||
|
||||
# ---------------------- helpers ----------------------
|
||||
def _get_columns(self, cur, table: str) -> List[str]:
|
||||
@@ -589,6 +630,135 @@ class DwdLoadTask(BaseTask):
|
||||
expanded.append(child_row)
|
||||
return expanded
|
||||
|
||||
def _merge_dim(
|
||||
self,
|
||||
cur,
|
||||
dwd_table: str,
|
||||
ods_table: str,
|
||||
dwd_cols: Sequence[str],
|
||||
ods_cols: Sequence[str],
|
||||
now: datetime,
|
||||
) -> int:
|
||||
"""
|
||||
维表合并策略:
|
||||
- 若主键包含 scd2 列(如 scd2_start_time/scd2_version),执行真正的 SCD2(关闭旧版+插入新版)。
|
||||
- 否则(多数现有表主键仅为业务主键),执行 Type1 Upsert,避免重复键异常并保证可重复回放。
|
||||
"""
|
||||
pk_cols = self._get_primary_keys(cur, dwd_table)
|
||||
if not pk_cols:
|
||||
raise ValueError(f"{dwd_table} 未配置主键,无法执行维表合并")
|
||||
|
||||
pk_has_scd = any(pk.lower() in self.SCD_COLS for pk in pk_cols)
|
||||
scd_cols_present = any(c.lower() in self.SCD_COLS for c in dwd_cols)
|
||||
if scd_cols_present and pk_has_scd:
|
||||
return self._merge_dim_scd2(cur, dwd_table, ods_table, dwd_cols, ods_cols, now)
|
||||
return self._merge_dim_type1_upsert(cur, dwd_table, ods_table, dwd_cols, ods_cols, pk_cols, now)
|
||||
|
||||
def _merge_dim_type1_upsert(
|
||||
self,
|
||||
cur,
|
||||
dwd_table: str,
|
||||
ods_table: str,
|
||||
dwd_cols: Sequence[str],
|
||||
ods_cols: Sequence[str],
|
||||
pk_cols: Sequence[str],
|
||||
now: datetime,
|
||||
) -> int:
|
||||
"""维表 Type1 Upsert(主键冲突则更新),兼容带 scd2 字段但主键不支持多版本的表。"""
|
||||
mapping = self._build_column_mapping(dwd_table, pk_cols, ods_cols)
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
ods_table_sql = self._format_table(ods_table, "billiards_ods")
|
||||
|
||||
select_exprs: list[str] = []
|
||||
added: set[str] = set()
|
||||
for col in dwd_cols:
|
||||
lc = col.lower()
|
||||
if lc in self.SCD_COLS:
|
||||
continue
|
||||
if lc in mapping:
|
||||
src, cast_type = mapping[lc]
|
||||
select_exprs.append(f"{self._cast_expr(src, cast_type)} AS \"{lc}\"")
|
||||
added.add(lc)
|
||||
elif lc in ods_set:
|
||||
select_exprs.append(f'\"{lc}\" AS \"{lc}\"')
|
||||
added.add(lc)
|
||||
|
||||
for pk in pk_cols:
|
||||
lc = pk.lower()
|
||||
if lc in added:
|
||||
continue
|
||||
if lc in mapping:
|
||||
src, cast_type = mapping[lc]
|
||||
select_exprs.append(f"{self._cast_expr(src, cast_type)} AS \"{lc}\"")
|
||||
elif lc in ods_set:
|
||||
select_exprs.append(f'\"{lc}\" AS \"{lc}\"')
|
||||
added.add(lc)
|
||||
|
||||
if not select_exprs:
|
||||
return 0
|
||||
|
||||
cur.execute(f"SELECT {', '.join(select_exprs)} FROM {ods_table_sql}")
|
||||
rows = [{k.lower(): v for k, v in r.items()} for r in cur.fetchall()]
|
||||
|
||||
if dwd_table == "billiards_dwd.dim_goods_category":
|
||||
rows = self._expand_goods_category_rows(rows)
|
||||
|
||||
# 按主键去重
|
||||
seen_pk: set[tuple[Any, ...]] = set()
|
||||
src_rows: list[Dict[str, Any]] = []
|
||||
pk_lower = [c.lower() for c in pk_cols]
|
||||
for row in rows:
|
||||
pk_key = tuple(row.get(pk) for pk in pk_lower)
|
||||
if pk_key in seen_pk:
|
||||
continue
|
||||
if any(v is None for v in pk_key):
|
||||
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
|
||||
continue
|
||||
seen_pk.add(pk_key)
|
||||
src_rows.append(row)
|
||||
|
||||
if not src_rows:
|
||||
return 0
|
||||
|
||||
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
|
||||
sorted_cols = [c.lower() for c in sorted(dwd_cols)]
|
||||
insert_cols_sql = ", ".join(f'\"{c}\"' for c in sorted_cols)
|
||||
|
||||
def build_row(src_row: Dict[str, Any]) -> list[Any]:
|
||||
values: list[Any] = []
|
||||
for c in sorted_cols:
|
||||
if c == "scd2_start_time":
|
||||
values.append(now)
|
||||
elif c == "scd2_end_time":
|
||||
values.append(datetime(9999, 12, 31, 0, 0, 0))
|
||||
elif c == "scd2_is_current":
|
||||
values.append(1)
|
||||
elif c == "scd2_version":
|
||||
values.append(1)
|
||||
else:
|
||||
values.append(src_row.get(c))
|
||||
return values
|
||||
|
||||
pk_sql = ", ".join(f'\"{c.lower()}\"' for c in pk_cols)
|
||||
pk_lower_set = {c.lower() for c in pk_cols}
|
||||
set_exprs: list[str] = []
|
||||
for c in sorted_cols:
|
||||
if c in pk_lower_set:
|
||||
continue
|
||||
if c == "scd2_start_time":
|
||||
set_exprs.append(f'\"{c}\" = COALESCE({dwd_table_sql}.\"{c}\", EXCLUDED.\"{c}\")')
|
||||
elif c == "scd2_version":
|
||||
set_exprs.append(f'\"{c}\" = COALESCE({dwd_table_sql}.\"{c}\", EXCLUDED.\"{c}\")')
|
||||
else:
|
||||
set_exprs.append(f'\"{c}\" = EXCLUDED.\"{c}\"')
|
||||
|
||||
upsert_sql = (
|
||||
f"INSERT INTO {dwd_table_sql} ({insert_cols_sql}) VALUES %s "
|
||||
f"ON CONFLICT ({pk_sql}) DO UPDATE SET {', '.join(set_exprs)}"
|
||||
)
|
||||
execute_values(cur, upsert_sql, [build_row(r) for r in src_rows], page_size=500)
|
||||
return len(src_rows)
|
||||
|
||||
def _merge_dim_scd2(
|
||||
self,
|
||||
cur,
|
||||
@@ -646,8 +816,9 @@ class DwdLoadTask(BaseTask):
|
||||
if dwd_table == "billiards_dwd.dim_goods_category":
|
||||
rows = self._expand_goods_category_rows(rows)
|
||||
|
||||
inserted_or_updated = 0
|
||||
# 归一化源行并按主键去重
|
||||
seen_pk = set()
|
||||
src_rows_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
|
||||
for row in rows:
|
||||
mapped_row: Dict[str, Any] = {}
|
||||
for col in dwd_cols:
|
||||
@@ -663,10 +834,110 @@ class DwdLoadTask(BaseTask):
|
||||
pk_key = tuple(mapped_row.get(pk) for pk in pk_cols)
|
||||
if pk_key in seen_pk:
|
||||
continue
|
||||
if any(v is None for v in pk_key):
|
||||
self.logger.warning("跳过 %s:主键缺失 %s", dwd_table, dict(zip(pk_cols, pk_key)))
|
||||
continue
|
||||
seen_pk.add(pk_key)
|
||||
if self._upsert_scd2_row(cur, dwd_table, dwd_cols, pk_cols, mapped_row, now):
|
||||
inserted_or_updated += 1
|
||||
return len(rows)
|
||||
src_rows_by_pk[pk_key] = mapped_row
|
||||
|
||||
if not src_rows_by_pk:
|
||||
return 0
|
||||
|
||||
# 预加载当前版本(scd2_is_current=1),避免逐行 SELECT 造成大量 round-trip
|
||||
table_sql_dwd = self._format_table(dwd_table, "billiards_dwd")
|
||||
where_current = " AND ".join([f"COALESCE(scd2_is_current,1)=1"])
|
||||
cur.execute(f"SELECT * FROM {table_sql_dwd} WHERE {where_current}")
|
||||
current_rows = cur.fetchall() or []
|
||||
current_by_pk: dict[tuple[Any, ...], Dict[str, Any]] = {}
|
||||
for r in current_rows:
|
||||
rr = {k.lower(): v for k, v in r.items()}
|
||||
pk_key = tuple(rr.get(pk) for pk in pk_cols)
|
||||
current_by_pk[pk_key] = rr
|
||||
|
||||
# 计算需要关闭/插入的主键集合
|
||||
to_close: list[tuple[Any, ...]] = []
|
||||
to_insert: list[tuple[Dict[str, Any], int]] = []
|
||||
for pk_key, incoming in src_rows_by_pk.items():
|
||||
current = current_by_pk.get(pk_key)
|
||||
if current and not self._is_row_changed(current, incoming, dwd_cols):
|
||||
continue
|
||||
if current:
|
||||
version = (current.get("scd2_version") or 1) + 1
|
||||
to_close.append(pk_key)
|
||||
else:
|
||||
version = 1
|
||||
to_insert.append((incoming, version))
|
||||
|
||||
# 先关闭旧版本(同一批次统一 end_time)
|
||||
if to_close:
|
||||
self._close_current_dim_bulk(cur, dwd_table, pk_cols, to_close, now)
|
||||
|
||||
# 批量插入新版本
|
||||
if to_insert:
|
||||
self._insert_dim_rows_bulk(cur, dwd_table, dwd_cols, to_insert, now)
|
||||
|
||||
return len(src_rows_by_pk)
|
||||
|
||||
def _close_current_dim_bulk(
|
||||
self,
|
||||
cur,
|
||||
table: str,
|
||||
pk_cols: Sequence[str],
|
||||
pk_keys: Sequence[tuple[Any, ...]],
|
||||
now: datetime,
|
||||
) -> None:
|
||||
"""批量关闭当前版本(scd2_is_current=0 + 填充结束时间)。"""
|
||||
table_sql = self._format_table(table, "billiards_dwd")
|
||||
if len(pk_cols) == 1:
|
||||
pk = pk_cols[0]
|
||||
ids = [k[0] for k in pk_keys]
|
||||
cur.execute(
|
||||
f'UPDATE {table_sql} SET scd2_end_time=%s, scd2_is_current=0 '
|
||||
f'WHERE COALESCE(scd2_is_current,1)=1 AND "{pk}" = ANY(%s)',
|
||||
(now, ids),
|
||||
)
|
||||
return
|
||||
|
||||
# 复合主键:对“发生变更的键”逐条关闭(数量通常远小于全量行数)
|
||||
where_clause = " AND ".join(f'"{pk}" = %s' for pk in pk_cols)
|
||||
sql = (
|
||||
f"UPDATE {table_sql} SET scd2_end_time=%s, scd2_is_current=0 "
|
||||
f"WHERE COALESCE(scd2_is_current,1)=1 AND {where_clause}"
|
||||
)
|
||||
args_list = [(now, *pk_key) for pk_key in pk_keys]
|
||||
execute_batch(cur, sql, args_list, page_size=500)
|
||||
|
||||
def _insert_dim_rows_bulk(
|
||||
self,
|
||||
cur,
|
||||
table: str,
|
||||
dwd_cols: Sequence[str],
|
||||
rows_with_version: Sequence[tuple[Dict[str, Any], int]],
|
||||
now: datetime,
|
||||
) -> None:
|
||||
"""批量插入新的 SCD2 版本行。"""
|
||||
sorted_cols = [c.lower() for c in sorted(dwd_cols)]
|
||||
insert_cols_sql = ", ".join(f'"{c}"' for c in sorted_cols)
|
||||
table_sql = self._format_table(table, "billiards_dwd")
|
||||
|
||||
def build_row(src_row: Dict[str, Any], version: int) -> list[Any]:
|
||||
values: list[Any] = []
|
||||
for c in sorted_cols:
|
||||
if c == "scd2_start_time":
|
||||
values.append(now)
|
||||
elif c == "scd2_end_time":
|
||||
values.append(datetime(9999, 12, 31, 0, 0, 0))
|
||||
elif c == "scd2_is_current":
|
||||
values.append(1)
|
||||
elif c == "scd2_version":
|
||||
values.append(version)
|
||||
else:
|
||||
values.append(src_row.get(c))
|
||||
return values
|
||||
|
||||
values_rows = [build_row(r, ver) for r, ver in rows_with_version]
|
||||
insert_sql = f"INSERT INTO {table_sql} ({insert_cols_sql}) VALUES %s"
|
||||
execute_values(cur, insert_sql, values_rows, page_size=500)
|
||||
|
||||
def _upsert_scd2_row(
|
||||
self,
|
||||
@@ -762,6 +1033,8 @@ class DwdLoadTask(BaseTask):
|
||||
ods_cols: Sequence[str],
|
||||
dwd_types: Dict[str, str],
|
||||
ods_types: Dict[str, str],
|
||||
window_start: datetime | None = None,
|
||||
window_end: datetime | None = None,
|
||||
) -> int:
|
||||
"""事实表按时间增量插入,默认按列名交集写入。"""
|
||||
mapping_entries = self.FACT_MAPPINGS.get(dwd_table) or []
|
||||
@@ -813,7 +1086,10 @@ class DwdLoadTask(BaseTask):
|
||||
params: List[Any] = []
|
||||
dwd_table_sql = self._format_table(dwd_table, "billiards_dwd")
|
||||
ods_table_sql = self._format_table(ods_table, "billiards_ods")
|
||||
if order_col:
|
||||
if order_col and window_start and window_end:
|
||||
where_sql = f'WHERE "{order_col}" >= %s AND "{order_col}" < %s'
|
||||
params.extend([window_start, window_end])
|
||||
elif order_col:
|
||||
cur.execute(f'SELECT COALESCE(MAX("{order_col}"), %s) FROM {dwd_table_sql}', ("1970-01-01",))
|
||||
row = cur.fetchone() or {}
|
||||
watermark = list(row.values())[0] if row else "1970-01-01"
|
||||
|
||||
142
etl_billiards/tasks/dws_build_order_summary_task.py
Normal file
142
etl_billiards/tasks/dws_build_order_summary_task.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Build DWS order summary table from DWD fact tables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from .base_task import BaseTask, TaskContext
|
||||
from scripts.build_dws_order_summary import SQL_BUILD_SUMMARY
|
||||
|
||||
|
||||
class DwsBuildOrderSummaryTask(BaseTask):
|
||||
"""Recompute/refresh `billiards_dws.dws_order_summary` for a date window."""
|
||||
|
||||
def get_task_code(self) -> str:
|
||||
return "DWS_BUILD_ORDER_SUMMARY"
|
||||
|
||||
def execute(self, cursor_data: dict | None = None) -> dict:
|
||||
context = self._build_context(cursor_data)
|
||||
task_code = self.get_task_code()
|
||||
self.logger.info(
|
||||
"%s: start, window[%s ~ %s]",
|
||||
task_code,
|
||||
context.window_start,
|
||||
context.window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
extracted = self.extract(context)
|
||||
transformed = self.transform(extracted, context)
|
||||
load_result = self.load(transformed, context) or {}
|
||||
self.db.commit()
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
self.logger.error("%s: failed", task_code, exc_info=True)
|
||||
raise
|
||||
|
||||
counts = load_result.get("counts") or {}
|
||||
result = {"status": "SUCCESS", "counts": counts}
|
||||
result["window"] = {
|
||||
"start": context.window_start,
|
||||
"end": context.window_end,
|
||||
"minutes": context.window_minutes,
|
||||
}
|
||||
if "request_params" in load_result:
|
||||
result["request_params"] = load_result["request_params"]
|
||||
if "extra" in load_result:
|
||||
result["extra"] = load_result["extra"]
|
||||
self.logger.info("%s: done, counts=%s", task_code, counts)
|
||||
return result
|
||||
|
||||
def extract(self, context: TaskContext) -> dict[str, Any]:
|
||||
store_id = int(self.config.get("app.store_id"))
|
||||
|
||||
full_refresh = bool(self.config.get("dws.order_summary.full_refresh", False))
|
||||
site_id = self.config.get("dws.order_summary.site_id", store_id)
|
||||
if site_id in ("", None, "null", "NULL"):
|
||||
site_id = None
|
||||
|
||||
start_date = self.config.get("dws.order_summary.start_date")
|
||||
end_date = self.config.get("dws.order_summary.end_date")
|
||||
if not full_refresh:
|
||||
if not start_date:
|
||||
start_date = context.window_start.date()
|
||||
if not end_date:
|
||||
end_date = context.window_end.date()
|
||||
else:
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
delete_before_insert = bool(self.config.get("dws.order_summary.delete_before_insert", True))
|
||||
return {
|
||||
"site_id": site_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"full_refresh": full_refresh,
|
||||
"delete_before_insert": delete_before_insert,
|
||||
}
|
||||
|
||||
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
|
||||
sql_params = {
|
||||
"site_id": extracted["site_id"],
|
||||
"start_date": extracted["start_date"],
|
||||
"end_date": extracted["end_date"],
|
||||
}
|
||||
request_params = {
|
||||
"site_id": extracted["site_id"],
|
||||
"start_date": _jsonable_date(extracted["start_date"]),
|
||||
"end_date": _jsonable_date(extracted["end_date"]),
|
||||
}
|
||||
|
||||
with self.db.conn.cursor() as cur:
|
||||
cur.execute("SELECT to_regclass('billiards_dws.dws_order_summary') AS reg;")
|
||||
row = cur.fetchone()
|
||||
reg = row[0] if row else None
|
||||
if not reg:
|
||||
raise RuntimeError("DWS 表不存在:请先运行任务 INIT_DWS_SCHEMA")
|
||||
|
||||
deleted = 0
|
||||
if extracted["delete_before_insert"]:
|
||||
if extracted["full_refresh"] and extracted["site_id"] is None:
|
||||
cur.execute("TRUNCATE TABLE billiards_dws.dws_order_summary;")
|
||||
self.logger.info("DWS_BUILD_ORDER_SUMMARY: truncated billiards_dws.dws_order_summary")
|
||||
else:
|
||||
delete_sql = "DELETE FROM billiards_dws.dws_order_summary WHERE 1=1"
|
||||
delete_args: list[Any] = []
|
||||
if extracted["site_id"] is not None:
|
||||
delete_sql += " AND site_id = %s"
|
||||
delete_args.append(extracted["site_id"])
|
||||
if extracted["start_date"] is not None:
|
||||
delete_sql += " AND order_date >= %s"
|
||||
delete_args.append(_as_date(extracted["start_date"]))
|
||||
if extracted["end_date"] is not None:
|
||||
delete_sql += " AND order_date <= %s"
|
||||
delete_args.append(_as_date(extracted["end_date"]))
|
||||
cur.execute(delete_sql, delete_args)
|
||||
deleted = cur.rowcount
|
||||
self.logger.info("DWS_BUILD_ORDER_SUMMARY: deleted=%s sql=%s", deleted, delete_sql)
|
||||
|
||||
cur.execute(SQL_BUILD_SUMMARY, sql_params)
|
||||
affected = cur.rowcount
|
||||
|
||||
return {
|
||||
"counts": {"fetched": 0, "inserted": affected, "updated": 0, "skipped": 0, "errors": 0},
|
||||
"request_params": request_params,
|
||||
"extra": {"deleted": deleted},
|
||||
}
|
||||
|
||||
|
||||
def _as_date(v: Any) -> date:
|
||||
if isinstance(v, date):
|
||||
return v
|
||||
return date.fromisoformat(str(v))
|
||||
|
||||
|
||||
def _jsonable_date(v: Any):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, date):
|
||||
return v.isoformat()
|
||||
return str(v)
|
||||
34
etl_billiards/tasks/init_dws_schema_task.py
Normal file
34
etl_billiards/tasks/init_dws_schema_task.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Initialize DWS schema (billiards_dws)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .base_task import BaseTask, TaskContext
|
||||
|
||||
|
||||
class InitDwsSchemaTask(BaseTask):
|
||||
"""Apply DWS schema SQL."""
|
||||
|
||||
def get_task_code(self) -> str:
|
||||
return "INIT_DWS_SCHEMA"
|
||||
|
||||
def extract(self, context: TaskContext) -> dict[str, Any]:
|
||||
base_dir = Path(__file__).resolve().parents[1] / "database"
|
||||
dws_path = Path(self.config.get("schema.dws_file", base_dir / "schema_dws.sql"))
|
||||
if not dws_path.exists():
|
||||
raise FileNotFoundError(f"未找到 DWS schema 文件: {dws_path}")
|
||||
drop_first = bool(self.config.get("dws.drop_schema_first", False))
|
||||
return {"dws_sql": dws_path.read_text(encoding="utf-8"), "dws_file": str(dws_path), "drop_first": drop_first}
|
||||
|
||||
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
|
||||
with self.db.conn.cursor() as cur:
|
||||
if extracted["drop_first"]:
|
||||
cur.execute("DROP SCHEMA IF EXISTS billiards_dws CASCADE;")
|
||||
self.logger.info("已执行 DROP SCHEMA billiards_dws CASCADE")
|
||||
self.logger.info("执行 DWS schema 文件: %s", extracted["dws_file"])
|
||||
cur.execute(extracted["dws_sql"])
|
||||
return {"executed": 1, "files": [extracted["dws_file"]]}
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable
|
||||
|
||||
from psycopg2.extras import Json
|
||||
from psycopg2.extras import Json, execute_values
|
||||
|
||||
from .base_task import BaseTask
|
||||
|
||||
@@ -75,7 +75,7 @@ class ManualIngestTask(BaseTask):
|
||||
return "MANUAL_INGEST"
|
||||
|
||||
def execute(self, cursor_data: dict | None = None) -> dict:
|
||||
"""从目录读取 JSON,按表定义批量入库。"""
|
||||
"""从目录读取 JSON,按表定义批量入库(按文件提交事务,避免长事务导致连接不稳定)。"""
|
||||
data_dir = (
|
||||
self.config.get("manual.data_dir")
|
||||
or self.config.get("pipeline.ingest_source_dir")
|
||||
@@ -87,9 +87,15 @@ class ManualIngestTask(BaseTask):
|
||||
|
||||
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
|
||||
|
||||
include_files_cfg = self.config.get("manual.include_files") or []
|
||||
include_files = {str(x).strip().lower() for x in include_files_cfg if str(x).strip()} if include_files_cfg else set()
|
||||
|
||||
for filename in sorted(os.listdir(data_dir)):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
stem = os.path.splitext(filename)[0].lower()
|
||||
if include_files and stem not in include_files:
|
||||
continue
|
||||
filepath = os.path.join(data_dir, filename)
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as fh:
|
||||
@@ -113,22 +119,25 @@ class ManualIngestTask(BaseTask):
|
||||
|
||||
self.logger.info("Ingesting %s into %s", filename, target_table)
|
||||
try:
|
||||
inserted, updated = self._ingest_table(target_table, records, filename)
|
||||
inserted, updated, row_errors = self._ingest_table(target_table, records, filename)
|
||||
counts["inserted"] += inserted
|
||||
counts["updated"] += updated
|
||||
counts["fetched"] += len(records)
|
||||
counts["errors"] += row_errors
|
||||
# 每个文件一次提交:降低单次事务体积,避免长事务/连接异常导致整体回滚失败。
|
||||
self.db.commit()
|
||||
except Exception:
|
||||
counts["errors"] += 1
|
||||
self.logger.exception("Error processing %s", filename)
|
||||
self.db.rollback()
|
||||
try:
|
||||
self.db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
# 若连接已断开,后续文件无法继续,直接抛出让上层处理(重连/重跑)。
|
||||
if getattr(self.db.conn, "closed", 0):
|
||||
raise
|
||||
continue
|
||||
|
||||
try:
|
||||
self.db.commit()
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
return {"status": "SUCCESS", "counts": counts}
|
||||
|
||||
def _match_by_filename(self, filename: str) -> str | None:
|
||||
@@ -211,8 +220,15 @@ class ManualIngestTask(BaseTask):
|
||||
self._table_columns_cache = cache
|
||||
return cols
|
||||
|
||||
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int]:
|
||||
"""构建 INSERT/ON CONFLICT 语句并批量执行。"""
|
||||
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
构建 INSERT/ON CONFLICT 语句并批量执行(优先向量化,小批次提交)。
|
||||
|
||||
设计目标:
|
||||
- 控制单条 SQL 体积(避免一次性 VALUES 过大导致服务端 backend 被 OOM/异常终止);
|
||||
- 发生异常时,可降级逐行并用 SAVEPOINT 跳过异常行;
|
||||
- 统计口径偏“尽量可跑通”,插入/更新计数为近似值(不强依赖 RETURNING)。
|
||||
"""
|
||||
spec = self.TABLE_SPECS.get(table)
|
||||
if not spec:
|
||||
raise ValueError(f"No table spec for {table}")
|
||||
@@ -229,15 +245,19 @@ class ManualIngestTask(BaseTask):
|
||||
pk_col_db = None
|
||||
if pk_col:
|
||||
pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col)
|
||||
pk_index = None
|
||||
if pk_col_db:
|
||||
try:
|
||||
pk_index = next(i for i, c in enumerate(columns_info) if c[0] == pk_col_db)
|
||||
except Exception:
|
||||
pk_index = None
|
||||
|
||||
placeholders = ", ".join(["%s"] * len(columns))
|
||||
col_list = ", ".join(f'"{c}"' for c in columns)
|
||||
sql = f'INSERT INTO {table} ({col_list}) VALUES ({placeholders})'
|
||||
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
|
||||
if pk_col_db:
|
||||
update_cols = [c for c in columns if c != pk_col_db]
|
||||
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
|
||||
sql += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
|
||||
sql += " RETURNING (xmax = 0) AS inserted"
|
||||
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
|
||||
|
||||
params = []
|
||||
now = datetime.now()
|
||||
@@ -288,19 +308,55 @@ class ManualIngestTask(BaseTask):
|
||||
params.append(tuple(row_vals))
|
||||
|
||||
if not params:
|
||||
return 0, 0
|
||||
return 0, 0, 0
|
||||
|
||||
# 先尝试向量化执行(速度快);若失败,再降级逐行并用 SAVEPOINT 跳过异常行。
|
||||
try:
|
||||
with self.db.conn.cursor() as cur:
|
||||
# 分批提交:降低单次事务/单次 SQL 压力,避免服务端异常中断连接。
|
||||
affected = 0
|
||||
chunk_size = int(self.config.get("manual.execute_values_page_size", 50) or 50)
|
||||
chunk_size = max(1, min(chunk_size, 500))
|
||||
for i in range(0, len(params), chunk_size):
|
||||
chunk = params[i : i + chunk_size]
|
||||
execute_values(cur, sql_prefix, chunk, page_size=len(chunk))
|
||||
if cur.rowcount is not None and cur.rowcount > 0:
|
||||
affected += int(cur.rowcount)
|
||||
# 这里无法精确拆分 inserted/updated(除非 RETURNING),按“受影响行数≈插入”近似返回。
|
||||
return int(affected), 0, 0
|
||||
except Exception as exc:
|
||||
self.logger.warning("批量入库失败,准备降级逐行处理:table=%s, err=%s", table, exc)
|
||||
try:
|
||||
self.db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inserted = 0
|
||||
updated = 0
|
||||
errors = 0
|
||||
with self.db.conn.cursor() as cur:
|
||||
for row in params:
|
||||
cur.execute(sql, row)
|
||||
flag = cur.fetchone()[0]
|
||||
if flag:
|
||||
cur.execute("SAVEPOINT sp_manual_ingest_row")
|
||||
try:
|
||||
cur.execute(sql_prefix.replace(" VALUES %s", f" VALUES ({', '.join(['%s'] * len(row))})"), row)
|
||||
inserted += 1
|
||||
else:
|
||||
updated += 1
|
||||
return inserted, updated
|
||||
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors += 1
|
||||
try:
|
||||
cur.execute("ROLLBACK TO SAVEPOINT sp_manual_ingest_row")
|
||||
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
|
||||
except Exception:
|
||||
pass
|
||||
pk_val = None
|
||||
if pk_index is not None:
|
||||
try:
|
||||
pk_val = row[pk_index]
|
||||
except Exception:
|
||||
pk_val = None
|
||||
self.logger.warning("跳过异常行:table=%s pk=%s err=%s", table, pk_val, exc)
|
||||
|
||||
return inserted, updated, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_value_case_insensitive(record: dict, col: str | None):
|
||||
|
||||
260
etl_billiards/tasks/ods_json_archive_task.py
Normal file
260
etl_billiards/tasks/ods_json_archive_task.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""在线抓取 ODS 相关接口并落盘为 JSON(用于后续离线回放/入库)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from api.client import APIClient
|
||||
from models.parsers import TypeParser
|
||||
from utils.json_store import dump_json, endpoint_to_filename
|
||||
|
||||
from .base_task import BaseTask, TaskContext
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EndpointSpec:
|
||||
endpoint: str
|
||||
window_style: str # site | start_end | range | pay | none
|
||||
data_path: tuple[str, ...] = ("data",)
|
||||
list_key: str | None = None
|
||||
|
||||
|
||||
class OdsJsonArchiveTask(BaseTask):
|
||||
"""
|
||||
抓取一组 ODS 所需接口并落盘为“简化 JSON”:
|
||||
{"code": 0, "data": [...records...]}
|
||||
|
||||
说明:
|
||||
- 该输出格式与 tasks/manual_ingest_task.py 的解析逻辑兼容;
|
||||
- 默认每页一个文件,避免单文件过大;
|
||||
- 结算小票(/Order/GetOrderSettleTicketNew)按 orderSettleId 分文件写入。
|
||||
"""
|
||||
|
||||
ENDPOINTS: tuple[EndpointSpec, ...] = (
|
||||
EndpointSpec("/MemberProfile/GetTenantMemberList", "site", list_key="tenantMemberInfos"),
|
||||
EndpointSpec("/MemberProfile/GetTenantMemberCardList", "site", list_key="tenantMemberCards"),
|
||||
EndpointSpec("/MemberProfile/GetMemberCardBalanceChange", "start_end"),
|
||||
EndpointSpec("/PersonnelManagement/SearchAssistantInfo", "site", list_key="assistantInfos"),
|
||||
EndpointSpec(
|
||||
"/AssistantPerformance/GetOrderAssistantDetails",
|
||||
"start_end",
|
||||
list_key="orderAssistantDetails",
|
||||
),
|
||||
EndpointSpec(
|
||||
"/AssistantPerformance/GetAbolitionAssistant",
|
||||
"start_end",
|
||||
list_key="abolitionAssistants",
|
||||
),
|
||||
EndpointSpec("/Table/GetSiteTables", "site", list_key="siteTables"),
|
||||
EndpointSpec(
|
||||
"/TenantGoodsCategory/QueryPrimarySecondaryCategory",
|
||||
"site",
|
||||
list_key="goodsCategoryList",
|
||||
),
|
||||
EndpointSpec("/TenantGoods/QueryTenantGoods", "site", list_key="tenantGoodsList"),
|
||||
EndpointSpec("/TenantGoods/GetGoodsInventoryList", "site", list_key="orderGoodsList"),
|
||||
EndpointSpec("/TenantGoods/GetGoodsStockReport", "site"),
|
||||
EndpointSpec("/TenantGoods/GetGoodsSalesList", "start_end", list_key="orderGoodsLedgers"),
|
||||
EndpointSpec(
|
||||
"/PackageCoupon/QueryPackageCouponList",
|
||||
"site",
|
||||
list_key="packageCouponList",
|
||||
),
|
||||
EndpointSpec("/Site/GetSiteTableUseDetails", "start_end", list_key="siteTableUseDetailsList"),
|
||||
EndpointSpec("/Site/GetSiteTableOrderDetails", "start_end", list_key="siteTableUseDetailsList"),
|
||||
EndpointSpec("/Site/GetTaiFeeAdjustList", "start_end", list_key="taiFeeAdjustInfos"),
|
||||
EndpointSpec(
|
||||
"/GoodsStockManage/QueryGoodsOutboundReceipt",
|
||||
"start_end",
|
||||
list_key="queryDeliveryRecordsList",
|
||||
),
|
||||
EndpointSpec("/Promotion/GetOfflineCouponConsumePageList", "start_end"),
|
||||
EndpointSpec("/Order/GetRefundPayLogList", "start_end"),
|
||||
EndpointSpec("/Site/GetAllOrderSettleList", "range", list_key="settleList"),
|
||||
EndpointSpec("/Site/GetRechargeSettleList", "range", list_key="settleList"),
|
||||
EndpointSpec("/PayLog/GetPayLogListPage", "pay"),
|
||||
)
|
||||
|
||||
TICKET_ENDPOINT = "/Order/GetOrderSettleTicketNew"
|
||||
|
||||
def get_task_code(self) -> str:
|
||||
return "ODS_JSON_ARCHIVE"
|
||||
|
||||
def extract(self, context: TaskContext) -> dict:
|
||||
base_client = getattr(self.api, "base", None) or self.api
|
||||
if not isinstance(base_client, APIClient):
|
||||
raise TypeError("ODS_JSON_ARCHIVE 需要 APIClient(在线抓取)")
|
||||
|
||||
output_dir = getattr(self.api, "output_dir", None)
|
||||
if output_dir:
|
||||
out = Path(output_dir)
|
||||
else:
|
||||
out = Path(self.config.get("pipeline.fetch_root") or self.config["pipeline"]["fetch_root"])
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
write_pretty = bool(self.config.get("io.write_pretty_json", False))
|
||||
page_size = int(self.config.get("api.page_size", 200) or 200)
|
||||
store_id = int(context.store_id)
|
||||
|
||||
total_records = 0
|
||||
ticket_ids: set[int] = set()
|
||||
per_endpoint: list[dict] = []
|
||||
|
||||
self.logger.info(
|
||||
"ODS_JSON_ARCHIVE: 开始抓取,窗口[%s ~ %s] 输出目录=%s",
|
||||
context.window_start,
|
||||
context.window_end,
|
||||
out,
|
||||
)
|
||||
|
||||
for spec in self.ENDPOINTS:
|
||||
self.logger.info("ODS_JSON_ARCHIVE: 抓取 endpoint=%s", spec.endpoint)
|
||||
built_params = self._build_params(
|
||||
spec.window_style, store_id, context.window_start, context.window_end
|
||||
)
|
||||
# /TenantGoods/GetGoodsInventoryList 要求 siteId 为数组(标量会触发服务端异常,返回畸形状态行 HTTP/1.1 1400)
|
||||
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
|
||||
built_params["siteId"] = [store_id]
|
||||
params = self._merge_common_params(built_params)
|
||||
|
||||
base_filename = endpoint_to_filename(spec.endpoint)
|
||||
stem = Path(base_filename).stem
|
||||
suffix = Path(base_filename).suffix or ".json"
|
||||
|
||||
endpoint_records = 0
|
||||
endpoint_pages = 0
|
||||
endpoint_error: str | None = None
|
||||
|
||||
try:
|
||||
for page_no, records, _, _ in base_client.iter_paginated(
|
||||
endpoint=spec.endpoint,
|
||||
params=params,
|
||||
page_size=page_size,
|
||||
data_path=spec.data_path,
|
||||
list_key=spec.list_key,
|
||||
):
|
||||
endpoint_pages += 1
|
||||
total_records += len(records)
|
||||
endpoint_records += len(records)
|
||||
|
||||
if spec.endpoint == "/PayLog/GetPayLogListPage":
|
||||
for rec in records or []:
|
||||
relate_id = TypeParser.parse_int(
|
||||
(rec or {}).get("relateId")
|
||||
or (rec or {}).get("orderSettleId")
|
||||
or (rec or {}).get("order_settle_id")
|
||||
)
|
||||
if relate_id:
|
||||
ticket_ids.add(relate_id)
|
||||
|
||||
out_path = out / f"{stem}__p{int(page_no):04d}{suffix}"
|
||||
dump_json(out_path, {"code": 0, "data": records}, pretty=write_pretty)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
endpoint_error = f"{type(exc).__name__}: {exc}"
|
||||
self.logger.error("ODS_JSON_ARCHIVE: 接口抓取失败 endpoint=%s err=%s", spec.endpoint, endpoint_error)
|
||||
|
||||
per_endpoint.append(
|
||||
{
|
||||
"endpoint": spec.endpoint,
|
||||
"file_stem": stem,
|
||||
"pages": endpoint_pages,
|
||||
"records": endpoint_records,
|
||||
"error": endpoint_error,
|
||||
}
|
||||
)
|
||||
if endpoint_error:
|
||||
self.logger.warning(
|
||||
"ODS_JSON_ARCHIVE: endpoint=%s 完成(失败)pages=%s records=%s err=%s",
|
||||
spec.endpoint,
|
||||
endpoint_pages,
|
||||
endpoint_records,
|
||||
endpoint_error,
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
"ODS_JSON_ARCHIVE: endpoint=%s 完成 pages=%s records=%s",
|
||||
spec.endpoint,
|
||||
endpoint_pages,
|
||||
endpoint_records,
|
||||
)
|
||||
|
||||
# Ticket details: per orderSettleId
|
||||
ticket_ids_sorted = sorted(ticket_ids)
|
||||
self.logger.info("ODS_JSON_ARCHIVE: 小票候选数=%s", len(ticket_ids_sorted))
|
||||
|
||||
ticket_file_stem = Path(endpoint_to_filename(self.TICKET_ENDPOINT)).stem
|
||||
ticket_file_suffix = Path(endpoint_to_filename(self.TICKET_ENDPOINT)).suffix or ".json"
|
||||
ticket_records = 0
|
||||
|
||||
for order_settle_id in ticket_ids_sorted:
|
||||
params = self._merge_common_params({"orderSettleId": int(order_settle_id)})
|
||||
try:
|
||||
records, _ = base_client.get_paginated(
|
||||
endpoint=self.TICKET_ENDPOINT,
|
||||
params=params,
|
||||
page_size=None,
|
||||
data_path=("data",),
|
||||
list_key=None,
|
||||
)
|
||||
if not records:
|
||||
continue
|
||||
ticket_records += len(records)
|
||||
out_path = out / f"{ticket_file_stem}__{int(order_settle_id)}{ticket_file_suffix}"
|
||||
dump_json(out_path, {"code": 0, "data": records}, pretty=write_pretty)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self.logger.error(
|
||||
"ODS_JSON_ARCHIVE: 小票抓取失败 orderSettleId=%s err=%s",
|
||||
order_settle_id,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
|
||||
total_records += ticket_records
|
||||
|
||||
manifest = {
|
||||
"task": self.get_task_code(),
|
||||
"store_id": store_id,
|
||||
"window_start": context.window_start.isoformat(),
|
||||
"window_end": context.window_end.isoformat(),
|
||||
"page_size": page_size,
|
||||
"total_records": total_records,
|
||||
"ticket_ids": len(ticket_ids_sorted),
|
||||
"ticket_records": ticket_records,
|
||||
"endpoints": per_endpoint,
|
||||
}
|
||||
manifest_path = out / "manifest.json"
|
||||
dump_json(manifest_path, manifest, pretty=True)
|
||||
if hasattr(self.api, "last_dump"):
|
||||
try:
|
||||
self.api.last_dump = {"file": str(manifest_path), "records": total_records, "pages": None}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.logger.info("ODS_JSON_ARCHIVE: 抓取完成,总记录数=%s(含小票=%s)", total_records, ticket_records)
|
||||
return {"fetched": total_records, "ticket_ids": len(ticket_ids_sorted)}
|
||||
|
||||
def _build_params(self, window_style: str, store_id: int, window_start, window_end) -> dict:
|
||||
if window_style == "none":
|
||||
return {}
|
||||
if window_style == "site":
|
||||
return {"siteId": store_id}
|
||||
if window_style == "range":
|
||||
return {
|
||||
"siteId": store_id,
|
||||
"rangeStartTime": TypeParser.format_timestamp(window_start, self.tz),
|
||||
"rangeEndTime": TypeParser.format_timestamp(window_end, self.tz),
|
||||
}
|
||||
if window_style == "pay":
|
||||
return {
|
||||
"siteId": store_id,
|
||||
"StartPayTime": TypeParser.format_timestamp(window_start, self.tz),
|
||||
"EndPayTime": TypeParser.format_timestamp(window_end, self.tz),
|
||||
}
|
||||
# default: startTime/endTime
|
||||
return {
|
||||
"siteId": store_id,
|
||||
"startTime": TypeParser.format_timestamp(window_start, self.tz),
|
||||
"endTime": TypeParser.format_timestamp(window_end, self.tz),
|
||||
}
|
||||
@@ -2,11 +2,13 @@
|
||||
"""ODS ingestion tasks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Type
|
||||
|
||||
from loaders.ods import GenericODSLoader
|
||||
from psycopg2.extras import Json, execute_values
|
||||
|
||||
from models.parsers import TypeParser
|
||||
from .base_task import BaseTask
|
||||
|
||||
@@ -60,70 +62,61 @@ class BaseOdsTask(BaseTask):
|
||||
def get_task_code(self) -> str:
|
||||
return self.SPEC.code
|
||||
|
||||
def execute(self) -> dict:
|
||||
def execute(self, cursor_data: dict | None = None) -> dict:
|
||||
spec = self.SPEC
|
||||
self.logger.info("寮€濮嬫墽琛?%s (ODS)", spec.code)
|
||||
|
||||
window_start, window_end, window_minutes = self._resolve_window(cursor_data)
|
||||
|
||||
store_id = TypeParser.parse_int(self.config.get("app.store_id"))
|
||||
if not store_id:
|
||||
raise ValueError("app.store_id 鏈厤缃紝鏃犳硶鎵ц ODS 浠诲姟")
|
||||
|
||||
page_size = self.config.get("api.page_size", 200)
|
||||
params = self._build_params(spec, store_id)
|
||||
columns = self._resolve_columns(spec)
|
||||
if spec.conflict_columns_override:
|
||||
conflict_columns = list(spec.conflict_columns_override)
|
||||
else:
|
||||
conflict_columns = []
|
||||
if spec.include_site_column:
|
||||
conflict_columns.append("site_id")
|
||||
conflict_columns += [col.column for col in spec.pk_columns]
|
||||
loader = GenericODSLoader(
|
||||
self.db,
|
||||
spec.table_name,
|
||||
columns,
|
||||
conflict_columns,
|
||||
params = self._build_params(
|
||||
spec,
|
||||
store_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
)
|
||||
|
||||
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
|
||||
source_file = self._resolve_source_file_hint(spec)
|
||||
|
||||
try:
|
||||
global_index = 0
|
||||
for page_no, page_records, _, _ in self.api.iter_paginated(
|
||||
for _, page_records, _, response_payload in self.api.iter_paginated(
|
||||
endpoint=spec.endpoint,
|
||||
params=params,
|
||||
page_size=page_size,
|
||||
data_path=spec.data_path,
|
||||
list_key=spec.list_key,
|
||||
):
|
||||
rows: List[dict] = []
|
||||
for raw in page_records:
|
||||
row = self._build_row(
|
||||
spec=spec,
|
||||
store_id=store_id,
|
||||
record=raw,
|
||||
page_no=page_no if spec.include_page_no else None,
|
||||
page_size_value=len(page_records)
|
||||
if spec.include_page_size
|
||||
else None,
|
||||
source_file=source_file,
|
||||
record_index=global_index if spec.include_record_index else None,
|
||||
)
|
||||
if row is None:
|
||||
counts["skipped"] += 1
|
||||
continue
|
||||
rows.append(row)
|
||||
global_index += 1
|
||||
|
||||
inserted, updated, _ = loader.upsert_rows(rows)
|
||||
counts["inserted"] += inserted
|
||||
counts["updated"] += updated
|
||||
inserted, skipped = self._insert_records_schema_aware(
|
||||
table=spec.table_name,
|
||||
records=page_records,
|
||||
response_payload=response_payload,
|
||||
source_file=source_file,
|
||||
source_endpoint=spec.endpoint if spec.include_source_endpoint else None,
|
||||
)
|
||||
counts["fetched"] += len(page_records)
|
||||
counts["inserted"] += inserted
|
||||
counts["skipped"] += skipped
|
||||
|
||||
self.db.commit()
|
||||
self.logger.info("%s ODS 浠诲姟瀹屾垚: %s", spec.code, counts)
|
||||
return self._build_result("SUCCESS", counts)
|
||||
allow_empty_advance = bool(self.config.get("run.allow_empty_result_advance", False))
|
||||
status = "SUCCESS"
|
||||
if counts["fetched"] == 0 and not allow_empty_advance:
|
||||
status = "PARTIAL"
|
||||
|
||||
result = self._build_result(status, counts)
|
||||
result["window"] = {
|
||||
"start": window_start,
|
||||
"end": window_end,
|
||||
"minutes": window_minutes,
|
||||
}
|
||||
result["request_params"] = params
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
@@ -131,12 +124,70 @@ class BaseOdsTask(BaseTask):
|
||||
self.logger.error("%s ODS 浠诲姟澶辫触", spec.code, exc_info=True)
|
||||
raise
|
||||
|
||||
def _build_params(self, spec: OdsTaskSpec, store_id: int) -> dict:
|
||||
def _resolve_window(self, cursor_data: dict | None) -> tuple[datetime, datetime, int]:
|
||||
base_start, base_end, base_minutes = self._get_time_window(cursor_data)
|
||||
|
||||
if self.config.get("run.force_window_override"):
|
||||
override_start = self.config.get("run.window_override.start")
|
||||
override_end = self.config.get("run.window_override.end")
|
||||
if override_start and override_end:
|
||||
return base_start, base_end, base_minutes
|
||||
|
||||
# 以 ODS 表 MAX(fetched_at) 兜底:避免“窗口游标推进但未实际入库”导致漏数。
|
||||
last_fetched = self._get_max_fetched_at(self.SPEC.table_name)
|
||||
if last_fetched:
|
||||
overlap_seconds = int(self.config.get("run.overlap_seconds", 120) or 120)
|
||||
cursor_end = cursor_data.get("last_end") if isinstance(cursor_data, dict) else None
|
||||
anchor = cursor_end or last_fetched
|
||||
# 如果 cursor_end 比真实入库时间(last_fetched)更靠后,说明游标被推进但表未跟上:改用 last_fetched 作为起点
|
||||
if isinstance(cursor_end, datetime) and cursor_end.tzinfo is None:
|
||||
cursor_end = cursor_end.replace(tzinfo=self.tz)
|
||||
if isinstance(cursor_end, datetime) and cursor_end > last_fetched:
|
||||
anchor = last_fetched
|
||||
start = anchor - timedelta(seconds=max(0, overlap_seconds))
|
||||
if start.tzinfo is None:
|
||||
start = start.replace(tzinfo=self.tz)
|
||||
else:
|
||||
start = start.astimezone(self.tz)
|
||||
|
||||
end = datetime.now(self.tz)
|
||||
minutes = max(1, int((end - start).total_seconds() // 60))
|
||||
return start, end, minutes
|
||||
|
||||
return base_start, base_end, base_minutes
|
||||
|
||||
def _get_max_fetched_at(self, table_name: str) -> datetime | None:
|
||||
try:
|
||||
rows = self.db.query(f"SELECT MAX(fetched_at) AS mx FROM {table_name}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not rows or not rows[0].get("mx"):
|
||||
return None
|
||||
|
||||
mx = rows[0]["mx"]
|
||||
if not isinstance(mx, datetime):
|
||||
return None
|
||||
if mx.tzinfo is None:
|
||||
return mx.replace(tzinfo=self.tz)
|
||||
return mx.astimezone(self.tz)
|
||||
|
||||
def _build_params(
|
||||
self,
|
||||
spec: OdsTaskSpec,
|
||||
store_id: int,
|
||||
*,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> dict:
|
||||
base: dict[str, Any] = {}
|
||||
if spec.include_site_id:
|
||||
base["siteId"] = store_id
|
||||
# /TenantGoods/GetGoodsInventoryList 要求 siteId 为数组(标量会触发服务端异常,返回畸形状态行 HTTP/1.1 1400)
|
||||
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
|
||||
base["siteId"] = [store_id]
|
||||
else:
|
||||
base["siteId"] = store_id
|
||||
if spec.requires_window and spec.time_fields:
|
||||
window_start, window_end, _ = self._get_time_window()
|
||||
start_key, end_key = spec.time_fields
|
||||
base[start_key] = TypeParser.format_timestamp(window_start, self.tz)
|
||||
base[end_key] = TypeParser.format_timestamp(window_end, self.tz)
|
||||
@@ -145,109 +196,226 @@ class BaseOdsTask(BaseTask):
|
||||
params.update(spec.extra_params)
|
||||
return params
|
||||
|
||||
def _resolve_columns(self, spec: OdsTaskSpec) -> List[str]:
|
||||
columns: List[str] = []
|
||||
if spec.include_site_column:
|
||||
columns.append("site_id")
|
||||
seen = set(columns)
|
||||
for col_spec in list(spec.pk_columns) + list(spec.extra_columns):
|
||||
if col_spec.column not in seen:
|
||||
columns.append(col_spec.column)
|
||||
seen.add(col_spec.column)
|
||||
# ------------------------------------------------------------------ schema-aware ingest (ODS doc schema)
|
||||
def _get_table_columns(self, table: str) -> list[tuple[str, str, str]]:
|
||||
cache = getattr(self, "_table_columns_cache", {})
|
||||
if table in cache:
|
||||
return cache[table]
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT column_name, data_type, udt_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
with self.db.conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
cols = [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
|
||||
cache[table] = cols
|
||||
self._table_columns_cache = cache
|
||||
return cols
|
||||
|
||||
if spec.include_record_index and "record_index" not in seen:
|
||||
columns.append("record_index")
|
||||
seen.add("record_index")
|
||||
def _get_table_pk_columns(self, table: str) -> list[str]:
|
||||
cache = getattr(self, "_table_pk_cache", {})
|
||||
if table in cache:
|
||||
return cache[table]
|
||||
if "." in table:
|
||||
schema, name = table.split(".", 1)
|
||||
else:
|
||||
schema, name = "public", table
|
||||
sql = """
|
||||
SELECT kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'PRIMARY KEY'
|
||||
AND tc.table_schema = %s
|
||||
AND tc.table_name = %s
|
||||
ORDER BY kcu.ordinal_position
|
||||
"""
|
||||
with self.db.conn.cursor() as cur:
|
||||
cur.execute(sql, (schema, name))
|
||||
cols = [r[0] for r in cur.fetchall()]
|
||||
cache[table] = cols
|
||||
self._table_pk_cache = cache
|
||||
return cols
|
||||
|
||||
if spec.include_page_no and "page_no" not in seen:
|
||||
columns.append("page_no")
|
||||
seen.add("page_no")
|
||||
|
||||
if spec.include_page_size and "page_size" not in seen:
|
||||
columns.append("page_size")
|
||||
seen.add("page_size")
|
||||
|
||||
if spec.include_source_file and "source_file" not in seen:
|
||||
columns.append("source_file")
|
||||
seen.add("source_file")
|
||||
|
||||
if spec.include_source_endpoint and "source_endpoint" not in seen:
|
||||
columns.append("source_endpoint")
|
||||
seen.add("source_endpoint")
|
||||
|
||||
if spec.include_fetched_at and "fetched_at" not in seen:
|
||||
columns.append("fetched_at")
|
||||
seen.add("fetched_at")
|
||||
if "payload" not in seen:
|
||||
columns.append("payload")
|
||||
|
||||
return columns
|
||||
|
||||
def _build_row(
|
||||
def _insert_records_schema_aware(
|
||||
self,
|
||||
spec: OdsTaskSpec,
|
||||
store_id: int,
|
||||
record: dict,
|
||||
page_no: int | None,
|
||||
page_size_value: int | None,
|
||||
*,
|
||||
table: str,
|
||||
records: list,
|
||||
response_payload: dict | list | None,
|
||||
source_file: str | None,
|
||||
record_index: int | None = None,
|
||||
) -> dict | None:
|
||||
row: dict[str, Any] = {}
|
||||
if spec.include_site_column:
|
||||
row["site_id"] = store_id
|
||||
source_endpoint: str | None,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
按 DB 表结构动态写入 ODS(只插新数据:ON CONFLICT DO NOTHING)。
|
||||
返回 (inserted, skipped)。
|
||||
"""
|
||||
if not records:
|
||||
return 0, 0
|
||||
|
||||
for col_spec in spec.pk_columns + spec.extra_columns:
|
||||
value = self._extract_value(record, col_spec)
|
||||
if value is None and col_spec.required:
|
||||
self.logger.warning(
|
||||
"%s 缂哄皯蹇呭~瀛楁 %s锛屽師濮嬭褰? %s",
|
||||
spec.code,
|
||||
col_spec.column,
|
||||
record,
|
||||
)
|
||||
return None
|
||||
row[col_spec.column] = value
|
||||
cols_info = self._get_table_columns(table)
|
||||
if not cols_info:
|
||||
raise ValueError(f"Cannot resolve columns for table={table}")
|
||||
|
||||
if spec.include_page_no:
|
||||
row["page_no"] = page_no
|
||||
if spec.include_page_size:
|
||||
row["page_size"] = page_size_value
|
||||
if spec.include_record_index:
|
||||
row["record_index"] = record_index
|
||||
if spec.include_source_file:
|
||||
row["source_file"] = source_file
|
||||
if spec.include_source_endpoint:
|
||||
row["source_endpoint"] = spec.endpoint
|
||||
pk_cols = self._get_table_pk_columns(table)
|
||||
db_json_cols_lower = {
|
||||
c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
|
||||
}
|
||||
|
||||
if spec.include_fetched_at:
|
||||
row["fetched_at"] = datetime.now(self.tz)
|
||||
row["payload"] = record
|
||||
return row
|
||||
col_names = [c[0] for c in cols_info]
|
||||
quoted_cols = ", ".join(f'\"{c}\"' for c in col_names)
|
||||
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
|
||||
if pk_cols:
|
||||
pk_clause = ", ".join(f'\"{c}\"' for c in pk_cols)
|
||||
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
|
||||
|
||||
def _extract_value(self, record: dict, spec: ColumnSpec):
|
||||
value = None
|
||||
for key in spec.sources:
|
||||
value = self._dig(record, key)
|
||||
if value is not None:
|
||||
break
|
||||
if value is None and spec.default is not None:
|
||||
value = spec.default
|
||||
if value is not None and spec.transform:
|
||||
value = spec.transform(value)
|
||||
now = datetime.now(self.tz)
|
||||
json_dump = lambda v: json.dumps(v, ensure_ascii=False) # noqa: E731
|
||||
|
||||
params: list[tuple] = []
|
||||
skipped = 0
|
||||
|
||||
root_site_profile = None
|
||||
if isinstance(response_payload, dict):
|
||||
data_part = response_payload.get("data")
|
||||
if isinstance(data_part, dict):
|
||||
sp = data_part.get("siteProfile") or data_part.get("site_profile")
|
||||
if isinstance(sp, dict):
|
||||
root_site_profile = sp
|
||||
|
||||
for rec in records:
|
||||
if not isinstance(rec, dict):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
merged_rec = self._merge_record_layers(rec)
|
||||
if table in {"billiards_ods.recharge_settlements", "billiards_ods.settlement_records"}:
|
||||
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") or root_site_profile
|
||||
if isinstance(site_profile, dict):
|
||||
# 避免写入 None 覆盖原本存在的 camelCase 字段(例如 tenantId/siteId/siteName)
|
||||
def _fill_missing(target_col: str, candidates: list[Any]):
|
||||
existing = self._get_value_case_insensitive(merged_rec, target_col)
|
||||
if existing not in (None, ""):
|
||||
return
|
||||
for cand in candidates:
|
||||
if cand in (None, "", 0):
|
||||
continue
|
||||
merged_rec[target_col] = cand
|
||||
return
|
||||
|
||||
_fill_missing("tenantid", [site_profile.get("tenant_id"), site_profile.get("tenantId")])
|
||||
_fill_missing("siteid", [site_profile.get("siteId"), site_profile.get("id")])
|
||||
_fill_missing("sitename", [site_profile.get("shop_name"), site_profile.get("siteName")])
|
||||
|
||||
if pk_cols:
|
||||
missing_pk = False
|
||||
for pk in pk_cols:
|
||||
pk_val = self._get_value_case_insensitive(merged_rec, pk)
|
||||
if pk_val is None or pk_val == "":
|
||||
missing_pk = True
|
||||
break
|
||||
if missing_pk:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
row_vals: list[Any] = []
|
||||
for (col_name, data_type, _udt) in cols_info:
|
||||
col_lower = col_name.lower()
|
||||
if col_lower == "payload":
|
||||
row_vals.append(Json(rec, dumps=json_dump))
|
||||
continue
|
||||
if col_lower == "source_file":
|
||||
row_vals.append(source_file)
|
||||
continue
|
||||
if col_lower == "source_endpoint":
|
||||
row_vals.append(source_endpoint)
|
||||
continue
|
||||
if col_lower == "fetched_at":
|
||||
row_vals.append(now)
|
||||
continue
|
||||
|
||||
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
|
||||
if col_lower in db_json_cols_lower:
|
||||
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
|
||||
continue
|
||||
|
||||
row_vals.append(self._cast_value(value, data_type))
|
||||
|
||||
params.append(tuple(row_vals))
|
||||
|
||||
if not params:
|
||||
return 0, skipped
|
||||
|
||||
inserted = 0
|
||||
chunk_size = int(self.config.get("run.ods_execute_values_page_size", 200) or 200)
|
||||
chunk_size = max(1, min(chunk_size, 2000))
|
||||
with self.db.conn.cursor() as cur:
|
||||
for i in range(0, len(params), chunk_size):
|
||||
chunk = params[i : i + chunk_size]
|
||||
execute_values(cur, sql, chunk, page_size=len(chunk))
|
||||
if cur.rowcount is not None and cur.rowcount > 0:
|
||||
inserted += int(cur.rowcount)
|
||||
return inserted, skipped
|
||||
|
||||
@staticmethod
|
||||
def _merge_record_layers(record: dict) -> dict:
|
||||
merged = record
|
||||
data_part = merged.get("data")
|
||||
while isinstance(data_part, dict):
|
||||
merged = {**data_part, **merged}
|
||||
data_part = data_part.get("data")
|
||||
settle_inner = merged.get("settleList")
|
||||
if isinstance(settle_inner, dict):
|
||||
merged = {**settle_inner, **merged}
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _get_value_case_insensitive(record: dict | None, col: str | None):
|
||||
if record is None or col is None:
|
||||
return None
|
||||
if col in record:
|
||||
return record.get(col)
|
||||
col_lower = col.lower()
|
||||
for k, v in record.items():
|
||||
if isinstance(k, str) and k.lower() == col_lower:
|
||||
return v
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scalar(value):
|
||||
if value == "" or value == "{}" or value == "[]":
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _dig(record: Any, path: str | None):
|
||||
if not path:
|
||||
def _cast_value(value, data_type: str):
|
||||
if value is None:
|
||||
return None
|
||||
current = record
|
||||
for part in path.split("."):
|
||||
if isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
else:
|
||||
dt = (data_type or "").lower()
|
||||
if dt in ("integer", "bigint", "smallint"):
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
return current
|
||||
if dt in ("numeric", "double precision", "real", "decimal"):
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
|
||||
return value if isinstance(value, (str, datetime)) else None
|
||||
return value
|
||||
|
||||
def _resolve_source_file_hint(self, spec: OdsTaskSpec) -> str | None:
|
||||
resolver = getattr(self.api, "get_source_hint", None)
|
||||
@@ -319,15 +487,16 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
|
||||
endpoint="/Site/GetAllOrderSettleList",
|
||||
data_path=("data",),
|
||||
list_key="settleList",
|
||||
time_fields=("rangeStartTime", "rangeEndTime"),
|
||||
pk_columns=(),
|
||||
include_site_column=False,
|
||||
include_source_endpoint=False,
|
||||
include_source_endpoint=True,
|
||||
include_page_no=False,
|
||||
include_page_size=False,
|
||||
include_fetched_at=False,
|
||||
include_record_index=True,
|
||||
conflict_columns_override=("source_file", "record_index"),
|
||||
requires_window=False,
|
||||
requires_window=True,
|
||||
description="缁撹处璁板綍 ODS锛欸etAllOrderSettleList -> settleList 鍘熷 JSON",
|
||||
),
|
||||
OdsTaskSpec(
|
||||
@@ -512,6 +681,7 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
|
||||
endpoint="/Site/GetRechargeSettleList",
|
||||
data_path=("data",),
|
||||
list_key="settleList",
|
||||
time_fields=("rangeStartTime", "rangeEndTime"),
|
||||
pk_columns=(_int_col("recharge_order_id", "settleList.id", "id", required=True),),
|
||||
extra_columns=(
|
||||
_int_col("tenant_id", "settleList.tenantId", "tenantId"),
|
||||
@@ -583,7 +753,7 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
|
||||
include_fetched_at=True,
|
||||
include_record_index=False,
|
||||
conflict_columns_override=None,
|
||||
requires_window=False,
|
||||
requires_window=True,
|
||||
description="?????? ODS?GetRechargeSettleList -> data.settleList ????",
|
||||
),
|
||||
|
||||
@@ -800,12 +970,6 @@ class OdsSettlementTicketTask(BaseOdsTask):
|
||||
store_id = TypeParser.parse_int(self.config.get("app.store_id")) or 0
|
||||
|
||||
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
|
||||
loader = GenericODSLoader(
|
||||
self.db,
|
||||
spec.table_name,
|
||||
self._resolve_columns(spec),
|
||||
list(spec.conflict_columns_override or ("source_file", "record_index")),
|
||||
)
|
||||
source_file = self._resolve_source_file_hint(spec)
|
||||
|
||||
try:
|
||||
@@ -823,39 +987,43 @@ class OdsSettlementTicketTask(BaseOdsTask):
|
||||
context.window_start,
|
||||
context.window_end,
|
||||
)
|
||||
return self._build_result("SUCCESS", counts)
|
||||
result = self._build_result("SUCCESS", counts)
|
||||
result["window"] = {
|
||||
"start": context.window_start,
|
||||
"end": context.window_end,
|
||||
"minutes": context.window_minutes,
|
||||
}
|
||||
result["request_params"] = {"candidates": 0}
|
||||
return result
|
||||
|
||||
payloads, skipped = self._fetch_ticket_payloads(candidates)
|
||||
counts["skipped"] += skipped
|
||||
rows: list[dict] = []
|
||||
for idx, payload in enumerate(payloads):
|
||||
row = self._build_row(
|
||||
spec=spec,
|
||||
store_id=store_id,
|
||||
record=payload,
|
||||
page_no=None,
|
||||
page_size_value=None,
|
||||
source_file=source_file,
|
||||
record_index=idx if spec.include_record_index else None,
|
||||
)
|
||||
if row is None:
|
||||
counts["skipped"] += 1
|
||||
continue
|
||||
rows.append(row)
|
||||
|
||||
inserted, updated, _ = loader.upsert_rows(rows)
|
||||
inserted, skipped2 = self._insert_records_schema_aware(
|
||||
table=spec.table_name,
|
||||
records=payloads,
|
||||
response_payload=None,
|
||||
source_file=source_file,
|
||||
source_endpoint=spec.endpoint,
|
||||
)
|
||||
counts["inserted"] += inserted
|
||||
counts["updated"] += updated
|
||||
counts["skipped"] += skipped2
|
||||
self.db.commit()
|
||||
self.logger.info(
|
||||
"%s: 灏忕エ鎶撳彇瀹屾垚锛屽€欓€?%s 鎻掑叆=%s 鏇存柊=%s 璺宠繃=%s",
|
||||
spec.code,
|
||||
len(candidates),
|
||||
inserted,
|
||||
updated,
|
||||
0,
|
||||
counts["skipped"],
|
||||
)
|
||||
return self._build_result("SUCCESS", counts)
|
||||
result = self._build_result("SUCCESS", counts)
|
||||
result["window"] = {
|
||||
"start": context.window_start,
|
||||
"end": context.window_end,
|
||||
"minutes": context.window_minutes,
|
||||
}
|
||||
result["request_params"] = {"candidates": len(candidates)}
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
counts["errors"] += 1
|
||||
@@ -1026,4 +1194,3 @@ ODS_TASK_CLASSES: Dict[str, Type[BaseOdsTask]] = {
|
||||
ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"] = OdsSettlementTicketTask
|
||||
|
||||
__all__ = ["ODS_TASK_CLASSES", "ODS_TASK_SPECS", "BaseOdsTask", "ENABLED_ODS_CODES"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user