ETL 完成

This commit is contained in:
Neo
2026-01-18 22:37:38 +08:00
parent 8da6cb6563
commit 7ca19a4a2c
159 changed files with 31225 additions and 467 deletions

View File

@@ -2,11 +2,13 @@
"""ODS ingestion tasks."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, Iterable, List, Sequence, Tuple, Type
from loaders.ods import GenericODSLoader
from psycopg2.extras import Json, execute_values
from models.parsers import TypeParser
from .base_task import BaseTask
@@ -60,70 +62,61 @@ class BaseOdsTask(BaseTask):
def get_task_code(self) -> str:
return self.SPEC.code
def execute(self) -> dict:
def execute(self, cursor_data: dict | None = None) -> dict:
spec = self.SPEC
self.logger.info("寮€濮嬫墽琛?%s (ODS)", spec.code)
window_start, window_end, window_minutes = self._resolve_window(cursor_data)
store_id = TypeParser.parse_int(self.config.get("app.store_id"))
if not store_id:
raise ValueError("app.store_id 鏈厤缃紝鏃犳硶鎵ц ODS 浠诲姟")
page_size = self.config.get("api.page_size", 200)
params = self._build_params(spec, store_id)
columns = self._resolve_columns(spec)
if spec.conflict_columns_override:
conflict_columns = list(spec.conflict_columns_override)
else:
conflict_columns = []
if spec.include_site_column:
conflict_columns.append("site_id")
conflict_columns += [col.column for col in spec.pk_columns]
loader = GenericODSLoader(
self.db,
spec.table_name,
columns,
conflict_columns,
params = self._build_params(
spec,
store_id,
window_start=window_start,
window_end=window_end,
)
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
source_file = self._resolve_source_file_hint(spec)
try:
global_index = 0
for page_no, page_records, _, _ in self.api.iter_paginated(
for _, page_records, _, response_payload in self.api.iter_paginated(
endpoint=spec.endpoint,
params=params,
page_size=page_size,
data_path=spec.data_path,
list_key=spec.list_key,
):
rows: List[dict] = []
for raw in page_records:
row = self._build_row(
spec=spec,
store_id=store_id,
record=raw,
page_no=page_no if spec.include_page_no else None,
page_size_value=len(page_records)
if spec.include_page_size
else None,
source_file=source_file,
record_index=global_index if spec.include_record_index else None,
)
if row is None:
counts["skipped"] += 1
continue
rows.append(row)
global_index += 1
inserted, updated, _ = loader.upsert_rows(rows)
counts["inserted"] += inserted
counts["updated"] += updated
inserted, skipped = self._insert_records_schema_aware(
table=spec.table_name,
records=page_records,
response_payload=response_payload,
source_file=source_file,
source_endpoint=spec.endpoint if spec.include_source_endpoint else None,
)
counts["fetched"] += len(page_records)
counts["inserted"] += inserted
counts["skipped"] += skipped
self.db.commit()
self.logger.info("%s ODS 浠诲姟瀹屾垚: %s", spec.code, counts)
return self._build_result("SUCCESS", counts)
allow_empty_advance = bool(self.config.get("run.allow_empty_result_advance", False))
status = "SUCCESS"
if counts["fetched"] == 0 and not allow_empty_advance:
status = "PARTIAL"
result = self._build_result(status, counts)
result["window"] = {
"start": window_start,
"end": window_end,
"minutes": window_minutes,
}
result["request_params"] = params
return result
except Exception:
self.db.rollback()
@@ -131,12 +124,70 @@ class BaseOdsTask(BaseTask):
self.logger.error("%s ODS 浠诲姟澶辫触", spec.code, exc_info=True)
raise
def _build_params(self, spec: OdsTaskSpec, store_id: int) -> dict:
def _resolve_window(self, cursor_data: dict | None) -> tuple[datetime, datetime, int]:
base_start, base_end, base_minutes = self._get_time_window(cursor_data)
if self.config.get("run.force_window_override"):
override_start = self.config.get("run.window_override.start")
override_end = self.config.get("run.window_override.end")
if override_start and override_end:
return base_start, base_end, base_minutes
# 以 ODS 表 MAX(fetched_at) 兜底:避免“窗口游标推进但未实际入库”导致漏数。
last_fetched = self._get_max_fetched_at(self.SPEC.table_name)
if last_fetched:
overlap_seconds = int(self.config.get("run.overlap_seconds", 120) or 120)
cursor_end = cursor_data.get("last_end") if isinstance(cursor_data, dict) else None
anchor = cursor_end or last_fetched
# 如果 cursor_end 比真实入库时间(last_fetched)更靠后,说明游标被推进但表未跟上:改用 last_fetched 作为起点
if isinstance(cursor_end, datetime) and cursor_end.tzinfo is None:
cursor_end = cursor_end.replace(tzinfo=self.tz)
if isinstance(cursor_end, datetime) and cursor_end > last_fetched:
anchor = last_fetched
start = anchor - timedelta(seconds=max(0, overlap_seconds))
if start.tzinfo is None:
start = start.replace(tzinfo=self.tz)
else:
start = start.astimezone(self.tz)
end = datetime.now(self.tz)
minutes = max(1, int((end - start).total_seconds() // 60))
return start, end, minutes
return base_start, base_end, base_minutes
def _get_max_fetched_at(self, table_name: str) -> datetime | None:
try:
rows = self.db.query(f"SELECT MAX(fetched_at) AS mx FROM {table_name}")
except Exception:
return None
if not rows or not rows[0].get("mx"):
return None
mx = rows[0]["mx"]
if not isinstance(mx, datetime):
return None
if mx.tzinfo is None:
return mx.replace(tzinfo=self.tz)
return mx.astimezone(self.tz)
def _build_params(
self,
spec: OdsTaskSpec,
store_id: int,
*,
window_start: datetime,
window_end: datetime,
) -> dict:
base: dict[str, Any] = {}
if spec.include_site_id:
base["siteId"] = store_id
# /TenantGoods/GetGoodsInventoryList 要求 siteId 为数组(标量会触发服务端异常,返回畸形状态行 HTTP/1.1 1400
if spec.endpoint == "/TenantGoods/GetGoodsInventoryList":
base["siteId"] = [store_id]
else:
base["siteId"] = store_id
if spec.requires_window and spec.time_fields:
window_start, window_end, _ = self._get_time_window()
start_key, end_key = spec.time_fields
base[start_key] = TypeParser.format_timestamp(window_start, self.tz)
base[end_key] = TypeParser.format_timestamp(window_end, self.tz)
@@ -145,109 +196,226 @@ class BaseOdsTask(BaseTask):
params.update(spec.extra_params)
return params
def _resolve_columns(self, spec: OdsTaskSpec) -> List[str]:
columns: List[str] = []
if spec.include_site_column:
columns.append("site_id")
seen = set(columns)
for col_spec in list(spec.pk_columns) + list(spec.extra_columns):
if col_spec.column not in seen:
columns.append(col_spec.column)
seen.add(col_spec.column)
# ------------------------------------------------------------------ schema-aware ingest (ODS doc schema)
def _get_table_columns(self, table: str) -> list[tuple[str, str, str]]:
cache = getattr(self, "_table_columns_cache", {})
if table in cache:
return cache[table]
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
cache[table] = cols
self._table_columns_cache = cache
return cols
if spec.include_record_index and "record_index" not in seen:
columns.append("record_index")
seen.add("record_index")
def _get_table_pk_columns(self, table: str) -> list[str]:
cache = getattr(self, "_table_pk_cache", {})
if table in cache:
return cache[table]
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = %s
AND tc.table_name = %s
ORDER BY kcu.ordinal_position
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [r[0] for r in cur.fetchall()]
cache[table] = cols
self._table_pk_cache = cache
return cols
if spec.include_page_no and "page_no" not in seen:
columns.append("page_no")
seen.add("page_no")
if spec.include_page_size and "page_size" not in seen:
columns.append("page_size")
seen.add("page_size")
if spec.include_source_file and "source_file" not in seen:
columns.append("source_file")
seen.add("source_file")
if spec.include_source_endpoint and "source_endpoint" not in seen:
columns.append("source_endpoint")
seen.add("source_endpoint")
if spec.include_fetched_at and "fetched_at" not in seen:
columns.append("fetched_at")
seen.add("fetched_at")
if "payload" not in seen:
columns.append("payload")
return columns
def _build_row(
def _insert_records_schema_aware(
self,
spec: OdsTaskSpec,
store_id: int,
record: dict,
page_no: int | None,
page_size_value: int | None,
*,
table: str,
records: list,
response_payload: dict | list | None,
source_file: str | None,
record_index: int | None = None,
) -> dict | None:
row: dict[str, Any] = {}
if spec.include_site_column:
row["site_id"] = store_id
source_endpoint: str | None,
) -> tuple[int, int]:
"""
按 DB 表结构动态写入 ODS只插新数据ON CONFLICT DO NOTHING
返回 (inserted, skipped)。
"""
if not records:
return 0, 0
for col_spec in spec.pk_columns + spec.extra_columns:
value = self._extract_value(record, col_spec)
if value is None and col_spec.required:
self.logger.warning(
"%s 缂哄皯蹇呭~瀛楁 %s锛屽師濮嬭褰? %s",
spec.code,
col_spec.column,
record,
)
return None
row[col_spec.column] = value
cols_info = self._get_table_columns(table)
if not cols_info:
raise ValueError(f"Cannot resolve columns for table={table}")
if spec.include_page_no:
row["page_no"] = page_no
if spec.include_page_size:
row["page_size"] = page_size_value
if spec.include_record_index:
row["record_index"] = record_index
if spec.include_source_file:
row["source_file"] = source_file
if spec.include_source_endpoint:
row["source_endpoint"] = spec.endpoint
pk_cols = self._get_table_pk_columns(table)
db_json_cols_lower = {
c[0].lower() for c in cols_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
if spec.include_fetched_at:
row["fetched_at"] = datetime.now(self.tz)
row["payload"] = record
return row
col_names = [c[0] for c in cols_info]
quoted_cols = ", ".join(f'\"{c}\"' for c in col_names)
sql = f"INSERT INTO {table} ({quoted_cols}) VALUES %s"
if pk_cols:
pk_clause = ", ".join(f'\"{c}\"' for c in pk_cols)
sql += f" ON CONFLICT ({pk_clause}) DO NOTHING"
def _extract_value(self, record: dict, spec: ColumnSpec):
value = None
for key in spec.sources:
value = self._dig(record, key)
if value is not None:
break
if value is None and spec.default is not None:
value = spec.default
if value is not None and spec.transform:
value = spec.transform(value)
now = datetime.now(self.tz)
json_dump = lambda v: json.dumps(v, ensure_ascii=False) # noqa: E731
params: list[tuple] = []
skipped = 0
root_site_profile = None
if isinstance(response_payload, dict):
data_part = response_payload.get("data")
if isinstance(data_part, dict):
sp = data_part.get("siteProfile") or data_part.get("site_profile")
if isinstance(sp, dict):
root_site_profile = sp
for rec in records:
if not isinstance(rec, dict):
skipped += 1
continue
merged_rec = self._merge_record_layers(rec)
if table in {"billiards_ods.recharge_settlements", "billiards_ods.settlement_records"}:
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") or root_site_profile
if isinstance(site_profile, dict):
# 避免写入 None 覆盖原本存在的 camelCase 字段(例如 tenantId/siteId/siteName
def _fill_missing(target_col: str, candidates: list[Any]):
existing = self._get_value_case_insensitive(merged_rec, target_col)
if existing not in (None, ""):
return
for cand in candidates:
if cand in (None, "", 0):
continue
merged_rec[target_col] = cand
return
_fill_missing("tenantid", [site_profile.get("tenant_id"), site_profile.get("tenantId")])
_fill_missing("siteid", [site_profile.get("siteId"), site_profile.get("id")])
_fill_missing("sitename", [site_profile.get("shop_name"), site_profile.get("siteName")])
if pk_cols:
missing_pk = False
for pk in pk_cols:
pk_val = self._get_value_case_insensitive(merged_rec, pk)
if pk_val is None or pk_val == "":
missing_pk = True
break
if missing_pk:
skipped += 1
continue
row_vals: list[Any] = []
for (col_name, data_type, _udt) in cols_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append(source_file)
continue
if col_lower == "source_endpoint":
row_vals.append(source_endpoint)
continue
if col_lower == "fetched_at":
row_vals.append(now)
continue
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
if col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
row_vals.append(self._cast_value(value, data_type))
params.append(tuple(row_vals))
if not params:
return 0, skipped
inserted = 0
chunk_size = int(self.config.get("run.ods_execute_values_page_size", 200) or 200)
chunk_size = max(1, min(chunk_size, 2000))
with self.db.conn.cursor() as cur:
for i in range(0, len(params), chunk_size):
chunk = params[i : i + chunk_size]
execute_values(cur, sql, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
inserted += int(cur.rowcount)
return inserted, skipped
@staticmethod
def _merge_record_layers(record: dict) -> dict:
merged = record
data_part = merged.get("data")
while isinstance(data_part, dict):
merged = {**data_part, **merged}
data_part = data_part.get("data")
settle_inner = merged.get("settleList")
if isinstance(settle_inner, dict):
merged = {**settle_inner, **merged}
return merged
@staticmethod
def _get_value_case_insensitive(record: dict | None, col: str | None):
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
@staticmethod
def _normalize_scalar(value):
if value == "" or value == "{}" or value == "[]":
return None
return value
@staticmethod
def _dig(record: Any, path: str | None):
if not path:
def _cast_value(value, data_type: str):
if value is None:
return None
current = record
for part in path.split("."):
if isinstance(current, dict):
current = current.get(part)
else:
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
return current
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, (str, datetime)) else None
return value
def _resolve_source_file_hint(self, spec: OdsTaskSpec) -> str | None:
resolver = getattr(self.api, "get_source_hint", None)
@@ -319,15 +487,16 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
endpoint="/Site/GetAllOrderSettleList",
data_path=("data",),
list_key="settleList",
time_fields=("rangeStartTime", "rangeEndTime"),
pk_columns=(),
include_site_column=False,
include_source_endpoint=False,
include_source_endpoint=True,
include_page_no=False,
include_page_size=False,
include_fetched_at=False,
include_record_index=True,
conflict_columns_override=("source_file", "record_index"),
requires_window=False,
requires_window=True,
description="缁撹处璁板綍 ODS锛欸etAllOrderSettleList -> settleList 鍘熷 JSON",
),
OdsTaskSpec(
@@ -512,6 +681,7 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
endpoint="/Site/GetRechargeSettleList",
data_path=("data",),
list_key="settleList",
time_fields=("rangeStartTime", "rangeEndTime"),
pk_columns=(_int_col("recharge_order_id", "settleList.id", "id", required=True),),
extra_columns=(
_int_col("tenant_id", "settleList.tenantId", "tenantId"),
@@ -583,7 +753,7 @@ ODS_TASK_SPECS: Tuple[OdsTaskSpec, ...] = (
include_fetched_at=True,
include_record_index=False,
conflict_columns_override=None,
requires_window=False,
requires_window=True,
description="?????? ODS?GetRechargeSettleList -> data.settleList ????",
),
@@ -800,12 +970,6 @@ class OdsSettlementTicketTask(BaseOdsTask):
store_id = TypeParser.parse_int(self.config.get("app.store_id")) or 0
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
loader = GenericODSLoader(
self.db,
spec.table_name,
self._resolve_columns(spec),
list(spec.conflict_columns_override or ("source_file", "record_index")),
)
source_file = self._resolve_source_file_hint(spec)
try:
@@ -823,39 +987,43 @@ class OdsSettlementTicketTask(BaseOdsTask):
context.window_start,
context.window_end,
)
return self._build_result("SUCCESS", counts)
result = self._build_result("SUCCESS", counts)
result["window"] = {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
}
result["request_params"] = {"candidates": 0}
return result
payloads, skipped = self._fetch_ticket_payloads(candidates)
counts["skipped"] += skipped
rows: list[dict] = []
for idx, payload in enumerate(payloads):
row = self._build_row(
spec=spec,
store_id=store_id,
record=payload,
page_no=None,
page_size_value=None,
source_file=source_file,
record_index=idx if spec.include_record_index else None,
)
if row is None:
counts["skipped"] += 1
continue
rows.append(row)
inserted, updated, _ = loader.upsert_rows(rows)
inserted, skipped2 = self._insert_records_schema_aware(
table=spec.table_name,
records=payloads,
response_payload=None,
source_file=source_file,
source_endpoint=spec.endpoint,
)
counts["inserted"] += inserted
counts["updated"] += updated
counts["skipped"] += skipped2
self.db.commit()
self.logger.info(
"%s: 灏忕エ鎶撳彇瀹屾垚锛屽€欓€?%s 鎻掑叆=%s 鏇存柊=%s 璺宠繃=%s",
spec.code,
len(candidates),
inserted,
updated,
0,
counts["skipped"],
)
return self._build_result("SUCCESS", counts)
result = self._build_result("SUCCESS", counts)
result["window"] = {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
}
result["request_params"] = {"candidates": len(candidates)}
return result
except Exception:
counts["errors"] += 1
@@ -1026,4 +1194,3 @@ ODS_TASK_CLASSES: Dict[str, Type[BaseOdsTask]] = {
ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"] = OdsSettlementTicketTask
__all__ = ["ODS_TASK_CLASSES", "ODS_TASK_SPECS", "BaseOdsTask", "ENABLED_ODS_CODES"]