init: 项目初始提交 - NeoZQYY Monorepo 完整代码
This commit is contained in:
392
apps/etl/pipelines/feiqiu/tests/unit/task_test_utils.py
Normal file
392
apps/etl/pipelines/feiqiu/tests/unit/task_test_utils.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-14 | 删除废弃的 14 个独立 ODS 任务 TaskSpec 定义和对应 import;修复语法错误(TASK_SPECS=[] 后残留孤立代码块)
|
||||
# 直接原因: 之前清理只把 TASK_SPECS 赋值为空列表,但未删除后续 ~370 行废弃 TaskSpec 定义,导致 IndentationError
|
||||
# 验证: `python -c "import ast; ast.parse(open('tests/unit/task_test_utils.py','utf-8').read()); print('OK')"`
|
||||
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Tuple, Type
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations as PgDBOperations
|
||||
from utils.json_store import endpoint_to_filename
|
||||
|
||||
DEFAULT_STORE_ID = 2790685415443269
|
||||
BASE_TS = "2025-01-01 10:00:00"
|
||||
END_TS = "2025-01-01 12:00:00"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskSpec:
|
||||
"""描述单个任务在测试中如何被驱动的元数据,包含任务代码、API 路径、数据路径与样例记录。"""
|
||||
|
||||
code: str
|
||||
task_cls: Type
|
||||
endpoint: str
|
||||
data_path: Tuple[str, ...]
|
||||
sample_records: List[Dict]
|
||||
|
||||
@property
|
||||
def archive_filename(self) -> str:
|
||||
return endpoint_to_filename(self.endpoint)
|
||||
|
||||
|
||||
def wrap_records(records: List[Dict], data_path: Sequence[str]):
|
||||
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
|
||||
payload = records
|
||||
for key in reversed(data_path):
|
||||
payload = {key: payload}
|
||||
return payload
|
||||
|
||||
|
||||
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
|
||||
"""构建一份适合测试的 AppConfig,自动填充存储、日志、归档目录等参数并保证目录存在。"""
|
||||
archive_dir = Path(archive_dir)
|
||||
temp_dir = Path(temp_dir)
|
||||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
|
||||
overrides = {
|
||||
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Shanghai"},
|
||||
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
|
||||
"api": {
|
||||
"base_url": "https://api.example.com",
|
||||
"token": "test-token",
|
||||
"timeout_sec": 3,
|
||||
"page_size": 50,
|
||||
},
|
||||
"pipeline": {
|
||||
"flow": flow,
|
||||
"fetch_root": str(temp_dir / "json_fetch"),
|
||||
"ingest_source_dir": str(archive_dir),
|
||||
},
|
||||
"io": {
|
||||
"export_root": str(temp_dir / "export"),
|
||||
"log_root": str(temp_dir / "logs"),
|
||||
},
|
||||
}
|
||||
return AppConfig.load(overrides)
|
||||
|
||||
|
||||
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
|
||||
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
|
||||
archive_dir = Path(archive_dir)
|
||||
payload = wrap_records(spec.sample_records, spec.data_path)
|
||||
file_path = archive_dir / spec.archive_filename
|
||||
with file_path.open("w", encoding="utf-8") as fp:
|
||||
json.dump(payload, fp, ensure_ascii=False)
|
||||
return file_path
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
|
||||
|
||||
def __init__(self, recorder: List[Dict], db_ops=None):
|
||||
self.recorder = recorder
|
||||
self._db_ops = db_ops
|
||||
self._pending_rows: List[Tuple] = []
|
||||
self._fetchall_rows: List[Tuple] = []
|
||||
self.rowcount = 0
|
||||
self.connection = SimpleNamespace(encoding="UTF8")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def execute(self, sql: str, params=None):
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
# 处理 information_schema 查询,用于结构感知写入。
|
||||
lowered = sql_text.lower()
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
self._fetchall_rows = []
|
||||
return
|
||||
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
def fetchone(self):
|
||||
return None
|
||||
|
||||
def fetchall(self):
|
||||
return list(self._fetchall_rows)
|
||||
|
||||
def mogrify(self, template, args):
|
||||
self._pending_rows.append(tuple(args))
|
||||
return b"(?)"
|
||||
|
||||
def _record_upserts(self, sql_text: str):
|
||||
if not self._db_ops:
|
||||
return
|
||||
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
|
||||
if not match:
|
||||
return
|
||||
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
|
||||
rows = []
|
||||
for idx, row in enumerate(self._pending_rows):
|
||||
if len(row) != len(columns):
|
||||
continue
|
||||
row_dict = {}
|
||||
for col, val in zip(columns, row):
|
||||
if col == "record_index" and val in (None, ""):
|
||||
row_dict[col] = idx
|
||||
continue
|
||||
if hasattr(val, "adapted"):
|
||||
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
|
||||
else:
|
||||
row_dict[col] = val
|
||||
rows.append(row_dict)
|
||||
if rows:
|
||||
self._db_ops.upserts.append(
|
||||
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
|
||||
return [
|
||||
("id", "bigint", "int8"),
|
||||
("sitegoodsstockid", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("content_hash", "text", "text"),
|
||||
("source_file", "text", "text"),
|
||||
("source_endpoint", "text", "text"),
|
||||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
|
||||
|
||||
def __init__(self, db_ops):
|
||||
self.statements: List[Dict] = []
|
||||
self._db_ops = db_ops
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.statements, self._db_ops)
|
||||
|
||||
|
||||
class FakeDBOperations:
|
||||
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
|
||||
|
||||
def __init__(self):
|
||||
self.upserts: List[Dict] = []
|
||||
self.executes: List[Dict] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
self.conn = FakeConnection(self)
|
||||
# 预设查询结果(FIFO),用于测试中控制数据库返回的行
|
||||
self.query_results: List[List[Dict]] = []
|
||||
|
||||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
self.upserts.append(
|
||||
{
|
||||
"sql": sql.strip(),
|
||||
"count": len(rows),
|
||||
"page_size": page_size,
|
||||
"rows": [dict(row) for row in rows],
|
||||
}
|
||||
)
|
||||
return len(rows), 0
|
||||
|
||||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
self.executes.append(
|
||||
{
|
||||
"sql": sql.strip(),
|
||||
"count": len(rows),
|
||||
"page_size": page_size,
|
||||
"rows": [dict(row) for row in rows],
|
||||
}
|
||||
)
|
||||
|
||||
def execute(self, sql: str, params=None):
|
||||
self.executes.append({"sql": sql.strip(), "params": params})
|
||||
|
||||
def query(self, sql: str, params=None):
|
||||
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
|
||||
if self.query_results:
|
||||
return self.query_results.pop(0)
|
||||
return []
|
||||
|
||||
def cursor(self):
|
||||
return self.conn.cursor()
|
||||
|
||||
def commit(self):
|
||||
self.commits += 1
|
||||
|
||||
def rollback(self):
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
class FakeAPIClient:
|
||||
"""在线模式使用的伪 API Client,直接返回预置的内存数据并记录调用,以确保任务参数正确传递。"""
|
||||
|
||||
def __init__(self, data_map: Dict[str, List[Dict]]):
|
||||
self.data_map = data_map
|
||||
self.calls: List[Dict] = []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def iter_paginated(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
page_size: int = 200,
|
||||
page_field: str = "page",
|
||||
size_field: str = "limit",
|
||||
data_path: Tuple[str, ...] = (),
|
||||
list_key: str | None = None,
|
||||
):
|
||||
self.calls.append({"endpoint": endpoint, "params": params})
|
||||
if endpoint not in self.data_map:
|
||||
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
|
||||
|
||||
records = list(self.data_map[endpoint])
|
||||
yield 1, records, dict(params or {}), {"data": records}
|
||||
|
||||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||||
records = []
|
||||
pages = []
|
||||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||||
records.extend(page_records)
|
||||
pages.append({"page": page_no, "request": req, "response": resp})
|
||||
return records, pages
|
||||
|
||||
def get_source_hint(self, endpoint: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class OfflineAPIClient:
|
||||
"""离线模式专用 API Client,根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
|
||||
|
||||
def __init__(self, file_map: Dict[str, Path]):
|
||||
self.file_map = {k: Path(v) for k, v in file_map.items()}
|
||||
self.calls: List[Dict] = []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def iter_paginated(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
page_size: int = 200,
|
||||
page_field: str = "page",
|
||||
size_field: str = "limit",
|
||||
data_path: Tuple[str, ...] = (),
|
||||
list_key: str | None = None,
|
||||
):
|
||||
self.calls.append({"endpoint": endpoint, "params": params})
|
||||
if endpoint not in self.file_map:
|
||||
raise AssertionError(f"Missing archive for endpoint {endpoint}")
|
||||
|
||||
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
|
||||
payload = json.load(fp)
|
||||
|
||||
data = payload
|
||||
for key in data_path:
|
||||
if isinstance(data, dict):
|
||||
data = data.get(key, [])
|
||||
|
||||
if list_key and isinstance(data, dict):
|
||||
data = data.get(list_key, [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = []
|
||||
|
||||
total = len(data)
|
||||
start = 0
|
||||
page = 1
|
||||
while start < total or (start == 0 and total == 0):
|
||||
chunk = data[start : start + page_size]
|
||||
if not chunk and total != 0:
|
||||
break
|
||||
yield page, list(chunk), dict(params or {}), payload
|
||||
if len(chunk) < page_size:
|
||||
break
|
||||
start += page_size
|
||||
page += 1
|
||||
|
||||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||||
records = []
|
||||
pages = []
|
||||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||||
records.extend(page_records)
|
||||
pages.append({"page": page_no, "request": req, "response": resp})
|
||||
return records, pages
|
||||
|
||||
def get_source_hint(self, endpoint: str) -> str | None:
|
||||
if endpoint not in self.file_map:
|
||||
return None
|
||||
return str(self.file_map[endpoint])
|
||||
|
||||
|
||||
class RealDBOperationsAdapter:
|
||||
|
||||
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
|
||||
|
||||
def __init__(self, dsn: str):
|
||||
self._conn = DatabaseConnection(dsn)
|
||||
self._ops = PgDBOperations(self._conn)
|
||||
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
|
||||
self.conn = self._conn.conn
|
||||
|
||||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
|
||||
|
||||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
return self._ops.batch_execute(sql, rows, page_size=page_size)
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def rollback(self):
|
||||
self._conn.rollback()
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_operations():
|
||||
"""
|
||||
测试专用的 DB 操作上下文:
|
||||
- 若设置 TEST_DB_DSN,则连接真实 PostgreSQL;
|
||||
- 否则回退到 FakeDBOperations(内存桩)。
|
||||
"""
|
||||
dsn = os.environ.get("TEST_DB_DSN")
|
||||
if dsn:
|
||||
adapter = RealDBOperationsAdapter(dsn)
|
||||
try:
|
||||
yield adapter
|
||||
finally:
|
||||
adapter.close()
|
||||
else:
|
||||
fake = FakeDBOperations()
|
||||
yield fake
|
||||
|
||||
|
||||
# 14 个独立 ODS 任务已废弃删除(写入不存在的 billiards.* schema,已被通用 ODS 任务替代)
|
||||
TASK_SPECS: List[TaskSpec] = []
|
||||
Reference in New Issue
Block a user