补全任务与测试

This commit is contained in:
Neo
2025-11-19 03:36:44 +08:00
parent 5bb5a8a568
commit 9a1df70a23
31 changed files with 3034 additions and 6 deletions

View File

@@ -0,0 +1,638 @@
# -*- coding: utf-8 -*-
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
from __future__ import annotations
import json
import os
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple, Type
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations as PgDBOperations
from tasks.assistant_abolish_task import AssistantAbolishTask
from tasks.assistants_task import AssistantsTask
from tasks.coupon_usage_task import CouponUsageTask
from tasks.inventory_change_task import InventoryChangeTask
from tasks.ledger_task import LedgerTask
from tasks.members_task import MembersTask
from tasks.orders_task import OrdersTask
from tasks.packages_task import PackagesDefTask
from tasks.payments_task import PaymentsTask
from tasks.products_task import ProductsTask
from tasks.refunds_task import RefundsTask
from tasks.table_discount_task import TableDiscountTask
from tasks.tables_task import TablesTask
from tasks.topups_task import TopupsTask
DEFAULT_STORE_ID = 2790685415443269
BASE_TS = "2025-01-01 10:00:00"
END_TS = "2025-01-01 12:00:00"
@dataclass(frozen=True)
class TaskSpec:
"""描述单个任务在测试中如何被驱动的元数据包含任务代码、API 路径、数据路径与样例记录。"""
code: str
task_cls: Type
endpoint: str
data_path: Tuple[str, ...]
sample_records: List[Dict]
@property
def archive_filename(self) -> str:
return endpoint_to_filename(self.endpoint)
def endpoint_to_filename(endpoint: str) -> str:
"""根据 API endpoint 生成稳定可复用的文件名,便于离线模式在目录中直接定位归档 JSON。"""
normalized = endpoint.strip("/").replace("/", "__").replace(" ", "_").lower()
return f"{normalized or 'root'}.json"
def wrap_records(records: List[Dict], data_path: Sequence[str]):
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
payload = records
for key in reversed(data_path):
payload = {key: payload}
return payload
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
"""构建一份适合测试的 AppConfig自动填充存储、日志、归档目录等参数并保证目录存在。"""
archive_dir = Path(archive_dir)
temp_dir = Path(temp_dir)
archive_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
overrides = {
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Taipei"},
"db": {"dsn": "postgresql://user:pass@localhost:5432/etl_billiards_test"},
"api": {
"base_url": "https://api.example.com",
"token": "test-token",
"timeout_sec": 3,
"page_size": 50,
},
"testing": {
"mode": mode,
"json_archive_dir": str(archive_dir),
"temp_json_dir": str(temp_dir),
},
"io": {
"export_root": str(temp_dir / "export"),
"log_root": str(temp_dir / "logs"),
},
}
return AppConfig.load(overrides)
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
archive_dir = Path(archive_dir)
payload = wrap_records(spec.sample_records, spec.data_path)
file_path = archive_dir / spec.archive_filename
with file_path.open("w", encoding="utf-8") as fp:
json.dump(payload, fp, ensure_ascii=False)
return file_path
class FakeCursor:
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
def __init__(self, recorder: List[Dict]):
self.recorder = recorder
# pylint: disable=unused-argument
def execute(self, sql: str, params=None):
self.recorder.append({"sql": sql.strip(), "params": params})
def fetchone(self):
return None
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeConnection:
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
def __init__(self):
self.statements: List[Dict] = []
def cursor(self):
return FakeCursor(self.statements)
class FakeDBOperations:
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
def __init__(self):
self.upserts: List[Dict] = []
self.commits = 0
self.rollbacks = 0
self.conn = FakeConnection()
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.upserts.append({"sql": sql.strip(), "count": len(rows), "page_size": page_size})
return len(rows), 0
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.upserts.append({"sql": sql.strip(), "count": len(rows), "page_size": page_size})
def commit(self):
self.commits += 1
def rollback(self):
self.rollbacks += 1
class FakeAPIClient:
"""在线模式使用的伪 API Client直接返回预置的内存数据并记录调用以确保任务参数正确传递。"""
def __init__(self, data_map: Dict[str, List[Dict]]):
self.data_map = data_map
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def get_paginated(self, endpoint: str, params=None, **kwargs):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.data_map:
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
return list(self.data_map[endpoint]), [{"page": 1, "size": len(self.data_map[endpoint])}]
class OfflineAPIClient:
"""离线模式专用 API Client根据 endpoint 读取归档 JSON、套用 data_path 并回放列表数据。"""
def __init__(self, file_map: Dict[str, Path]):
self.file_map = {k: Path(v) for k, v in file_map.items()}
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def get_paginated(self, endpoint: str, params=None, page_size: int = 200, data_path: Tuple[str, ...] = (), **kwargs):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.file_map:
raise AssertionError(f"Missing archive for endpoint {endpoint}")
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
payload = json.load(fp)
data = payload
for key in data_path:
if isinstance(data, dict):
data = data.get(key, [])
else:
data = []
break
if not isinstance(data, list):
data = []
return data, [{"page": 1, "mode": "offline"}]
class RealDBOperationsAdapter:
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
def __init__(self, dsn: str):
self._conn = DatabaseConnection(dsn)
self._ops = PgDBOperations(self._conn)
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
self.conn = self._conn.conn
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_execute(sql, rows, page_size=page_size)
def commit(self):
self._conn.commit()
def rollback(self):
self._conn.rollback()
def close(self):
self._conn.close()
@contextmanager
def get_db_operations():
"""
测试专用的 DB 操作上下文:
- 若设置 TEST_DB_DSN则连接真实 PostgreSQL
- 否则回退到 FakeDBOperations内存桩
"""
dsn = os.environ.get("TEST_DB_DSN")
if dsn:
adapter = RealDBOperationsAdapter(dsn)
try:
yield adapter
finally:
adapter.close()
else:
fake = FakeDBOperations()
yield fake
TASK_SPECS: List[TaskSpec] = [
TaskSpec(
code="PRODUCTS",
task_cls=ProductsTask,
endpoint="/TenantGoods/QueryTenantGoods",
data_path=("data",),
sample_records=[
{
"siteGoodsId": 101,
"tenantGoodsId": 101,
"goodsName": "测试球杆",
"goodsCategoryId": 201,
"categoryName": "器材",
"goodsCategorySecondId": 202,
"goodsUnit": "",
"costPrice": "100.00",
"goodsPrice": "150.00",
"goodsState": "ON",
"supplierId": 20,
"barcode": "PRD001",
"isCombo": False,
"createTime": BASE_TS,
"updateTime": END_TS,
}
],
),
TaskSpec(
code="TABLES",
task_cls=TablesTask,
endpoint="/Table/GetSiteTables",
data_path=("data", "siteTables"),
sample_records=[
{
"id": 301,
"site_id": 30,
"site_table_area_id": 40,
"areaName": "大厅",
"table_name": "1号桌",
"table_price": "50.00",
"table_status": "FREE",
"tableStatusName": "空闲",
"light_status": "OFF",
"is_rest_area": False,
"show_status": True,
"virtual_table": False,
"charge_free": False,
"only_allow_groupon": False,
"is_online_reservation": True,
"createTime": BASE_TS,
}
],
),
TaskSpec(
code="MEMBERS",
task_cls=MembersTask,
endpoint="/MemberProfile/GetTenantMemberList",
data_path=("data",),
sample_records=[
{
"memberId": 401,
"memberName": "张三",
"phone": "13800000000",
"balance": "88.88",
"status": "ACTIVE",
"registerTime": BASE_TS,
}
],
),
TaskSpec(
code="ASSISTANTS",
task_cls=AssistantsTask,
endpoint="/Assistant/List",
data_path=("data", "assistantInfos"),
sample_records=[
{
"id": 501,
"assistant_no": "AS001",
"nickname": "小李",
"real_name": "李雷",
"gender": "M",
"mobile": "13900000000",
"level": "A",
"team_id": 10,
"team_name": "先锋队",
"assistant_status": "ON",
"work_status": "BUSY",
"entry_time": BASE_TS,
"resign_time": END_TS,
"start_time": BASE_TS,
"end_time": END_TS,
"create_time": BASE_TS,
"update_time": END_TS,
"system_role_id": 1,
"online_status": "ONLINE",
"allow_cx": True,
"charge_way": "TIME",
"pd_unit_price": "30.00",
"cx_unit_price": "20.00",
"is_guaranteed": True,
"is_team_leader": False,
"serial_number": "SN001",
"show_sort": 1,
"is_delete": False,
}
],
),
TaskSpec(
code="PACKAGES_DEF",
task_cls=PackagesDefTask,
endpoint="/Package/List",
data_path=("data", "packageCouponList"),
sample_records=[
{
"id": 601,
"package_id": "PKG001",
"package_name": "白天特惠",
"table_area_id": 70,
"table_area_name": "大厅",
"selling_price": "199.00",
"duration": 120,
"start_time": BASE_TS,
"end_time": END_TS,
"type": "Groupon",
"is_enabled": True,
"is_delete": False,
"usable_count": 3,
"creator_name": "系统",
"date_type": "WEEKDAY",
"group_type": "DINE_IN",
"coupon_money": "30.00",
"area_tag_type": "VIP",
"system_group_type": "BASIC",
"card_type_ids": "1,2,3",
}
],
),
TaskSpec(
code="ORDERS",
task_cls=OrdersTask,
endpoint="/order/list",
data_path=("data",),
sample_records=[
{
"orderId": 701,
"orderNo": "ORD001",
"memberId": 401,
"tableId": 301,
"orderTime": BASE_TS,
"endTime": END_TS,
"totalAmount": "300.00",
"discountAmount": "20.00",
"finalAmount": "280.00",
"payStatus": "PAID",
"orderStatus": "CLOSED",
"remark": "测试订单",
}
],
),
TaskSpec(
code="PAYMENTS",
task_cls=PaymentsTask,
endpoint="/pay/records",
data_path=("data",),
sample_records=[
{
"payId": 801,
"orderId": 701,
"payTime": END_TS,
"payAmount": "280.00",
"payType": "CARD",
"payStatus": "SUCCESS",
"remark": "测试支付",
}
],
),
TaskSpec(
code="REFUNDS",
task_cls=RefundsTask,
endpoint="/Pay/RefundList",
data_path=(),
sample_records=[
{
"id": 901,
"site_id": 1,
"tenant_id": 2,
"pay_amount": "100.00",
"pay_status": "SUCCESS",
"pay_time": END_TS,
"create_time": END_TS,
"relate_type": "ORDER",
"relate_id": 701,
"payment_method": "CARD",
"refund_amount": "20.00",
"action_type": "PARTIAL",
"pay_terminal": "POS",
"operator_id": 11,
"channel_pay_no": "CH001",
"channel_fee": "1.00",
"is_delete": False,
"member_id": 401,
"member_card_id": 501,
}
],
),
TaskSpec(
code="COUPON_USAGE",
task_cls=CouponUsageTask,
endpoint="/Coupon/UsageList",
data_path=(),
sample_records=[
{
"id": 1001,
"coupon_code": "CP001",
"coupon_channel": "MEITUAN",
"coupon_name": "双人券",
"sale_price": "50.00",
"coupon_money": "30.00",
"coupon_free_time": 60,
"use_status": "USED",
"create_time": BASE_TS,
"consume_time": END_TS,
"operator_id": 11,
"operator_name": "操作员",
"table_id": 301,
"site_order_id": 701,
"group_package_id": 601,
"coupon_remark": "备注",
"deal_id": "DEAL001",
"certificate_id": "CERT001",
"verify_id": "VERIFY001",
"is_delete": False,
}
],
),
TaskSpec(
code="INVENTORY_CHANGE",
task_cls=InventoryChangeTask,
endpoint="/Inventory/ChangeList",
data_path=("data", "queryDeliveryRecordsList"),
sample_records=[
{
"siteGoodsStockId": 1101,
"siteGoodsId": 101,
"stockType": "OUT",
"goodsName": "测试球杆",
"createTime": END_TS,
"startNum": 10,
"endNum": 8,
"changeNum": -2,
"unit": "",
"price": "120.00",
"operatorName": "仓管",
"remark": "测试出库",
"goodsCategoryId": 201,
"goodsSecondCategoryId": 202,
}
],
),
TaskSpec(
code="TOPUPS",
task_cls=TopupsTask,
endpoint="/Topup/SettleList",
data_path=("data", "settleList"),
sample_records=[
{
"id": 1201,
"memberId": 401,
"memberName": "张三",
"memberPhone": "13800000000",
"tenantMemberCardId": 1301,
"memberCardTypeName": "金卡",
"payAmount": "500.00",
"consumeMoney": "100.00",
"settleStatus": "DONE",
"settleType": "AUTO",
"settleName": "日结",
"settleRelateId": 1501,
"payTime": BASE_TS,
"createTime": END_TS,
"operatorId": 11,
"operatorName": "收银员",
"paymentMethod": "CASH",
"refundAmount": "0",
"cashAmount": "500.00",
"cardAmount": "0",
"balanceAmount": "0",
"onlineAmount": "0",
"roundingAmount": "0",
"adjustAmount": "0",
"goodsMoney": "0",
"tableChargeMoney": "0",
"serviceMoney": "0",
"couponAmount": "0",
"orderRemark": "首次充值",
}
],
),
TaskSpec(
code="TABLE_DISCOUNT",
task_cls=TableDiscountTask,
endpoint="/Table/AdjustList",
data_path=("data", "taiFeeAdjustInfos"),
sample_records=[
{
"id": 1301,
"adjust_type": "DISCOUNT",
"applicant_id": 11,
"applicant_name": "店长",
"operator_id": 22,
"operator_name": "值班",
"ledger_amount": "50.00",
"ledger_count": 2,
"ledger_name": "调价",
"ledger_status": "APPROVED",
"order_settle_id": 7010,
"order_trade_no": 8001,
"site_table_id": 301,
"create_time": END_TS,
"is_delete": False,
"tableProfile": {
"id": 301,
"site_table_area_id": 40,
"site_table_area_name": "大厅",
},
}
],
),
TaskSpec(
code="ASSISTANT_ABOLISH",
task_cls=AssistantAbolishTask,
endpoint="/Assistant/AbolishList",
data_path=("data", "abolitionAssistants"),
sample_records=[
{
"id": 1401,
"tableId": 301,
"tableName": "1号桌",
"tableAreaId": 40,
"tableArea": "大厅",
"assistantOn": "AS001",
"assistantName": "小李",
"pdChargeMinutes": 30,
"assistantAbolishAmount": "15.00",
"createTime": END_TS,
"trashReason": "测试",
}
],
),
TaskSpec(
code="LEDGER",
task_cls=LedgerTask,
endpoint="/Assistant/LedgerList",
data_path=("data", "orderAssistantDetails"),
sample_records=[
{
"id": 1501,
"assistantNo": "AS001",
"assistantName": "小李",
"nickname": "",
"levelName": "L1",
"tableName": "1号桌",
"ledger_unit_price": "30.00",
"ledger_count": 2,
"ledger_amount": "60.00",
"projected_income": "80.00",
"service_money": "5.00",
"member_discount_amount": "2.00",
"manual_discount_amount": "1.00",
"coupon_deduct_money": "3.00",
"order_trade_no": 8001,
"order_settle_id": 7010,
"operator_id": 22,
"operator_name": "值班",
"assistant_team_id": 10,
"assistant_level": "A",
"site_table_id": 301,
"order_assistant_id": 1601,
"site_assistant_id": 501,
"user_id": 5010,
"ledger_start_time": BASE_TS,
"ledger_end_time": END_TS,
"start_use_time": BASE_TS,
"last_use_time": END_TS,
"income_seconds": 3600,
"real_use_seconds": 3300,
"is_trash": False,
"trash_reason": "",
"is_confirm": True,
"ledger_status": "CLOSED",
"create_time": END_TS,
}
],
),
]

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""离线模式任务测试,通过回放归档 JSON 来验证 T+L 链路可用。"""
import logging
from pathlib import Path
import pytest
from .task_test_utils import (
TASK_SPECS,
OfflineAPIClient,
create_test_config,
dump_offline_payload,
get_db_operations,
)
@pytest.mark.parametrize("spec", TASK_SPECS, ids=lambda spec: spec.code)
def test_task_offline_mode(spec, tmp_path):
"""确保每个任务都能读取归档 JSON 并完成 Transform + Load 操作。"""
archive_dir = tmp_path / "archive"
temp_dir = tmp_path / "tmp"
archive_dir.mkdir()
temp_dir.mkdir()
file_path = dump_offline_payload(spec, archive_dir)
config = create_test_config("OFFLINE", archive_dir, temp_dir)
offline_api = OfflineAPIClient({spec.endpoint: Path(file_path)})
logger = logging.getLogger(f"test_offline_{spec.code.lower()}")
with get_db_operations() as db_ops:
task = spec.task_cls(config, db_ops, offline_api, logger)
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == len(spec.sample_records)
assert result["counts"]["inserted"] == len(spec.sample_records)
if hasattr(db_ops, "commits"):
assert db_ops.commits == 1
assert db_ops.rollbacks == 0

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""在线模式下的端到端任务测试,验证所有任务在模拟 API 下能顺利执行。"""
import logging
import pytest
from .task_test_utils import (
TASK_SPECS,
FakeAPIClient,
create_test_config,
get_db_operations,
)
@pytest.mark.parametrize("spec", TASK_SPECS, ids=lambda spec: spec.code)
def test_task_online_mode(spec, tmp_path):
"""针对每个 TaskSpec 验证:模拟 API 数据下依旧能完整跑完 ETL并正确统计。"""
archive_dir = tmp_path / "archive"
temp_dir = tmp_path / "tmp"
config = create_test_config("ONLINE", archive_dir, temp_dir)
fake_api = FakeAPIClient({spec.endpoint: spec.sample_records})
logger = logging.getLogger(f"test_online_{spec.code.lower()}")
with get_db_operations() as db_ops:
task = spec.task_cls(config, db_ops, fake_api, logger)
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == len(spec.sample_records)
assert result["counts"]["inserted"] == len(spec.sample_records)
if hasattr(db_ops, "commits"):
assert db_ops.commits == 1
assert db_ops.rollbacks == 0