Files
Neo-ZQYY/apps/etl/pipelines/feiqiu/tests/unit/task_test_utils.py

393 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
# AI_CHANGELOG
# - 2026-02-14 | 删除废弃的 14 个独立 ODS 任务 TaskSpec 定义和对应 import修复语法错误TASK_SPECS=[] 后残留孤立代码块)
# 直接原因: 之前清理只把 TASK_SPECS 赋值为空列表,但未删除后续 ~370 行废弃 TaskSpec 定义,导致 IndentationError
# 验证: `python -c "import ast; ast.parse(open('tests/unit/task_test_utils.py','utf-8').read()); print('OK')"`
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
from __future__ import annotations
import json
import os
import re
from types import SimpleNamespace
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple, Type
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations as PgDBOperations
from utils.json_store import endpoint_to_filename
DEFAULT_STORE_ID = 2790685415443269
BASE_TS = "2025-01-01 10:00:00"
END_TS = "2025-01-01 12:00:00"
@dataclass(frozen=True)
class TaskSpec:
"""描述单个任务在测试中如何被驱动的元数据包含任务代码、API 路径、数据路径与样例记录。"""
code: str
task_cls: Type
endpoint: str
data_path: Tuple[str, ...]
sample_records: List[Dict]
@property
def archive_filename(self) -> str:
return endpoint_to_filename(self.endpoint)
def wrap_records(records: List[Dict], data_path: Sequence[str]):
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
payload = records
for key in reversed(data_path):
payload = {key: payload}
return payload
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
"""构建一份适合测试的 AppConfig自动填充存储、日志、归档目录等参数并保证目录存在。"""
archive_dir = Path(archive_dir)
temp_dir = Path(temp_dir)
archive_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
overrides = {
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Shanghai"},
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
"api": {
"base_url": "https://api.example.com",
"token": "test-token",
"timeout_sec": 3,
"page_size": 50,
},
"pipeline": {
"flow": flow,
"fetch_root": str(temp_dir / "json_fetch"),
"ingest_source_dir": str(archive_dir),
},
"io": {
"export_root": str(temp_dir / "export"),
"log_root": str(temp_dir / "logs"),
},
}
return AppConfig.load(overrides)
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
archive_dir = Path(archive_dir)
payload = wrap_records(spec.sample_records, spec.data_path)
file_path = archive_dir / spec.archive_filename
with file_path.open("w", encoding="utf-8") as fp:
json.dump(payload, fp, ensure_ascii=False)
return file_path
class FakeCursor:
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
def __init__(self, recorder: List[Dict], db_ops=None):
self.recorder = recorder
self._db_ops = db_ops
self._pending_rows: List[Tuple] = []
self._fetchall_rows: List[Tuple] = []
self.rowcount = 0
self.connection = SimpleNamespace(encoding="UTF8")
# pylint: disable=unused-argument
def execute(self, sql: str, params=None):
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
self.recorder.append({"sql": sql_text.strip(), "params": params})
self._fetchall_rows = []
# 处理 information_schema 查询,用于结构感知写入。
lowered = sql_text.lower()
if "from information_schema.columns" in lowered:
table_name = None
if params and len(params) >= 2:
table_name = params[1]
self._fetchall_rows = self._fake_columns(table_name)
return
if "from information_schema.table_constraints" in lowered:
self._fetchall_rows = []
return
if self._pending_rows:
self.rowcount = len(self._pending_rows)
self._record_upserts(sql_text)
self._pending_rows = []
else:
self.rowcount = 0
def fetchone(self):
return None
def fetchall(self):
return list(self._fetchall_rows)
def mogrify(self, template, args):
self._pending_rows.append(tuple(args))
return b"(?)"
def _record_upserts(self, sql_text: str):
if not self._db_ops:
return
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
if not match:
return
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
rows = []
for idx, row in enumerate(self._pending_rows):
if len(row) != len(columns):
continue
row_dict = {}
for col, val in zip(columns, row):
if col == "record_index" and val in (None, ""):
row_dict[col] = idx
continue
if hasattr(val, "adapted"):
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
else:
row_dict[col] = val
rows.append(row_dict)
if rows:
self._db_ops.upserts.append(
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
)
@staticmethod
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
return [
("id", "bigint", "int8"),
("sitegoodsstockid", "bigint", "int8"),
("record_index", "integer", "int4"),
("content_hash", "text", "text"),
("source_file", "text", "text"),
("source_endpoint", "text", "text"),
("fetched_at", "timestamp with time zone", "timestamptz"),
("payload", "jsonb", "jsonb"),
]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeConnection:
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
def __init__(self, db_ops):
self.statements: List[Dict] = []
self._db_ops = db_ops
def cursor(self):
return FakeCursor(self.statements, self._db_ops)
class FakeDBOperations:
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
def __init__(self):
self.upserts: List[Dict] = []
self.executes: List[Dict] = []
self.commits = 0
self.rollbacks = 0
self.conn = FakeConnection(self)
# 预设查询结果FIFO用于测试中控制数据库返回的行
self.query_results: List[List[Dict]] = []
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.upserts.append(
{
"sql": sql.strip(),
"count": len(rows),
"page_size": page_size,
"rows": [dict(row) for row in rows],
}
)
return len(rows), 0
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.executes.append(
{
"sql": sql.strip(),
"count": len(rows),
"page_size": page_size,
"rows": [dict(row) for row in rows],
}
)
def execute(self, sql: str, params=None):
self.executes.append({"sql": sql.strip(), "params": params})
def query(self, sql: str, params=None):
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
if self.query_results:
return self.query_results.pop(0)
return []
def cursor(self):
return self.conn.cursor()
def commit(self):
self.commits += 1
def rollback(self):
self.rollbacks += 1
class FakeAPIClient:
"""在线模式使用的伪 API Client直接返回预置的内存数据并记录调用以确保任务参数正确传递。"""
def __init__(self, data_map: Dict[str, List[Dict]]):
self.data_map = data_map
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def iter_paginated(
self,
endpoint: str,
params=None,
page_size: int = 200,
page_field: str = "page",
size_field: str = "limit",
data_path: Tuple[str, ...] = (),
list_key: str | None = None,
):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.data_map:
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
records = list(self.data_map[endpoint])
yield 1, records, dict(params or {}), {"data": records}
def get_paginated(self, endpoint: str, params=None, **kwargs):
records = []
pages = []
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
records.extend(page_records)
pages.append({"page": page_no, "request": req, "response": resp})
return records, pages
def get_source_hint(self, endpoint: str) -> str | None:
return None
class OfflineAPIClient:
"""离线模式专用 API Client根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
def __init__(self, file_map: Dict[str, Path]):
self.file_map = {k: Path(v) for k, v in file_map.items()}
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def iter_paginated(
self,
endpoint: str,
params=None,
page_size: int = 200,
page_field: str = "page",
size_field: str = "limit",
data_path: Tuple[str, ...] = (),
list_key: str | None = None,
):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.file_map:
raise AssertionError(f"Missing archive for endpoint {endpoint}")
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
payload = json.load(fp)
data = payload
for key in data_path:
if isinstance(data, dict):
data = data.get(key, [])
if list_key and isinstance(data, dict):
data = data.get(list_key, [])
if not isinstance(data, list):
data = []
total = len(data)
start = 0
page = 1
while start < total or (start == 0 and total == 0):
chunk = data[start : start + page_size]
if not chunk and total != 0:
break
yield page, list(chunk), dict(params or {}), payload
if len(chunk) < page_size:
break
start += page_size
page += 1
def get_paginated(self, endpoint: str, params=None, **kwargs):
records = []
pages = []
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
records.extend(page_records)
pages.append({"page": page_no, "request": req, "response": resp})
return records, pages
def get_source_hint(self, endpoint: str) -> str | None:
if endpoint not in self.file_map:
return None
return str(self.file_map[endpoint])
class RealDBOperationsAdapter:
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
def __init__(self, dsn: str):
self._conn = DatabaseConnection(dsn)
self._ops = PgDBOperations(self._conn)
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
self.conn = self._conn.conn
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_execute(sql, rows, page_size=page_size)
def commit(self):
self._conn.commit()
def rollback(self):
self._conn.rollback()
def close(self):
self._conn.close()
@contextmanager
def get_db_operations():
"""
测试专用的 DB 操作上下文:
- 若设置 TEST_DB_DSN则连接真实 PostgreSQL
- 否则回退到 FakeDBOperations内存桩
"""
dsn = os.environ.get("TEST_DB_DSN")
if dsn:
adapter = RealDBOperationsAdapter(dsn)
try:
yield adapter
finally:
adapter.close()
else:
fake = FakeDBOperations()
yield fake
# 14 个独立 ODS 任务已废弃删除(写入不存在的 billiards.* schema已被通用 ODS 任务替代)
TASK_SPECS: List[TaskSpec] = []