# -*- coding: utf-8 -*- """ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。""" from __future__ import annotations import json import os from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Sequence, Tuple, Type from config.settings import AppConfig from database.connection import DatabaseConnection from database.operations import DatabaseOperations as PgDBOperations from tasks.assistant_abolish_task import AssistantAbolishTask from tasks.assistants_task import AssistantsTask from tasks.coupon_usage_task import CouponUsageTask from tasks.inventory_change_task import InventoryChangeTask from tasks.ledger_task import LedgerTask from tasks.members_task import MembersTask from tasks.orders_task import OrdersTask from tasks.packages_task import PackagesDefTask from tasks.payments_task import PaymentsTask from tasks.products_task import ProductsTask from tasks.refunds_task import RefundsTask from tasks.table_discount_task import TableDiscountTask from tasks.tables_task import TablesTask from tasks.topups_task import TopupsTask DEFAULT_STORE_ID = 2790685415443269 BASE_TS = "2025-01-01 10:00:00" END_TS = "2025-01-01 12:00:00" @dataclass(frozen=True) class TaskSpec: """描述单个任务在测试中如何被驱动的元数据,包含任务代码、API 路径、数据路径与样例记录。""" code: str task_cls: Type endpoint: str data_path: Tuple[str, ...] sample_records: List[Dict] @property def archive_filename(self) -> str: return endpoint_to_filename(self.endpoint) def endpoint_to_filename(endpoint: str) -> str: """根据 API endpoint 生成稳定可复用的文件名,便于离线模式在目录中直接定位归档 JSON。""" normalized = endpoint.strip("/").replace("/", "__").replace(" ", "_").lower() return f"{normalized or 'root'}.json" def wrap_records(records: List[Dict], data_path: Sequence[str]): """按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。""" payload = records for key in reversed(data_path): payload = {key: payload} return payload def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig: """构建一份适合测试的 AppConfig,自动填充存储、日志、归档目录等参数并保证目录存在。""" archive_dir = Path(archive_dir) temp_dir = Path(temp_dir) archive_dir.mkdir(parents=True, exist_ok=True) temp_dir.mkdir(parents=True, exist_ok=True) flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY" overrides = { "app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Taipei"}, "db": {"dsn": "postgresql://user:pass@localhost:5432/etl_billiards_test"}, "api": { "base_url": "https://api.example.com", "token": "test-token", "timeout_sec": 3, "page_size": 50, }, "pipeline": { "flow": flow, "fetch_root": str(temp_dir / "json_fetch"), "ingest_source_dir": str(archive_dir), }, "io": { "export_root": str(temp_dir / "export"), "log_root": str(temp_dir / "logs"), }, } return AppConfig.load(overrides) def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path: """将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。""" archive_dir = Path(archive_dir) payload = wrap_records(spec.sample_records, spec.data_path) file_path = archive_dir / spec.archive_filename with file_path.open("w", encoding="utf-8") as fp: json.dump(payload, fp, ensure_ascii=False) return file_path class FakeCursor: """极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。""" def __init__(self, recorder: List[Dict]): self.recorder = recorder # pylint: disable=unused-argument def execute(self, sql: str, params=None): self.recorder.append({"sql": sql.strip(), "params": params}) def fetchone(self): return None def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False class FakeConnection: """仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。""" def __init__(self): self.statements: List[Dict] = [] def cursor(self): return FakeCursor(self.statements) class FakeDBOperations: """拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。""" def __init__(self): self.upserts: List[Dict] = [] self.executes: List[Dict] = [] self.commits = 0 self.rollbacks = 0 self.conn = FakeConnection() def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000): self.upserts.append( { "sql": sql.strip(), "count": len(rows), "page_size": page_size, "rows": [dict(row) for row in rows], } ) return len(rows), 0 def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000): self.executes.append( { "sql": sql.strip(), "count": len(rows), "page_size": page_size, "rows": [dict(row) for row in rows], } ) def execute(self, sql: str, params=None): self.executes.append({"sql": sql.strip(), "params": params}) def query(self, sql: str, params=None): self.executes.append({"sql": sql.strip(), "params": params, "type": "query"}) return [] def cursor(self): return self.conn.cursor() def commit(self): self.commits += 1 def rollback(self): self.rollbacks += 1 class FakeAPIClient: """在线模式使用的伪 API Client,直接返回预置的内存数据并记录调用,以确保任务参数正确传递。""" def __init__(self, data_map: Dict[str, List[Dict]]): self.data_map = data_map self.calls: List[Dict] = [] # pylint: disable=unused-argument def iter_paginated( self, endpoint: str, params=None, page_size: int = 200, page_field: str = "page", size_field: str = "limit", data_path: Tuple[str, ...] = (), list_key: str | None = None, ): self.calls.append({"endpoint": endpoint, "params": params}) if endpoint not in self.data_map: raise AssertionError(f"Missing fixture for endpoint {endpoint}") records = list(self.data_map[endpoint]) yield 1, records, dict(params or {}), {"data": records} def get_paginated(self, endpoint: str, params=None, **kwargs): records = [] pages = [] for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs): records.extend(page_records) pages.append({"page": page_no, "request": req, "response": resp}) return records, pages def get_source_hint(self, endpoint: str) -> str | None: return None class OfflineAPIClient: """离线模式专用 API Client,根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。""" def __init__(self, file_map: Dict[str, Path]): self.file_map = {k: Path(v) for k, v in file_map.items()} self.calls: List[Dict] = [] # pylint: disable=unused-argument def iter_paginated( self, endpoint: str, params=None, page_size: int = 200, page_field: str = "page", size_field: str = "limit", data_path: Tuple[str, ...] = (), list_key: str | None = None, ): self.calls.append({"endpoint": endpoint, "params": params}) if endpoint not in self.file_map: raise AssertionError(f"Missing archive for endpoint {endpoint}") with self.file_map[endpoint].open("r", encoding="utf-8") as fp: payload = json.load(fp) data = payload for key in data_path: if isinstance(data, dict): data = data.get(key, []) if list_key and isinstance(data, dict): data = data.get(list_key, []) if not isinstance(data, list): data = [] total = len(data) start = 0 page = 1 while start < total or (start == 0 and total == 0): chunk = data[start : start + page_size] if not chunk and total != 0: break yield page, list(chunk), dict(params or {}), payload if len(chunk) < page_size: break start += page_size page += 1 def get_paginated(self, endpoint: str, params=None, **kwargs): records = [] pages = [] for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs): records.extend(page_records) pages.append({"page": page_no, "request": req, "response": resp}) return records, pages def get_source_hint(self, endpoint: str) -> str | None: if endpoint not in self.file_map: return None return str(self.file_map[endpoint]) class RealDBOperationsAdapter: """连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。""" def __init__(self, dsn: str): self._conn = DatabaseConnection(dsn) self._ops = PgDBOperations(self._conn) # SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接 self.conn = self._conn.conn def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000): return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size) def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000): return self._ops.batch_execute(sql, rows, page_size=page_size) def commit(self): self._conn.commit() def rollback(self): self._conn.rollback() def close(self): self._conn.close() @contextmanager def get_db_operations(): """ 测试专用的 DB 操作上下文: - 若设置 TEST_DB_DSN,则连接真实 PostgreSQL; - 否则回退到 FakeDBOperations(内存桩)。 """ dsn = os.environ.get("TEST_DB_DSN") if dsn: adapter = RealDBOperationsAdapter(dsn) try: yield adapter finally: adapter.close() else: fake = FakeDBOperations() yield fake TASK_SPECS: List[TaskSpec] = [ TaskSpec( code="PRODUCTS", task_cls=ProductsTask, endpoint="/TenantGoods/QueryTenantGoods", data_path=("data", "tenantGoodsList"), sample_records=[ { "siteGoodsId": 101, "tenantGoodsId": 101, "goodsName": "测试球杆", "goodsCategoryId": 201, "categoryName": "器材", "goodsCategorySecondId": 202, "goodsUnit": "支", "costPrice": "100.00", "goodsPrice": "150.00", "goodsState": "ON", "supplierId": 20, "barcode": "PRD001", "isCombo": False, "createTime": BASE_TS, "updateTime": END_TS, } ], ), TaskSpec( code="TABLES", task_cls=TablesTask, endpoint="/Table/GetSiteTables", data_path=("data", "siteTables"), sample_records=[ { "id": 301, "site_id": 30, "site_table_area_id": 40, "areaName": "大厅", "table_name": "1号桌", "table_price": "50.00", "table_status": "FREE", "tableStatusName": "空闲", "light_status": "OFF", "is_rest_area": False, "show_status": True, "virtual_table": False, "charge_free": False, "only_allow_groupon": False, "is_online_reservation": True, "createTime": BASE_TS, } ], ), TaskSpec( code="MEMBERS", task_cls=MembersTask, endpoint="/MemberProfile/GetTenantMemberList", data_path=("data", "tenantMemberInfos"), sample_records=[ { "memberId": 401, "memberName": "张三", "phone": "13800000000", "balance": "88.88", "status": "ACTIVE", "registerTime": BASE_TS, } ], ), TaskSpec( code="ASSISTANTS", task_cls=AssistantsTask, endpoint="/PersonnelManagement/SearchAssistantInfo", data_path=("data", "assistantInfos"), sample_records=[ { "id": 501, "assistant_no": "AS001", "nickname": "小李", "real_name": "李雷", "gender": "M", "mobile": "13900000000", "level": "A", "team_id": 10, "team_name": "先锋队", "assistant_status": "ON", "work_status": "BUSY", "entry_time": BASE_TS, "resign_time": END_TS, "start_time": BASE_TS, "end_time": END_TS, "create_time": BASE_TS, "update_time": END_TS, "system_role_id": 1, "online_status": "ONLINE", "allow_cx": True, "charge_way": "TIME", "pd_unit_price": "30.00", "cx_unit_price": "20.00", "is_guaranteed": True, "is_team_leader": False, "serial_number": "SN001", "show_sort": 1, "is_delete": False, } ], ), TaskSpec( code="PACKAGES_DEF", task_cls=PackagesDefTask, endpoint="/PackageCoupon/QueryPackageCouponList", data_path=("data", "packageCouponList"), sample_records=[ { "id": 601, "package_id": "PKG001", "package_name": "白天特惠", "table_area_id": 70, "table_area_name": "大厅", "selling_price": "199.00", "duration": 120, "start_time": BASE_TS, "end_time": END_TS, "type": "Groupon", "is_enabled": True, "is_delete": False, "usable_count": 3, "creator_name": "系统", "date_type": "WEEKDAY", "group_type": "DINE_IN", "coupon_money": "30.00", "area_tag_type": "VIP", "system_group_type": "BASIC", "card_type_ids": "1,2,3", } ], ), TaskSpec( code="ORDERS", task_cls=OrdersTask, endpoint="/Site/GetAllOrderSettleList", data_path=("data", "settleList"), sample_records=[ { "orderId": 701, "orderNo": "ORD001", "memberId": 401, "tableId": 301, "orderTime": BASE_TS, "endTime": END_TS, "totalAmount": "300.00", "discountAmount": "20.00", "finalAmount": "280.00", "payStatus": "PAID", "orderStatus": "CLOSED", "remark": "测试订单", } ], ), TaskSpec( code="PAYMENTS", task_cls=PaymentsTask, endpoint="/PayLog/GetPayLogListPage", data_path=("data",), sample_records=[ { "payId": 801, "orderId": 701, "payTime": END_TS, "payAmount": "280.00", "payType": "CARD", "payStatus": "SUCCESS", "remark": "测试支付", } ], ), TaskSpec( code="REFUNDS", task_cls=RefundsTask, endpoint="/Order/GetRefundPayLogList", data_path=("data",), sample_records=[ { "id": 901, "site_id": 1, "tenant_id": 2, "pay_amount": "100.00", "pay_status": "SUCCESS", "pay_time": END_TS, "create_time": END_TS, "relate_type": "ORDER", "relate_id": 701, "payment_method": "CARD", "refund_amount": "20.00", "action_type": "PARTIAL", "pay_terminal": "POS", "operator_id": 11, "channel_pay_no": "CH001", "channel_fee": "1.00", "is_delete": False, "member_id": 401, "member_card_id": 501, } ], ), TaskSpec( code="COUPON_USAGE", task_cls=CouponUsageTask, endpoint="/Promotion/GetOfflineCouponConsumePageList", data_path=("data",), sample_records=[ { "id": 1001, "coupon_code": "CP001", "coupon_channel": "MEITUAN", "coupon_name": "双人券", "sale_price": "50.00", "coupon_money": "30.00", "coupon_free_time": 60, "use_status": "USED", "create_time": BASE_TS, "consume_time": END_TS, "operator_id": 11, "operator_name": "操作员", "table_id": 301, "site_order_id": 701, "group_package_id": 601, "coupon_remark": "备注", "deal_id": "DEAL001", "certificate_id": "CERT001", "verify_id": "VERIFY001", "is_delete": False, } ], ), TaskSpec( code="INVENTORY_CHANGE", task_cls=InventoryChangeTask, endpoint="/GoodsStockManage/QueryGoodsOutboundReceipt", data_path=("data", "queryDeliveryRecordsList"), sample_records=[ { "siteGoodsStockId": 1101, "siteGoodsId": 101, "stockType": "OUT", "goodsName": "测试球杆", "createTime": END_TS, "startNum": 10, "endNum": 8, "changeNum": -2, "unit": "支", "price": "120.00", "operatorName": "仓管", "remark": "测试出库", "goodsCategoryId": 201, "goodsSecondCategoryId": 202, } ], ), TaskSpec( code="TOPUPS", task_cls=TopupsTask, endpoint="/Site/GetRechargeSettleList", data_path=("data", "settleList"), sample_records=[ { "id": 1201, "memberId": 401, "memberName": "张三", "memberPhone": "13800000000", "tenantMemberCardId": 1301, "memberCardTypeName": "金卡", "payAmount": "500.00", "consumeMoney": "100.00", "settleStatus": "DONE", "settleType": "AUTO", "settleName": "日结", "settleRelateId": 1501, "payTime": BASE_TS, "createTime": END_TS, "operatorId": 11, "operatorName": "收银员", "paymentMethod": "CASH", "refundAmount": "0", "cashAmount": "500.00", "cardAmount": "0", "balanceAmount": "0", "onlineAmount": "0", "roundingAmount": "0", "adjustAmount": "0", "goodsMoney": "0", "tableChargeMoney": "0", "serviceMoney": "0", "couponAmount": "0", "orderRemark": "首次充值", } ], ), TaskSpec( code="TABLE_DISCOUNT", task_cls=TableDiscountTask, endpoint="/Site/GetTaiFeeAdjustList", data_path=("data", "taiFeeAdjustInfos"), sample_records=[ { "id": 1301, "adjust_type": "DISCOUNT", "applicant_id": 11, "applicant_name": "店长", "operator_id": 22, "operator_name": "值班", "ledger_amount": "50.00", "ledger_count": 2, "ledger_name": "调价", "ledger_status": "APPROVED", "order_settle_id": 7010, "order_trade_no": 8001, "site_table_id": 301, "create_time": END_TS, "is_delete": False, "tableProfile": { "id": 301, "site_table_area_id": 40, "site_table_area_name": "大厅", }, } ], ), TaskSpec( code="ASSISTANT_ABOLISH", task_cls=AssistantAbolishTask, endpoint="/AssistantPerformance/GetAbolitionAssistant", data_path=("data", "abolitionAssistants"), sample_records=[ { "id": 1401, "tableId": 301, "tableName": "1号桌", "tableAreaId": 40, "tableArea": "大厅", "assistantOn": "AS001", "assistantName": "小李", "pdChargeMinutes": 30, "assistantAbolishAmount": "15.00", "createTime": END_TS, "trashReason": "测试", } ], ), TaskSpec( code="LEDGER", task_cls=LedgerTask, endpoint="/AssistantPerformance/GetOrderAssistantDetails", data_path=("data", "orderAssistantDetails"), sample_records=[ { "id": 1501, "assistantNo": "AS001", "assistantName": "小李", "nickname": "李", "levelName": "L1", "tableName": "1号桌", "ledger_unit_price": "30.00", "ledger_count": 2, "ledger_amount": "60.00", "projected_income": "80.00", "service_money": "5.00", "member_discount_amount": "2.00", "manual_discount_amount": "1.00", "coupon_deduct_money": "3.00", "order_trade_no": 8001, "order_settle_id": 7010, "operator_id": 22, "operator_name": "值班", "assistant_team_id": 10, "assistant_level": "A", "site_table_id": 301, "order_assistant_id": 1601, "site_assistant_id": 501, "user_id": 5010, "ledger_start_time": BASE_TS, "ledger_end_time": END_TS, "start_use_time": BASE_TS, "last_use_time": END_TS, "income_seconds": 3600, "real_use_seconds": 3300, "is_trash": False, "trash_reason": "", "is_confirm": True, "ledger_status": "CLOSED", "create_time": END_TS, } ], ), ]