合并
This commit is contained in:
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -99,16 +101,87 @@ def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
|
||||
class FakeCursor:
|
||||
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
|
||||
|
||||
def __init__(self, recorder: List[Dict]):
|
||||
def __init__(self, recorder: List[Dict], db_ops=None):
|
||||
self.recorder = recorder
|
||||
self._db_ops = db_ops
|
||||
self._pending_rows: List[Tuple] = []
|
||||
self._fetchall_rows: List[Tuple] = []
|
||||
self.rowcount = 0
|
||||
self.connection = SimpleNamespace(encoding="UTF8")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def execute(self, sql: str, params=None):
|
||||
self.recorder.append({"sql": sql.strip(), "params": params})
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
# Handle information_schema queries for schema-aware inserts.
|
||||
lowered = sql_text.lower()
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
self._fetchall_rows = []
|
||||
return
|
||||
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
def fetchone(self):
|
||||
return None
|
||||
|
||||
def fetchall(self):
|
||||
return list(self._fetchall_rows)
|
||||
|
||||
def mogrify(self, template, args):
|
||||
self._pending_rows.append(tuple(args))
|
||||
return b"(?)"
|
||||
|
||||
def _record_upserts(self, sql_text: str):
|
||||
if not self._db_ops:
|
||||
return
|
||||
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
|
||||
if not match:
|
||||
return
|
||||
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
|
||||
rows = []
|
||||
for idx, row in enumerate(self._pending_rows):
|
||||
if len(row) != len(columns):
|
||||
continue
|
||||
row_dict = {}
|
||||
for col, val in zip(columns, row):
|
||||
if col == "record_index" and val in (None, ""):
|
||||
row_dict[col] = idx
|
||||
continue
|
||||
if hasattr(val, "adapted"):
|
||||
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
|
||||
else:
|
||||
row_dict[col] = val
|
||||
rows.append(row_dict)
|
||||
if rows:
|
||||
self._db_ops.upserts.append(
|
||||
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
|
||||
return [
|
||||
("id", "bigint", "int8"),
|
||||
("sitegoodsstockid", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("source_file", "text", "text"),
|
||||
("source_endpoint", "text", "text"),
|
||||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@@ -119,11 +192,12 @@ class FakeCursor:
|
||||
class FakeConnection:
|
||||
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, db_ops):
|
||||
self.statements: List[Dict] = []
|
||||
self._db_ops = db_ops
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.statements)
|
||||
return FakeCursor(self.statements, self._db_ops)
|
||||
|
||||
|
||||
class FakeDBOperations:
|
||||
@@ -134,7 +208,7 @@ class FakeDBOperations:
|
||||
self.executes: List[Dict] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
self.conn = FakeConnection()
|
||||
self.conn = FakeConnection(self)
|
||||
# Pre-seeded query results (FIFO) to let tests control DB-returned rows
|
||||
self.query_results: List[List[Dict]] = []
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from config.defaults import DEFAULTS
|
||||
|
||||
def test_config_load():
|
||||
"""测试配置加载"""
|
||||
config = AppConfig.load()
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("app.timezone") == DEFAULTS["app"]["timezone"]
|
||||
|
||||
def test_config_override():
|
||||
@@ -19,6 +19,6 @@ def test_config_override():
|
||||
|
||||
def test_config_get_nested():
|
||||
"""测试嵌套配置获取"""
|
||||
config = AppConfig.load()
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("db.batch_size") == 1000
|
||||
assert config.get("nonexistent.key", "default") == "default"
|
||||
|
||||
68
etl_billiards/tests/unit/test_endpoint_routing.py
Normal file
68
etl_billiards/tests/unit/test_endpoint_routing.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for recent/former endpoint routing."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from api.endpoint_routing import plan_calls, recent_boundary
|
||||
|
||||
|
||||
TZ = ZoneInfo("Asia/Shanghai")
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime(2025, 12, 18, 10, 0, 0, tzinfo=TZ)
|
||||
|
||||
|
||||
def test_recent_boundary_month_start():
|
||||
b = recent_boundary(_now())
|
||||
assert b.isoformat() == "2025-09-01T00:00:00+08:00"
|
||||
|
||||
|
||||
def test_paylog_routes_to_former_when_old_window():
|
||||
params = {"siteId": 1, "StartPayTime": "2025-08-01 00:00:00", "EndPayTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/PayLog/GetPayLogListPage", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/PayLog/GetFormerPayLogListPage"]
|
||||
|
||||
|
||||
def test_coupon_usage_stays_same_path_even_when_old():
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/Promotion/GetOfflineCouponConsumePageList", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/Promotion/GetOfflineCouponConsumePageList"]
|
||||
|
||||
|
||||
def test_goods_outbound_routes_to_queryformer_when_old():
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/GoodsStockManage/QueryGoodsOutboundReceipt", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/GoodsStockManage/QueryFormerGoodsOutboundReceipt"]
|
||||
|
||||
|
||||
def test_settlement_records_split_when_crossing_boundary():
|
||||
params = {"siteId": 1, "rangeStartTime": "2025-08-15 00:00:00", "rangeEndTime": "2025-09-10 00:00:00"}
|
||||
calls = plan_calls("/Site/GetAllOrderSettleList", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/Site/GetFormerOrderSettleList", "/Site/GetAllOrderSettleList"]
|
||||
assert calls[0].params["rangeEndTime"] == "2025-09-01 00:00:00"
|
||||
assert calls[1].params["rangeStartTime"] == "2025-09-01 00:00:00"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/PayLog/GetFormerPayLogListPage",
|
||||
"/Site/GetFormerOrderSettleList",
|
||||
"/GoodsStockManage/QueryFormerGoodsOutboundReceipt",
|
||||
],
|
||||
)
|
||||
def test_explicit_former_endpoint_not_rerouted(endpoint):
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls(endpoint, params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == [endpoint]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the new ODS ingestion tasks."""
|
||||
import logging
|
||||
import os
|
||||
@@ -22,21 +22,21 @@ def _build_config(tmp_path):
|
||||
return create_test_config("ONLINE", archive_dir, temp_dir)
|
||||
|
||||
|
||||
def test_ods_assistant_accounts_ingest(tmp_path):
|
||||
"""Ensure ODS_ASSISTANT_ACCOUNTS task stores raw payload with record_index dedup keys."""
|
||||
def test_assistant_accounts_masters_ingest(tmp_path):
|
||||
"""Ensure ODS_ASSISTANT_ACCOUNT stores raw payload with record_index dedup keys."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [
|
||||
{
|
||||
"id": 5001,
|
||||
"assistant_no": "A01",
|
||||
"nickname": "小张",
|
||||
"nickname": "灏忓紶",
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/PersonnelManagement/SearchAssistantInfo": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_ASSISTANT_ACCOUNTS"]
|
||||
task_cls = ODS_TASK_CLASSES["ODS_ASSISTANT_ACCOUNT"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_assistant_accounts"))
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_assistant_accounts_masters"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
@@ -49,21 +49,21 @@ def test_ods_assistant_accounts_ingest(tmp_path):
|
||||
assert '"id": 5001' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_inventory_change_ingest(tmp_path):
|
||||
"""Ensure ODS_INVENTORY_CHANGE task stores raw payload with record_index dedup keys."""
|
||||
def test_goods_stock_movements_ingest(tmp_path):
|
||||
"""Ensure ODS_INVENTORY_CHANGE stores raw payload with record_index dedup keys."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [
|
||||
{
|
||||
"siteGoodsStockId": 123456,
|
||||
"stockType": 1,
|
||||
"goodsName": "测试商品",
|
||||
"goodsName": "娴嬭瘯鍟嗗搧",
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/GoodsStockManage/QueryGoodsOutboundReceipt": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_INVENTORY_CHANGE"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_inventory_change"))
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_goods_stock_movements"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
@@ -75,7 +75,7 @@ def test_ods_inventory_change_ingest(tmp_path):
|
||||
assert '"siteGoodsStockId": 123456' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_member_profiles_ingest(tmp_path):
|
||||
def test_member_profiless_ingest(tmp_path):
|
||||
"""Ensure ODS_MEMBER task stores tenantMemberInfos raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"tenantMemberInfos": [{"id": 101, "mobile": "13800000000"}]}]
|
||||
@@ -110,14 +110,14 @@ def test_ods_payment_ingest(tmp_path):
|
||||
|
||||
|
||||
def test_ods_settlement_records_ingest(tmp_path):
|
||||
"""Ensure ODS_ORDER_SETTLE task stores settleList raw JSON."""
|
||||
"""Ensure ODS_SETTLEMENT_RECORDS stores settleList raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"data": {"settleList": [{"id": 701, "orderTradeNo": 8001}]}}]
|
||||
sample = [{"id": 701, "orderTradeNo": 8001}]
|
||||
api = FakeAPIClient({"/Site/GetAllOrderSettleList": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_ORDER_SETTLE"]
|
||||
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_RECORDS"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_order_settle"))
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_settlement_records"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
@@ -158,3 +158,4 @@ def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
|
||||
and call.get("params", {}).get("orderSettleId") == 9001
|
||||
for call in api.calls
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user