393 lines
13 KiB
Python
393 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
# AI_CHANGELOG
|
||
# - 2026-02-14 | 删除废弃的 14 个独立 ODS 任务 TaskSpec 定义和对应 import;修复语法错误(TASK_SPECS=[] 后残留孤立代码块)
|
||
# 直接原因: 之前清理只把 TASK_SPECS 赋值为空列表,但未删除后续 ~370 行废弃 TaskSpec 定义,导致 IndentationError
|
||
# 验证: `python -c "import ast; ast.parse(open('tests/unit/task_test_utils.py','utf-8').read()); print('OK')"`
|
||
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
from types import SimpleNamespace
|
||
from contextlib import contextmanager
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Dict, List, Sequence, Tuple, Type
|
||
|
||
from config.settings import AppConfig
|
||
from database.connection import DatabaseConnection
|
||
from database.operations import DatabaseOperations as PgDBOperations
|
||
from utils.json_store import endpoint_to_filename
|
||
|
||
DEFAULT_STORE_ID = 2790685415443269
|
||
BASE_TS = "2025-01-01 10:00:00"
|
||
END_TS = "2025-01-01 12:00:00"
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class TaskSpec:
|
||
"""描述单个任务在测试中如何被驱动的元数据,包含任务代码、API 路径、数据路径与样例记录。"""
|
||
|
||
code: str
|
||
task_cls: Type
|
||
endpoint: str
|
||
data_path: Tuple[str, ...]
|
||
sample_records: List[Dict]
|
||
|
||
@property
|
||
def archive_filename(self) -> str:
|
||
return endpoint_to_filename(self.endpoint)
|
||
|
||
|
||
def wrap_records(records: List[Dict], data_path: Sequence[str]):
|
||
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
|
||
payload = records
|
||
for key in reversed(data_path):
|
||
payload = {key: payload}
|
||
return payload
|
||
|
||
|
||
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
|
||
"""构建一份适合测试的 AppConfig,自动填充存储、日志、归档目录等参数并保证目录存在。"""
|
||
archive_dir = Path(archive_dir)
|
||
temp_dir = Path(temp_dir)
|
||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
|
||
overrides = {
|
||
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Shanghai"},
|
||
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
|
||
"api": {
|
||
"base_url": "https://api.example.com",
|
||
"token": "test-token",
|
||
"timeout_sec": 3,
|
||
"page_size": 50,
|
||
},
|
||
"pipeline": {
|
||
"flow": flow,
|
||
"fetch_root": str(temp_dir / "json_fetch"),
|
||
"ingest_source_dir": str(archive_dir),
|
||
},
|
||
"io": {
|
||
"export_root": str(temp_dir / "export"),
|
||
"log_root": str(temp_dir / "logs"),
|
||
},
|
||
}
|
||
return AppConfig.load(overrides)
|
||
|
||
|
||
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
|
||
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
|
||
archive_dir = Path(archive_dir)
|
||
payload = wrap_records(spec.sample_records, spec.data_path)
|
||
file_path = archive_dir / spec.archive_filename
|
||
with file_path.open("w", encoding="utf-8") as fp:
|
||
json.dump(payload, fp, ensure_ascii=False)
|
||
return file_path
|
||
|
||
|
||
class FakeCursor:
|
||
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
|
||
|
||
def __init__(self, recorder: List[Dict], db_ops=None):
|
||
self.recorder = recorder
|
||
self._db_ops = db_ops
|
||
self._pending_rows: List[Tuple] = []
|
||
self._fetchall_rows: List[Tuple] = []
|
||
self.rowcount = 0
|
||
self.connection = SimpleNamespace(encoding="UTF8")
|
||
|
||
# pylint: disable=unused-argument
|
||
def execute(self, sql: str, params=None):
|
||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||
self._fetchall_rows = []
|
||
|
||
# 处理 information_schema 查询,用于结构感知写入。
|
||
lowered = sql_text.lower()
|
||
if "from information_schema.columns" in lowered:
|
||
table_name = None
|
||
if params and len(params) >= 2:
|
||
table_name = params[1]
|
||
self._fetchall_rows = self._fake_columns(table_name)
|
||
return
|
||
if "from information_schema.table_constraints" in lowered:
|
||
self._fetchall_rows = []
|
||
return
|
||
|
||
if self._pending_rows:
|
||
self.rowcount = len(self._pending_rows)
|
||
self._record_upserts(sql_text)
|
||
self._pending_rows = []
|
||
else:
|
||
self.rowcount = 0
|
||
|
||
def fetchone(self):
|
||
return None
|
||
|
||
def fetchall(self):
|
||
return list(self._fetchall_rows)
|
||
|
||
def mogrify(self, template, args):
|
||
self._pending_rows.append(tuple(args))
|
||
return b"(?)"
|
||
|
||
def _record_upserts(self, sql_text: str):
|
||
if not self._db_ops:
|
||
return
|
||
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
|
||
if not match:
|
||
return
|
||
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
|
||
rows = []
|
||
for idx, row in enumerate(self._pending_rows):
|
||
if len(row) != len(columns):
|
||
continue
|
||
row_dict = {}
|
||
for col, val in zip(columns, row):
|
||
if col == "record_index" and val in (None, ""):
|
||
row_dict[col] = idx
|
||
continue
|
||
if hasattr(val, "adapted"):
|
||
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
|
||
else:
|
||
row_dict[col] = val
|
||
rows.append(row_dict)
|
||
if rows:
|
||
self._db_ops.upserts.append(
|
||
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
|
||
)
|
||
|
||
@staticmethod
|
||
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
|
||
return [
|
||
("id", "bigint", "int8"),
|
||
("sitegoodsstockid", "bigint", "int8"),
|
||
("record_index", "integer", "int4"),
|
||
("content_hash", "text", "text"),
|
||
("source_file", "text", "text"),
|
||
("source_endpoint", "text", "text"),
|
||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||
("payload", "jsonb", "jsonb"),
|
||
]
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc, tb):
|
||
return False
|
||
|
||
|
||
class FakeConnection:
|
||
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
|
||
|
||
def __init__(self, db_ops):
|
||
self.statements: List[Dict] = []
|
||
self._db_ops = db_ops
|
||
|
||
def cursor(self):
|
||
return FakeCursor(self.statements, self._db_ops)
|
||
|
||
|
||
class FakeDBOperations:
|
||
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
|
||
|
||
def __init__(self):
|
||
self.upserts: List[Dict] = []
|
||
self.executes: List[Dict] = []
|
||
self.commits = 0
|
||
self.rollbacks = 0
|
||
self.conn = FakeConnection(self)
|
||
# 预设查询结果(FIFO),用于测试中控制数据库返回的行
|
||
self.query_results: List[List[Dict]] = []
|
||
|
||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||
self.upserts.append(
|
||
{
|
||
"sql": sql.strip(),
|
||
"count": len(rows),
|
||
"page_size": page_size,
|
||
"rows": [dict(row) for row in rows],
|
||
}
|
||
)
|
||
return len(rows), 0
|
||
|
||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||
self.executes.append(
|
||
{
|
||
"sql": sql.strip(),
|
||
"count": len(rows),
|
||
"page_size": page_size,
|
||
"rows": [dict(row) for row in rows],
|
||
}
|
||
)
|
||
|
||
def execute(self, sql: str, params=None):
|
||
self.executes.append({"sql": sql.strip(), "params": params})
|
||
|
||
def query(self, sql: str, params=None):
|
||
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
|
||
if self.query_results:
|
||
return self.query_results.pop(0)
|
||
return []
|
||
|
||
def cursor(self):
|
||
return self.conn.cursor()
|
||
|
||
def commit(self):
|
||
self.commits += 1
|
||
|
||
def rollback(self):
|
||
self.rollbacks += 1
|
||
|
||
|
||
class FakeAPIClient:
|
||
"""在线模式使用的伪 API Client,直接返回预置的内存数据并记录调用,以确保任务参数正确传递。"""
|
||
|
||
def __init__(self, data_map: Dict[str, List[Dict]]):
|
||
self.data_map = data_map
|
||
self.calls: List[Dict] = []
|
||
|
||
# pylint: disable=unused-argument
|
||
def iter_paginated(
|
||
self,
|
||
endpoint: str,
|
||
params=None,
|
||
page_size: int = 200,
|
||
page_field: str = "page",
|
||
size_field: str = "limit",
|
||
data_path: Tuple[str, ...] = (),
|
||
list_key: str | None = None,
|
||
):
|
||
self.calls.append({"endpoint": endpoint, "params": params})
|
||
if endpoint not in self.data_map:
|
||
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
|
||
|
||
records = list(self.data_map[endpoint])
|
||
yield 1, records, dict(params or {}), {"data": records}
|
||
|
||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||
records = []
|
||
pages = []
|
||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||
records.extend(page_records)
|
||
pages.append({"page": page_no, "request": req, "response": resp})
|
||
return records, pages
|
||
|
||
def get_source_hint(self, endpoint: str) -> str | None:
|
||
return None
|
||
|
||
|
||
class OfflineAPIClient:
|
||
"""离线模式专用 API Client,根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
|
||
|
||
def __init__(self, file_map: Dict[str, Path]):
|
||
self.file_map = {k: Path(v) for k, v in file_map.items()}
|
||
self.calls: List[Dict] = []
|
||
|
||
# pylint: disable=unused-argument
|
||
def iter_paginated(
|
||
self,
|
||
endpoint: str,
|
||
params=None,
|
||
page_size: int = 200,
|
||
page_field: str = "page",
|
||
size_field: str = "limit",
|
||
data_path: Tuple[str, ...] = (),
|
||
list_key: str | None = None,
|
||
):
|
||
self.calls.append({"endpoint": endpoint, "params": params})
|
||
if endpoint not in self.file_map:
|
||
raise AssertionError(f"Missing archive for endpoint {endpoint}")
|
||
|
||
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
|
||
payload = json.load(fp)
|
||
|
||
data = payload
|
||
for key in data_path:
|
||
if isinstance(data, dict):
|
||
data = data.get(key, [])
|
||
|
||
if list_key and isinstance(data, dict):
|
||
data = data.get(list_key, [])
|
||
|
||
if not isinstance(data, list):
|
||
data = []
|
||
|
||
total = len(data)
|
||
start = 0
|
||
page = 1
|
||
while start < total or (start == 0 and total == 0):
|
||
chunk = data[start : start + page_size]
|
||
if not chunk and total != 0:
|
||
break
|
||
yield page, list(chunk), dict(params or {}), payload
|
||
if len(chunk) < page_size:
|
||
break
|
||
start += page_size
|
||
page += 1
|
||
|
||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||
records = []
|
||
pages = []
|
||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||
records.extend(page_records)
|
||
pages.append({"page": page_no, "request": req, "response": resp})
|
||
return records, pages
|
||
|
||
def get_source_hint(self, endpoint: str) -> str | None:
|
||
if endpoint not in self.file_map:
|
||
return None
|
||
return str(self.file_map[endpoint])
|
||
|
||
|
||
class RealDBOperationsAdapter:
|
||
|
||
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
|
||
|
||
def __init__(self, dsn: str):
|
||
self._conn = DatabaseConnection(dsn)
|
||
self._ops = PgDBOperations(self._conn)
|
||
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
|
||
self.conn = self._conn.conn
|
||
|
||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
|
||
|
||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||
return self._ops.batch_execute(sql, rows, page_size=page_size)
|
||
|
||
def commit(self):
|
||
self._conn.commit()
|
||
|
||
def rollback(self):
|
||
self._conn.rollback()
|
||
|
||
def close(self):
|
||
self._conn.close()
|
||
|
||
|
||
@contextmanager
|
||
def get_db_operations():
|
||
"""
|
||
测试专用的 DB 操作上下文:
|
||
- 若设置 TEST_DB_DSN,则连接真实 PostgreSQL;
|
||
- 否则回退到 FakeDBOperations(内存桩)。
|
||
"""
|
||
dsn = os.environ.get("TEST_DB_DSN")
|
||
if dsn:
|
||
adapter = RealDBOperationsAdapter(dsn)
|
||
try:
|
||
yield adapter
|
||
finally:
|
||
adapter.close()
|
||
else:
|
||
fake = FakeDBOperations()
|
||
yield fake
|
||
|
||
|
||
# 14 个独立 ODS 任务已废弃删除(写入不存在的 billiards.* schema,已被通用 ODS 任务替代)
|
||
TASK_SPECS: List[TaskSpec] = []
|