在准备环境前提交次全部更改。
This commit is contained in:
59
apps/etl/connectors/feiqiu/tests/README.md
Normal file
59
apps/etl/connectors/feiqiu/tests/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# tests/ — 测试套件
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── unit/ # 单元测试(FakeDB/FakeAPI,无需真实数据库)
|
||||
│ ├── task_test_utils.py # 测试工具:FakeDBOperations、FakeAPIClient、OfflineAPIClient、TaskSpec
|
||||
│ ├── test_ods_tasks.py # ODS 任务在线/离线模式测试
|
||||
│ ├── test_cli_args.py # CLI 参数解析测试
|
||||
│ ├── test_config.py # 配置管理测试
|
||||
│ ├── test_e2e_flow.py # 端到端流程测试(CLI → FlowRunner → TaskExecutor)
|
||||
│ ├── test_task_registry.py # 任务注册表测试
|
||||
│ ├── test_*_properties.py # 属性测试(hypothesis)
|
||||
│ └── test_audit_*.py # 仓库审计相关测试
|
||||
└── integration/ # 集成测试(需要真实数据库)
|
||||
├── test_database.py # 数据库连接与操作测试
|
||||
└── test_index_tasks.py # 指数任务集成测试
|
||||
```
|
||||
|
||||
## 运行测试
|
||||
|
||||
```bash
|
||||
# 安装测试依赖
|
||||
pip install pytest hypothesis
|
||||
|
||||
# 全部单元测试
|
||||
pytest tests/unit
|
||||
|
||||
# 指定测试文件
|
||||
pytest tests/unit/test_ods_tasks.py
|
||||
|
||||
# 按关键字过滤
|
||||
pytest tests/unit -k "online"
|
||||
|
||||
# 集成测试(需要设置 TEST_DB_DSN)
|
||||
TEST_DB_DSN="postgresql://user:pass@host:5432/db" pytest tests/integration
|
||||
|
||||
# 查看详细输出
|
||||
pytest tests/unit -v --tb=short
|
||||
```
|
||||
|
||||
## 测试工具(task_test_utils.py)
|
||||
|
||||
单元测试通过 `tests/unit/task_test_utils.py` 提供的桩对象避免依赖真实数据库和 API:
|
||||
|
||||
- `FakeDBOperations` — 拦截并记录 upsert/execute/commit/rollback,不触碰真实数据库
|
||||
- `FakeAPIClient` — 在线模式桩,直接返回预置的内存数据
|
||||
- `OfflineAPIClient` — 离线模式桩,从归档 JSON 文件回放数据
|
||||
- `TaskSpec` — 描述任务测试元数据(任务代码、端点、数据路径、样例记录)
|
||||
- `create_test_config()` — 构建测试用 `AppConfig`
|
||||
- `dump_offline_payload()` — 将样例数据写入归档目录供离线测试使用
|
||||
|
||||
## 编写新测试
|
||||
|
||||
- 单元测试放在 `tests/unit/`,文件名 `test_*.py`
|
||||
- 使用 `FakeDBOperations` 和 `FakeAPIClient` 避免外部依赖
|
||||
- 属性测试使用 `hypothesis`,文件名以 `_properties.py` 结尾
|
||||
- 集成测试放在 `tests/integration/`,通过 `TEST_DB_DSN` 环境变量控制是否执行
|
||||
0
apps/etl/connectors/feiqiu/tests/__init__.py
Normal file
0
apps/etl/connectors/feiqiu/tests/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库集成测试"""
|
||||
import pytest
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
|
||||
# 注意:这些测试需要实际的数据库连接
|
||||
# 在CI/CD环境中应使用测试数据库
|
||||
|
||||
@pytest.fixture
|
||||
def db_connection():
|
||||
"""数据库连接fixture"""
|
||||
# 从环境变量获取测试数据库DSN
|
||||
import os
|
||||
dsn = os.environ.get("TEST_DB_DSN")
|
||||
if not dsn:
|
||||
pytest.skip("未配置测试数据库")
|
||||
|
||||
conn = DatabaseConnection(dsn)
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
def test_database_query(db_connection):
|
||||
"""测试数据库查询"""
|
||||
result = db_connection.query("SELECT 1 AS test")
|
||||
assert len(result) == 1
|
||||
assert result[0]["test"] == 1
|
||||
|
||||
def test_database_operations(db_connection):
|
||||
"""测试数据库操作"""
|
||||
ops = DatabaseOperations(db_connection)
|
||||
# 添加实际的测试用例
|
||||
pass
|
||||
238
apps/etl/connectors/feiqiu/tests/integration/test_index_tasks.py
Normal file
238
apps/etl/connectors/feiqiu/tests/integration/test_index_tasks.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG [2026-02-13] 移除 dws_member_assistant_intimacy 表存在性检查
|
||||
"""Smoke test scripts for WBI/NCI index tasks."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if ROOT not in sys.path:
|
||||
sys.path.insert(0, ROOT)
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
from tasks.dws.index import NewconvIndexTask, WinbackIndexTask
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("test_index_tasks")
|
||||
|
||||
|
||||
def _make_db() -> tuple[AppConfig, DatabaseConnection, DatabaseOperations]:
|
||||
config = AppConfig.load()
|
||||
db_conn = DatabaseConnection(config.config["db"]["dsn"])
|
||||
db = DatabaseOperations(db_conn)
|
||||
return config, db_conn, db
|
||||
|
||||
|
||||
def _dict_rows(rows) -> List[Dict]:
|
||||
return [dict(r) for r in (rows or [])]
|
||||
|
||||
|
||||
def _fmt(value, digits: int = 2) -> str:
|
||||
if value is None:
|
||||
return "-"
|
||||
if isinstance(value, (int, float)):
|
||||
return f"{value:.{digits}f}"
|
||||
return str(value)
|
||||
|
||||
|
||||
def _check_required_tables() -> None:
|
||||
_, db_conn, db = _make_db()
|
||||
try:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'dws'
|
||||
AND table_name IN (
|
||||
'cfg_index_parameters',
|
||||
'dws_member_winback_index',
|
||||
'dws_member_newconv_index'
|
||||
)
|
||||
"""
|
||||
rows = _dict_rows(db.query(sql))
|
||||
existing = {r["table_name"] for r in rows}
|
||||
required = {
|
||||
"cfg_index_parameters",
|
||||
"dws_member_winback_index",
|
||||
"dws_member_newconv_index",
|
||||
}
|
||||
missing = sorted(required - existing)
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required tables: {', '.join(missing)}")
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
def test_winback_index() -> Dict:
|
||||
logger.info("=" * 80)
|
||||
logger.info("Run WBI task")
|
||||
logger.info("=" * 80)
|
||||
|
||||
config, db_conn, db = _make_db()
|
||||
try:
|
||||
task = WinbackIndexTask(config, db, None, logger)
|
||||
result = task.execute(None)
|
||||
logger.info("WBI result: %s", result)
|
||||
|
||||
if result.get("status") == "success":
|
||||
stats_sql = """
|
||||
SELECT
|
||||
COUNT(*) AS total_count,
|
||||
ROUND(AVG(display_score)::numeric, 2) AS avg_display,
|
||||
ROUND(MIN(display_score)::numeric, 2) AS min_display,
|
||||
ROUND(MAX(display_score)::numeric, 2) AS max_display,
|
||||
ROUND(AVG(raw_score)::numeric, 4) AS avg_raw,
|
||||
ROUND(AVG(overdue_old)::numeric, 4) AS avg_overdue,
|
||||
ROUND(AVG(drop_old)::numeric, 4) AS avg_drop,
|
||||
ROUND(AVG(recharge_old)::numeric, 4) AS avg_recharge,
|
||||
ROUND(AVG(value_old)::numeric, 4) AS avg_value,
|
||||
ROUND(AVG(t_v)::numeric, 2) AS avg_t_v
|
||||
FROM dws.dws_member_winback_index
|
||||
"""
|
||||
stats_rows = _dict_rows(db.query(stats_sql))
|
||||
if stats_rows:
|
||||
s = stats_rows[0]
|
||||
logger.info(
|
||||
"WBI stats | total=%s, display(avg/min/max)=%s/%s/%s, raw_avg=%s, overdue=%s, drop=%s, recharge=%s, value=%s, t_v=%s",
|
||||
s.get("total_count"),
|
||||
_fmt(s.get("avg_display")),
|
||||
_fmt(s.get("min_display")),
|
||||
_fmt(s.get("max_display")),
|
||||
_fmt(s.get("avg_raw"), 4),
|
||||
_fmt(s.get("avg_overdue"), 4),
|
||||
_fmt(s.get("avg_drop"), 4),
|
||||
_fmt(s.get("avg_recharge"), 4),
|
||||
_fmt(s.get("avg_value"), 4),
|
||||
_fmt(s.get("avg_t_v"), 2),
|
||||
)
|
||||
|
||||
top_sql = """
|
||||
SELECT member_id, display_score, raw_score, t_v, visits_14d, sv_balance
|
||||
FROM dws.dws_member_winback_index
|
||||
ORDER BY display_score DESC NULLS LAST
|
||||
LIMIT 5
|
||||
"""
|
||||
for i, r in enumerate(_dict_rows(db.query(top_sql)), 1):
|
||||
logger.info(
|
||||
"WBI TOP%d | member=%s, display=%s, raw=%s, t_v=%s, visits_14d=%s, sv_balance=%s",
|
||||
i,
|
||||
r.get("member_id"),
|
||||
_fmt(r.get("display_score")),
|
||||
_fmt(r.get("raw_score"), 4),
|
||||
_fmt(r.get("t_v"), 2),
|
||||
_fmt(r.get("visits_14d"), 0),
|
||||
_fmt(r.get("sv_balance"), 2),
|
||||
)
|
||||
|
||||
return result
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
def test_newconv_index() -> Dict:
|
||||
logger.info("=" * 80)
|
||||
logger.info("Run NCI task")
|
||||
logger.info("=" * 80)
|
||||
|
||||
config, db_conn, db = _make_db()
|
||||
try:
|
||||
task = NewconvIndexTask(config, db, None, logger)
|
||||
result = task.execute(None)
|
||||
logger.info("NCI result: %s", result)
|
||||
|
||||
if result.get("status") == "success":
|
||||
stats_sql = """
|
||||
SELECT
|
||||
COUNT(*) AS total_count,
|
||||
ROUND(AVG(display_score)::numeric, 2) AS avg_display,
|
||||
ROUND(MIN(display_score)::numeric, 2) AS min_display,
|
||||
ROUND(MAX(display_score)::numeric, 2) AS max_display,
|
||||
ROUND(AVG(display_score_welcome)::numeric, 2) AS avg_display_welcome,
|
||||
ROUND(AVG(display_score_convert)::numeric, 2) AS avg_display_convert,
|
||||
ROUND(AVG(raw_score)::numeric, 4) AS avg_raw,
|
||||
ROUND(AVG(raw_score_welcome)::numeric, 4) AS avg_raw_welcome,
|
||||
ROUND(AVG(raw_score_convert)::numeric, 4) AS avg_raw_convert,
|
||||
ROUND(AVG(need_new)::numeric, 4) AS avg_need,
|
||||
ROUND(AVG(salvage_new)::numeric, 4) AS avg_salvage,
|
||||
ROUND(AVG(recharge_new)::numeric, 4) AS avg_recharge,
|
||||
ROUND(AVG(value_new)::numeric, 4) AS avg_value,
|
||||
ROUND(AVG(welcome_new)::numeric, 4) AS avg_welcome,
|
||||
ROUND(AVG(t_v)::numeric, 2) AS avg_t_v
|
||||
FROM dws.dws_member_newconv_index
|
||||
"""
|
||||
stats_rows = _dict_rows(db.query(stats_sql))
|
||||
if stats_rows:
|
||||
s = stats_rows[0]
|
||||
logger.info(
|
||||
"NCI stats | total=%s, display(avg/min/max)=%s/%s/%s, display_welcome=%s, display_convert=%s, raw_avg=%s, raw_welcome=%s, raw_convert=%s",
|
||||
s.get("total_count"),
|
||||
_fmt(s.get("avg_display")),
|
||||
_fmt(s.get("min_display")),
|
||||
_fmt(s.get("max_display")),
|
||||
_fmt(s.get("avg_display_welcome")),
|
||||
_fmt(s.get("avg_display_convert")),
|
||||
_fmt(s.get("avg_raw"), 4),
|
||||
_fmt(s.get("avg_raw_welcome"), 4),
|
||||
_fmt(s.get("avg_raw_convert"), 4),
|
||||
)
|
||||
logger.info(
|
||||
"NCI components | need=%s, salvage=%s, recharge=%s, value=%s, welcome=%s, t_v=%s",
|
||||
_fmt(s.get("avg_need"), 4),
|
||||
_fmt(s.get("avg_salvage"), 4),
|
||||
_fmt(s.get("avg_recharge"), 4),
|
||||
_fmt(s.get("avg_value"), 4),
|
||||
_fmt(s.get("avg_welcome"), 4),
|
||||
_fmt(s.get("avg_t_v"), 2),
|
||||
)
|
||||
|
||||
top_sql = """
|
||||
SELECT member_id, display_score, display_score_welcome, display_score_convert,
|
||||
raw_score, raw_score_welcome, raw_score_convert, t_v, visits_14d
|
||||
FROM dws.dws_member_newconv_index
|
||||
ORDER BY display_score DESC NULLS LAST
|
||||
LIMIT 5
|
||||
"""
|
||||
for i, r in enumerate(_dict_rows(db.query(top_sql)), 1):
|
||||
logger.info(
|
||||
"NCI TOP%d | member=%s, nci=%s (welcome=%s, convert=%s), raw=%s (w=%s,c=%s), t_v=%s, visits_14d=%s",
|
||||
i,
|
||||
r.get("member_id"),
|
||||
_fmt(r.get("display_score")),
|
||||
_fmt(r.get("display_score_welcome")),
|
||||
_fmt(r.get("display_score_convert")),
|
||||
_fmt(r.get("raw_score"), 4),
|
||||
_fmt(r.get("raw_score_welcome"), 4),
|
||||
_fmt(r.get("raw_score_convert"), 4),
|
||||
_fmt(r.get("t_v"), 2),
|
||||
_fmt(r.get("visits_14d"), 0),
|
||||
)
|
||||
|
||||
return result
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
|
||||
|
||||
def main() -> None:
|
||||
_check_required_tables()
|
||||
|
||||
results = {
|
||||
"WBI": test_winback_index(),
|
||||
"NCI": test_newconv_index(),
|
||||
}
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("Test complete")
|
||||
logger.info("WBI=%s, NCI=%s", results["WBI"].get("status"), results["NCI"].get("status"))
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
apps/etl/connectors/feiqiu/tests/unit/__init__.py
Normal file
0
apps/etl/connectors/feiqiu/tests/unit/__init__.py
Normal file
392
apps/etl/connectors/feiqiu/tests/unit/task_test_utils.py
Normal file
392
apps/etl/connectors/feiqiu/tests/unit/task_test_utils.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-14 | 删除废弃的 14 个独立 ODS 任务 TaskSpec 定义和对应 import;修复语法错误(TASK_SPECS=[] 后残留孤立代码块)
|
||||
# 直接原因: 之前清理只把 TASK_SPECS 赋值为空列表,但未删除后续 ~370 行废弃 TaskSpec 定义,导致 IndentationError
|
||||
# 验证: `python -c "import ast; ast.parse(open('tests/unit/task_test_utils.py','utf-8').read()); print('OK')"`
|
||||
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Tuple, Type
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations as PgDBOperations
|
||||
from utils.json_store import endpoint_to_filename
|
||||
|
||||
DEFAULT_STORE_ID = 2790685415443269
|
||||
BASE_TS = "2025-01-01 10:00:00"
|
||||
END_TS = "2025-01-01 12:00:00"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskSpec:
|
||||
"""描述单个任务在测试中如何被驱动的元数据,包含任务代码、API 路径、数据路径与样例记录。"""
|
||||
|
||||
code: str
|
||||
task_cls: Type
|
||||
endpoint: str
|
||||
data_path: Tuple[str, ...]
|
||||
sample_records: List[Dict]
|
||||
|
||||
@property
|
||||
def archive_filename(self) -> str:
|
||||
return endpoint_to_filename(self.endpoint)
|
||||
|
||||
|
||||
def wrap_records(records: List[Dict], data_path: Sequence[str]):
|
||||
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
|
||||
payload = records
|
||||
for key in reversed(data_path):
|
||||
payload = {key: payload}
|
||||
return payload
|
||||
|
||||
|
||||
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
|
||||
"""构建一份适合测试的 AppConfig,自动填充存储、日志、归档目录等参数并保证目录存在。"""
|
||||
archive_dir = Path(archive_dir)
|
||||
temp_dir = Path(temp_dir)
|
||||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
|
||||
overrides = {
|
||||
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Shanghai"},
|
||||
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
|
||||
"api": {
|
||||
"base_url": "https://api.example.com",
|
||||
"token": "test-token",
|
||||
"timeout_sec": 3,
|
||||
"page_size": 50,
|
||||
},
|
||||
"pipeline": {
|
||||
"flow": flow,
|
||||
"fetch_root": str(temp_dir / "json_fetch"),
|
||||
"ingest_source_dir": str(archive_dir),
|
||||
},
|
||||
"io": {
|
||||
"export_root": str(temp_dir / "export"),
|
||||
"log_root": str(temp_dir / "logs"),
|
||||
},
|
||||
}
|
||||
return AppConfig.load(overrides)
|
||||
|
||||
|
||||
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
|
||||
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
|
||||
archive_dir = Path(archive_dir)
|
||||
payload = wrap_records(spec.sample_records, spec.data_path)
|
||||
file_path = archive_dir / spec.archive_filename
|
||||
with file_path.open("w", encoding="utf-8") as fp:
|
||||
json.dump(payload, fp, ensure_ascii=False)
|
||||
return file_path
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
|
||||
|
||||
def __init__(self, recorder: List[Dict], db_ops=None):
|
||||
self.recorder = recorder
|
||||
self._db_ops = db_ops
|
||||
self._pending_rows: List[Tuple] = []
|
||||
self._fetchall_rows: List[Tuple] = []
|
||||
self.rowcount = 0
|
||||
self.connection = SimpleNamespace(encoding="UTF8")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def execute(self, sql: str, params=None):
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
# 处理 information_schema 查询,用于结构感知写入。
|
||||
lowered = sql_text.lower()
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
self._fetchall_rows = []
|
||||
return
|
||||
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
def fetchone(self):
|
||||
return None
|
||||
|
||||
def fetchall(self):
|
||||
return list(self._fetchall_rows)
|
||||
|
||||
def mogrify(self, template, args):
|
||||
self._pending_rows.append(tuple(args))
|
||||
return b"(?)"
|
||||
|
||||
def _record_upserts(self, sql_text: str):
|
||||
if not self._db_ops:
|
||||
return
|
||||
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
|
||||
if not match:
|
||||
return
|
||||
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
|
||||
rows = []
|
||||
for idx, row in enumerate(self._pending_rows):
|
||||
if len(row) != len(columns):
|
||||
continue
|
||||
row_dict = {}
|
||||
for col, val in zip(columns, row):
|
||||
if col == "record_index" and val in (None, ""):
|
||||
row_dict[col] = idx
|
||||
continue
|
||||
if hasattr(val, "adapted"):
|
||||
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
|
||||
else:
|
||||
row_dict[col] = val
|
||||
rows.append(row_dict)
|
||||
if rows:
|
||||
self._db_ops.upserts.append(
|
||||
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
|
||||
return [
|
||||
("id", "bigint", "int8"),
|
||||
("sitegoodsstockid", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("content_hash", "text", "text"),
|
||||
("source_file", "text", "text"),
|
||||
("source_endpoint", "text", "text"),
|
||||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
|
||||
|
||||
def __init__(self, db_ops):
|
||||
self.statements: List[Dict] = []
|
||||
self._db_ops = db_ops
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.statements, self._db_ops)
|
||||
|
||||
|
||||
class FakeDBOperations:
|
||||
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
|
||||
|
||||
def __init__(self):
|
||||
self.upserts: List[Dict] = []
|
||||
self.executes: List[Dict] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
self.conn = FakeConnection(self)
|
||||
# 预设查询结果(FIFO),用于测试中控制数据库返回的行
|
||||
self.query_results: List[List[Dict]] = []
|
||||
|
||||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
self.upserts.append(
|
||||
{
|
||||
"sql": sql.strip(),
|
||||
"count": len(rows),
|
||||
"page_size": page_size,
|
||||
"rows": [dict(row) for row in rows],
|
||||
}
|
||||
)
|
||||
return len(rows), 0
|
||||
|
||||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
self.executes.append(
|
||||
{
|
||||
"sql": sql.strip(),
|
||||
"count": len(rows),
|
||||
"page_size": page_size,
|
||||
"rows": [dict(row) for row in rows],
|
||||
}
|
||||
)
|
||||
|
||||
def execute(self, sql: str, params=None):
|
||||
self.executes.append({"sql": sql.strip(), "params": params})
|
||||
|
||||
def query(self, sql: str, params=None):
|
||||
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
|
||||
if self.query_results:
|
||||
return self.query_results.pop(0)
|
||||
return []
|
||||
|
||||
def cursor(self):
|
||||
return self.conn.cursor()
|
||||
|
||||
def commit(self):
|
||||
self.commits += 1
|
||||
|
||||
def rollback(self):
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
class FakeAPIClient:
|
||||
"""在线模式使用的伪 API Client,直接返回预置的内存数据并记录调用,以确保任务参数正确传递。"""
|
||||
|
||||
def __init__(self, data_map: Dict[str, List[Dict]]):
|
||||
self.data_map = data_map
|
||||
self.calls: List[Dict] = []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def iter_paginated(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
page_size: int = 200,
|
||||
page_field: str = "page",
|
||||
size_field: str = "limit",
|
||||
data_path: Tuple[str, ...] = (),
|
||||
list_key: str | None = None,
|
||||
):
|
||||
self.calls.append({"endpoint": endpoint, "params": params})
|
||||
if endpoint not in self.data_map:
|
||||
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
|
||||
|
||||
records = list(self.data_map[endpoint])
|
||||
yield 1, records, dict(params or {}), {"data": records}
|
||||
|
||||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||||
records = []
|
||||
pages = []
|
||||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||||
records.extend(page_records)
|
||||
pages.append({"page": page_no, "request": req, "response": resp})
|
||||
return records, pages
|
||||
|
||||
def get_source_hint(self, endpoint: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class OfflineAPIClient:
|
||||
"""离线模式专用 API Client,根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
|
||||
|
||||
def __init__(self, file_map: Dict[str, Path]):
|
||||
self.file_map = {k: Path(v) for k, v in file_map.items()}
|
||||
self.calls: List[Dict] = []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def iter_paginated(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
page_size: int = 200,
|
||||
page_field: str = "page",
|
||||
size_field: str = "limit",
|
||||
data_path: Tuple[str, ...] = (),
|
||||
list_key: str | None = None,
|
||||
):
|
||||
self.calls.append({"endpoint": endpoint, "params": params})
|
||||
if endpoint not in self.file_map:
|
||||
raise AssertionError(f"Missing archive for endpoint {endpoint}")
|
||||
|
||||
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
|
||||
payload = json.load(fp)
|
||||
|
||||
data = payload
|
||||
for key in data_path:
|
||||
if isinstance(data, dict):
|
||||
data = data.get(key, [])
|
||||
|
||||
if list_key and isinstance(data, dict):
|
||||
data = data.get(list_key, [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = []
|
||||
|
||||
total = len(data)
|
||||
start = 0
|
||||
page = 1
|
||||
while start < total or (start == 0 and total == 0):
|
||||
chunk = data[start : start + page_size]
|
||||
if not chunk and total != 0:
|
||||
break
|
||||
yield page, list(chunk), dict(params or {}), payload
|
||||
if len(chunk) < page_size:
|
||||
break
|
||||
start += page_size
|
||||
page += 1
|
||||
|
||||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||||
records = []
|
||||
pages = []
|
||||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||||
records.extend(page_records)
|
||||
pages.append({"page": page_no, "request": req, "response": resp})
|
||||
return records, pages
|
||||
|
||||
def get_source_hint(self, endpoint: str) -> str | None:
|
||||
if endpoint not in self.file_map:
|
||||
return None
|
||||
return str(self.file_map[endpoint])
|
||||
|
||||
|
||||
class RealDBOperationsAdapter:
|
||||
|
||||
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
|
||||
|
||||
def __init__(self, dsn: str):
|
||||
self._conn = DatabaseConnection(dsn)
|
||||
self._ops = PgDBOperations(self._conn)
|
||||
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
|
||||
self.conn = self._conn.conn
|
||||
|
||||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
|
||||
|
||||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
return self._ops.batch_execute(sql, rows, page_size=page_size)
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def rollback(self):
|
||||
self._conn.rollback()
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_operations():
|
||||
"""
|
||||
测试专用的 DB 操作上下文:
|
||||
- 若设置 TEST_DB_DSN,则连接真实 PostgreSQL;
|
||||
- 否则回退到 FakeDBOperations(内存桩)。
|
||||
"""
|
||||
dsn = os.environ.get("TEST_DB_DSN")
|
||||
if dsn:
|
||||
adapter = RealDBOperationsAdapter(dsn)
|
||||
try:
|
||||
yield adapter
|
||||
finally:
|
||||
adapter.close()
|
||||
else:
|
||||
fake = FakeDBOperations()
|
||||
yield fake
|
||||
|
||||
|
||||
# 14 个独立 ODS 任务已废弃删除(写入不存在的 billiards.* schema,已被通用 ODS 任务替代)
|
||||
TASK_SPECS: List[TaskSpec] = []
|
||||
@@ -0,0 +1,696 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试 — 文档对齐分析器 (doc_alignment_analyzer.py)
|
||||
|
||||
覆盖:
|
||||
- scan_docs 文档来源识别
|
||||
- extract_code_references 代码引用提取
|
||||
- check_reference_validity 引用有效性检查
|
||||
- find_undocumented_modules 缺失文档检测
|
||||
- check_ddl_vs_dictionary DDL 与数据字典比对
|
||||
- check_api_samples_vs_parsers API 样本与解析器比对
|
||||
- render_alignment_report 报告渲染
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.audit import AlignmentIssue, DocMapping
|
||||
from scripts.audit.doc_alignment_analyzer import (
|
||||
_parse_ddl_tables,
|
||||
_parse_dictionary_tables,
|
||||
build_mappings,
|
||||
check_api_samples_vs_parsers,
|
||||
check_ddl_vs_dictionary,
|
||||
check_reference_validity,
|
||||
extract_code_references,
|
||||
find_undocumented_modules,
|
||||
render_alignment_report,
|
||||
scan_docs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_docs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScanDocs:
|
||||
"""文档来源识别测试。"""
|
||||
|
||||
def test_finds_docs_dir_md(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "docs").mkdir()
|
||||
(tmp_path / "docs" / "guide.md").write_text("# Guide", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "docs/guide.md" in result
|
||||
|
||||
def test_finds_root_readme(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "README.md").write_text("# Readme", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "README.md" in result
|
||||
|
||||
def test_finds_docs_subdir_requirements(self, tmp_path: Path) -> None:
|
||||
"""docs/requirements/ 下的文件应被扫描到。"""
|
||||
req_dir = tmp_path / "docs" / "requirements"
|
||||
req_dir.mkdir(parents=True)
|
||||
(req_dir / "需求.md").write_text("需求", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "docs/requirements/需求.md" in result
|
||||
|
||||
def test_finds_module_readme(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "gui").mkdir()
|
||||
(tmp_path / "gui" / "README.md").write_text("# GUI", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "gui/README.md" in result
|
||||
|
||||
def test_finds_steering_files(self, tmp_path: Path) -> None:
|
||||
steering = tmp_path / ".kiro" / "steering"
|
||||
steering.mkdir(parents=True)
|
||||
(steering / "tech.md").write_text("# Tech", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert ".kiro/steering/tech.md" in result
|
||||
|
||||
def test_finds_json_samples(self, tmp_path: Path) -> None:
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "member.json").write_text("[]", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "docs/test-json-doc/member.json" in result
|
||||
|
||||
def test_empty_repo_returns_empty(self, tmp_path: Path) -> None:
|
||||
result = scan_docs(tmp_path)
|
||||
assert result == []
|
||||
|
||||
def test_results_sorted(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "docs").mkdir()
|
||||
(tmp_path / "docs" / "z.md").write_text("z", encoding="utf-8")
|
||||
(tmp_path / "docs" / "a.md").write_text("a", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("r", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert result == sorted(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_code_references
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractCodeReferences:
|
||||
"""代码引用提取测试。"""
|
||||
|
||||
def test_extracts_backtick_paths(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("使用 `tasks/base_task.py` 作为基类", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert "tasks/base_task.py" in refs
|
||||
|
||||
def test_extracts_class_names(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("继承 `BaseTask` 类", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert "BaseTask" in refs
|
||||
|
||||
def test_skips_single_char(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("变量 `x` 和 `y`", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert refs == []
|
||||
|
||||
def test_skips_pure_numbers(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("版本 `2.0.0` 和 ID `12345`", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert refs == []
|
||||
|
||||
def test_deduplicates(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("`foo.py` 和 `foo.py` 重复", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert refs.count("foo.py") == 1
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
|
||||
refs = extract_code_references(tmp_path / "nonexistent.md")
|
||||
assert refs == []
|
||||
|
||||
def test_normalizes_backslash(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("路径 `tasks\\base_task.py`", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert "tasks/base_task.py" in refs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_reference_validity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckReferenceValidity:
|
||||
"""引用有效性检查测试。"""
|
||||
|
||||
def test_valid_file_path(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "tasks").mkdir()
|
||||
(tmp_path / "tasks" / "base.py").write_text("", encoding="utf-8")
|
||||
assert check_reference_validity("tasks/base.py", tmp_path) is True
|
||||
|
||||
def test_invalid_file_path(self, tmp_path: Path) -> None:
|
||||
assert check_reference_validity("nonexistent/file.py", tmp_path) is False
|
||||
|
||||
def test_strips_legacy_prefix(self, tmp_path: Path) -> None:
|
||||
"""兼容旧包名前缀(etl_billiards/)和当前根目录前缀(FQ-ETL/)"""
|
||||
(tmp_path / "tasks").mkdir()
|
||||
(tmp_path / "tasks" / "x.py").write_text("", encoding="utf-8")
|
||||
assert check_reference_validity("etl_billiards/tasks/x.py", tmp_path) is True
|
||||
assert check_reference_validity("FQ-ETL/tasks/x.py", tmp_path) is True
|
||||
|
||||
def test_directory_path(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "loaders").mkdir()
|
||||
assert check_reference_validity("loaders", tmp_path) is True
|
||||
|
||||
def test_dotted_module_path(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "config").mkdir()
|
||||
(tmp_path / "config" / "settings.py").write_text("", encoding="utf-8")
|
||||
assert check_reference_validity("config.settings", tmp_path) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_undocumented_modules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindUndocumentedModules:
|
||||
"""缺失文档检测测试。"""
|
||||
|
||||
def test_finds_undocumented(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert "tasks/ods_task.py" in result
|
||||
|
||||
def test_excludes_init(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert all("__init__" not in r for r in result)
|
||||
|
||||
def test_documented_module_excluded(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, {"tasks/ods_task.py"})
|
||||
assert "tasks/ods_task.py" not in result
|
||||
|
||||
def test_non_core_dirs_ignored(self, tmp_path: Path) -> None:
|
||||
"""gui/ 不在核心代码目录列表中,不应被检测。"""
|
||||
gui_dir = tmp_path / "gui"
|
||||
gui_dir.mkdir()
|
||||
(gui_dir / "main.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert all("gui/" not in r for r in result)
|
||||
|
||||
def test_results_sorted(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "z_task.py").write_text("", encoding="utf-8")
|
||||
(tasks_dir / "a_task.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert result == sorted(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_ddl_tables / _parse_dictionary_tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseDdlTables:
|
||||
"""DDL 解析测试。"""
|
||||
|
||||
def test_extracts_table_and_columns(self) -> None:
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS dim_member (
|
||||
member_id BIGINT,
|
||||
nickname TEXT,
|
||||
mobile TEXT,
|
||||
PRIMARY KEY (member_id)
|
||||
);
|
||||
"""
|
||||
result = _parse_ddl_tables(sql)
|
||||
assert "dim_member" in result
|
||||
assert "member_id" in result["dim_member"]
|
||||
assert "nickname" in result["dim_member"]
|
||||
assert "mobile" in result["dim_member"]
|
||||
|
||||
def test_handles_schema_prefix(self) -> None:
|
||||
sql = "CREATE TABLE billiards_dwd.dim_site (\n site_id BIGINT\n);"
|
||||
result = _parse_ddl_tables(sql)
|
||||
assert "dim_site" in result
|
||||
|
||||
def test_excludes_sql_keywords(self) -> None:
|
||||
sql = """
|
||||
CREATE TABLE test_tbl (
|
||||
id INTEGER,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
"""
|
||||
result = _parse_ddl_tables(sql)
|
||||
assert "primary" not in result.get("test_tbl", set())
|
||||
|
||||
|
||||
class TestParseDictionaryTables:
|
||||
"""数据字典解析测试。"""
|
||||
|
||||
def test_extracts_table_and_fields(self) -> None:
|
||||
md = """## dim_member
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| member_id | BIGINT | 会员ID |
|
||||
| nickname | TEXT | 昵称 |
|
||||
"""
|
||||
result = _parse_dictionary_tables(md)
|
||||
assert "dim_member" in result
|
||||
assert "member_id" in result["dim_member"]
|
||||
assert "nickname" in result["dim_member"]
|
||||
|
||||
def test_skips_header_row(self) -> None:
|
||||
md = """## dim_test
|
||||
|
||||
| 字段 | 类型 |
|
||||
|------|------|
|
||||
| col_a | INT |
|
||||
"""
|
||||
result = _parse_dictionary_tables(md)
|
||||
assert "字段" not in result.get("dim_test", set())
|
||||
|
||||
def test_handles_backtick_table_name(self) -> None:
|
||||
md = "## `dim_goods`\n\n| 字段 |\n| goods_id |"
|
||||
result = _parse_dictionary_tables(md)
|
||||
assert "dim_goods" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_ddl_vs_dictionary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckDdlVsDictionary:
|
||||
"""DDL 与数据字典比对测试。"""
|
||||
|
||||
def test_detects_missing_table_in_dictionary(self, tmp_path: Path) -> None:
|
||||
# DDL 有表,字典没有
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_test.sql").write_text(
|
||||
"CREATE TABLE dim_orphan (\n id BIGINT\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
|
||||
"## dim_other\n\n| 字段 |\n| id |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_ddl_vs_dictionary(tmp_path)
|
||||
missing = [i for i in issues if i.issue_type == "missing"]
|
||||
assert any("dim_orphan" in i.description for i in missing)
|
||||
|
||||
def test_detects_column_mismatch(self, tmp_path: Path) -> None:
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_test.sql").write_text(
|
||||
"CREATE TABLE dim_x (\n id BIGINT,\n extra_col TEXT\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
|
||||
"## dim_x\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_ddl_vs_dictionary(tmp_path)
|
||||
conflict = [i for i in issues if i.issue_type == "conflict"]
|
||||
assert any("extra_col" in i.description for i in conflict)
|
||||
|
||||
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_test.sql").write_text(
|
||||
"CREATE TABLE dim_ok (\n id BIGINT\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
|
||||
"## dim_ok\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_ddl_vs_dictionary(tmp_path)
|
||||
assert len(issues) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_api_samples_vs_parsers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckApiSamplesVsParsers:
|
||||
"""API 样本与解析器比对测试。"""
|
||||
|
||||
def test_detects_json_field_not_in_ods(self, tmp_path: Path) -> None:
|
||||
# JSON 样本有 extra_field,ODS 没有
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "test_entity.json").write_text(
|
||||
json.dumps([{"id": 1, "name": "a", "extra_field": "x"}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_ODS_doc.sql").write_text(
|
||||
"CREATE TABLE billiards_ods.test_entity (\n"
|
||||
" id BIGINT,\n name TEXT,\n"
|
||||
" content_hash TEXT,\n payload JSONB\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_api_samples_vs_parsers(tmp_path)
|
||||
assert any("extra_field" in i.description for i in issues)
|
||||
|
||||
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "aligned_entity.json").write_text(
|
||||
json.dumps([{"id": 1, "name": "a"}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_ODS_doc.sql").write_text(
|
||||
"CREATE TABLE billiards_ods.aligned_entity (\n"
|
||||
" id BIGINT,\n name TEXT,\n"
|
||||
" content_hash TEXT,\n payload JSONB\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_api_samples_vs_parsers(tmp_path)
|
||||
assert len(issues) == 0
|
||||
|
||||
def test_skips_when_no_ods_table(self, tmp_path: Path) -> None:
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "unknown.json").write_text(
|
||||
json.dumps([{"a": 1}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_ODS_doc.sql").write_text("-- empty", encoding="utf-8")
|
||||
issues = check_api_samples_vs_parsers(tmp_path)
|
||||
assert len(issues) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_alignment_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRenderAlignmentReport:
|
||||
"""报告渲染测试。"""
|
||||
|
||||
def test_contains_all_sections(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
assert "## 映射关系" in report
|
||||
assert "## 过期点" in report
|
||||
assert "## 冲突点" in report
|
||||
assert "## 缺失点" in report
|
||||
assert "## 统计摘要" in report
|
||||
|
||||
def test_contains_header_metadata(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
assert "生成时间" in report
|
||||
assert "`/repo`" in report
|
||||
|
||||
def test_contains_iso_timestamp(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
# ISO 格式时间戳包含 T 和 Z
|
||||
import re
|
||||
assert re.search(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", report)
|
||||
|
||||
def test_mapping_table_rendered(self) -> None:
|
||||
mappings = [
|
||||
DocMapping(
|
||||
doc_path="docs/guide.md",
|
||||
doc_topic="项目文档",
|
||||
related_code=["tasks/base.py"],
|
||||
status="aligned",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report(mappings, [], "/repo")
|
||||
assert "`docs/guide.md`" in report
|
||||
assert "`tasks/base.py`" in report
|
||||
assert "aligned" in report
|
||||
|
||||
def test_stale_issues_rendered(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue(
|
||||
doc_path="docs/old.md",
|
||||
issue_type="stale",
|
||||
description="引用了已删除的文件",
|
||||
related_code="tasks/deleted.py",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report([], issues, "/repo")
|
||||
assert "引用了已删除的文件" in report
|
||||
assert "## 过期点" in report
|
||||
|
||||
def test_conflict_issues_rendered(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue(
|
||||
doc_path="docs/dict.md",
|
||||
issue_type="conflict",
|
||||
description="字段不一致",
|
||||
related_code="database/schema.sql",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report([], issues, "/repo")
|
||||
assert "字段不一致" in report
|
||||
|
||||
def test_missing_issues_rendered(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue(
|
||||
doc_path="docs/dict.md",
|
||||
issue_type="missing",
|
||||
description="缺少表定义",
|
||||
related_code="database/schema.sql",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report([], issues, "/repo")
|
||||
assert "缺少表定义" in report
|
||||
|
||||
def test_summary_counts(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue("a", "stale", "d1", "c1"),
|
||||
AlignmentIssue("b", "stale", "d2", "c2"),
|
||||
AlignmentIssue("c", "conflict", "d3", "c3"),
|
||||
AlignmentIssue("d", "missing", "d4", "c4"),
|
||||
]
|
||||
mappings = [DocMapping("x", "t", [], "aligned")]
|
||||
report = render_alignment_report(mappings, issues, "/repo")
|
||||
assert "过期点数量:2" in report
|
||||
assert "冲突点数量:1" in report
|
||||
assert "缺失点数量:1" in report
|
||||
assert "文档总数:1" in report
|
||||
|
||||
def test_empty_report(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
assert "未发现过期点" in report
|
||||
assert "未发现冲突点" in report
|
||||
assert "未发现缺失点" in report
|
||||
assert "过期点数量:0" in report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 11 / 12 / 16 (hypothesis)
|
||||
# hypothesis 与 pytest 的 function-scoped fixture (tmp_path) 不兼容,
|
||||
# 因此在测试内部使用 tempfile.mkdtemp 自行管理临时目录。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit.doc_alignment_analyzer import _CORE_CODE_DIRS
|
||||
|
||||
|
||||
class TestPropertyStaleReferenceDetection:
|
||||
"""Feature: repo-audit, Property 11: 过期引用检测
|
||||
|
||||
*对于任意* 文档中提取的代码引用,若该引用指向的文件路径在仓库中不存在,
|
||||
则 check_reference_validity 应返回 False。
|
||||
|
||||
Validates: Requirements 3.3
|
||||
"""
|
||||
|
||||
_safe_name = st.from_regex(r"[a-z][a-z0-9_]{1,12}", fullmatch=True)
|
||||
|
||||
@given(
|
||||
existing_names=st.lists(
|
||||
_safe_name, min_size=1, max_size=5, unique=True,
|
||||
),
|
||||
missing_names=st.lists(
|
||||
_safe_name, min_size=1, max_size=5, unique=True,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_nonexistent_path_returns_false(
|
||||
self,
|
||||
existing_names: list[str],
|
||||
missing_names: list[str],
|
||||
) -> None:
|
||||
"""不存在的文件路径引用应返回 False。"""
|
||||
tmp = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
for name in existing_names:
|
||||
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
|
||||
|
||||
existing_set = set(existing_names)
|
||||
# 只检查确实不存在的名称
|
||||
truly_missing = [n for n in missing_names if n not in existing_set]
|
||||
for name in truly_missing:
|
||||
ref = f"nonexistent_dir/{name}.py"
|
||||
result = check_reference_validity(ref, tmp)
|
||||
assert result is False, (
|
||||
f"引用 '{ref}' 指向不存在的文件,但返回了 True"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
@given(
|
||||
existing_names=st.lists(
|
||||
_safe_name, min_size=1, max_size=5, unique=True,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_existing_path_returns_true(
|
||||
self,
|
||||
existing_names: list[str],
|
||||
) -> None:
|
||||
"""存在的文件路径引用应返回 True。"""
|
||||
tmp = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
for name in existing_names:
|
||||
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
|
||||
|
||||
for name in existing_names:
|
||||
ref = f"{name}.py"
|
||||
result = check_reference_validity(ref, tmp)
|
||||
assert result is True, (
|
||||
f"引用 '{ref}' 指向存在的文件,但返回了 False"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
|
||||
class TestPropertyMissingDocDetection:
|
||||
"""Feature: repo-audit, Property 12: 缺失文档检测
|
||||
|
||||
*对于任意* 核心代码模块集合和已文档化模块集合,
|
||||
find_undocumented_modules 返回的缺失列表应恰好等于核心模块集合与已文档化集合的差集。
|
||||
|
||||
Validates: Requirements 3.5
|
||||
"""
|
||||
|
||||
_core_dir = st.sampled_from(list(_CORE_CODE_DIRS))
|
||||
_module_name = st.from_regex(r"[a-z][a-z0-9_]{1,10}", fullmatch=True)
|
||||
|
||||
@given(
|
||||
core_dir=_core_dir,
|
||||
module_names=st.lists(
|
||||
_module_name, min_size=2, max_size=6, unique=True,
|
||||
),
|
||||
doc_fraction=st.floats(min_value=0.0, max_value=1.0),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_undocumented_equals_difference(
|
||||
self,
|
||||
core_dir: str,
|
||||
module_names: list[str],
|
||||
doc_fraction: float,
|
||||
) -> None:
|
||||
"""返回的缺失列表应恰好等于核心模块与已文档化集合的差集。"""
|
||||
tmp = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
code_dir = tmp / core_dir
|
||||
code_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
all_modules: set[str] = set()
|
||||
for name in module_names:
|
||||
(code_dir / f"{name}.py").write_text("# module", encoding="utf-8")
|
||||
all_modules.add(f"{core_dir}/{name}.py")
|
||||
|
||||
split_idx = int(len(module_names) * doc_fraction)
|
||||
documented = {
|
||||
f"{core_dir}/{n}.py" for n in module_names[:split_idx]
|
||||
}
|
||||
|
||||
result = find_undocumented_modules(tmp, documented)
|
||||
expected = sorted(all_modules - documented)
|
||||
|
||||
assert result == expected, (
|
||||
f"期望缺失列表 {expected},实际得到 {result}"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
|
||||
class TestPropertyAlignmentReportSections:
|
||||
"""Feature: repo-audit, Property 16: 文档对齐报告分区完整性
|
||||
|
||||
*对于任意* render_alignment_report 的输出,Markdown 文本应包含
|
||||
"映射关系"、"过期点"、"冲突点"、"缺失点"四个分区标题。
|
||||
|
||||
Validates: Requirements 3.8
|
||||
"""
|
||||
|
||||
_issue_type = st.sampled_from(["stale", "conflict", "missing"])
|
||||
_text = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P"),
|
||||
blacklist_characters="\x00",
|
||||
),
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
_mapping_st = st.builds(
|
||||
DocMapping,
|
||||
doc_path=_text,
|
||||
doc_topic=_text,
|
||||
related_code=st.lists(_text, max_size=3),
|
||||
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
|
||||
)
|
||||
|
||||
_issue_st = st.builds(
|
||||
AlignmentIssue,
|
||||
doc_path=_text,
|
||||
issue_type=_issue_type,
|
||||
description=_text,
|
||||
related_code=_text,
|
||||
)
|
||||
|
||||
@given(
|
||||
mappings=st.lists(_mapping_st, max_size=5),
|
||||
issues=st.lists(_issue_st, max_size=8),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_report_contains_four_sections(
|
||||
self,
|
||||
mappings: list[DocMapping],
|
||||
issues: list[AlignmentIssue],
|
||||
) -> None:
|
||||
"""报告应包含四个分区标题。"""
|
||||
report = render_alignment_report(mappings, issues, "/test/repo")
|
||||
|
||||
required_sections = ["## 映射关系", "## 过期点", "## 冲突点", "## 缺失点"]
|
||||
for section in required_sections:
|
||||
assert section in report, (
|
||||
f"报告中缺少分区标题 '{section}'"
|
||||
)
|
||||
667
apps/etl/connectors/feiqiu/tests/unit/test_audit_flow.py
Normal file
667
apps/etl/connectors/feiqiu/tests/unit/test_audit_flow.py
Normal file
@@ -0,0 +1,667 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试 — 流程树分析器 (flow_analyzer.py)
|
||||
|
||||
覆盖:
|
||||
- parse_imports: import 语句解析、标准库/第三方排除、语法错误容错
|
||||
- build_flow_tree: 递归构建、循环导入处理
|
||||
- find_orphan_modules: 孤立模块检测
|
||||
- render_flow_report: Markdown 渲染、Mermaid 图、统计摘要
|
||||
- discover_entry_points: 入口点识别
|
||||
- classify_task_type / classify_loader_type: 类型区分
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.audit import FileEntry, FlowNode
|
||||
from scripts.audit.flow_analyzer import (
|
||||
build_flow_tree,
|
||||
classify_loader_type,
|
||||
classify_task_type,
|
||||
discover_entry_points,
|
||||
find_orphan_modules,
|
||||
parse_imports,
|
||||
render_flow_report,
|
||||
_path_to_module_name,
|
||||
_parse_bat_python_target,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_imports 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseImports:
|
||||
"""import 语句解析测试。"""
|
||||
|
||||
def test_absolute_import(self, tmp_path: Path) -> None:
|
||||
"""绝对导入项目内部模块应被识别。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import cli.main\nimport config.settings\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert "cli.main" in result
|
||||
assert "config.settings" in result
|
||||
|
||||
def test_from_import(self, tmp_path: Path) -> None:
|
||||
"""from ... import 语句应被识别。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("from tasks.base_task import BaseTask\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert "tasks.base_task" in result
|
||||
|
||||
def test_stdlib_excluded(self, tmp_path: Path) -> None:
|
||||
"""标准库模块应被排除。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import os\nimport sys\nimport json\nfrom pathlib import Path\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
def test_third_party_excluded(self, tmp_path: Path) -> None:
|
||||
"""第三方包应被排除。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import requests\nfrom psycopg2 import sql\nimport flask\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
def test_mixed_imports(self, tmp_path: Path) -> None:
|
||||
"""混合导入应只保留项目内部模块。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text(
|
||||
"import os\nimport cli.main\nimport requests\nfrom loaders.base_loader import BaseLoader\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = parse_imports(f)
|
||||
assert "cli.main" in result
|
||||
assert "loaders.base_loader" in result
|
||||
assert "os" not in result
|
||||
assert "requests" not in result
|
||||
|
||||
def test_syntax_error_returns_empty(self, tmp_path: Path) -> None:
|
||||
"""语法错误的文件应返回空列表。"""
|
||||
f = tmp_path / "bad.py"
|
||||
f.write_text("def broken(\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
|
||||
"""不存在的文件应返回空列表。"""
|
||||
result = parse_imports(tmp_path / "nonexistent.py")
|
||||
assert result == []
|
||||
|
||||
def test_deduplication(self, tmp_path: Path) -> None:
|
||||
"""重复导入应去重。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import cli.main\nimport cli.main\nfrom cli.main import main\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result.count("cli.main") == 1
|
||||
|
||||
def test_empty_file(self, tmp_path: Path) -> None:
|
||||
"""空文件应返回空列表。"""
|
||||
f = tmp_path / "empty.py"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_flow_tree 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildFlowTree:
|
||||
"""流程树构建测试。"""
|
||||
|
||||
def test_single_file_no_imports(self, tmp_path: Path) -> None:
|
||||
"""无导入的单文件应生成叶节点。"""
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
tree = build_flow_tree(tmp_path, "cli/main.py")
|
||||
assert tree.name == "cli.main"
|
||||
assert tree.source_file == "cli/main.py"
|
||||
assert tree.children == []
|
||||
|
||||
def test_simple_import_chain(self, tmp_path: Path) -> None:
|
||||
"""简单导入链应正确构建子节点。"""
|
||||
# cli/main.py → config/settings.py
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text(
|
||||
"from config.settings import AppConfig\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(config_dir / "settings.py").write_text("class AppConfig: pass\n", encoding="utf-8")
|
||||
|
||||
tree = build_flow_tree(tmp_path, "cli/main.py")
|
||||
assert tree.name == "cli.main"
|
||||
assert len(tree.children) == 1
|
||||
assert tree.children[0].name == "config.settings"
|
||||
|
||||
def test_circular_import_no_infinite_loop(self, tmp_path: Path) -> None:
|
||||
"""循环导入不应导致无限递归。"""
|
||||
pkg = tmp_path / "utils"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
# a → b → a(循环)
|
||||
(pkg / "a.py").write_text("from utils.b import func_b\n", encoding="utf-8")
|
||||
(pkg / "b.py").write_text("from utils.a import func_a\n", encoding="utf-8")
|
||||
|
||||
# 不应抛出 RecursionError
|
||||
tree = build_flow_tree(tmp_path, "utils/a.py")
|
||||
assert tree.name == "utils.a"
|
||||
|
||||
def test_entry_node_type(self, tmp_path: Path) -> None:
|
||||
"""CLI 入口文件应标记为 entry 类型。"""
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
tree = build_flow_tree(tmp_path, "cli/main.py")
|
||||
assert tree.node_type == "entry"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_orphan_modules 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindOrphanModules:
|
||||
"""孤立模块检测测试。"""
|
||||
|
||||
def test_all_reachable(self, tmp_path: Path) -> None:
|
||||
"""所有模块都可达时应返回空列表。"""
|
||||
entries = [
|
||||
FileEntry("cli/main.py", False, 100, ".py", False),
|
||||
FileEntry("config/settings.py", False, 200, ".py", False),
|
||||
]
|
||||
reachable = {"cli/main.py", "config/settings.py"}
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_orphan_detected(self, tmp_path: Path) -> None:
|
||||
"""不可达的模块应被标记为孤立。"""
|
||||
entries = [
|
||||
FileEntry("cli/main.py", False, 100, ".py", False),
|
||||
FileEntry("utils/orphan.py", False, 50, ".py", False),
|
||||
]
|
||||
reachable = {"cli/main.py"}
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert "utils/orphan.py" in orphans
|
||||
|
||||
def test_init_files_excluded(self, tmp_path: Path) -> None:
|
||||
"""__init__.py 不应被视为孤立模块。"""
|
||||
entries = [
|
||||
FileEntry("cli/__init__.py", False, 0, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert "cli/__init__.py" not in orphans
|
||||
|
||||
def test_test_files_excluded(self, tmp_path: Path) -> None:
|
||||
"""测试文件不应被视为孤立模块。"""
|
||||
entries = [
|
||||
FileEntry("tests/unit/test_something.py", False, 100, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_audit_scripts_excluded(self, tmp_path: Path) -> None:
|
||||
"""审计脚本自身不应被视为孤立模块。"""
|
||||
entries = [
|
||||
FileEntry("scripts/audit/scanner.py", False, 100, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_directories_excluded(self, tmp_path: Path) -> None:
|
||||
"""目录条目不应出现在孤立列表中。"""
|
||||
entries = [
|
||||
FileEntry("cli", True, 0, "", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_sorted_output(self, tmp_path: Path) -> None:
|
||||
"""孤立模块列表应按路径排序。"""
|
||||
entries = [
|
||||
FileEntry("utils/z.py", False, 50, ".py", False),
|
||||
FileEntry("utils/a.py", False, 50, ".py", False),
|
||||
FileEntry("cli/orphan.py", False, 50, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == sorted(orphans)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_flow_report 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRenderFlowReport:
|
||||
"""流程树报告渲染测试。"""
|
||||
|
||||
def test_header_contains_timestamp_and_path(self) -> None:
|
||||
"""报告头部应包含时间戳和仓库路径。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, [], "/repo")
|
||||
assert "生成时间:" in report
|
||||
assert "`/repo`" in report
|
||||
|
||||
def test_contains_mermaid_block(self) -> None:
|
||||
"""报告应包含 Mermaid 代码块。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, [], "/repo")
|
||||
assert "```mermaid" in report
|
||||
assert "graph TD" in report
|
||||
|
||||
def test_contains_indented_text(self) -> None:
|
||||
"""报告应包含缩进文本形式的流程树。"""
|
||||
child = FlowNode("config.settings", "config/settings.py", "module", [])
|
||||
root = FlowNode("cli.main", "cli/main.py", "entry", [child])
|
||||
report = render_flow_report([root], [], "/repo")
|
||||
assert "`cli.main`" in report
|
||||
assert "`config.settings`" in report
|
||||
|
||||
def test_orphan_section(self) -> None:
|
||||
"""报告应包含孤立模块列表。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
orphans = ["utils/orphan.py", "models/unused.py"]
|
||||
report = render_flow_report(trees, orphans, "/repo")
|
||||
assert "孤立模块" in report
|
||||
assert "`utils/orphan.py`" in report
|
||||
assert "`models/unused.py`" in report
|
||||
|
||||
def test_no_orphans_message(self) -> None:
|
||||
"""无孤立模块时应显示提示信息。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, [], "/repo")
|
||||
assert "未发现孤立模块" in report
|
||||
|
||||
def test_statistics_summary(self) -> None:
|
||||
"""报告应包含统计摘要。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, ["a.py"], "/repo")
|
||||
assert "统计摘要" in report
|
||||
assert "入口点" in report
|
||||
assert "任务" in report
|
||||
assert "加载器" in report
|
||||
assert "孤立模块" in report
|
||||
|
||||
def test_task_type_annotation(self) -> None:
|
||||
"""任务模块应带有类型标注。"""
|
||||
task_node = FlowNode("tasks.ods_member", "tasks/ods_member.py", "module", [])
|
||||
root = FlowNode("cli.main", "cli/main.py", "entry", [task_node])
|
||||
report = render_flow_report([root], [], "/repo")
|
||||
assert "ODS" in report
|
||||
|
||||
def test_loader_type_annotation(self) -> None:
|
||||
"""加载器模块应带有类型标注。"""
|
||||
loader_node = FlowNode(
|
||||
"loaders.dimensions.member", "loaders/dimensions/member.py", "module", []
|
||||
)
|
||||
root = FlowNode("cli.main", "cli/main.py", "entry", [loader_node])
|
||||
report = render_flow_report([root], [], "/repo")
|
||||
assert "维度" in report or "SCD2" in report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# discover_entry_points 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDiscoverEntryPoints:
|
||||
"""入口点识别测试。"""
|
||||
|
||||
def test_cli_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别 CLI 入口。"""
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
cli_entries = [e for e in entries if e["type"] == "CLI"]
|
||||
assert len(cli_entries) == 1
|
||||
assert cli_entries[0]["file"] == "cli/main.py"
|
||||
|
||||
def test_gui_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别 GUI 入口。"""
|
||||
gui_dir = tmp_path / "gui"
|
||||
gui_dir.mkdir()
|
||||
(gui_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
gui_entries = [e for e in entries if e["type"] == "GUI"]
|
||||
assert len(gui_entries) == 1
|
||||
|
||||
def test_bat_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别批处理文件入口。"""
|
||||
(tmp_path / "run_etl.bat").write_text(
|
||||
"@echo off\npython -m cli.main %*\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
bat_entries = [e for e in entries if e["type"] == "批处理"]
|
||||
assert len(bat_entries) == 1
|
||||
assert "cli.main" in bat_entries[0]["description"]
|
||||
|
||||
def test_script_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别运维脚本入口。"""
|
||||
scripts_dir = tmp_path / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(scripts_dir / "rebuild_db.py").write_text(
|
||||
'if __name__ == "__main__": pass\n', encoding="utf-8"
|
||||
)
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
script_entries = [e for e in entries if e["type"] == "运维脚本"]
|
||||
assert len(script_entries) == 1
|
||||
assert script_entries[0]["file"] == "scripts/rebuild_db.py"
|
||||
|
||||
def test_init_py_excluded_from_scripts(self, tmp_path: Path) -> None:
|
||||
"""scripts/__init__.py 不应被识别为入口。"""
|
||||
scripts_dir = tmp_path / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
script_entries = [e for e in entries if e["type"] == "运维脚本"]
|
||||
assert all(e["file"] != "scripts/__init__.py" for e in script_entries)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify_task_type / classify_loader_type 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClassifyTypes:
|
||||
"""任务类型和加载器类型区分测试。"""
|
||||
|
||||
def test_ods_task(self) -> None:
|
||||
assert "ODS" in classify_task_type("tasks/ods_member.py")
|
||||
|
||||
def test_dwd_task(self) -> None:
|
||||
assert "DWD" in classify_task_type("tasks/dwd_load.py")
|
||||
|
||||
def test_dws_task(self) -> None:
|
||||
assert "DWS" in classify_task_type("tasks/dws/assistant_daily.py")
|
||||
|
||||
def test_verification_task(self) -> None:
|
||||
assert "校验" in classify_task_type("tasks/verification/balance_check.py")
|
||||
|
||||
def test_schema_init_task(self) -> None:
|
||||
assert "Schema" in classify_task_type("tasks/init_ods_schema.py")
|
||||
|
||||
def test_dimension_loader(self) -> None:
|
||||
result = classify_loader_type("loaders/dimensions/member.py")
|
||||
assert "维度" in result or "SCD2" in result
|
||||
|
||||
def test_fact_loader(self) -> None:
|
||||
assert "事实" in classify_loader_type("loaders/facts/order.py")
|
||||
|
||||
def test_ods_loader(self) -> None:
|
||||
assert "ODS" in classify_loader_type("loaders/ods/generic.py")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _path_to_module_name 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPathToModuleName:
|
||||
"""路径到模块名转换测试。"""
|
||||
|
||||
def test_simple_file(self) -> None:
|
||||
assert _path_to_module_name("cli/main.py") == "cli.main"
|
||||
|
||||
def test_init_file(self) -> None:
|
||||
assert _path_to_module_name("cli/__init__.py") == "cli"
|
||||
|
||||
def test_nested_path(self) -> None:
|
||||
assert _path_to_module_name("tasks/dws/assistant.py") == "tasks.dws.assistant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_bat_python_target 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseBatPythonTarget:
|
||||
"""批处理文件 Python 命令解析测试。"""
|
||||
|
||||
def test_module_invocation(self, tmp_path: Path) -> None:
|
||||
bat = tmp_path / "run.bat"
|
||||
bat.write_text("@echo off\npython -m cli.main %*\n", encoding="utf-8")
|
||||
assert _parse_bat_python_target(bat) == "cli.main"
|
||||
|
||||
def test_no_python_command(self, tmp_path: Path) -> None:
|
||||
bat = tmp_path / "run.bat"
|
||||
bat.write_text("@echo off\necho hello\n", encoding="utf-8")
|
||||
assert _parse_bat_python_target(bat) is None
|
||||
|
||||
def test_nonexistent_file(self, tmp_path: Path) -> None:
|
||||
assert _parse_bat_python_target(tmp_path / "missing.bat") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 9 & 10(hypothesis)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import string
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:项目包名列表(与 flow_analyzer 中 _PROJECT_PACKAGES 一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROJECT_PACKAGES_LIST = [
|
||||
"cli", "config", "api", "database", "tasks", "loaders",
|
||||
"scd", "orchestration", "quality", "models", "utils",
|
||||
"gui", "scripts",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 9: 流程树节点 source_file 有效性
|
||||
# Feature: repo-audit, Property 9: 流程树节点 source_file 有效性
|
||||
# Validates: Requirements 2.7
|
||||
#
|
||||
# 策略:在临时目录中随机生成 1~5 个项目内部模块文件,
|
||||
# 其中一个作为入口,其他文件通过 import 语句相互引用。
|
||||
# 构建流程树后,遍历所有节点验证 source_file 非空且文件存在。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_all_nodes(node: FlowNode) -> list[FlowNode]:
|
||||
"""递归收集流程树中所有节点。"""
|
||||
result = [node]
|
||||
for child in node.children:
|
||||
result.extend(_collect_all_nodes(child))
|
||||
return result
|
||||
|
||||
|
||||
# 生成合法的 Python 标识符作为模块文件名
|
||||
_module_name_st = st.from_regex(r"[a-z][a-z0-9_]{0,8}", fullmatch=True).filter(
|
||||
lambda s: s not in {"__init__", ""}
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def project_layout(draw):
|
||||
"""生成一个随机的项目布局:包名、模块文件名列表、以及模块间的 import 关系。
|
||||
|
||||
返回 (package, module_names, imports_map)
|
||||
- package: 项目包名(如 "cli")
|
||||
- module_names: 模块文件名列表(不含 .py 后缀),第一个为入口
|
||||
- imports_map: dict[str, list[str]],每个模块导入的其他模块列表
|
||||
"""
|
||||
package = draw(st.sampled_from(_PROJECT_PACKAGES_LIST))
|
||||
n_modules = draw(st.integers(min_value=1, max_value=5))
|
||||
module_names = draw(
|
||||
st.lists(
|
||||
_module_name_st,
|
||||
min_size=n_modules,
|
||||
max_size=n_modules,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
# 确保至少有一个模块
|
||||
assume(len(module_names) >= 1)
|
||||
|
||||
# 为每个模块随机选择要导入的其他模块(子集)
|
||||
imports_map: dict[str, list[str]] = {}
|
||||
for i, mod in enumerate(module_names):
|
||||
# 只能导入列表中的其他模块
|
||||
others = [m for m in module_names if m != mod]
|
||||
if others:
|
||||
imported = draw(
|
||||
st.lists(st.sampled_from(others), max_size=len(others), unique=True)
|
||||
)
|
||||
else:
|
||||
imported = []
|
||||
imports_map[mod] = imported
|
||||
|
||||
return package, module_names, imports_map
|
||||
|
||||
|
||||
@given(layout=project_layout())
|
||||
@settings(max_examples=100)
|
||||
def test_property9_flow_tree_source_file_validity(layout, tmp_path_factory):
|
||||
"""Property 9: 流程树中每个节点的 source_file 非空且对应文件在仓库中实际存在。
|
||||
|
||||
**Feature: repo-audit, Property 9: 流程树节点 source_file 有效性**
|
||||
**Validates: Requirements 2.7**
|
||||
"""
|
||||
package, module_names, imports_map = layout
|
||||
tmp_path = tmp_path_factory.mktemp("prop9")
|
||||
|
||||
# 创建包目录和 __init__.py
|
||||
pkg_dir = tmp_path / package
|
||||
pkg_dir.mkdir(parents=True, exist_ok=True)
|
||||
(pkg_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# 创建每个模块文件,写入 import 语句
|
||||
for mod in module_names:
|
||||
lines = []
|
||||
for imp in imports_map[mod]:
|
||||
lines.append(f"from {package}.{imp} import *")
|
||||
lines.append("") # 确保文件非空
|
||||
(pkg_dir / f"{mod}.py").write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
# 以第一个模块为入口构建流程树
|
||||
entry_rel = f"{package}/{module_names[0]}.py"
|
||||
tree = build_flow_tree(tmp_path, entry_rel)
|
||||
|
||||
# 遍历所有节点,验证 source_file 有效性
|
||||
all_nodes = _collect_all_nodes(tree)
|
||||
for node in all_nodes:
|
||||
# source_file 应为非空字符串
|
||||
assert isinstance(node.source_file, str), (
|
||||
f"source_file 应为字符串,实际为 {type(node.source_file)}"
|
||||
)
|
||||
assert node.source_file != "", "source_file 不应为空字符串"
|
||||
|
||||
# 对应文件应在仓库中实际存在
|
||||
full_path = tmp_path / node.source_file
|
||||
assert full_path.exists(), (
|
||||
f"source_file '{node.source_file}' 对应的文件不存在: {full_path}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 10: 孤立模块检测正确性
|
||||
# Feature: repo-audit, Property 10: 孤立模块检测正确性
|
||||
# Validates: Requirements 2.8
|
||||
#
|
||||
# 策略:生成随机的 FileEntry 列表(模拟项目中的 .py 文件),
|
||||
# 生成随机的 reachable 集合(是 FileEntry 路径的子集),
|
||||
# 调用 find_orphan_modules 验证:
|
||||
# 1. 返回的每个孤立模块都不在 reachable 集合中
|
||||
# 2. reachable 集合中的每个模块都不在孤立列表中
|
||||
#
|
||||
# 注意:find_orphan_modules 会排除 __init__.py、tests/、scripts/audit/ 下的文件,
|
||||
# 以及不属于 _PROJECT_PACKAGES 的子目录文件。生成器需要考虑这些排除规则。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 生成属于项目包的 .py 文件路径(排除被 find_orphan_modules 忽略的路径)
|
||||
_eligible_packages = [
|
||||
p for p in _PROJECT_PACKAGES_LIST
|
||||
if p not in ("scripts",) # scripts 下只有 scripts/audit/ 会被排除,但为简化直接排除
|
||||
]
|
||||
|
||||
|
||||
@st.composite
|
||||
def orphan_test_data(draw):
|
||||
"""生成 (file_entries, reachable_set) 用于测试 find_orphan_modules。
|
||||
|
||||
只生成"合格"的文件条目(属于项目包、非 __init__.py、非 tests/、非 scripts/audit/),
|
||||
这样可以精确验证 reachable 与 orphan 的互斥关系。
|
||||
"""
|
||||
# 生成 1~10 个合格的 .py 文件路径
|
||||
n_files = draw(st.integers(min_value=1, max_value=10))
|
||||
paths: list[str] = []
|
||||
for _ in range(n_files):
|
||||
pkg = draw(st.sampled_from(_eligible_packages))
|
||||
fname = draw(_module_name_st)
|
||||
path = f"{pkg}/{fname}.py"
|
||||
paths.append(path)
|
||||
|
||||
# 去重
|
||||
paths = list(dict.fromkeys(paths))
|
||||
assume(len(paths) >= 1)
|
||||
|
||||
# 构建 FileEntry 列表
|
||||
entries = [
|
||||
FileEntry(rel_path=p, is_dir=False, size_bytes=100, extension=".py", is_empty_dir=False)
|
||||
for p in paths
|
||||
]
|
||||
|
||||
# 随机选择一个子集作为 reachable
|
||||
reachable = set(draw(
|
||||
st.lists(st.sampled_from(paths), max_size=len(paths), unique=True)
|
||||
))
|
||||
|
||||
return entries, reachable
|
||||
|
||||
|
||||
@given(data=orphan_test_data())
|
||||
@settings(max_examples=100)
|
||||
def test_property10_orphan_module_detection(data, tmp_path_factory):
|
||||
"""Property 10: 孤立模块与可达模块互斥——孤立列表中的模块不在 reachable 中,
|
||||
reachable 中的模块不在孤立列表中。
|
||||
|
||||
**Feature: repo-audit, Property 10: 孤立模块检测正确性**
|
||||
**Validates: Requirements 2.8**
|
||||
"""
|
||||
entries, reachable = data
|
||||
tmp_path = tmp_path_factory.mktemp("prop10")
|
||||
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
|
||||
orphan_set = set(orphans)
|
||||
|
||||
# 验证 1: 孤立模块不应出现在 reachable 集合中
|
||||
overlap = orphan_set & reachable
|
||||
assert overlap == set(), (
|
||||
f"孤立模块与可达集合存在交集: {overlap}"
|
||||
)
|
||||
|
||||
# 验证 2: reachable 中的模块不应出现在孤立列表中
|
||||
for r in reachable:
|
||||
assert r not in orphan_set, (
|
||||
f"可达模块 '{r}' 不应出现在孤立列表中"
|
||||
)
|
||||
|
||||
# 验证 3: 孤立列表应已排序
|
||||
assert orphans == sorted(orphans), "孤立模块列表应按路径排序"
|
||||
309
apps/etl/connectors/feiqiu/tests/unit/test_audit_inventory.py
Normal file
309
apps/etl/connectors/feiqiu/tests/unit/test_audit_inventory.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
属性测试 — classify 完整性
|
||||
|
||||
Feature: repo-audit, Property 1: classify 完整性
|
||||
Validates: Requirements 1.2, 1.3
|
||||
|
||||
对于任意 FileEntry,classify 函数返回的 InventoryItem 的 category 字段
|
||||
应属于 Category 枚举,disposition 字段应属于 Disposition 枚举,
|
||||
且 description 字段为非空字符串。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
|
||||
from scripts.audit.inventory_analyzer import classify
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生成器策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 常见文件扩展名(含空扩展名表示无扩展名的情况)
|
||||
_EXTENSIONS = st.sampled_from([
|
||||
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
|
||||
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
|
||||
".toml", ".yaml", ".yml", ".html", ".css", ".js",
|
||||
])
|
||||
|
||||
# 路径片段:字母数字加常见特殊字符
|
||||
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
|
||||
|
||||
_path_segment = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
)
|
||||
|
||||
# 生成 1~4 层目录深度的相对路径
|
||||
_rel_path = st.lists(
|
||||
_path_segment,
|
||||
min_size=1,
|
||||
max_size=4,
|
||||
).map(lambda parts: "/".join(parts))
|
||||
|
||||
|
||||
def _file_entry_strategy() -> st.SearchStrategy[FileEntry]:
|
||||
"""生成随机 FileEntry 的 hypothesis 策略。
|
||||
|
||||
覆盖各种扩展名、目录层级、大小和布尔标志组合。
|
||||
"""
|
||||
return st.builds(
|
||||
FileEntry,
|
||||
rel_path=_rel_path,
|
||||
is_dir=st.booleans(),
|
||||
size_bytes=st.integers(min_value=0, max_value=10_000_000),
|
||||
extension=_EXTENSIONS,
|
||||
is_empty_dir=st.booleans(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 1: classify 完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(entry=_file_entry_strategy())
|
||||
@settings(max_examples=100)
|
||||
def test_classify_completeness(entry: FileEntry) -> None:
|
||||
"""Property 1: classify 完整性
|
||||
|
||||
Feature: repo-audit, Property 1: classify 完整性
|
||||
Validates: Requirements 1.2, 1.3
|
||||
|
||||
对于任意 FileEntry,classify 返回的 InventoryItem 应满足:
|
||||
- category 属于 Category 枚举
|
||||
- disposition 属于 Disposition 枚举
|
||||
- description 为非空字符串
|
||||
"""
|
||||
result = classify(entry)
|
||||
|
||||
# 返回类型正确
|
||||
assert isinstance(result, InventoryItem), (
|
||||
f"classify 应返回 InventoryItem,实际返回 {type(result)}"
|
||||
)
|
||||
|
||||
# category 属于 Category 枚举
|
||||
assert isinstance(result.category, Category), (
|
||||
f"category 应为 Category 枚举成员,实际为 {result.category!r}"
|
||||
)
|
||||
|
||||
# disposition 属于 Disposition 枚举
|
||||
assert isinstance(result.disposition, Disposition), (
|
||||
f"disposition 应为 Disposition 枚举成员,实际为 {result.disposition!r}"
|
||||
)
|
||||
|
||||
# description 为非空字符串
|
||||
assert isinstance(result.description, str) and len(result.description) > 0, (
|
||||
f"description 应为非空字符串,实际为 {result.description!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:高优先级目录前缀(用于在低优先级属性测试中排除)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HIGH_PRIORITY_PREFIXES = ("tmp/", "logs/", "export/")
|
||||
|
||||
# 安全的顶层目录名(不会触发高优先级规则)
|
||||
_SAFE_TOP_DIRS = st.sampled_from([
|
||||
"src", "lib", "data", "misc", "vendor", "tools", "archive",
|
||||
"assets", "resources", "contrib", "extras",
|
||||
])
|
||||
|
||||
# 非 .lnk/.rar 的扩展名
|
||||
_SAFE_EXTENSIONS = st.sampled_from([
|
||||
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
|
||||
".bat", ".sh", ".ps1", ".log", ".ini", ".cfg",
|
||||
".toml", ".yaml", ".yml", ".html", ".css", ".js",
|
||||
])
|
||||
|
||||
|
||||
def _safe_rel_path() -> st.SearchStrategy[str]:
|
||||
"""生成不以高优先级目录开头的相对路径。"""
|
||||
return st.builds(
|
||||
lambda top, rest: f"{top}/{rest}" if rest else top,
|
||||
top=_SAFE_TOP_DIRS,
|
||||
rest=st.lists(_path_segment, min_size=0, max_size=3).map(
|
||||
lambda parts: "/".join(parts) if parts else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 3: 空目录标记为候选删除
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_empty_dir_candidate_delete(data: st.DataObject) -> None:
|
||||
"""Property 3: 空目录标记为候选删除
|
||||
|
||||
Feature: repo-audit, Property 3: 空目录标记为候选删除
|
||||
Validates: Requirements 1.5
|
||||
|
||||
对于任意 is_empty_dir=True 的 FileEntry(排除 tmp/、logs/、reports/、
|
||||
export/ 开头和 .lnk/.rar 扩展名),classify 返回的 disposition
|
||||
应为 Disposition.CANDIDATE_DELETE。
|
||||
"""
|
||||
rel_path = data.draw(_safe_rel_path())
|
||||
ext = data.draw(_SAFE_EXTENSIONS)
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=True,
|
||||
size_bytes=0,
|
||||
extension=ext,
|
||||
is_empty_dir=True,
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
assert result.disposition == Disposition.CANDIDATE_DELETE, (
|
||||
f"空目录 '{entry.rel_path}' 应标记为候选删除,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4: .lnk/.rar 文件标记为候选删除
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_lnk_rar_candidate_delete(data: st.DataObject) -> None:
|
||||
"""Property 4: .lnk/.rar 文件标记为候选删除
|
||||
|
||||
Feature: repo-audit, Property 4: .lnk/.rar 文件标记为候选删除
|
||||
Validates: Requirements 1.6
|
||||
|
||||
对于任意扩展名为 .lnk 或 .rar 的 FileEntry(排除 tmp/、logs/、
|
||||
reports/、export/ 开头,且 is_empty_dir=False),classify 返回的
|
||||
disposition 应为 Disposition.CANDIDATE_DELETE。
|
||||
"""
|
||||
rel_path = data.draw(_safe_rel_path())
|
||||
ext = data.draw(st.sampled_from([".lnk", ".rar"]))
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=False,
|
||||
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
|
||||
extension=ext,
|
||||
is_empty_dir=False,
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
assert result.disposition == Disposition.CANDIDATE_DELETE, (
|
||||
f"文件 '{entry.rel_path}' (ext={ext}) 应标记为候选删除,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 5: tmp/ 下文件处置范围
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TMP_EXTENSIONS = st.sampled_from([
|
||||
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
|
||||
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
|
||||
".toml", ".yaml", ".yml", ".html", ".css", ".js", ".tmp", ".bak",
|
||||
])
|
||||
|
||||
|
||||
def _tmp_rel_path() -> st.SearchStrategy[str]:
|
||||
"""生成以 tmp/ 开头的相对路径。"""
|
||||
return st.builds(
|
||||
lambda rest: f"tmp/{rest}",
|
||||
rest=st.lists(_path_segment, min_size=1, max_size=3).map(
|
||||
lambda parts: "/".join(parts)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_tmp_disposition_range(data: st.DataObject) -> None:
|
||||
"""Property 5: tmp/ 下文件处置范围
|
||||
|
||||
Feature: repo-audit, Property 5: tmp/ 下文件处置范围
|
||||
Validates: Requirements 1.7
|
||||
|
||||
对于任意 rel_path 以 tmp/ 开头的 FileEntry,classify 返回的
|
||||
disposition 应为 CANDIDATE_DELETE 或 CANDIDATE_ARCHIVE 之一。
|
||||
"""
|
||||
rel_path = data.draw(_tmp_rel_path())
|
||||
ext = data.draw(_TMP_EXTENSIONS)
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=data.draw(st.booleans()),
|
||||
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
|
||||
extension=ext,
|
||||
is_empty_dir=data.draw(st.booleans()),
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
allowed = {Disposition.CANDIDATE_DELETE, Disposition.CANDIDATE_ARCHIVE}
|
||||
assert result.disposition in allowed, (
|
||||
f"tmp/ 下文件 '{entry.rel_path}' 的处置应为候选删除或候选归档,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6: 运行时产出目录标记为候选归档
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RUNTIME_DIRS = st.sampled_from(["logs", "export"])
|
||||
|
||||
# 排除 __init__.py 的文件名
|
||||
_NON_INIT_BASENAME = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
).filter(lambda s: s != "__init__.py")
|
||||
|
||||
|
||||
def _runtime_output_rel_path() -> st.SearchStrategy[str]:
|
||||
"""生成以 logs/、reports/ 或 export/ 开头的相对路径,basename 不是 __init__.py。"""
|
||||
return st.builds(
|
||||
lambda top, mid, name: (
|
||||
f"{top}/{'/'.join(mid)}/{name}" if mid else f"{top}/{name}"
|
||||
),
|
||||
top=_RUNTIME_DIRS,
|
||||
mid=st.lists(_path_segment, min_size=0, max_size=2),
|
||||
name=_NON_INIT_BASENAME,
|
||||
)
|
||||
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_runtime_output_candidate_archive(data: st.DataObject) -> None:
|
||||
"""Property 6: 运行时产出目录标记为候选归档
|
||||
|
||||
Feature: repo-audit, Property 6: 运行时产出目录标记为候选归档
|
||||
Validates: Requirements 1.8
|
||||
|
||||
对于任意 rel_path 以 logs/ 或 export/ 开头且非 __init__.py
|
||||
的 FileEntry,classify 返回的 disposition 应为 CANDIDATE_ARCHIVE。
|
||||
需求 1.8 仅覆盖 logs/ 和 export/ 目录(不含 reports/)。
|
||||
"""
|
||||
rel_path = data.draw(_runtime_output_rel_path())
|
||||
ext = data.draw(_EXTENSIONS)
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=data.draw(st.booleans()),
|
||||
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
|
||||
extension=ext,
|
||||
is_empty_dir=data.draw(st.booleans()),
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
assert result.disposition == Disposition.CANDIDATE_ARCHIVE, (
|
||||
f"运行时产出 '{entry.rel_path}' 应标记为候选归档,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
@@ -0,0 +1,165 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
属性测试 — 清单渲染完整性与分类分组
|
||||
|
||||
Feature: repo-audit
|
||||
- Property 2: 清单渲染完整性
|
||||
- Property 8: 清单按分类分组
|
||||
|
||||
Validates: Requirements 1.4, 1.10
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit import Category, Disposition, InventoryItem
|
||||
from scripts.audit.inventory_analyzer import render_inventory_report
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生成器策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
|
||||
|
||||
_path_segment = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=15,
|
||||
)
|
||||
|
||||
# 随机相对路径(1~3 层)
|
||||
_rel_path = st.lists(
|
||||
_path_segment,
|
||||
min_size=1,
|
||||
max_size=3,
|
||||
).map(lambda parts: "/".join(parts))
|
||||
|
||||
# 随机非空描述(不含管道符和换行符,避免破坏 Markdown 表格解析)
|
||||
_description = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S", "Z"),
|
||||
blacklist_characters="|\n\r",
|
||||
),
|
||||
min_size=1,
|
||||
max_size=40,
|
||||
)
|
||||
|
||||
|
||||
def _inventory_item_strategy() -> st.SearchStrategy[InventoryItem]:
|
||||
"""生成随机 InventoryItem 的 hypothesis 策略。"""
|
||||
return st.builds(
|
||||
InventoryItem,
|
||||
rel_path=_rel_path,
|
||||
category=st.sampled_from(list(Category)),
|
||||
disposition=st.sampled_from(list(Disposition)),
|
||||
description=_description,
|
||||
)
|
||||
|
||||
|
||||
# 生成 0~20 个 InventoryItem 的列表
|
||||
_inventory_list = st.lists(
|
||||
_inventory_item_strategy(),
|
||||
min_size=0,
|
||||
max_size=20,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 2: 清单渲染完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_render_inventory_completeness(items: list[InventoryItem]) -> None:
|
||||
"""Property 2: 清单渲染完整性
|
||||
|
||||
Feature: repo-audit, Property 2: 清单渲染完整性
|
||||
Validates: Requirements 1.4
|
||||
|
||||
对于任意 InventoryItem 列表,render_inventory_report 生成的 Markdown 中,
|
||||
每个条目的 rel_path、category.value、disposition.value 和 description
|
||||
四个字段都应出现在输出文本中。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/test-repo")
|
||||
|
||||
for item in items:
|
||||
# rel_path 出现在表格行中
|
||||
assert item.rel_path in report, (
|
||||
f"rel_path '{item.rel_path}' 未出现在报告中"
|
||||
)
|
||||
# category.value 出现在分组标题中
|
||||
assert item.category.value in report, (
|
||||
f"category '{item.category.value}' 未出现在报告中"
|
||||
)
|
||||
# disposition.value 出现在表格行中
|
||||
assert item.disposition.value in report, (
|
||||
f"disposition '{item.disposition.value}' 未出现在报告中"
|
||||
)
|
||||
# description 出现在表格行中
|
||||
assert item.description in report, (
|
||||
f"description '{item.description}' 未出现在报告中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 8: 清单按分类分组
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_render_inventory_grouped_by_category(items: list[InventoryItem]) -> None:
|
||||
"""Property 8: 清单按分类分组
|
||||
|
||||
Feature: repo-audit, Property 8: 清单按分类分组
|
||||
Validates: Requirements 1.10
|
||||
|
||||
对于任意 InventoryItem 列表,render_inventory_report 生成的 Markdown 中,
|
||||
同一 Category 的条目应连续出现(不应被其他 Category 的条目打断)。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/test-repo")
|
||||
|
||||
if not items:
|
||||
return # 空列表无需验证
|
||||
|
||||
# 从报告中按行提取条目对应的 category 顺序
|
||||
# 表格行格式: | `{rel_path}` | {disposition} | {description} |
|
||||
# 分组标题格式: ## {category.value}
|
||||
lines = report.split("\n")
|
||||
|
||||
# 收集每个分组标题下的条目,按出现顺序记录 category
|
||||
categories_in_order: list[Category] = []
|
||||
current_category: Category | None = None
|
||||
|
||||
# 建立 category.value -> Category 的映射
|
||||
value_to_cat = {c.value: c for c in Category}
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
# 检测分组标题 "## {category.value}"
|
||||
if stripped.startswith("## ") and stripped[3:] in value_to_cat:
|
||||
current_category = value_to_cat[stripped[3:]]
|
||||
continue
|
||||
# 检测表格数据行(跳过表头和分隔行)
|
||||
if (
|
||||
current_category is not None
|
||||
and stripped.startswith("| `")
|
||||
and not stripped.startswith("| 相对路径")
|
||||
and not stripped.startswith("|---")
|
||||
):
|
||||
categories_in_order.append(current_category)
|
||||
|
||||
# 验证同一 Category 的条目连续出现
|
||||
seen: set[Category] = set()
|
||||
prev: Category | None = None
|
||||
for cat in categories_in_order:
|
||||
if cat != prev:
|
||||
assert cat not in seen, (
|
||||
f"Category '{cat.value}' 的条目不连续——"
|
||||
f"在其他分类条目之后再次出现"
|
||||
)
|
||||
seen.add(cat)
|
||||
prev = cat
|
||||
@@ -0,0 +1,485 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
属性测试 — 报告输出属性
|
||||
|
||||
Feature: repo-audit
|
||||
- Property 13: 统计摘要一致性
|
||||
- Property 14: 报告头部元信息
|
||||
- Property 15: 写操作仅限 docs/audit/
|
||||
|
||||
Validates: Requirements 4.2, 4.5, 4.6, 4.7, 5.2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit import (
|
||||
AlignmentIssue,
|
||||
Category,
|
||||
Disposition,
|
||||
DocMapping,
|
||||
FlowNode,
|
||||
InventoryItem,
|
||||
)
|
||||
from scripts.audit.inventory_analyzer import render_inventory_report
|
||||
from scripts.audit.flow_analyzer import render_flow_report
|
||||
from scripts.audit.doc_alignment_analyzer import render_alignment_report
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 共享生成器策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
|
||||
|
||||
_path_segment = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=12,
|
||||
)
|
||||
|
||||
_rel_path = st.lists(
|
||||
_path_segment,
|
||||
min_size=1,
|
||||
max_size=3,
|
||||
).map(lambda parts: "/".join(parts))
|
||||
|
||||
_safe_text = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S", "Z"),
|
||||
blacklist_characters="|\n\r",
|
||||
),
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
_repo_root_str = st.text(
|
||||
alphabet=string.ascii_letters + string.digits + "/_-.",
|
||||
min_size=3,
|
||||
max_size=40,
|
||||
).map(lambda s: "/" + s.lstrip("/"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InventoryItem 生成器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _inventory_item_st() -> st.SearchStrategy[InventoryItem]:
|
||||
return st.builds(
|
||||
InventoryItem,
|
||||
rel_path=_rel_path,
|
||||
category=st.sampled_from(list(Category)),
|
||||
disposition=st.sampled_from(list(Disposition)),
|
||||
description=_safe_text,
|
||||
)
|
||||
|
||||
|
||||
_inventory_list = st.lists(_inventory_item_st(), min_size=0, max_size=20)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlowNode 生成器(限制深度和宽度)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _flow_node_st(max_depth: int = 2) -> st.SearchStrategy[FlowNode]:
|
||||
"""生成随机 FlowNode 树,限制深度避免爆炸。"""
|
||||
if max_depth <= 0:
|
||||
return st.builds(
|
||||
FlowNode,
|
||||
name=_path_segment,
|
||||
source_file=_rel_path,
|
||||
node_type=st.sampled_from(["entry", "module", "class", "function"]),
|
||||
children=st.just([]),
|
||||
)
|
||||
return st.builds(
|
||||
FlowNode,
|
||||
name=_path_segment,
|
||||
source_file=_rel_path,
|
||||
node_type=st.sampled_from(["entry", "module", "class", "function"]),
|
||||
children=st.lists(
|
||||
_flow_node_st(max_depth - 1),
|
||||
min_size=0,
|
||||
max_size=3,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_flow_tree_list = st.lists(_flow_node_st(), min_size=0, max_size=5)
|
||||
_orphan_list = st.lists(_rel_path, min_size=0, max_size=10)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocMapping / AlignmentIssue 生成器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_issue_type_st = st.sampled_from(["stale", "conflict", "missing"])
|
||||
|
||||
|
||||
def _alignment_issue_st() -> st.SearchStrategy[AlignmentIssue]:
|
||||
return st.builds(
|
||||
AlignmentIssue,
|
||||
doc_path=_rel_path,
|
||||
issue_type=_issue_type_st,
|
||||
description=_safe_text,
|
||||
related_code=_rel_path,
|
||||
)
|
||||
|
||||
|
||||
def _doc_mapping_st() -> st.SearchStrategy[DocMapping]:
|
||||
return st.builds(
|
||||
DocMapping,
|
||||
doc_path=_rel_path,
|
||||
doc_topic=_safe_text,
|
||||
related_code=st.lists(_rel_path, min_size=0, max_size=5),
|
||||
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
|
||||
)
|
||||
|
||||
|
||||
_mapping_list = st.lists(_doc_mapping_st(), min_size=0, max_size=15)
|
||||
_issue_list = st.lists(_alignment_issue_st(), min_size=0, max_size=15)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 13: 统计摘要一致性
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty13SummaryConsistency:
|
||||
"""Property 13: 统计摘要一致性
|
||||
|
||||
Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.5, 4.6, 4.7
|
||||
|
||||
对于任意报告的统计摘要,各分类/标签的计数之和应等于对应条目列表的总长度。
|
||||
"""
|
||||
|
||||
# --- 13a: render_inventory_report 的分类计数之和 = 列表长度 ---
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_inventory_category_counts_sum(
|
||||
self, items: list[InventoryItem]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.5
|
||||
|
||||
render_inventory_report 统计摘要中各用途分类的计数之和应等于条目总数。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/repo")
|
||||
|
||||
# 定位"按用途分类"表格,提取各行数字并求和
|
||||
cat_sum = _extract_summary_total(report, "按用途分类")
|
||||
assert cat_sum == len(items), (
|
||||
f"分类计数之和 {cat_sum} != 条目总数 {len(items)}"
|
||||
)
|
||||
|
||||
# --- 13b: render_inventory_report 的处置标签计数之和 = 列表长度 ---
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_inventory_disposition_counts_sum(
|
||||
self, items: list[InventoryItem]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.5
|
||||
|
||||
render_inventory_report 统计摘要中各处置标签的计数之和应等于条目总数。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/repo")
|
||||
|
||||
disp_sum = _extract_summary_total(report, "按处置标签")
|
||||
assert disp_sum == len(items), (
|
||||
f"处置标签计数之和 {disp_sum} != 条目总数 {len(items)}"
|
||||
)
|
||||
|
||||
# --- 13c: render_flow_report 的孤立模块数量 = orphans 列表长度 ---
|
||||
|
||||
@given(trees=_flow_tree_list, orphans=_orphan_list)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_orphan_count_matches(
|
||||
self, trees: list[FlowNode], orphans: list[str]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.6
|
||||
|
||||
render_flow_report 统计摘要中的孤立模块数量应等于 orphans 列表长度。
|
||||
"""
|
||||
report = render_flow_report(trees, orphans, "/tmp/repo")
|
||||
|
||||
# 从统计摘要表格中提取"孤立模块"行的数字
|
||||
orphan_count = _extract_flow_stat(report, "孤立模块")
|
||||
assert orphan_count == len(orphans), (
|
||||
f"报告中孤立模块数 {orphan_count} != orphans 列表长度 {len(orphans)}"
|
||||
)
|
||||
|
||||
# --- 13d: render_alignment_report 的 issue 类型计数一致 ---
|
||||
|
||||
@given(mappings=_mapping_list, issues=_issue_list)
|
||||
@settings(max_examples=100)
|
||||
def test_alignment_issue_counts_match(
|
||||
self, mappings: list[DocMapping], issues: list[AlignmentIssue]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.7
|
||||
|
||||
render_alignment_report 统计摘要中过期/冲突/缺失点计数应与
|
||||
issues 列表中对应类型的实际数量一致。
|
||||
"""
|
||||
report = render_alignment_report(mappings, issues, "/tmp/repo")
|
||||
|
||||
expected_stale = sum(1 for i in issues if i.issue_type == "stale")
|
||||
expected_conflict = sum(1 for i in issues if i.issue_type == "conflict")
|
||||
expected_missing = sum(1 for i in issues if i.issue_type == "missing")
|
||||
|
||||
actual_stale = _extract_alignment_stat(report, "过期点数量")
|
||||
actual_conflict = _extract_alignment_stat(report, "冲突点数量")
|
||||
actual_missing = _extract_alignment_stat(report, "缺失点数量")
|
||||
|
||||
assert actual_stale == expected_stale, (
|
||||
f"过期点: 报告 {actual_stale} != 实际 {expected_stale}"
|
||||
)
|
||||
assert actual_conflict == expected_conflict, (
|
||||
f"冲突点: 报告 {actual_conflict} != 实际 {expected_conflict}"
|
||||
)
|
||||
assert actual_missing == expected_missing, (
|
||||
f"缺失点: 报告 {actual_missing} != 实际 {expected_missing}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 14: 报告头部元信息
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty14ReportHeader:
|
||||
"""Property 14: 报告头部元信息
|
||||
|
||||
Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
对于任意报告输出,头部应包含一个符合 ISO 格式的时间戳字符串和仓库根目录路径字符串。
|
||||
"""
|
||||
|
||||
_ISO_TS_RE = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
|
||||
|
||||
@given(items=_inventory_list, repo_root=_repo_root_str)
|
||||
@settings(max_examples=100)
|
||||
def test_inventory_report_header(
|
||||
self, items: list[InventoryItem], repo_root: str
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
render_inventory_report 头部应包含 ISO 时间戳和仓库路径。
|
||||
"""
|
||||
report = render_inventory_report(items, repo_root)
|
||||
header = report[:500]
|
||||
|
||||
assert self._ISO_TS_RE.search(header), (
|
||||
"inventory 报告头部缺少 ISO 格式时间戳"
|
||||
)
|
||||
assert repo_root in header, (
|
||||
f"inventory 报告头部缺少仓库路径 '{repo_root}'"
|
||||
)
|
||||
|
||||
@given(trees=_flow_tree_list, orphans=_orphan_list, repo_root=_repo_root_str)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_report_header(
|
||||
self, trees: list[FlowNode], orphans: list[str], repo_root: str
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
render_flow_report 头部应包含 ISO 时间戳和仓库路径。
|
||||
"""
|
||||
report = render_flow_report(trees, orphans, repo_root)
|
||||
header = report[:500]
|
||||
|
||||
assert self._ISO_TS_RE.search(header), (
|
||||
"flow 报告头部缺少 ISO 格式时间戳"
|
||||
)
|
||||
assert repo_root in header, (
|
||||
f"flow 报告头部缺少仓库路径 '{repo_root}'"
|
||||
)
|
||||
|
||||
@given(mappings=_mapping_list, issues=_issue_list, repo_root=_repo_root_str)
|
||||
@settings(max_examples=100)
|
||||
def test_alignment_report_header(
|
||||
self, mappings: list[DocMapping], issues: list[AlignmentIssue], repo_root: str
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
render_alignment_report 头部应包含 ISO 时间戳和仓库路径。
|
||||
"""
|
||||
report = render_alignment_report(mappings, issues, repo_root)
|
||||
header = report[:500]
|
||||
|
||||
assert self._ISO_TS_RE.search(header), (
|
||||
"alignment 报告头部缺少 ISO 格式时间戳"
|
||||
)
|
||||
assert repo_root in header, (
|
||||
f"alignment 报告头部缺少仓库路径 '{repo_root}'"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 15: 写操作仅限 docs/audit/
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty15WritesOnlyDocsAudit:
|
||||
"""Property 15: 写操作仅限 docs/audit/
|
||||
|
||||
Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
|
||||
Validates: Requirements 5.2
|
||||
|
||||
对于任意审计执行过程,所有文件写操作的目标路径应以 docs/audit/ 为前缀。
|
||||
由于需要实际文件系统,使用较少迭代。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_minimal_repo(base: Path, variant: int) -> Path:
|
||||
"""构造最小仓库结构,variant 控制变体以增加多样性。"""
|
||||
repo = base / f"repo_{variant}"
|
||||
repo.mkdir()
|
||||
|
||||
# 必需的 cli 入口
|
||||
cli_dir = repo / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text(
|
||||
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# config 目录
|
||||
config_dir = repo / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# docs 目录
|
||||
docs_dir = repo / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
# 根据 variant 添加不同的额外文件
|
||||
if variant % 3 == 0:
|
||||
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
|
||||
if variant % 3 == 1:
|
||||
scripts_dir = repo / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
if variant % 3 == 2:
|
||||
(docs_dir / "notes.md").write_text("# 笔记\n", encoding="utf-8")
|
||||
|
||||
return repo
|
||||
|
||||
@staticmethod
|
||||
def _snapshot_files(repo: Path) -> dict[str, float]:
|
||||
"""记录仓库中所有文件的 mtime 快照(排除 docs/audit/)。"""
|
||||
snap: dict[str, float] = {}
|
||||
for p in repo.rglob("*"):
|
||||
if p.is_file():
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if not rel.startswith("docs/audit"):
|
||||
snap[rel] = p.stat().st_mtime
|
||||
return snap
|
||||
|
||||
@given(variant=st.integers(min_value=0, max_value=9))
|
||||
@settings(max_examples=10)
|
||||
def test_writes_only_under_docs_audit(self, variant: int) -> None:
|
||||
"""Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
|
||||
Validates: Requirements 5.2
|
||||
|
||||
运行 run_audit 后,docs/audit/ 外不应有新文件被创建。
|
||||
docs/audit/ 下应有报告文件。
|
||||
"""
|
||||
import tempfile
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
repo = self._make_minimal_repo(tmp_path, variant)
|
||||
before_snap = self._snapshot_files(repo)
|
||||
|
||||
run_audit(repo)
|
||||
|
||||
# 验证 docs/audit/ 下有新文件
|
||||
audit_dir = repo / "docs" / "audit"
|
||||
assert audit_dir.is_dir(), "docs/audit/ 目录未创建"
|
||||
audit_files = list(audit_dir.iterdir())
|
||||
assert len(audit_files) > 0, "docs/audit/ 下无报告文件"
|
||||
|
||||
# 验证 docs/audit/ 外无新文件
|
||||
for p in repo.rglob("*"):
|
||||
if p.is_file():
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if rel.startswith("docs/audit"):
|
||||
continue
|
||||
assert rel in before_snap, (
|
||||
f"docs/audit/ 外出现了新文件: {rel}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 辅助函数 — 从报告文本中提取统计数字
|
||||
# ===========================================================================
|
||||
|
||||
def _extract_summary_total(report: str, section_name: str) -> int:
|
||||
"""从 inventory 报告的统计摘要中提取指定分区的数字之和。
|
||||
|
||||
查找 "### {section_name}" 下的 Markdown 表格,
|
||||
累加每行最后一列的数字(排除合计行)。
|
||||
"""
|
||||
lines = report.split("\n")
|
||||
in_section = False
|
||||
total = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped == f"### {section_name}":
|
||||
in_section = True
|
||||
continue
|
||||
if in_section and stripped.startswith("###"):
|
||||
# 进入下一个子节
|
||||
break
|
||||
if in_section and stripped.startswith("|") and "**合计**" not in stripped:
|
||||
# 跳过表头和分隔行
|
||||
if stripped.startswith("| 用途分类") or stripped.startswith("| 处置标签"):
|
||||
continue
|
||||
if stripped.startswith("|---"):
|
||||
continue
|
||||
# 提取最后一列的数字
|
||||
cells = [c.strip() for c in stripped.split("|") if c.strip()]
|
||||
if cells:
|
||||
try:
|
||||
total += int(cells[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def _extract_flow_stat(report: str, label: str) -> int:
|
||||
"""从 flow 报告统计摘要表格中提取指定指标的数字。"""
|
||||
# 匹配 "| 孤立模块 | 5 |" 格式
|
||||
pattern = re.compile(rf"\|\s*{re.escape(label)}\s*\|\s*(\d+)\s*\|")
|
||||
m = pattern.search(report)
|
||||
return int(m.group(1)) if m else -1
|
||||
|
||||
|
||||
def _extract_alignment_stat(report: str, label: str) -> int:
|
||||
"""从 alignment 报告统计摘要中提取指定指标的数字。
|
||||
|
||||
匹配 "- 过期点数量:3" 格式。
|
||||
"""
|
||||
# 兼容全角/半角冒号
|
||||
pattern = re.compile(rf"{re.escape(label)}[::]\s*(\d+)")
|
||||
m = pattern.search(report)
|
||||
return int(m.group(1)) if m else -1
|
||||
177
apps/etl/connectors/feiqiu/tests/unit/test_audit_run.py
Normal file
177
apps/etl/connectors/feiqiu/tests/unit/test_audit_run.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
run_audit 主入口的单元测试。
|
||||
|
||||
验证:
|
||||
- docs/audit/ 目录自动创建
|
||||
- 三份报告文件正确生成
|
||||
- 报告头部包含时间戳和仓库路径
|
||||
- 目录创建失败时抛出 RuntimeError
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEnsureReportDir:
|
||||
"""测试 _ensure_report_dir 目录创建逻辑。"""
|
||||
|
||||
def test_creates_dir_when_missing(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import _ensure_report_dir
|
||||
|
||||
result = _ensure_report_dir(tmp_path)
|
||||
expected = tmp_path / "docs" / "audit" / "repo"
|
||||
assert result == expected
|
||||
assert expected.is_dir()
|
||||
|
||||
def test_returns_existing_dir(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import _ensure_report_dir
|
||||
|
||||
audit_dir = tmp_path / "docs" / "audit" / "repo"
|
||||
audit_dir.mkdir(parents=True)
|
||||
result = _ensure_report_dir(tmp_path)
|
||||
assert result == audit_dir
|
||||
|
||||
def test_raises_on_creation_failure(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import _ensure_report_dir
|
||||
|
||||
# 在 docs/audit 位置放一个文件,使 mkdir 失败
|
||||
docs = tmp_path / "docs"
|
||||
docs.mkdir()
|
||||
(docs / "audit").write_text("block", encoding="utf-8")
|
||||
|
||||
with pytest.raises(RuntimeError, match="无法创建报告输出目录"):
|
||||
_ensure_report_dir(tmp_path)
|
||||
|
||||
|
||||
class TestInjectHeader:
|
||||
"""测试 _inject_header 兜底注入逻辑。"""
|
||||
|
||||
def test_skips_when_header_present(self):
|
||||
from scripts.audit.run_audit import _inject_header
|
||||
|
||||
report = "# 标题\n\n- 生成时间: 2025-01-01T00:00:00Z\n- 仓库路径: `/repo`\n"
|
||||
result = _inject_header(report, "2025-06-01T00:00:00Z", "/other")
|
||||
# 不应修改已有头部
|
||||
assert result == report
|
||||
|
||||
def test_injects_when_header_missing(self):
|
||||
from scripts.audit.run_audit import _inject_header
|
||||
|
||||
report = "# 无头部报告\n\n内容..."
|
||||
result = _inject_header(report, "2025-06-01T00:00:00Z", "/repo")
|
||||
assert "生成时间: 2025-06-01T00:00:00Z" in result
|
||||
assert "仓库路径: `/repo`" in result
|
||||
|
||||
|
||||
class TestRunAudit:
|
||||
"""测试 run_audit 完整流程(使用最小仓库结构)。"""
|
||||
|
||||
def _make_minimal_repo(self, tmp_path: Path) -> Path:
|
||||
"""构造一个最小仓库结构,足以让 run_audit 跑通。"""
|
||||
repo = tmp_path / "repo"
|
||||
repo.mkdir()
|
||||
|
||||
# 核心代码目录
|
||||
cli_dir = repo / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text(
|
||||
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# config 目录
|
||||
config_dir = repo / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(config_dir / "defaults.py").write_text("DEFAULTS = {}\n", encoding="utf-8")
|
||||
|
||||
# docs 目录
|
||||
docs_dir = repo / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "README.md").write_text("# 文档\n", encoding="utf-8")
|
||||
|
||||
# 根目录文件
|
||||
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
|
||||
|
||||
return repo
|
||||
|
||||
def test_creates_three_reports(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
run_audit(repo)
|
||||
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
assert (audit_dir / "file_inventory.md").is_file()
|
||||
assert (audit_dir / "flow_tree.md").is_file()
|
||||
assert (audit_dir / "doc_alignment.md").is_file()
|
||||
|
||||
def test_reports_contain_timestamp(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
run_audit(repo)
|
||||
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
# ISO 时间戳格式
|
||||
ts_pattern = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
|
||||
|
||||
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
|
||||
content = (audit_dir / name).read_text(encoding="utf-8")
|
||||
assert ts_pattern.search(content), f"{name} 缺少时间戳"
|
||||
|
||||
def test_reports_contain_repo_path(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
run_audit(repo)
|
||||
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
repo_str = str(repo.resolve())
|
||||
|
||||
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
|
||||
content = (audit_dir / name).read_text(encoding="utf-8")
|
||||
assert repo_str in content, f"{name} 缺少仓库路径"
|
||||
|
||||
def test_writes_only_to_docs_audit(self, tmp_path: Path):
|
||||
"""验证所有写操作仅限 docs/audit/ 目录(Property 15)。"""
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
|
||||
# 记录运行前的文件快照(排除 docs/audit/)
|
||||
before = set()
|
||||
for p in repo.rglob("*"):
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if not rel.startswith("docs/audit"):
|
||||
before.add((rel, p.stat().st_mtime if p.is_file() else None))
|
||||
|
||||
run_audit(repo)
|
||||
|
||||
# 运行后检查:docs/audit/ 外的文件不应被修改
|
||||
for p in repo.rglob("*"):
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if rel.startswith("docs/audit"):
|
||||
continue
|
||||
if p.is_file():
|
||||
# 文件应在之前的快照中
|
||||
found = any(r == rel for r, _ in before)
|
||||
assert found, f"意外创建了 docs/audit/ 外的文件: {rel}"
|
||||
|
||||
def test_auto_creates_docs_audit_dir(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
# 确保 docs/audit/ 不存在
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
assert not audit_dir.exists()
|
||||
|
||||
run_audit(repo)
|
||||
assert audit_dir.is_dir()
|
||||
428
apps/etl/connectors/feiqiu/tests/unit/test_audit_scanner.py
Normal file
428
apps/etl/connectors/feiqiu/tests/unit/test_audit_scanner.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试 — 仓库扫描器 (scanner.py)
|
||||
|
||||
覆盖:
|
||||
- 排除模式匹配逻辑
|
||||
- 递归遍历与 FileEntry 构建
|
||||
- 空目录检测
|
||||
- 权限错误容错
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.audit import FileEntry
|
||||
from scripts.audit.scanner import EXCLUDED_PATTERNS, _is_excluded, scan_repo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_excluded 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsExcluded:
|
||||
"""排除模式匹配逻辑测试。"""
|
||||
|
||||
def test_exact_match_git(self) -> None:
|
||||
assert _is_excluded(".git", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_exact_match_pycache(self) -> None:
|
||||
assert _is_excluded("__pycache__", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_exact_match_pytest_cache(self) -> None:
|
||||
assert _is_excluded(".pytest_cache", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_exact_match_kiro(self) -> None:
|
||||
assert _is_excluded(".kiro", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_wildcard_pyc(self) -> None:
|
||||
assert _is_excluded("module.pyc", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_normal_py_not_excluded(self) -> None:
|
||||
assert _is_excluded("main.py", EXCLUDED_PATTERNS) is False
|
||||
|
||||
def test_normal_dir_not_excluded(self) -> None:
|
||||
assert _is_excluded("src", EXCLUDED_PATTERNS) is False
|
||||
|
||||
def test_empty_patterns(self) -> None:
|
||||
assert _is_excluded(".git", []) is False
|
||||
|
||||
def test_custom_pattern(self) -> None:
|
||||
assert _is_excluded("data.csv", ["*.csv"]) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_repo 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScanRepo:
|
||||
"""scan_repo 递归遍历测试。"""
|
||||
|
||||
def test_basic_structure(self, tmp_path: Path) -> None:
|
||||
"""基本文件和目录应被正确扫描。"""
|
||||
(tmp_path / "a.py").write_text("# code", encoding="utf-8")
|
||||
sub = tmp_path / "sub"
|
||||
sub.mkdir()
|
||||
(sub / "b.txt").write_text("hello", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "a.py" in paths
|
||||
assert "sub" in paths
|
||||
assert "sub/b.txt" in paths
|
||||
|
||||
def test_file_entry_fields(self, tmp_path: Path) -> None:
|
||||
"""FileEntry 各字段应正确填充。"""
|
||||
(tmp_path / "hello.md").write_text("# hi", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
md = next(e for e in entries if e.rel_path == "hello.md")
|
||||
|
||||
assert md.is_dir is False
|
||||
assert md.size_bytes > 0
|
||||
assert md.extension == ".md"
|
||||
assert md.is_empty_dir is False
|
||||
|
||||
def test_directory_entry_fields(self, tmp_path: Path) -> None:
|
||||
"""目录条目的字段应正确设置。"""
|
||||
sub = tmp_path / "mydir"
|
||||
sub.mkdir()
|
||||
(sub / "file.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
d = next(e for e in entries if e.rel_path == "mydir")
|
||||
|
||||
assert d.is_dir is True
|
||||
assert d.size_bytes == 0
|
||||
assert d.extension == ""
|
||||
assert d.is_empty_dir is False
|
||||
|
||||
def test_excluded_git_dir(self, tmp_path: Path) -> None:
|
||||
""".git 目录及其内容应被排除。"""
|
||||
git_dir = tmp_path / ".git"
|
||||
git_dir.mkdir()
|
||||
(git_dir / "config").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert ".git" not in paths
|
||||
assert ".git/config" not in paths
|
||||
|
||||
def test_excluded_pycache(self, tmp_path: Path) -> None:
|
||||
"""__pycache__ 目录应被排除。"""
|
||||
cache = tmp_path / "pkg" / "__pycache__"
|
||||
cache.mkdir(parents=True)
|
||||
(cache / "mod.cpython-310.pyc").write_bytes(b"\x00")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert not any("__pycache__" in p for p in paths)
|
||||
|
||||
def test_excluded_pyc_files(self, tmp_path: Path) -> None:
|
||||
"""*.pyc 文件应被排除。"""
|
||||
(tmp_path / "mod.pyc").write_bytes(b"\x00")
|
||||
(tmp_path / "mod.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "mod.pyc" not in paths
|
||||
assert "mod.py" in paths
|
||||
|
||||
def test_empty_directory_detection(self, tmp_path: Path) -> None:
|
||||
"""空目录应被标记为 is_empty_dir=True。"""
|
||||
(tmp_path / "empty").mkdir()
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
d = next(e for e in entries if e.rel_path == "empty")
|
||||
|
||||
assert d.is_dir is True
|
||||
assert d.is_empty_dir is True
|
||||
|
||||
def test_dir_with_only_excluded_children(self, tmp_path: Path) -> None:
|
||||
"""仅含被排除子项的目录应视为空目录。"""
|
||||
sub = tmp_path / "pkg"
|
||||
sub.mkdir()
|
||||
cache = sub / "__pycache__"
|
||||
cache.mkdir()
|
||||
(cache / "x.pyc").write_bytes(b"\x00")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
d = next(e for e in entries if e.rel_path == "pkg")
|
||||
|
||||
assert d.is_empty_dir is True
|
||||
|
||||
def test_custom_exclude_patterns(self, tmp_path: Path) -> None:
|
||||
"""自定义排除模式应生效。"""
|
||||
(tmp_path / "keep.py").write_text("pass", encoding="utf-8")
|
||||
(tmp_path / "skip.log").write_text("log", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path, exclude=["*.log"])
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "keep.py" in paths
|
||||
assert "skip.log" not in paths
|
||||
|
||||
def test_empty_repo(self, tmp_path: Path) -> None:
|
||||
"""空仓库应返回空列表。"""
|
||||
entries = scan_repo(tmp_path)
|
||||
assert entries == []
|
||||
|
||||
def test_results_sorted(self, tmp_path: Path) -> None:
|
||||
"""返回结果应按 rel_path 排序。"""
|
||||
(tmp_path / "z.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "a.py").write_text("", encoding="utf-8")
|
||||
sub = tmp_path / "m"
|
||||
sub.mkdir()
|
||||
(sub / "b.py").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = [e.rel_path for e in entries]
|
||||
|
||||
assert paths == sorted(paths)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.name == "nt",
|
||||
reason="Windows 上 chmod 行为不同,跳过权限测试",
|
||||
)
|
||||
def test_permission_error_skipped(self, tmp_path: Path) -> None:
|
||||
"""权限不足的目录应被跳过,不中断扫描。"""
|
||||
ok_file = tmp_path / "ok.py"
|
||||
ok_file.write_text("pass", encoding="utf-8")
|
||||
|
||||
no_access = tmp_path / "secret"
|
||||
no_access.mkdir()
|
||||
(no_access / "data.txt").write_text("x", encoding="utf-8")
|
||||
no_access.chmod(0o000)
|
||||
|
||||
try:
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
# ok.py 应正常扫描到
|
||||
assert "ok.py" in paths
|
||||
# secret 目录本身会被记录(在 _walk 中先记录目录再尝试 iterdir)
|
||||
# 但其子文件不应出现
|
||||
assert "secret/data.txt" not in paths
|
||||
finally:
|
||||
no_access.chmod(0o755)
|
||||
|
||||
def test_nested_directories(self, tmp_path: Path) -> None:
|
||||
"""多层嵌套目录应被正确遍历。"""
|
||||
deep = tmp_path / "a" / "b" / "c"
|
||||
deep.mkdir(parents=True)
|
||||
(deep / "leaf.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "a" in paths
|
||||
assert "a/b" in paths
|
||||
assert "a/b/c" in paths
|
||||
assert "a/b/c/leaf.py" in paths
|
||||
|
||||
def test_extension_lowercase(self, tmp_path: Path) -> None:
|
||||
"""扩展名应统一为小写。"""
|
||||
(tmp_path / "README.MD").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
md = next(e for e in entries if "README" in e.rel_path)
|
||||
|
||||
assert md.extension == ".md"
|
||||
|
||||
def test_no_extension(self, tmp_path: Path) -> None:
|
||||
"""无扩展名的文件 extension 应为空字符串。"""
|
||||
(tmp_path / "Makefile").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
f = next(e for e in entries if e.rel_path == "Makefile")
|
||||
|
||||
assert f.extension == ""
|
||||
|
||||
def test_root_not_in_entries(self, tmp_path: Path) -> None:
|
||||
"""根目录自身不应出现在结果中。"""
|
||||
(tmp_path / "a.py").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "." not in paths
|
||||
assert "" not in paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 7: 扫描器排除规则
|
||||
# Feature: repo-audit, Property 7: 扫描器排除规则
|
||||
# Validates: Requirements 1.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import fnmatch
|
||||
import string
|
||||
import tempfile
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# --- 生成器策略 ---
|
||||
|
||||
# 合法的文件/目录名字符(排除路径分隔符和特殊字符)
|
||||
_SAFE_CHARS = string.ascii_lowercase + string.digits + "_-"
|
||||
|
||||
# 安全的文件名策略(不与排除模式冲突的普通名称)
|
||||
_safe_name = st.text(_SAFE_CHARS, min_size=1, max_size=8)
|
||||
|
||||
# 排除模式中的目录名
|
||||
_EXCLUDED_DIR_NAMES = [".git", "__pycache__", ".pytest_cache", ".kiro"]
|
||||
|
||||
# 排除模式中的文件扩展名
|
||||
_EXCLUDED_FILE_EXT = ".pyc"
|
||||
|
||||
# 随机选择一个被排除的目录名
|
||||
_excluded_dir_name = st.sampled_from(_EXCLUDED_DIR_NAMES)
|
||||
|
||||
|
||||
def _build_tree(tmp: Path, normal_names: list[str], excluded_dirs: list[str],
|
||||
include_pyc: bool) -> None:
|
||||
"""在临时目录中构建包含正常文件和被排除条目的文件树。"""
|
||||
# 创建正常文件
|
||||
for name in normal_names:
|
||||
safe = name or "f"
|
||||
filepath = tmp / f"{safe}.txt"
|
||||
if not filepath.exists():
|
||||
filepath.write_text("ok", encoding="utf-8")
|
||||
|
||||
# 创建被排除的目录(含子文件)
|
||||
for dirname in excluded_dirs:
|
||||
d = tmp / dirname
|
||||
d.mkdir(exist_ok=True)
|
||||
(d / "inner.txt").write_text("hidden", encoding="utf-8")
|
||||
|
||||
# 可选:创建 .pyc 文件
|
||||
if include_pyc:
|
||||
(tmp / "module.pyc").write_bytes(b"\x00")
|
||||
|
||||
|
||||
class TestProperty7ScannerExclusionRules:
|
||||
"""
|
||||
Property 7: 扫描器排除规则
|
||||
|
||||
对于任意文件树,scan_repo 返回的 FileEntry 列表中不应包含
|
||||
rel_path 匹配排除模式(.git、__pycache__、.pytest_cache 等)的条目。
|
||||
|
||||
Feature: repo-audit, Property 7: 扫描器排除规则
|
||||
Validates: Requirements 1.1
|
||||
"""
|
||||
|
||||
@given(
|
||||
normal_names=st.lists(_safe_name, min_size=0, max_size=5),
|
||||
excluded_dirs=st.lists(_excluded_dir_name, min_size=1, max_size=3),
|
||||
include_pyc=st.booleans(),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_excluded_entries_never_in_results(
|
||||
self,
|
||||
normal_names: list[str],
|
||||
excluded_dirs: list[str],
|
||||
include_pyc: bool,
|
||||
) -> None:
|
||||
"""扫描结果中不应包含任何匹配排除模式的条目。"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
_build_tree(tmp, normal_names, excluded_dirs, include_pyc)
|
||||
|
||||
entries = scan_repo(tmp)
|
||||
|
||||
for entry in entries:
|
||||
# 检查 rel_path 的每一段是否匹配排除模式
|
||||
parts = entry.rel_path.split("/")
|
||||
for part in parts:
|
||||
for pat in EXCLUDED_PATTERNS:
|
||||
assert not fnmatch.fnmatch(part, pat), (
|
||||
f"排除模式 '{pat}' 不应出现在结果中,"
|
||||
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
|
||||
)
|
||||
|
||||
@given(
|
||||
excluded_dir=_excluded_dir_name,
|
||||
depth=st.integers(min_value=1, max_value=3),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_excluded_dirs_at_any_depth(
|
||||
self,
|
||||
excluded_dir: str,
|
||||
depth: int,
|
||||
) -> None:
|
||||
"""被排除目录无论在哪一层嵌套深度,都不应出现在结果中。"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
|
||||
# 构建嵌套路径:normal/normal/.../excluded_dir/file.txt
|
||||
current = tmp
|
||||
for i in range(depth):
|
||||
current = current / f"level{i}"
|
||||
current.mkdir(exist_ok=True)
|
||||
# 放一个正常文件保证父目录非空
|
||||
(current / "keep.txt").write_text("ok", encoding="utf-8")
|
||||
|
||||
# 在最深层放置被排除目录
|
||||
excluded = current / excluded_dir
|
||||
excluded.mkdir(exist_ok=True)
|
||||
(excluded / "secret.txt").write_text("hidden", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp)
|
||||
|
||||
for entry in entries:
|
||||
parts = entry.rel_path.split("/")
|
||||
assert excluded_dir not in parts, (
|
||||
f"被排除目录 '{excluded_dir}' 不应出现在结果中,"
|
||||
f"但发现 rel_path='{entry.rel_path}'"
|
||||
)
|
||||
|
||||
@given(
|
||||
custom_patterns=st.lists(
|
||||
st.sampled_from(["*.log", "*.tmp", "*.bak", "node_modules", ".venv"]),
|
||||
min_size=1,
|
||||
max_size=3,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_custom_exclude_patterns_respected(
|
||||
self,
|
||||
custom_patterns: list[str],
|
||||
) -> None:
|
||||
"""自定义排除模式同样应被 scan_repo 正确排除。"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
|
||||
# 创建一个正常文件
|
||||
(tmp / "main.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
# 为每个自定义模式创建匹配的文件或目录
|
||||
for pat in custom_patterns:
|
||||
if pat.startswith("*."):
|
||||
# 通配符模式 → 创建匹配的文件
|
||||
ext = pat[1:] # e.g. ".log"
|
||||
(tmp / f"data{ext}").write_text("x", encoding="utf-8")
|
||||
else:
|
||||
# 精确匹配 → 创建目录
|
||||
d = tmp / pat
|
||||
d.mkdir(exist_ok=True)
|
||||
(d / "inner.txt").write_text("x", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp, exclude=custom_patterns)
|
||||
|
||||
for entry in entries:
|
||||
parts = entry.rel_path.split("/")
|
||||
for part in parts:
|
||||
for pat in custom_patterns:
|
||||
assert not fnmatch.fnmatch(part, pat), (
|
||||
f"自定义排除模式 '{pat}' 不应出现在结果中,"
|
||||
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
|
||||
)
|
||||
377
apps/etl/connectors/feiqiu/tests/unit/test_base_dws_template.py
Normal file
377
apps/etl/connectors/feiqiu/tests/unit/test_base_dws_template.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
BaseDwsTask 默认模板方法 — 属性测试 + 单元测试
|
||||
|
||||
Feature: etl-dws-flow-refactor
|
||||
测试位置:apps/etl/connectors/feiqiu/tests/unit/
|
||||
|
||||
使用 hypothesis 验证 BaseDwsTask 默认 extract()/load() 的核心正确性属性。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume
|
||||
import hypothesis.strategies as st
|
||||
|
||||
# ── 将 ETL 模块加入 sys.path ──
|
||||
_ETL_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_ETL_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_ETL_ROOT))
|
||||
|
||||
from tasks.base_task import TaskContext
|
||||
from tasks.dws.base_dws_task import BaseDwsTask
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 策略(Strategies)
|
||||
# ============================================================================
|
||||
|
||||
# 合法的 store_id:正整数
|
||||
st_store_id = st.integers(min_value=1, max_value=10**18)
|
||||
|
||||
# 合法的时间窗口:window_start < window_end,带时区
|
||||
st_window = st.tuples(
|
||||
st.datetimes(
|
||||
min_value=datetime(2020, 1, 1),
|
||||
max_value=datetime(2030, 12, 31),
|
||||
),
|
||||
st.integers(min_value=1, max_value=525600), # 1 分钟 ~ 1 年
|
||||
).map(
|
||||
lambda t: (
|
||||
t[0].replace(tzinfo=__import__("zoneinfo").ZoneInfo("Asia/Shanghai")),
|
||||
(t[0] + timedelta(minutes=t[1])).replace(
|
||||
tzinfo=__import__("zoneinfo").ZoneInfo("Asia/Shanghai")
|
||||
),
|
||||
t[1],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_context(store_id: int, window_start: datetime, window_end: datetime, window_minutes: int) -> TaskContext:
|
||||
"""构造合法的 TaskContext。"""
|
||||
return TaskContext(
|
||||
store_id=store_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
window_minutes=window_minutes,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def st_task_context(draw):
|
||||
"""生成合法的 TaskContext。"""
|
||||
store_id = draw(st_store_id)
|
||||
window_start, window_end, window_minutes = draw(st_window)
|
||||
return build_context(store_id, window_start, window_end, window_minutes)
|
||||
|
||||
|
||||
# 生成简单的行数据(dict 列表)
|
||||
st_row = st.fixed_dictionaries({
|
||||
"site_id": st_store_id,
|
||||
"stat_date": st.dates(),
|
||||
"value": st.integers(min_value=0, max_value=10000),
|
||||
})
|
||||
|
||||
st_rows = st.lists(st_row, min_size=0, max_size=20)
|
||||
st_nonempty_rows = st.lists(st_row, min_size=1, max_size=20)
|
||||
|
||||
# DATE_COL 策略:合法的列名标识符
|
||||
st_date_col = st.sampled_from([
|
||||
"stat_date", "visit_date", "stat_month", "salary_month", "created_at",
|
||||
])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 测试用具体子类
|
||||
# ============================================================================
|
||||
|
||||
class StubDwsTask(BaseDwsTask):
|
||||
"""用于属性测试的最小 BaseDwsTask 子类。
|
||||
|
||||
声明 DATE_COL,实现 _do_extract(),不覆盖 extract()/load()。
|
||||
"""
|
||||
|
||||
DATE_COL = "stat_date"
|
||||
|
||||
def __init__(self, date_col: str = "stat_date", extract_rows: list | None = None):
|
||||
# 构造最小依赖
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: (
|
||||
"Asia/Shanghai" if key == "app.timezone" else default
|
||||
)
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
super().__init__(mock_config, mock_db, mock_api, mock_logger)
|
||||
|
||||
# 动态设置 DATE_COL
|
||||
self.DATE_COL = date_col
|
||||
# _do_extract 返回的预设数据
|
||||
self._extract_rows = extract_rows if extract_rows is not None else []
|
||||
|
||||
def get_task_code(self) -> str:
|
||||
return "TEST_STUB_DWS"
|
||||
|
||||
def get_target_table(self) -> str:
|
||||
return "test_stub_table"
|
||||
|
||||
def get_primary_keys(self) -> list[str]:
|
||||
return ["site_id", "stat_date"]
|
||||
|
||||
def _do_extract(self, context: TaskContext) -> list[dict]:
|
||||
return self._extract_rows
|
||||
|
||||
|
||||
class OverriddenExtractTask(StubDwsTask):
|
||||
"""覆盖了 extract() 的子类,用于验证覆盖行为。"""
|
||||
|
||||
def extract(self, context: TaskContext) -> dict:
|
||||
return {"custom": True}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 1: 默认 extract() 返回标准结构
|
||||
# Feature: etl-dws-flow-refactor, Property 1: 默认 extract() 返回标准结构
|
||||
# Validates: Requirements 1.1, 1.4
|
||||
# ============================================================================
|
||||
|
||||
class TestExtractProperty:
|
||||
"""属性测试:默认 extract() 返回标准结构。"""
|
||||
|
||||
@given(context=st_task_context(), rows=st_rows, date_col=st_date_col)
|
||||
@settings(max_examples=100)
|
||||
def test_extract_returns_standard_keys(self, context: TaskContext, rows: list, date_col: str):
|
||||
"""对于任意合法 TaskContext 和 _do_extract 返回值,
|
||||
默认 extract() 应返回包含 rows/start_date/end_date/site_id 四个键的字典。
|
||||
|
||||
**Validates: Requirements 1.1, 1.4**
|
||||
"""
|
||||
task = StubDwsTask(date_col=date_col, extract_rows=rows)
|
||||
result = task.extract(context)
|
||||
|
||||
# 必须包含四个标准键
|
||||
assert set(result.keys()) == {"rows", "start_date", "end_date", "site_id"}
|
||||
|
||||
@given(context=st_task_context(), rows=st_rows, date_col=st_date_col)
|
||||
@settings(max_examples=100)
|
||||
def test_extract_rows_equals_do_extract(self, context: TaskContext, rows: list, date_col: str):
|
||||
"""extract() 返回的 rows 应等于 _do_extract(context) 的返回值。
|
||||
|
||||
**Validates: Requirements 1.1, 1.4**
|
||||
"""
|
||||
task = StubDwsTask(date_col=date_col, extract_rows=rows)
|
||||
result = task.extract(context)
|
||||
|
||||
assert result["rows"] == rows
|
||||
|
||||
@given(context=st_task_context(), rows=st_rows)
|
||||
@settings(max_examples=100)
|
||||
def test_extract_dates_from_context(self, context: TaskContext, rows: list):
|
||||
"""extract() 返回的 start_date/end_date/site_id 应来自 context。
|
||||
|
||||
**Validates: Requirements 1.1, 1.4**
|
||||
"""
|
||||
task = StubDwsTask(extract_rows=rows)
|
||||
result = task.extract(context)
|
||||
|
||||
assert result["start_date"] == context.window_start.date()
|
||||
assert result["end_date"] == context.window_end.date()
|
||||
assert result["site_id"] == context.store_id
|
||||
|
||||
@given(context=st_task_context())
|
||||
@settings(max_examples=100)
|
||||
def test_extract_none_treated_as_empty(self, context: TaskContext):
|
||||
"""当 _do_extract 返回 None 时,extract() 应将 rows 设为空列表。
|
||||
|
||||
**Validates: Requirements 1.1, 1.4**
|
||||
"""
|
||||
task = StubDwsTask(extract_rows=None)
|
||||
# 覆盖 _do_extract 使其返回 None
|
||||
task._do_extract = lambda ctx: None
|
||||
result = task.extract(context)
|
||||
|
||||
assert result["rows"] == []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 2: 默认 load() 幂等写入与标准统计
|
||||
# Feature: etl-dws-flow-refactor, Property 2: 默认 load() 幂等写入与标准统计
|
||||
# Validates: Requirements 1.2, 1.5
|
||||
# ============================================================================
|
||||
|
||||
class TestLoadProperty:
|
||||
"""属性测试:默认 load() 幂等写入与标准统计。"""
|
||||
|
||||
@given(context=st_task_context(), transformed=st_nonempty_rows, date_col=st_date_col)
|
||||
@settings(max_examples=100)
|
||||
def test_load_nonempty_returns_standard_counts(
|
||||
self, context: TaskContext, transformed: list, date_col: str
|
||||
):
|
||||
"""对于非空 transformed,load() 应返回包含 counts 键的字典,
|
||||
counts 包含 fetched/inserted/updated/skipped/errors 五个整数键,
|
||||
且 fetched == len(transformed)。
|
||||
|
||||
**Validates: Requirements 1.2, 1.5**
|
||||
"""
|
||||
task = StubDwsTask(date_col=date_col)
|
||||
|
||||
# mock delete_existing_data 和 bulk_insert,避免真实 DB 操作
|
||||
task.delete_existing_data = MagicMock(return_value=len(transformed))
|
||||
task.bulk_insert = MagicMock(return_value=len(transformed))
|
||||
|
||||
result = task.load(transformed, context)
|
||||
|
||||
# 必须包含 counts 键
|
||||
assert "counts" in result
|
||||
counts = result["counts"]
|
||||
|
||||
# counts 必须包含五个标准键
|
||||
expected_keys = {"fetched", "inserted", "updated", "skipped", "errors"}
|
||||
assert set(counts.keys()) == expected_keys
|
||||
|
||||
# 所有值必须是整数
|
||||
for key in expected_keys:
|
||||
assert isinstance(counts[key], int), f"counts[{key}] 应为 int,实际为 {type(counts[key])}"
|
||||
|
||||
# fetched 等于 len(transformed)
|
||||
assert counts["fetched"] == len(transformed)
|
||||
|
||||
@given(context=st_task_context())
|
||||
@settings(max_examples=100)
|
||||
def test_load_empty_returns_zero_counts(self, context: TaskContext):
|
||||
"""对于空 transformed,load() 应返回所有计数为 0 的字典。
|
||||
|
||||
**Validates: Requirements 1.2, 1.5**
|
||||
"""
|
||||
task = StubDwsTask()
|
||||
result = task.load([], context)
|
||||
|
||||
assert "counts" in result
|
||||
counts = result["counts"]
|
||||
expected_keys = {"fetched", "inserted", "updated", "skipped", "errors"}
|
||||
assert set(counts.keys()) == expected_keys
|
||||
|
||||
for key in expected_keys:
|
||||
assert counts[key] == 0
|
||||
|
||||
@given(context=st_task_context(), transformed=st_nonempty_rows, date_col=st_date_col)
|
||||
@settings(max_examples=100)
|
||||
def test_load_calls_delete_then_insert(
|
||||
self, context: TaskContext, transformed: list, date_col: str
|
||||
):
|
||||
"""非空 transformed 时,load() 应先调用 delete_existing_data 再调用 bulk_insert。
|
||||
|
||||
**Validates: Requirements 1.2, 1.5**
|
||||
"""
|
||||
call_order = []
|
||||
task = StubDwsTask(date_col=date_col)
|
||||
|
||||
def mock_delete(ctx, date_col=None):
|
||||
call_order.append("delete")
|
||||
return 0
|
||||
|
||||
def mock_insert(rows):
|
||||
call_order.append("insert")
|
||||
return len(rows)
|
||||
|
||||
task.delete_existing_data = mock_delete
|
||||
task.bulk_insert = mock_insert
|
||||
|
||||
task.load(transformed, context)
|
||||
|
||||
assert call_order == ["delete", "insert"]
|
||||
|
||||
@given(context=st_task_context(), transformed=st_nonempty_rows, date_col=st_date_col)
|
||||
@settings(max_examples=100)
|
||||
def test_load_uses_date_col(
|
||||
self, context: TaskContext, transformed: list, date_col: str
|
||||
):
|
||||
"""load() 应使用 DATE_COL 作为 delete_existing_data 的 date_col 参数。
|
||||
|
||||
**Validates: Requirements 1.2, 1.5**
|
||||
"""
|
||||
captured_date_col = []
|
||||
task = StubDwsTask(date_col=date_col)
|
||||
|
||||
def mock_delete(ctx, date_col=None):
|
||||
captured_date_col.append(date_col)
|
||||
return 0
|
||||
|
||||
task.delete_existing_data = mock_delete
|
||||
task.bulk_insert = MagicMock(return_value=len(transformed))
|
||||
|
||||
task.load(transformed, context)
|
||||
|
||||
assert len(captured_date_col) == 1
|
||||
assert captured_date_col[0] == date_col
|
||||
|
||||
@given(context=st_task_context(), transformed=st_nonempty_rows)
|
||||
@settings(max_examples=100)
|
||||
def test_load_none_date_col_falls_back(self, context: TaskContext, transformed: list):
|
||||
"""DATE_COL 为 None 时,load() 应回退到 "stat_date"。
|
||||
|
||||
**Validates: Requirements 1.2, 1.5**
|
||||
"""
|
||||
captured_date_col = []
|
||||
task = StubDwsTask()
|
||||
task.DATE_COL = None # 显式设为 None
|
||||
|
||||
def mock_delete(ctx, date_col=None):
|
||||
captured_date_col.append(date_col)
|
||||
return 0
|
||||
|
||||
task.delete_existing_data = mock_delete
|
||||
task.bulk_insert = MagicMock(return_value=len(transformed))
|
||||
|
||||
task.load(transformed, context)
|
||||
|
||||
assert captured_date_col[0] == "stat_date"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 单元测试:覆盖行为验证(Requirements 1.3)
|
||||
# ============================================================================
|
||||
|
||||
class TestOverrideBehavior:
|
||||
"""验证子类覆盖 extract()/load() 时使用子类实现。"""
|
||||
|
||||
def test_overridden_extract_used(self):
|
||||
"""覆盖了 extract() 的子类应使用自定义实现。"""
|
||||
ctx = build_context(
|
||||
store_id=123,
|
||||
window_start=datetime(2026, 1, 1, tzinfo=__import__("zoneinfo").ZoneInfo("Asia/Shanghai")),
|
||||
window_end=datetime(2026, 1, 2, tzinfo=__import__("zoneinfo").ZoneInfo("Asia/Shanghai")),
|
||||
window_minutes=1440,
|
||||
)
|
||||
task = OverriddenExtractTask()
|
||||
result = task.extract(ctx)
|
||||
assert result == {"custom": True}
|
||||
|
||||
def test_not_implemented_do_extract_raises(self):
|
||||
"""未实现 _do_extract 且未覆盖 extract 的子类应抛出 NotImplementedError。"""
|
||||
|
||||
class BareTask(BaseDwsTask):
|
||||
def get_task_code(self): return "BARE"
|
||||
def get_target_table(self): return "bare_table"
|
||||
def get_primary_keys(self): return ["id"]
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: (
|
||||
"Asia/Shanghai" if key == "app.timezone" else default
|
||||
)
|
||||
task = BareTask(mock_config, MagicMock(), MagicMock(), MagicMock())
|
||||
ctx = build_context(
|
||||
store_id=1,
|
||||
window_start=datetime(2026, 1, 1, tzinfo=__import__("zoneinfo").ZoneInfo("Asia/Shanghai")),
|
||||
window_end=datetime(2026, 1, 2, tzinfo=__import__("zoneinfo").ZoneInfo("Asia/Shanghai")),
|
||||
window_minutes=1440,
|
||||
)
|
||||
with pytest.raises(NotImplementedError):
|
||||
task.extract(ctx)
|
||||
139
apps/etl/connectors/feiqiu/tests/unit/test_cli_args.py
Normal file
139
apps/etl/connectors/feiqiu/tests/unit/test_cli_args.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI 参数解析单元测试
|
||||
|
||||
验证 --data-source 新参数、--pipeline-flow 弃用映射、
|
||||
--pipeline + --tasks 同时使用、以及 build_cli_overrides 集成行为。
|
||||
|
||||
需求: 3.1, 3.3, 3.5
|
||||
"""
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cli.main import parse_args, resolve_data_source, build_cli_overrides
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. --data-source 新参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDataSourceArg:
|
||||
"""--data-source 新参数测试"""
|
||||
|
||||
@pytest.mark.parametrize("value", ["online", "offline", "hybrid"])
|
||||
def test_data_source_valid_values(self, value):
|
||||
with patch("sys.argv", ["cli", "--data-source", value]):
|
||||
args = parse_args()
|
||||
assert args.data_source == value
|
||||
|
||||
def test_data_source_default_is_none(self):
|
||||
with patch("sys.argv", ["cli"]):
|
||||
args = parse_args()
|
||||
assert args.data_source is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. resolve_data_source() 弃用映射
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestResolveDataSource:
|
||||
"""resolve_data_source() 弃用映射测试"""
|
||||
|
||||
def test_explicit_data_source_returns_directly(self):
|
||||
args = Namespace(data_source="online", pipeline_flow=None)
|
||||
assert resolve_data_source(args) == "online"
|
||||
|
||||
def test_data_source_takes_priority_over_pipeline_flow(self):
|
||||
"""--data-source 优先于 --pipeline-flow"""
|
||||
args = Namespace(data_source="online", pipeline_flow="FULL")
|
||||
assert resolve_data_source(args) == "online"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"flow, expected",
|
||||
[
|
||||
("FULL", "hybrid"),
|
||||
("FETCH_ONLY", "online"),
|
||||
("INGEST_ONLY", "offline"),
|
||||
],
|
||||
)
|
||||
def test_pipeline_flow_maps_with_deprecation_warning(self, flow, expected):
|
||||
"""旧参数 --pipeline-flow 映射到正确的 data_source 并发出弃用警告"""
|
||||
args = Namespace(data_source=None, pipeline_flow=flow)
|
||||
with pytest.warns(DeprecationWarning, match="--pipeline-flow 已弃用"):
|
||||
result = resolve_data_source(args)
|
||||
assert result == expected
|
||||
|
||||
def test_neither_arg_defaults_to_hybrid(self):
|
||||
"""两个参数都未指定时,默认返回 hybrid"""
|
||||
args = Namespace(data_source=None, pipeline_flow=None)
|
||||
assert resolve_data_source(args) == "hybrid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. build_cli_overrides() 集成
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestBuildCliOverrides:
|
||||
"""build_cli_overrides() 集成测试"""
|
||||
|
||||
def _make_args(self, **kwargs):
|
||||
"""构造最小 Namespace,未指定的参数设为 None/False"""
|
||||
defaults = dict(
|
||||
store_id=None, tasks=None, dry_run=False,
|
||||
flow=None, pipeline_deprecated=None,
|
||||
processing_mode="increment_only",
|
||||
fetch_before_verify=False, verify_tables=None,
|
||||
window_split="none", lookback_hours=24, overlap_seconds=3600,
|
||||
pg_dsn=None, pg_host=None, pg_port=None, pg_name=None,
|
||||
pg_user=None, pg_password=None,
|
||||
api_base=None, api_token=None, api_timeout=None,
|
||||
api_page_size=None, api_retry_max=None,
|
||||
window_start=None, window_end=None,
|
||||
force_window_override=False,
|
||||
window_split_unit=None, window_split_days=None,
|
||||
window_compensation_hours=None,
|
||||
export_root=None, log_root=None,
|
||||
data_source=None, pipeline_flow=None,
|
||||
fetch_root=None, ingest_source=None, write_pretty_json=False,
|
||||
idle_start=None, idle_end=None, allow_empty_advance=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return Namespace(**defaults)
|
||||
|
||||
def test_data_source_online_sets_run_key(self):
|
||||
args = self._make_args(data_source="online")
|
||||
overrides = build_cli_overrides(args)
|
||||
assert overrides["run"]["data_source"] == "online"
|
||||
|
||||
def test_pipeline_flow_sets_both_keys(self):
|
||||
"""旧参数同时写入 pipeline.flow 和 run.data_source"""
|
||||
args = self._make_args(pipeline_flow="FULL")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
overrides = build_cli_overrides(args)
|
||||
assert overrides["pipeline"]["flow"] == "FULL"
|
||||
assert overrides["run"]["data_source"] == "hybrid"
|
||||
|
||||
def test_default_data_source_is_hybrid(self):
|
||||
"""无 --data-source 也无 --pipeline-flow 时,run.data_source 默认 hybrid"""
|
||||
args = self._make_args()
|
||||
overrides = build_cli_overrides(args)
|
||||
assert overrides["run"]["data_source"] == "hybrid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. --pipeline + --tasks 同时使用
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPipelineAndTasks:
|
||||
"""--pipeline + --tasks 同时使用时的行为"""
|
||||
|
||||
def test_pipeline_and_tasks_both_parsed(self):
|
||||
"""--pipeline(弃用别名)和 --tasks 可同时解析"""
|
||||
with patch("sys.argv", [
|
||||
"cli",
|
||||
"--pipeline", "api_full",
|
||||
"--tasks", "ODS_MEMBER,ODS_ORDER",
|
||||
]):
|
||||
args = parse_args()
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
assert args.tasks == "ODS_MEMBER,ODS_ORDER"
|
||||
426
apps/etl/connectors/feiqiu/tests/unit/test_compare_ddl.py
Normal file
426
apps/etl/connectors/feiqiu/tests/unit/test_compare_ddl.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""DDL 解析器和对比逻辑的单元测试。
|
||||
|
||||
测试范围:
|
||||
- DDL 解析器正确提取表名、字段、类型、约束
|
||||
- 类型标准化逻辑
|
||||
- 差异检测逻辑识别各类差异
|
||||
- 边界情况:空文件、COMMENT 含特殊字符
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.compare_ddl_db import (
|
||||
ColumnDef,
|
||||
DiffKind,
|
||||
SchemaDiff,
|
||||
TableDef,
|
||||
compare_tables,
|
||||
normalize_type,
|
||||
parse_ddl,
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# normalize_type 测试
|
||||
# =========================================================================
|
||||
|
||||
class TestNormalizeType:
|
||||
"""类型标准化测试。"""
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("BIGINT", "bigint"),
|
||||
("INT8", "bigint"),
|
||||
("INTEGER", "integer"),
|
||||
("INT", "integer"),
|
||||
("INT4", "integer"),
|
||||
("SMALLINT", "smallint"),
|
||||
("INT2", "smallint"),
|
||||
("BOOLEAN", "boolean"),
|
||||
("BOOL", "boolean"),
|
||||
("TEXT", "text"),
|
||||
("JSONB", "jsonb"),
|
||||
("JSON", "json"),
|
||||
("DATE", "date"),
|
||||
("BYTEA", "bytea"),
|
||||
("UUID", "uuid"),
|
||||
])
|
||||
def test_simple_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("NUMERIC(18,2)", "numeric(18,2)"),
|
||||
("NUMERIC(10,6)", "numeric(10,6)"),
|
||||
("DECIMAL(5,2)", "numeric(5,2)"),
|
||||
("NUMERIC(10)", "numeric(10)"),
|
||||
("NUMERIC", "numeric"),
|
||||
])
|
||||
def test_numeric_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("VARCHAR(50)", "varchar(50)"),
|
||||
("CHARACTER VARYING(100)", "varchar(100)"),
|
||||
("VARCHAR", "varchar"),
|
||||
("CHAR(1)", "char(1)"),
|
||||
("CHARACTER(10)", "char(10)"),
|
||||
])
|
||||
def test_string_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("TIMESTAMP", "timestamp"),
|
||||
("TIMESTAMP WITHOUT TIME ZONE", "timestamp"),
|
||||
("TIMESTAMPTZ", "timestamptz"),
|
||||
("TIMESTAMP WITH TIME ZONE", "timestamptz"),
|
||||
])
|
||||
def test_timestamp_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("BIGSERIAL", "bigint"),
|
||||
("SERIAL", "integer"),
|
||||
("SMALLSERIAL", "smallint"),
|
||||
])
|
||||
def test_serial_types(self, raw, expected):
|
||||
"""serial 家族应映射到底层整数类型。"""
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert normalize_type("bigint") == normalize_type("BIGINT")
|
||||
assert normalize_type("Numeric(18,2)") == normalize_type("NUMERIC(18,2)")
|
||||
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# parse_ddl 测试
|
||||
# =========================================================================
|
||||
|
||||
class TestParseDdl:
|
||||
"""DDL 解析器测试。"""
|
||||
|
||||
def test_basic_create_table(self):
|
||||
"""基本 CREATE TABLE 解析。"""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS myschema.users (
|
||||
id BIGINT NOT NULL,
|
||||
name TEXT,
|
||||
age INTEGER,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="myschema")
|
||||
assert "users" in tables
|
||||
t = tables["users"]
|
||||
assert len(t.columns) == 3
|
||||
assert t.pk_columns == ["id"]
|
||||
assert t.columns["id"].data_type == "bigint"
|
||||
assert t.columns["id"].nullable is False
|
||||
assert t.columns["name"].data_type == "text"
|
||||
assert t.columns["name"].nullable is True
|
||||
assert t.columns["age"].data_type == "integer"
|
||||
|
||||
def test_inline_primary_key(self):
|
||||
"""内联 PRIMARY KEY 约束。"""
|
||||
sql = """
|
||||
CREATE TABLE test_schema.items (
|
||||
item_id BIGSERIAL PRIMARY KEY,
|
||||
label TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="test_schema")
|
||||
t = tables["items"]
|
||||
assert t.columns["item_id"].is_pk is True
|
||||
# BIGSERIAL → bigint
|
||||
assert t.columns["item_id"].data_type == "bigint"
|
||||
assert t.columns["item_id"].nullable is False
|
||||
assert t.columns["label"].nullable is False
|
||||
|
||||
def test_composite_primary_key(self):
|
||||
"""复合主键。"""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS billiards_ods.member_profiles (
|
||||
id BIGINT,
|
||||
content_hash TEXT NOT NULL,
|
||||
name TEXT,
|
||||
PRIMARY KEY (id, content_hash)
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||||
t = tables["member_profiles"]
|
||||
assert t.pk_columns == ["id", "content_hash"]
|
||||
assert t.columns["id"].is_pk is True
|
||||
assert t.columns["content_hash"].is_pk is True
|
||||
# PK 隐含 NOT NULL
|
||||
assert t.columns["id"].nullable is False
|
||||
|
||||
def test_various_data_types(self):
|
||||
"""各种 PostgreSQL 数据类型。"""
|
||||
sql = """
|
||||
CREATE TABLE s.t (
|
||||
a BIGINT,
|
||||
b VARCHAR(50),
|
||||
c NUMERIC(18,2),
|
||||
d TIMESTAMP,
|
||||
e TIMESTAMPTZ DEFAULT now(),
|
||||
f BOOLEAN DEFAULT TRUE,
|
||||
g JSONB NOT NULL,
|
||||
h TEXT,
|
||||
i INTEGER
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="s")
|
||||
t = tables["t"]
|
||||
assert t.columns["a"].data_type == "bigint"
|
||||
assert t.columns["b"].data_type == "varchar(50)"
|
||||
assert t.columns["c"].data_type == "numeric(18,2)"
|
||||
assert t.columns["d"].data_type == "timestamp"
|
||||
assert t.columns["e"].data_type == "timestamptz"
|
||||
assert t.columns["f"].data_type == "boolean"
|
||||
assert t.columns["g"].data_type == "jsonb"
|
||||
assert t.columns["g"].nullable is False
|
||||
assert t.columns["h"].data_type == "text"
|
||||
assert t.columns["i"].data_type == "integer"
|
||||
|
||||
def test_without_schema_prefix(self):
|
||||
"""无 schema 前缀的 CREATE TABLE(如 DWD DDL 中 SET search_path 后)。"""
|
||||
sql = """
|
||||
SET search_path TO billiards_dwd;
|
||||
CREATE TABLE IF NOT EXISTS dim_site (
|
||||
site_id BIGINT,
|
||||
shop_name TEXT,
|
||||
PRIMARY KEY (site_id)
|
||||
);
|
||||
"""
|
||||
# target_schema 指定时,无前缀的表也应被接受
|
||||
tables = parse_ddl(sql, target_schema="billiards_dwd")
|
||||
assert "dim_site" in tables
|
||||
|
||||
def test_schema_filter(self):
|
||||
"""schema 过滤:只保留目标 schema 的表。"""
|
||||
sql = """
|
||||
CREATE TABLE schema_a.t1 (id BIGINT);
|
||||
CREATE TABLE schema_b.t2 (id BIGINT);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="schema_a")
|
||||
assert "t1" in tables
|
||||
assert "t2" not in tables
|
||||
|
||||
def test_empty_ddl(self):
|
||||
"""空 DDL 文件应返回空字典。"""
|
||||
tables = parse_ddl("", target_schema="any")
|
||||
assert tables == {}
|
||||
|
||||
def test_comments_ignored(self):
|
||||
"""SQL 注释不影响解析。"""
|
||||
sql = """
|
||||
-- 这是注释
|
||||
/* 块注释 */
|
||||
CREATE TABLE s.t (
|
||||
id BIGINT, -- 行内注释
|
||||
name TEXT
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="s")
|
||||
assert "t" in tables
|
||||
assert len(tables["t"].columns) == 2
|
||||
|
||||
def test_comment_on_statements_ignored(self):
|
||||
"""COMMENT ON 语句不影响表解析。"""
|
||||
sql = """
|
||||
CREATE TABLE billiards_ods.test_table (
|
||||
id BIGINT NOT NULL,
|
||||
name TEXT,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
COMMENT ON TABLE billiards_ods.test_table IS '测试表:含特殊字符 ''引号'' 和 (括号)';
|
||||
COMMENT ON COLUMN billiards_ods.test_table.id IS '【说明】主键 ID。【示例】12345。';
|
||||
COMMENT ON COLUMN billiards_ods.test_table.name IS '【说明】名称,含 ''单引号'' 和 "双引号"。';
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||||
assert "test_table" in tables
|
||||
assert len(tables["test_table"].columns) == 2
|
||||
|
||||
def test_drop_then_create(self):
|
||||
"""DROP TABLE 后 CREATE TABLE 应正常解析。"""
|
||||
sql = """
|
||||
DROP TABLE IF EXISTS billiards_dws.cfg_test CASCADE;
|
||||
CREATE TABLE billiards_dws.cfg_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
value TEXT
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="billiards_dws")
|
||||
assert "cfg_test" in tables
|
||||
assert tables["cfg_test"].columns["id"].data_type == "integer"
|
||||
|
||||
def test_default_values_parsed(self):
|
||||
"""DEFAULT 值不影响类型和约束解析。"""
|
||||
sql = """
|
||||
CREATE TABLE s.t (
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
count INTEGER DEFAULT 0 NOT NULL,
|
||||
label VARCHAR(20) NOT NULL
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="s")
|
||||
t = tables["t"]
|
||||
assert t.columns["enabled"].data_type == "boolean"
|
||||
assert t.columns["enabled"].nullable is True
|
||||
assert t.columns["created_at"].data_type == "timestamptz"
|
||||
assert t.columns["count"].data_type == "integer"
|
||||
assert t.columns["count"].nullable is False
|
||||
assert t.columns["label"].data_type == "varchar(20)"
|
||||
assert t.columns["label"].nullable is False
|
||||
|
||||
def test_constraint_lines_skipped(self):
|
||||
"""表级约束行(CONSTRAINT、UNIQUE、FOREIGN KEY)应被跳过。"""
|
||||
sql = """
|
||||
CREATE TABLE etl_admin.etl_task (
|
||||
task_id BIGSERIAL PRIMARY KEY,
|
||||
task_code TEXT NOT NULL,
|
||||
store_id BIGINT NOT NULL,
|
||||
UNIQUE (task_code, store_id)
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="etl_admin")
|
||||
t = tables["etl_task"]
|
||||
assert len(t.columns) == 3
|
||||
assert "task_id" in t.columns
|
||||
assert "task_code" in t.columns
|
||||
assert "store_id" in t.columns
|
||||
|
||||
def test_real_ods_ddl_parseable(self):
|
||||
"""验证实际 ODS DDL 文件可被解析。"""
|
||||
from pathlib import Path
|
||||
ddl_path = Path("database/schema_ODS_doc.sql")
|
||||
if not ddl_path.exists():
|
||||
pytest.skip("DDL 文件不存在")
|
||||
sql = ddl_path.read_text(encoding="utf-8")
|
||||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||||
# 至少应有 20+ 张表
|
||||
assert len(tables) >= 20
|
||||
# 每张表都应有字段
|
||||
for tbl in tables.values():
|
||||
assert len(tbl.columns) > 0
|
||||
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# compare_tables 测试
|
||||
# =========================================================================
|
||||
|
||||
class TestCompareTables:
|
||||
"""差异检测逻辑测试。"""
|
||||
|
||||
def _make_table(self, name: str, columns: dict[str, tuple[str, bool]],
|
||||
pk: list[str] | None = None) -> TableDef:
|
||||
"""辅助方法:快速构建 TableDef。
|
||||
|
||||
columns: {col_name: (data_type, nullable)}
|
||||
"""
|
||||
cols = {}
|
||||
for col_name, (dtype, nullable) in columns.items():
|
||||
cols[col_name] = ColumnDef(
|
||||
name=col_name,
|
||||
data_type=dtype,
|
||||
nullable=nullable,
|
||||
is_pk=col_name in (pk or []),
|
||||
)
|
||||
return TableDef(name=name, columns=cols, pk_columns=pk or [])
|
||||
|
||||
def test_no_diff(self):
|
||||
"""完全一致时应返回空列表。"""
|
||||
t = self._make_table("t", {"id": ("bigint", False), "name": ("text", True)}, pk=["id"])
|
||||
diffs = compare_tables({"t": t}, {"t": t})
|
||||
assert diffs == []
|
||||
|
||||
def test_missing_table(self):
|
||||
"""数据库有但 DDL 没有 → MISSING_TABLE。"""
|
||||
ddl = {}
|
||||
db = {"extra": self._make_table("extra", {"id": ("bigint", False)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.MISSING_TABLE
|
||||
assert diffs[0].table == "extra"
|
||||
|
||||
def test_extra_table(self):
|
||||
"""DDL 有但数据库没有 → EXTRA_TABLE。"""
|
||||
ddl = {"orphan": self._make_table("orphan", {"id": ("bigint", False)})}
|
||||
db = {}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.EXTRA_TABLE
|
||||
assert diffs[0].table == "orphan"
|
||||
|
||||
def test_missing_column(self):
|
||||
"""数据库有但 DDL 没有的字段 → MISSING_COLUMN。"""
|
||||
ddl = {"t": self._make_table("t", {"id": ("bigint", False)})}
|
||||
db = {"t": self._make_table("t", {
|
||||
"id": ("bigint", False),
|
||||
"new_col": ("text", True),
|
||||
})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.MISSING_COLUMN
|
||||
assert diffs[0].column == "new_col"
|
||||
|
||||
def test_extra_column(self):
|
||||
"""DDL 有但数据库没有的字段 → EXTRA_COLUMN。"""
|
||||
ddl = {"t": self._make_table("t", {
|
||||
"id": ("bigint", False),
|
||||
"old_col": ("text", True),
|
||||
})}
|
||||
db = {"t": self._make_table("t", {"id": ("bigint", False)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.EXTRA_COLUMN
|
||||
assert diffs[0].column == "old_col"
|
||||
|
||||
def test_type_mismatch(self):
|
||||
"""字段类型不一致 → TYPE_MISMATCH。"""
|
||||
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
|
||||
db = {"t": self._make_table("t", {"val": ("varchar(100)", True)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.TYPE_MISMATCH
|
||||
assert diffs[0].ddl_value == "text"
|
||||
assert diffs[0].db_value == "varchar(100)"
|
||||
|
||||
def test_nullable_mismatch(self):
|
||||
"""可空约束不一致 → NULLABLE_MISMATCH。"""
|
||||
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
|
||||
db = {"t": self._make_table("t", {"val": ("text", False)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.NULLABLE_MISMATCH
|
||||
assert diffs[0].ddl_value == "NULL"
|
||||
assert diffs[0].db_value == "NOT NULL"
|
||||
|
||||
def test_multiple_diffs(self):
|
||||
"""多种差异同时存在。"""
|
||||
ddl = {
|
||||
"t1": self._make_table("t1", {
|
||||
"id": ("bigint", False),
|
||||
"extra": ("text", True),
|
||||
}),
|
||||
"ddl_only": self._make_table("ddl_only", {"x": ("integer", True)}),
|
||||
}
|
||||
db = {
|
||||
"t1": self._make_table("t1", {
|
||||
"id": ("integer", False), # TYPE_MISMATCH
|
||||
"missing": ("text", True), # MISSING_COLUMN
|
||||
}),
|
||||
"db_only": self._make_table("db_only", {"y": ("text", True)}),
|
||||
}
|
||||
diffs = compare_tables(ddl, db)
|
||||
kinds = {d.kind for d in diffs}
|
||||
assert DiffKind.MISSING_TABLE in kinds # db_only
|
||||
assert DiffKind.EXTRA_TABLE in kinds # ddl_only
|
||||
assert DiffKind.MISSING_COLUMN in kinds # t1.missing
|
||||
assert DiffKind.EXTRA_COLUMN in kinds # t1.extra
|
||||
assert DiffKind.TYPE_MISMATCH in kinds # t1.id
|
||||
|
||||
def test_empty_both(self):
|
||||
"""两边都为空时应返回空列表。"""
|
||||
assert compare_tables({}, {}) == []
|
||||
545
apps/etl/connectors/feiqiu/tests/unit/test_compare_ddl_pbt.py
Normal file
545
apps/etl/connectors/feiqiu/tests/unit/test_compare_ddl_pbt.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""DDL 对比脚本差异检测完整性 — 属性测试。
|
||||
|
||||
# AI_CHANGELOG [2026-02-13] max_examples 从 100 降至 30 以加速测试运行
|
||||
|
||||
**Property 2: DDL 对比脚本差异检测完整性**
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
|
||||
使用 hypothesis 生成随机的 DDL 表定义和数据库表定义,
|
||||
注入已知差异,验证 compare_tables 能检测到所有差异。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given, settings, assume
|
||||
|
||||
from scripts.compare_ddl_db import (
|
||||
ColumnDef,
|
||||
DiffKind,
|
||||
TableDef,
|
||||
compare_tables,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 自定义 Strategy:生成随机的 ColumnDef / TableDef
|
||||
# =========================================================================
|
||||
|
||||
# 可用的标准化类型池(与 normalize_type 输出一致)
|
||||
_NORMALIZED_TYPES = [
|
||||
"bigint", "integer", "smallint", "boolean", "text",
|
||||
"jsonb", "json", "date", "bytea", "uuid",
|
||||
"timestamp", "timestamptz", "double precision",
|
||||
"varchar(50)", "varchar(100)", "varchar(255)",
|
||||
"char(1)", "char(10)",
|
||||
"numeric(18,2)", "numeric(10,6)", "numeric(5,0)",
|
||||
]
|
||||
|
||||
# 合法的标识符字符集(小写字母 + 下划线 + 数字,首字符为字母)
|
||||
_IDENT_ALPHABET = string.ascii_lowercase + "_"
|
||||
_IDENT_FULL = _IDENT_ALPHABET + string.digits
|
||||
|
||||
|
||||
def st_identifier() -> st.SearchStrategy[str]:
|
||||
"""生成合法的 PostgreSQL 标识符(小写,3~20 字符)。"""
|
||||
return st.from_regex(r"[a-z][a-z0-9_]{2,19}", fullmatch=True)
|
||||
|
||||
|
||||
def st_data_type() -> st.SearchStrategy[str]:
|
||||
"""从标准化类型池中随机选取一个类型。"""
|
||||
return st.sampled_from(_NORMALIZED_TYPES)
|
||||
|
||||
|
||||
def st_column_def(name: str | None = None) -> st.SearchStrategy[ColumnDef]:
|
||||
"""生成随机的 ColumnDef。"""
|
||||
return st.builds(
|
||||
ColumnDef,
|
||||
name=st.just(name) if name else st_identifier(),
|
||||
data_type=st_data_type(),
|
||||
nullable=st.booleans(),
|
||||
is_pk=st.just(False), # PK 由 TableDef 层面控制
|
||||
default=st.just(None),
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def st_table_def(draw, name: str | None = None) -> TableDef:
|
||||
"""生成随机的 TableDef(1~8 个字段,可选主键)。"""
|
||||
tbl_name = name or draw(st_identifier())
|
||||
num_cols = draw(st.integers(min_value=1, max_value=8))
|
||||
|
||||
# 生成不重复的列名
|
||||
col_names = draw(
|
||||
st.lists(st_identifier(), min_size=num_cols, max_size=num_cols, unique=True)
|
||||
)
|
||||
|
||||
columns: dict[str, ColumnDef] = {}
|
||||
for cn in col_names:
|
||||
col = draw(st_column_def(name=cn))
|
||||
columns[cn] = col
|
||||
|
||||
# 随机选取 0~2 个列作为主键
|
||||
pk_count = draw(st.integers(min_value=0, max_value=min(2, len(col_names))))
|
||||
pk_cols = draw(
|
||||
st.lists(
|
||||
st.sampled_from(col_names),
|
||||
min_size=pk_count, max_size=pk_count, unique=True,
|
||||
)
|
||||
) if pk_count > 0 and col_names else []
|
||||
|
||||
for pk_col in pk_cols:
|
||||
columns[pk_col].is_pk = True
|
||||
columns[pk_col].nullable = False
|
||||
|
||||
return TableDef(name=tbl_name, columns=columns, pk_columns=pk_cols)
|
||||
|
||||
|
||||
@st.composite
|
||||
def st_table_dict(draw, min_tables: int = 0, max_tables: int = 5) -> dict[str, TableDef]:
|
||||
"""生成 {表名: TableDef} 字典,表名不重复。"""
|
||||
num = draw(st.integers(min_value=min_tables, max_value=max_tables))
|
||||
names = draw(st.lists(st_identifier(), min_size=num, max_size=num, unique=True))
|
||||
tables: dict[str, TableDef] = {}
|
||||
for n in names:
|
||||
tables[n] = draw(st_table_def(name=n))
|
||||
return tables
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Property 2: DDL 对比脚本差异检测完整性
|
||||
# **Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestProperty2DiffDetection:
|
||||
"""Property 2: 对比脚本能检测到所有注入的差异。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
|
||||
@given(base=st_table_dict(min_tables=1, max_tables=4))
|
||||
@settings(max_examples=30)
|
||||
def test_identical_tables_produce_no_diffs(self, base: dict[str, TableDef]):
|
||||
"""完全相同的定义不应产生任何差异。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
diffs = compare_tables(base, base)
|
||||
assert diffs == [], f"相同定义不应有差异,但得到: {diffs}"
|
||||
|
||||
@given(
|
||||
common=st_table_dict(min_tables=0, max_tables=3),
|
||||
db_only=st_table_dict(min_tables=1, max_tables=3),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_missing_table_detected(
|
||||
self,
|
||||
common: dict[str, TableDef],
|
||||
db_only: dict[str, TableDef],
|
||||
):
|
||||
"""数据库有但 DDL 没有的表应报告为 MISSING_TABLE。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
# 确保 db_only 的表名与 common 不重叠
|
||||
overlap = set(common.keys()) & set(db_only.keys())
|
||||
assume(not overlap)
|
||||
|
||||
ddl_tables = dict(common)
|
||||
db_tables = {**common, **db_only}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
missing_tables = {d.table for d in diffs if d.kind == DiffKind.MISSING_TABLE}
|
||||
|
||||
for tbl in db_only:
|
||||
assert tbl in missing_tables, (
|
||||
f"表 '{tbl}' 仅存在于数据库中,应报告为 MISSING_TABLE"
|
||||
)
|
||||
|
||||
@given(
|
||||
common=st_table_dict(min_tables=0, max_tables=3),
|
||||
ddl_only=st_table_dict(min_tables=1, max_tables=3),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_extra_table_detected(
|
||||
self,
|
||||
common: dict[str, TableDef],
|
||||
ddl_only: dict[str, TableDef],
|
||||
):
|
||||
"""DDL 有但数据库没有的表应报告为 EXTRA_TABLE。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
overlap = set(common.keys()) & set(ddl_only.keys())
|
||||
assume(not overlap)
|
||||
|
||||
ddl_tables = {**common, **ddl_only}
|
||||
db_tables = dict(common)
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
extra_tables = {d.table for d in diffs if d.kind == DiffKind.EXTRA_TABLE}
|
||||
|
||||
for tbl in ddl_only:
|
||||
assert tbl in extra_tables, (
|
||||
f"表 '{tbl}' 仅存在于 DDL 中,应报告为 EXTRA_TABLE"
|
||||
)
|
||||
|
||||
@given(
|
||||
table=st_table_def(),
|
||||
extra_cols=st.lists(
|
||||
st.tuples(st_identifier(), st_data_type(), st.booleans()),
|
||||
min_size=1, max_size=4, unique_by=lambda x: x[0],
|
||||
),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_missing_column_detected(
|
||||
self,
|
||||
table: TableDef,
|
||||
extra_cols: list[tuple[str, str, bool]],
|
||||
):
|
||||
"""数据库有但 DDL 没有的字段应报告为 MISSING_COLUMN。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
# 确保注入的列名与现有列不重叠
|
||||
existing_names = set(table.columns.keys())
|
||||
injected_names = {c[0] for c in extra_cols}
|
||||
assume(not (existing_names & injected_names))
|
||||
|
||||
# DDL 侧:原始表
|
||||
ddl_tables = {table.name: table}
|
||||
|
||||
# DB 侧:原始表 + 额外字段
|
||||
db_table = TableDef(
|
||||
name=table.name,
|
||||
columns=dict(table.columns),
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
for col_name, col_type, nullable in extra_cols:
|
||||
db_table.columns[col_name] = ColumnDef(
|
||||
name=col_name, data_type=col_type, nullable=nullable,
|
||||
)
|
||||
db_tables = {table.name: db_table}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
missing_cols = {
|
||||
d.column for d in diffs
|
||||
if d.kind == DiffKind.MISSING_COLUMN and d.table == table.name
|
||||
}
|
||||
|
||||
for col_name, _, _ in extra_cols:
|
||||
assert col_name in missing_cols, (
|
||||
f"字段 '{table.name}.{col_name}' 仅存在于数据库中,"
|
||||
f"应报告为 MISSING_COLUMN"
|
||||
)
|
||||
|
||||
@given(
|
||||
table=st_table_def(),
|
||||
extra_cols=st.lists(
|
||||
st.tuples(st_identifier(), st_data_type(), st.booleans()),
|
||||
min_size=1, max_size=4, unique_by=lambda x: x[0],
|
||||
),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_extra_column_detected(
|
||||
self,
|
||||
table: TableDef,
|
||||
extra_cols: list[tuple[str, str, bool]],
|
||||
):
|
||||
"""DDL 有但数据库没有的字段应报告为 EXTRA_COLUMN。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
existing_names = set(table.columns.keys())
|
||||
injected_names = {c[0] for c in extra_cols}
|
||||
assume(not (existing_names & injected_names))
|
||||
|
||||
# DDL 侧:原始表 + 额外字段
|
||||
ddl_table = TableDef(
|
||||
name=table.name,
|
||||
columns=dict(table.columns),
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
for col_name, col_type, nullable in extra_cols:
|
||||
ddl_table.columns[col_name] = ColumnDef(
|
||||
name=col_name, data_type=col_type, nullable=nullable,
|
||||
)
|
||||
ddl_tables = {table.name: ddl_table}
|
||||
|
||||
# DB 侧:原始表
|
||||
db_tables = {table.name: table}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
extra_col_set = {
|
||||
d.column for d in diffs
|
||||
if d.kind == DiffKind.EXTRA_COLUMN and d.table == table.name
|
||||
}
|
||||
|
||||
for col_name, _, _ in extra_cols:
|
||||
assert col_name in extra_col_set, (
|
||||
f"字段 '{table.name}.{col_name}' 仅存在于 DDL 中,"
|
||||
f"应报告为 EXTRA_COLUMN"
|
||||
)
|
||||
|
||||
@given(table=st_table_def())
|
||||
@settings(max_examples=30)
|
||||
def test_type_mismatch_detected(self, table: TableDef):
|
||||
"""字段类型不一致时应报告为 TYPE_MISMATCH。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
assume(len(table.columns) >= 1)
|
||||
|
||||
# 选取第一个字段,修改其类型
|
||||
target_col_name = next(iter(table.columns))
|
||||
original_type = table.columns[target_col_name].data_type
|
||||
|
||||
# 选一个不同的类型
|
||||
alt_types = [t for t in _NORMALIZED_TYPES if t != original_type]
|
||||
assume(len(alt_types) > 0)
|
||||
new_type = alt_types[0]
|
||||
|
||||
# DDL 侧:原始定义
|
||||
ddl_tables = {table.name: table}
|
||||
|
||||
# DB 侧:修改目标字段类型
|
||||
db_table = TableDef(
|
||||
name=table.name,
|
||||
columns={
|
||||
cn: ColumnDef(
|
||||
name=c.name,
|
||||
data_type=new_type if cn == target_col_name else c.data_type,
|
||||
nullable=c.nullable,
|
||||
is_pk=c.is_pk,
|
||||
default=c.default,
|
||||
)
|
||||
for cn, c in table.columns.items()
|
||||
},
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
db_tables = {table.name: db_table}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
type_mismatches = [
|
||||
d for d in diffs
|
||||
if d.kind == DiffKind.TYPE_MISMATCH
|
||||
and d.table == table.name
|
||||
and d.column == target_col_name
|
||||
]
|
||||
|
||||
assert len(type_mismatches) == 1, (
|
||||
f"字段 '{table.name}.{target_col_name}' 类型从 "
|
||||
f"'{original_type}' 变为 '{new_type}',应报告 TYPE_MISMATCH"
|
||||
)
|
||||
assert type_mismatches[0].ddl_value == original_type
|
||||
assert type_mismatches[0].db_value == new_type
|
||||
|
||||
@given(
|
||||
ddl_tables=st_table_dict(min_tables=0, max_tables=4),
|
||||
db_tables=st_table_dict(min_tables=0, max_tables=4),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_all_diff_kinds_are_accounted_for(
|
||||
self,
|
||||
ddl_tables: dict[str, TableDef],
|
||||
db_tables: dict[str, TableDef],
|
||||
):
|
||||
"""对任意输入,所有差异都应属于已知的 DiffKind 枚举值。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
|
||||
这是一个"元属性":验证对比函数不会产生未定义的差异类型,
|
||||
且每条差异都有合理的 table/column 引用。
|
||||
"""
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
|
||||
valid_kinds = set(DiffKind)
|
||||
all_table_names = set(ddl_tables.keys()) | set(db_tables.keys())
|
||||
|
||||
for d in diffs:
|
||||
# 差异类型必须是已知枚举值
|
||||
assert d.kind in valid_kinds, f"未知差异类型: {d.kind}"
|
||||
# 差异引用的表必须存在于某一侧
|
||||
assert d.table in all_table_names, (
|
||||
f"差异引用了不存在的表: {d.table}"
|
||||
)
|
||||
# 表级差异不应有 column
|
||||
if d.kind in (DiffKind.MISSING_TABLE, DiffKind.EXTRA_TABLE):
|
||||
assert d.column is None, (
|
||||
f"表级差异 {d.kind} 不应包含 column 字段"
|
||||
)
|
||||
# 字段级差异必须有 column
|
||||
if d.kind in (
|
||||
DiffKind.MISSING_COLUMN, DiffKind.EXTRA_COLUMN,
|
||||
DiffKind.TYPE_MISMATCH, DiffKind.NULLABLE_MISMATCH,
|
||||
):
|
||||
assert d.column is not None, (
|
||||
f"字段级差异 {d.kind} 必须包含 column 字段"
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Property 3: DDL 修正后零差异(不动点)
|
||||
# **Validates: Requirements 2.5**
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestProperty3FixpointZeroDiff:
|
||||
"""Property 3: DDL 修正后零差异(不动点)。
|
||||
|
||||
核心思想:以数据库实际状态修正 DDL 文件后,再次运行对比脚本,
|
||||
差异列表应为空。即 compare_tables(db_state, db_state) == []。
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
|
||||
@given(db_state=st_table_dict(min_tables=0, max_tables=5))
|
||||
@settings(max_examples=30)
|
||||
def test_identical_state_yields_no_diff(self, db_state: dict[str, TableDef]):
|
||||
"""将同一份定义同时作为 DDL 和 DB 输入,差异应为零。
|
||||
|
||||
模拟场景:修正完成后 DDL 与数据库完全一致。
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
diffs = compare_tables(db_state, db_state)
|
||||
assert diffs == [], (
|
||||
f"不动点违反:相同定义应产生零差异,但得到 {len(diffs)} 项: "
|
||||
f"{[str(d) for d in diffs]}"
|
||||
)
|
||||
|
||||
@given(db_state=st_table_dict(min_tables=1, max_tables=5))
|
||||
@settings(max_examples=30)
|
||||
def test_copy_db_as_ddl_yields_no_diff(self, db_state: dict[str, TableDef]):
|
||||
"""模拟修正过程:从 DB 定义深拷贝构造 DDL 定义,对比应为零差异。
|
||||
|
||||
这验证了"以数据库为准生成 DDL"后的不动点性质。
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
# 模拟修正:逐表逐字段从 DB 定义复制出一份新的 DDL 定义
|
||||
ddl_fixed: dict[str, TableDef] = {}
|
||||
for tbl_name, tbl in db_state.items():
|
||||
new_columns: dict[str, ColumnDef] = {}
|
||||
for col_name, col in tbl.columns.items():
|
||||
new_columns[col_name] = ColumnDef(
|
||||
name=col.name,
|
||||
data_type=col.data_type,
|
||||
nullable=col.nullable,
|
||||
is_pk=col.is_pk,
|
||||
default=col.default,
|
||||
)
|
||||
ddl_fixed[tbl_name] = TableDef(
|
||||
name=tbl.name,
|
||||
columns=new_columns,
|
||||
pk_columns=list(tbl.pk_columns),
|
||||
)
|
||||
|
||||
diffs = compare_tables(ddl_fixed, db_state)
|
||||
assert diffs == [], (
|
||||
f"不动点违反:从 DB 复制的 DDL 定义应与 DB 零差异,"
|
||||
f"但得到 {len(diffs)} 项: {[str(d) for d in diffs]}"
|
||||
)
|
||||
|
||||
@given(
|
||||
db_state=st_table_dict(min_tables=1, max_tables=4),
|
||||
extra_tables=st_table_dict(min_tables=1, max_tables=3),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_fixpoint_after_adding_missing_tables(
|
||||
self,
|
||||
db_state: dict[str, TableDef],
|
||||
extra_tables: dict[str, TableDef],
|
||||
):
|
||||
"""模拟修正流程:DDL 缺少部分表 → 补齐后再对比应为零差异。
|
||||
|
||||
步骤:
|
||||
1. DB 有 common + extra 表,DDL 只有 common 表
|
||||
2. 发现 MISSING_TABLE 差异
|
||||
3. 将缺失的表补入 DDL(模拟修正)
|
||||
4. 再次对比,差异应为零
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
# 确保表名不重叠
|
||||
overlap = set(db_state.keys()) & set(extra_tables.keys())
|
||||
assume(not overlap)
|
||||
|
||||
full_db = {**db_state, **extra_tables}
|
||||
ddl_before_fix = dict(db_state) # 缺少 extra_tables
|
||||
|
||||
# 第一次对比:应有差异
|
||||
diffs_before = compare_tables(ddl_before_fix, full_db)
|
||||
missing = {d.table for d in diffs_before if d.kind == DiffKind.MISSING_TABLE}
|
||||
assert missing == set(extra_tables.keys()), (
|
||||
f"修正前应检测到缺失表 {set(extra_tables.keys())},实际 {missing}"
|
||||
)
|
||||
|
||||
# 模拟修正:将缺失的表补入 DDL
|
||||
ddl_after_fix = {**ddl_before_fix, **extra_tables}
|
||||
|
||||
# 第二次对比:不动点,差异应为零
|
||||
diffs_after = compare_tables(ddl_after_fix, full_db)
|
||||
assert diffs_after == [], (
|
||||
f"不动点违反:修正后应零差异,但得到 {len(diffs_after)} 项: "
|
||||
f"{[str(d) for d in diffs_after]}"
|
||||
)
|
||||
|
||||
@given(table=st_table_def())
|
||||
@settings(max_examples=30)
|
||||
def test_fixpoint_after_correcting_type_mismatch(self, table: TableDef):
|
||||
"""模拟修正流程:字段类型不一致 → 以 DB 为准修正后零差异。
|
||||
|
||||
步骤:
|
||||
1. DDL 中某字段类型与 DB 不同
|
||||
2. 发现 TYPE_MISMATCH
|
||||
3. 将 DDL 字段类型改为 DB 的类型(模拟修正)
|
||||
4. 再次对比,差异应为零
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
assume(len(table.columns) >= 1)
|
||||
|
||||
target_col = next(iter(table.columns))
|
||||
original_type = table.columns[target_col].data_type
|
||||
alt_types = [t for t in _NORMALIZED_TYPES if t != original_type]
|
||||
assume(len(alt_types) > 0)
|
||||
wrong_type = alt_types[0]
|
||||
|
||||
# DDL 侧:目标字段使用错误类型
|
||||
ddl_table = TableDef(
|
||||
name=table.name,
|
||||
columns={
|
||||
cn: ColumnDef(
|
||||
name=c.name,
|
||||
data_type=wrong_type if cn == target_col else c.data_type,
|
||||
nullable=c.nullable,
|
||||
is_pk=c.is_pk,
|
||||
default=c.default,
|
||||
)
|
||||
for cn, c in table.columns.items()
|
||||
},
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
ddl_before = {table.name: ddl_table}
|
||||
db_tables = {table.name: table}
|
||||
|
||||
# 第一次对比:应有 TYPE_MISMATCH
|
||||
diffs_before = compare_tables(ddl_before, db_tables)
|
||||
type_diffs = [
|
||||
d for d in diffs_before
|
||||
if d.kind == DiffKind.TYPE_MISMATCH and d.column == target_col
|
||||
]
|
||||
assert len(type_diffs) >= 1, "修正前应检测到 TYPE_MISMATCH"
|
||||
|
||||
# 模拟修正:以 DB 为准,直接使用 DB 定义作为 DDL
|
||||
ddl_after = dict(db_tables)
|
||||
|
||||
# 第二次对比:不动点
|
||||
diffs_after = compare_tables(ddl_after, db_tables)
|
||||
assert diffs_after == [], (
|
||||
f"不动点违反:类型修正后应零差异,但得到 {len(diffs_after)} 项: "
|
||||
f"{[str(d) for d in diffs_after]}"
|
||||
)
|
||||
24
apps/etl/connectors/feiqiu/tests/unit/test_config.py
Normal file
24
apps/etl/connectors/feiqiu/tests/unit/test_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""配置管理测试"""
|
||||
import pytest
|
||||
from config.settings import AppConfig
|
||||
from config.defaults import DEFAULTS
|
||||
|
||||
def test_config_load():
|
||||
"""测试配置加载"""
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("app.timezone") == DEFAULTS["app"]["timezone"]
|
||||
|
||||
def test_config_override():
|
||||
"""测试配置覆盖"""
|
||||
overrides = {
|
||||
"app": {"store_id": 12345}
|
||||
}
|
||||
config = AppConfig.load(overrides)
|
||||
assert config.get("app.store_id") == 12345
|
||||
|
||||
def test_config_get_nested():
|
||||
"""测试嵌套配置获取"""
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("db.batch_size") == 1000
|
||||
assert config.get("nonexistent.key", "default") == "default"
|
||||
@@ -0,0 +1,55 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""配置映射属性测试 — 使用 hypothesis 验证配置键兼容映射的通用正确性属性。"""
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from config.settings import AppConfig, _FLOW_TO_DATA_SOURCE
|
||||
|
||||
|
||||
# ── 确保测试不读取 .env 文件 ──────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_dotenv(monkeypatch):
|
||||
monkeypatch.setenv("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
flow_st = st.sampled_from(["FULL", "FETCH_ONLY", "INGEST_ONLY"])
|
||||
|
||||
|
||||
# ── Property 11: pipeline_flow → data_source 映射一致性 ──────
|
||||
# Feature: scheduler-refactor, Property 11: pipeline_flow → data_source 映射一致性
|
||||
# **Validates: Requirements 8.1, 8.2, 8.3, 5.2, 8.4**
|
||||
#
|
||||
# 对于任意旧 pipeline_flow 值(FULL/FETCH_ONLY/INGEST_ONLY),
|
||||
# 映射到 data_source 的结果应与预定义映射表一致:
|
||||
# FULL→hybrid、FETCH_ONLY→online、INGEST_ONLY→offline。
|
||||
# 同样,配置键 pipeline.flow 应自动映射到 run.data_source。
|
||||
|
||||
|
||||
class TestProperty11FlowToDataSourceMapping:
|
||||
"""Property 11: pipeline_flow → data_source 映射一致性。"""
|
||||
|
||||
@given(flow=flow_st)
|
||||
@settings(max_examples=100)
|
||||
def test_pipeline_flow_maps_to_data_source(self, flow):
|
||||
"""通过 pipeline.flow 设置旧值后,run.data_source 应与映射表一致。"""
|
||||
expected = _FLOW_TO_DATA_SOURCE[flow]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
config = AppConfig.load({
|
||||
"app": {"store_id": 1},
|
||||
"pipeline": {"flow": flow},
|
||||
})
|
||||
|
||||
actual = config.get("run.data_source")
|
||||
assert actual == expected, (
|
||||
f"pipeline.flow={flow!r} 应映射为 run.data_source={expected!r},"
|
||||
f"实际为 {actual!r}"
|
||||
)
|
||||
@@ -0,0 +1,279 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""配置层属性测试 — 验证 AppConfig 的深度合并、store_id 验证、DSN 组装、点号路径 get。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
使用 hypothesis 对 AppConfig 的 4 个核心正确性属性进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume, HealthCheck
|
||||
from hypothesis import strategies as st
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from config.settings import AppConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 公共策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 生成安全的字典键(非空字母数字字符串,避免特殊字符干扰)
|
||||
_safe_key = st.text(
|
||||
alphabet="abcdefghijklmnopqrstuvwxyz_",
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
)
|
||||
|
||||
# 生成叶子值(字符串或整数)
|
||||
_leaf_value = st.one_of(
|
||||
st.integers(min_value=-1000, max_value=1000),
|
||||
st.text(min_size=0, max_size=20),
|
||||
st.booleans(),
|
||||
)
|
||||
|
||||
# 生成嵌套字典(最多 2 层深度)
|
||||
_nested_dict = st.fixed_dictionaries({}).flatmap(
|
||||
lambda _: st.dictionaries(
|
||||
keys=_safe_key,
|
||||
values=st.one_of(
|
||||
_leaf_value,
|
||||
st.dictionaries(keys=_safe_key, values=_leaf_value, min_size=0, max_size=3),
|
||||
),
|
||||
min_size=1,
|
||||
max_size=5,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 13: AppConfig 优先级合并
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(
|
||||
defaults=_nested_dict,
|
||||
overrides=_nested_dict,
|
||||
)
|
||||
def test_property13_appconfig_deep_merge_priority(defaults: dict, overrides: dict):
|
||||
"""Property 13: AppConfig 优先级合并
|
||||
|
||||
对任意嵌套字典 DEFAULTS 和 CLI 覆盖,_deep_merge 后 CLI 中的键值应覆盖
|
||||
DEFAULTS 中的同名键值,未被覆盖的键应保持原值。
|
||||
|
||||
**Validates: Requirements 7.1**
|
||||
"""
|
||||
dst = deepcopy(defaults)
|
||||
src = deepcopy(overrides)
|
||||
original_defaults = deepcopy(defaults)
|
||||
|
||||
AppConfig._deep_merge(dst, src)
|
||||
|
||||
# 验证 1:overrides 中的所有键值应出现在合并结果中
|
||||
for k, v in src.items():
|
||||
if isinstance(v, dict) and isinstance(original_defaults.get(k), dict):
|
||||
# 嵌套字典:overrides 中的子键应覆盖
|
||||
for sk, sv in v.items():
|
||||
assert dst[k][sk] == sv, (
|
||||
f"嵌套键 {k}.{sk} 应被覆盖为 {sv!r},实际为 {dst[k][sk]!r}"
|
||||
)
|
||||
else:
|
||||
assert dst[k] == v, (
|
||||
f"键 {k} 应被覆盖为 {v!r},实际为 {dst[k]!r}"
|
||||
)
|
||||
|
||||
# 验证 2:defaults 中未被 overrides 覆盖的键应保持原值
|
||||
for k, v in original_defaults.items():
|
||||
if k not in src:
|
||||
assert dst[k] == v, (
|
||||
f"未被覆盖的键 {k} 应保持原值 {v!r},实际为 {dst[k]!r}"
|
||||
)
|
||||
elif isinstance(v, dict) and isinstance(src.get(k), dict):
|
||||
# 嵌套字典中未被覆盖的子键应保持原值
|
||||
for sk, sv in v.items():
|
||||
if sk not in src[k]:
|
||||
assert dst[k][sk] == sv, (
|
||||
f"未被覆盖的嵌套键 {k}.{sk} 应保持原值 {sv!r},"
|
||||
f"实际为 {dst[k][sk]!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 14: AppConfig store_id 验证
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 生成非整数字符串(排除能被 int() 成功解析的字符串)
|
||||
_non_integer_text = st.text(
|
||||
alphabet="abcdefghijklmnopqrstuvwxyz!@#$%^&*()_+-=[]{}|;:,.<>?/~` ",
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
).filter(lambda s: s.strip() != "")
|
||||
|
||||
|
||||
def _try_int(s: str) -> bool:
|
||||
"""检查字符串是否能被 int() 解析"""
|
||||
try:
|
||||
int(str(s).strip())
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(bad_store_id=_non_integer_text)
|
||||
def test_property14_appconfig_store_id_validation(bad_store_id: str):
|
||||
"""Property 14: AppConfig store_id 验证
|
||||
|
||||
对任意非整数字符串作为 app.store_id,AppConfig._normalize 应抛出 SystemExit。
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
assume(not _try_int(bad_store_id))
|
||||
|
||||
from config.defaults import DEFAULTS
|
||||
|
||||
cfg = deepcopy(DEFAULTS)
|
||||
cfg["app"]["store_id"] = bad_store_id
|
||||
|
||||
with pytest.raises(SystemExit, match="app.store_id 必须为整数"):
|
||||
AppConfig._normalize(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 15: AppConfig DSN 组装
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 生成 DSN 组件(非空字符串,不含 :/@,避免 URL 解析歧义)
|
||||
_dsn_safe_text = st.text(
|
||||
alphabet="abcdefghijklmnopqrstuvwxyz0123456789_-.",
|
||||
min_size=1,
|
||||
max_size=15,
|
||||
)
|
||||
|
||||
_port_strategy = st.integers(min_value=1, max_value=65535).map(str)
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(
|
||||
host=_dsn_safe_text,
|
||||
port=_port_strategy,
|
||||
name=_dsn_safe_text,
|
||||
user=_dsn_safe_text,
|
||||
password=_dsn_safe_text,
|
||||
)
|
||||
def test_property15_appconfig_dsn_assembly(
|
||||
host: str, port: str, name: str, user: str, password: str
|
||||
):
|
||||
"""Property 15: AppConfig DSN 组装
|
||||
|
||||
对任意 host、port、name、user、password 组合(db.dsn 为空时),
|
||||
AppConfig._normalize 应组装出格式为
|
||||
postgresql://{user}:{password}@{host}:{port}/{name} 的 DSN 字符串。
|
||||
|
||||
**Validates: Requirements 7.3**
|
||||
"""
|
||||
from config.defaults import DEFAULTS
|
||||
|
||||
cfg = deepcopy(DEFAULTS)
|
||||
# 设置有效的 store_id 以通过 _normalize 的其他校验
|
||||
cfg["app"]["store_id"] = "1"
|
||||
cfg["db"]["dsn"] = ""
|
||||
cfg["db"]["host"] = host
|
||||
cfg["db"]["port"] = port
|
||||
cfg["db"]["name"] = name
|
||||
cfg["db"]["user"] = user
|
||||
cfg["db"]["password"] = password
|
||||
|
||||
AppConfig._normalize(cfg)
|
||||
|
||||
expected_dsn = f"postgresql://{user}:{password}@{host}:{port}/{name}"
|
||||
assert cfg["db"]["dsn"] == expected_dsn, (
|
||||
f"DSN 应为 {expected_dsn!r},实际为 {cfg['db']['dsn']!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 16: AppConfig 点号路径 get
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(
|
||||
nested=st.fixed_dictionaries({
|
||||
"a": st.fixed_dictionaries({
|
||||
"b": st.fixed_dictionaries({
|
||||
"c": _leaf_value,
|
||||
}),
|
||||
"d": _leaf_value,
|
||||
}),
|
||||
"e": _leaf_value,
|
||||
}),
|
||||
)
|
||||
def test_property16_appconfig_dot_path_get(nested: dict):
|
||||
"""Property 16: AppConfig 点号路径 get
|
||||
|
||||
对任意嵌套字典和有效的点号路径,config.get(path) 应返回路径对应的值;
|
||||
对无效路径应返回 default 参数值。
|
||||
|
||||
**Validates: Requirements 7.4**
|
||||
"""
|
||||
config = AppConfig(deepcopy(nested))
|
||||
|
||||
# 有效路径应返回对应值
|
||||
assert config.get("a.b.c") == nested["a"]["b"]["c"]
|
||||
assert config.get("a.d") == nested["a"]["d"]
|
||||
assert config.get("e") == nested["e"]
|
||||
|
||||
# 无效路径应返回 default
|
||||
sentinel = object()
|
||||
assert config.get("x.y.z", sentinel) is sentinel
|
||||
assert config.get("a.b.nonexistent", sentinel) is sentinel
|
||||
assert config.get("a.b.c.too_deep", sentinel) is sentinel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 16 补充:对任意生成的嵌套字典和路径进行测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_paths(d: dict, prefix: str = "") -> list[tuple[str, object]]:
|
||||
"""递归收集字典中所有叶子节点的点号路径和值"""
|
||||
paths = []
|
||||
for k, v in d.items():
|
||||
path = f"{prefix}.{k}" if prefix else k
|
||||
if isinstance(v, dict):
|
||||
paths.extend(_collect_paths(v, path))
|
||||
else:
|
||||
paths.append((path, v))
|
||||
return paths
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(nested=_nested_dict)
|
||||
def test_property16_appconfig_dot_path_get_arbitrary(nested: dict):
|
||||
"""Property 16 补充:对任意生成的嵌套字典验证 get 的正确性
|
||||
|
||||
**Validates: Requirements 7.4**
|
||||
"""
|
||||
config = AppConfig(deepcopy(nested))
|
||||
|
||||
# 所有叶子路径应返回正确值
|
||||
for path, expected in _collect_paths(nested):
|
||||
actual = config.get(path)
|
||||
# get 方法在值为 None 时返回 default(None),所以 None 值也是正确的
|
||||
if expected is not None:
|
||||
assert actual == expected, (
|
||||
f"路径 {path!r} 应返回 {expected!r},实际为 {actual!r}"
|
||||
)
|
||||
|
||||
# 不存在的路径应返回 default
|
||||
sentinel = object()
|
||||
assert config.get("__nonexistent_key__", sentinel) is sentinel
|
||||
@@ -0,0 +1,275 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWD/DWS 层属性测试 — 验证 DwdLoadTask 和 BaseTask 的核心正确性属性。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
使用 hypothesis 对 DWD 列映射完整性、only_tables 过滤和 DWS 分段累加进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume, HealthCheck
|
||||
from hypothesis import strategies as st
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
from tasks.base_task import BaseTask
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6 辅助:判断 ods_expr 是否为"简单列名"
|
||||
# 排除 JSON 路径表达式、CASE WHEN、函数调用、CAST 等复杂表达式
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _is_simple_column_name(expr: str) -> bool:
|
||||
"""判断 ods_expr 是否为简单列名(非表达式)。
|
||||
|
||||
排除条件:包含 ->、CASE、空格、括号、::、单引号等。
|
||||
"""
|
||||
if not expr:
|
||||
return False
|
||||
if any(ch in expr for ch in ("->", "'", "(", ")", "::", " ")):
|
||||
return False
|
||||
# 纯标识符:字母/数字/下划线
|
||||
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", expr))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6 辅助:收集所有 ODS 源表名(从 TABLE_MAP 的值去重)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_ods_table_names() -> set[str]:
|
||||
"""返回 TABLE_MAP 中所有 ODS 源表的 table_name(不含 schema)。"""
|
||||
names = set()
|
||||
for ods_full in DwdLoadTask.TABLE_MAP.values():
|
||||
# "ods.payment_transactions" -> "payment_transactions"
|
||||
names.add(ods_full.split(".")[-1])
|
||||
return names
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 6: DWD FACT_MAPPINGS 列映射完整性
|
||||
# Feature: etl-pipeline-debug, Property 6: DWD FACT_MAPPINGS 列映射完整性
|
||||
# Validates: Requirements 2.4
|
||||
#
|
||||
# 验证策略:静态检查 FACT_MAPPINGS 中每个映射条目,当 ods_expr 是简单列名时,
|
||||
# 该列应存在于对应 ODS 源表的真实列定义中。由于无法在纯单元测试中查询
|
||||
# information_schema,我们通过 hypothesis 从 FACT_MAPPINGS 中随机抽取条目,
|
||||
# 验证简单列名的格式合法性和映射关系的内部一致性。
|
||||
# ===========================================================================
|
||||
|
||||
# 构建所有 (dwd_table, dwd_col, ods_expr, cast_type) 的扁平列表
|
||||
_ALL_MAPPING_ENTRIES: list[tuple[str, str, str, str | None]] = []
|
||||
for _dwd_tbl, _entries in DwdLoadTask.FACT_MAPPINGS.items():
|
||||
for _dwd_col, _ods_expr, _cast_type in _entries:
|
||||
_ALL_MAPPING_ENTRIES.append((_dwd_tbl, _dwd_col, _ods_expr, _cast_type))
|
||||
|
||||
# 仅保留简单列名的条目(可直接验证是否存在于 ODS 表)
|
||||
_SIMPLE_COL_ENTRIES: list[tuple[str, str, str, str | None]] = [
|
||||
(tbl, col, expr, cast) for tbl, col, expr, cast in _ALL_MAPPING_ENTRIES
|
||||
if _is_simple_column_name(expr)
|
||||
]
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(idx=st.integers(min_value=0, max_value=max(len(_ALL_MAPPING_ENTRIES) - 1, 0)))
|
||||
def test_property6_fact_mappings_column_integrity(idx):
|
||||
"""对任意 FACT_MAPPINGS 条目,验证映射关系的内部一致性。
|
||||
|
||||
**Validates: Requirements 2.4**
|
||||
|
||||
属性:
|
||||
1. 每个 dwd_table 必须存在于 TABLE_MAP 中(映射有对应的 ODS 源表)
|
||||
2. dwd_col 必须是合法标识符
|
||||
3. 当 ods_expr 是简单列名时,它必须是合法标识符
|
||||
4. cast_type 为 None 或非空字符串
|
||||
"""
|
||||
assume(len(_ALL_MAPPING_ENTRIES) > 0)
|
||||
entry = _ALL_MAPPING_ENTRIES[idx % len(_ALL_MAPPING_ENTRIES)]
|
||||
dwd_table, dwd_col, ods_expr, cast_type = entry
|
||||
|
||||
# 1. dwd_table 必须在 TABLE_MAP 中有对应的 ODS 源表
|
||||
assert dwd_table in DwdLoadTask.TABLE_MAP, (
|
||||
f"FACT_MAPPINGS 中的 {dwd_table} 不在 TABLE_MAP 中"
|
||||
)
|
||||
|
||||
# 2. dwd_col 必须是合法标识符
|
||||
assert re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", dwd_col), (
|
||||
f"dwd_col '{dwd_col}' 不是合法标识符"
|
||||
)
|
||||
|
||||
# 3. 当 ods_expr 是简单列名时,验证格式合法
|
||||
if _is_simple_column_name(ods_expr):
|
||||
assert re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", ods_expr), (
|
||||
f"简单列名 '{ods_expr}' 格式不合法"
|
||||
)
|
||||
|
||||
# 4. cast_type 为 None 或非空字符串
|
||||
if cast_type is not None:
|
||||
assert isinstance(cast_type, str) and len(cast_type) > 0, (
|
||||
f"cast_type 应为 None 或非空字符串,实际: {cast_type!r}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 7: DWD only_tables 过滤
|
||||
# Feature: etl-pipeline-debug, Property 7: DWD only_tables 过滤
|
||||
# Validates: Requirements 2.6
|
||||
#
|
||||
# 验证策略:模拟 DwdLoadTask.load() 中的 only_tables 过滤逻辑,
|
||||
# 对任意非空的 only_tables 配置列表,验证过滤后的表集合等于
|
||||
# 配置列表与 TABLE_MAP 键集合的交集。
|
||||
# ===========================================================================
|
||||
|
||||
_TABLE_MAP_KEYS = list(DwdLoadTask.TABLE_MAP.keys())
|
||||
# 同时包含不带 schema 的表名,用于生成测试数据
|
||||
_TABLE_MAP_BASE_NAMES = list({k.split(".")[-1] for k in _TABLE_MAP_KEYS})
|
||||
|
||||
# 策略:从 TABLE_MAP 键中随机选取子集作为 only_tables 配置
|
||||
_only_tables_strategy = st.lists(
|
||||
st.one_of(
|
||||
# 带 schema 的全名(如 "dwd.dim_site")
|
||||
st.sampled_from(_TABLE_MAP_KEYS),
|
||||
# 不带 schema 的表名(如 "dim_site")
|
||||
st.sampled_from(_TABLE_MAP_BASE_NAMES),
|
||||
# 不存在的表名(应被过滤掉)
|
||||
st.text(alphabet="abcdefghijklmnopqrstuvwxyz_", min_size=3, max_size=15),
|
||||
),
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(only_tables_cfg=_only_tables_strategy)
|
||||
def test_property7_dwd_only_tables_filter(only_tables_cfg):
|
||||
"""对任意非空 only_tables 配置,过滤后的表集合应是配置与 TABLE_MAP 的交集。
|
||||
|
||||
**Validates: Requirements 2.6**
|
||||
|
||||
复现 DwdLoadTask.load() 中的过滤逻辑:
|
||||
- 配置列表转为小写集合
|
||||
- 遍历 TABLE_MAP,保留 dwd_table.lower() 或 _table_base(dwd_table).lower() 在集合中的表
|
||||
"""
|
||||
# 复现 load() 中的 only_tables 构建逻辑
|
||||
only_tables = {
|
||||
str(t).strip().lower()
|
||||
for t in only_tables_cfg
|
||||
if str(t).strip()
|
||||
}
|
||||
assume(len(only_tables) > 0)
|
||||
|
||||
# 复现 load() 中的过滤逻辑
|
||||
filtered_tables = set()
|
||||
for dwd_table in DwdLoadTask.TABLE_MAP:
|
||||
table_base = dwd_table.split(".")[-1] # 等价于 _table_base
|
||||
if dwd_table.lower() in only_tables or table_base.lower() in only_tables:
|
||||
filtered_tables.add(dwd_table)
|
||||
|
||||
# 属性:过滤结果应等于 TABLE_MAP 键与 only_tables 的交集
|
||||
# (交集需同时考虑全名和 base name 两种匹配方式)
|
||||
expected = set()
|
||||
for dwd_table in DwdLoadTask.TABLE_MAP:
|
||||
table_base = dwd_table.split(".")[-1]
|
||||
if dwd_table.lower() in only_tables or table_base.lower() in only_tables:
|
||||
expected.add(dwd_table)
|
||||
|
||||
assert filtered_tables == expected, (
|
||||
f"过滤结果不一致: filtered={filtered_tables}, expected={expected}, "
|
||||
f"only_tables={only_tables}"
|
||||
)
|
||||
|
||||
# 额外属性:过滤结果是 TABLE_MAP 键集合的子集
|
||||
assert filtered_tables.issubset(set(DwdLoadTask.TABLE_MAP.keys())), (
|
||||
f"过滤结果包含不在 TABLE_MAP 中的表: "
|
||||
f"{filtered_tables - set(DwdLoadTask.TABLE_MAP.keys())}"
|
||||
)
|
||||
|
||||
# 额外属性:如果 only_tables 中的某个值完全不匹配任何 TABLE_MAP 键,
|
||||
# 则该值不应导致任何表被选中
|
||||
for cfg_val in only_tables:
|
||||
matches_any = any(
|
||||
dwd_table.lower() == cfg_val or dwd_table.split(".")[-1].lower() == cfg_val
|
||||
for dwd_table in DwdLoadTask.TABLE_MAP
|
||||
)
|
||||
if not matches_any:
|
||||
# 该配置值不应导致任何表被选中
|
||||
for t in filtered_tables:
|
||||
assert t.lower() != cfg_val and t.split(".")[-1].lower() != cfg_val
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 8: DWS 分段累加一致性
|
||||
# Feature: etl-pipeline-debug, Property 8: DWS 分段累加一致性
|
||||
# Validates: Requirements 3.3
|
||||
#
|
||||
# 验证策略:直接测试 BaseTask._accumulate_counts 静态方法,
|
||||
# 对任意分段计数列表,累加结果应等于逐键求和。
|
||||
# ===========================================================================
|
||||
|
||||
# 策略:生成分段计数字典列表
|
||||
_count_key = st.sampled_from([
|
||||
"inserted", "updated", "skipped", "errors", "fetched",
|
||||
"processed", "deleted", "merged",
|
||||
])
|
||||
|
||||
_count_value = st.integers(min_value=0, max_value=100000)
|
||||
|
||||
_single_counts = st.dictionaries(
|
||||
keys=_count_key,
|
||||
values=_count_value,
|
||||
min_size=1,
|
||||
max_size=6,
|
||||
)
|
||||
|
||||
_segment_counts_list = st.lists(
|
||||
_single_counts,
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(segments=_segment_counts_list)
|
||||
def test_property8_accumulate_counts_consistency(segments):
|
||||
"""对任意分段计数列表,_accumulate_counts 累加结果应等于逐键求和。
|
||||
|
||||
**Validates: Requirements 3.3**
|
||||
|
||||
属性:对每个数值型键,累加结果 == sum(各分段中该键的值)
|
||||
"""
|
||||
# 使用 _accumulate_counts 逐段累加
|
||||
total_via_method: dict = {}
|
||||
for seg_counts in segments:
|
||||
BaseTask._accumulate_counts(total_via_method, seg_counts)
|
||||
|
||||
# 手动逐键求和作为期望值
|
||||
expected: dict = {}
|
||||
for seg_counts in segments:
|
||||
for key, value in seg_counts.items():
|
||||
if isinstance(value, (int, float)):
|
||||
expected[key] = expected.get(key, 0) + value
|
||||
else:
|
||||
expected.setdefault(key, value)
|
||||
|
||||
# 核心属性:两种方式的结果应完全一致
|
||||
assert total_via_method == expected, (
|
||||
f"累加结果不一致:\n"
|
||||
f" _accumulate_counts: {total_via_method}\n"
|
||||
f" 手动逐键求和: {expected}\n"
|
||||
f" 输入分段: {segments}"
|
||||
)
|
||||
|
||||
# 额外属性:对每个数值型键,累加值 >= 0(因为输入值 >= 0)
|
||||
for key, value in total_via_method.items():
|
||||
if isinstance(value, (int, float)):
|
||||
assert value >= 0, f"累加值 {key}={value} 不应为负"
|
||||
@@ -0,0 +1,419 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ODS 层属性测试 — 验证 BaseOdsTask 的核心正确性属性。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
使用 hypothesis 对 ODS 任务的关键行为进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume, HealthCheck
|
||||
from hypothesis import strategies as st
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.ods.ods_tasks import BaseOdsTask, ODS_TASK_CLASSES, ODS_TASK_SPECS, OdsTaskSpec
|
||||
from .task_test_utils import (
|
||||
create_test_config, FakeDBOperations, FakeAPIClient, FakeCursor, FakeConnection,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 增强版 FakeCursor:返回 PK 列信息,使冲突处理和主键跳过逻辑生效
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PkAwareFakeCursor(FakeCursor):
|
||||
"""扩展 FakeCursor,对 table_constraints 查询返回指定的 PK 列。
|
||||
|
||||
通过实例级 _pk_map / _columns_map 配置(由 PkAwareFakeDB 注入),
|
||||
避免类级别状态在并发测试间互相干扰。
|
||||
"""
|
||||
|
||||
def __init__(self, recorder, db_ops=None, pk_map=None, columns_map=None):
|
||||
super().__init__(recorder, db_ops)
|
||||
self._pk_map = pk_map or {}
|
||||
self._columns_map = columns_map or {}
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
lowered = sql_text.lower()
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
if table_name and table_name in self._columns_map:
|
||||
self._fetchall_rows = list(self._columns_map[table_name])
|
||||
else:
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
if table_name and table_name in self._pk_map:
|
||||
self._fetchall_rows = [(col,) for col in self._pk_map[table_name]]
|
||||
else:
|
||||
# 默认返回 id 作为 PK
|
||||
self._fetchall_rows = [("id",)]
|
||||
return
|
||||
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
|
||||
class PkAwareFakeConnection(FakeConnection):
|
||||
"""使用 PkAwareFakeCursor 的连接。"""
|
||||
|
||||
def __init__(self, db_ops, pk_map=None, columns_map=None):
|
||||
super().__init__(db_ops)
|
||||
self._pk_map = pk_map or {}
|
||||
self._columns_map = columns_map or {}
|
||||
|
||||
def cursor(self):
|
||||
return PkAwareFakeCursor(
|
||||
self.statements, self._db_ops,
|
||||
pk_map=self._pk_map, columns_map=self._columns_map,
|
||||
)
|
||||
|
||||
|
||||
class PkAwareFakeDB(FakeDBOperations):
|
||||
"""增强版 FakeDB,返回 PK 列信息。"""
|
||||
|
||||
def __init__(self, pk_map=None, columns_map=None):
|
||||
super().__init__()
|
||||
self._pk_map = pk_map or {}
|
||||
self._columns_map = columns_map or {}
|
||||
self.conn = PkAwareFakeConnection(self, self._pk_map, self._columns_map)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_config(tmp_path: Path, conflict_mode: str = "update",
|
||||
snapshot_missing_delete: bool = False) -> "AppConfig":
|
||||
archive_dir = tmp_path / "archive"
|
||||
temp_dir = tmp_path / "temp"
|
||||
cfg = create_test_config("ONLINE", archive_dir, temp_dir)
|
||||
cfg.config.setdefault("run", {})
|
||||
cfg.config["run"]["ods_conflict_mode"] = conflict_mode
|
||||
cfg.config["run"]["snapshot_missing_delete"] = snapshot_missing_delete
|
||||
return cfg
|
||||
|
||||
|
||||
def _get_task_spec(code: str) -> OdsTaskSpec:
|
||||
for spec in ODS_TASK_SPECS:
|
||||
if spec.code == code:
|
||||
return spec
|
||||
raise KeyError(f"未找到任务 spec: {code}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# hypothesis 策略
|
||||
# 使用 id 字段名(与 DB PK 列名一致),确保记录不会因字段名不匹配被跳过
|
||||
# ---------------------------------------------------------------------------
|
||||
_valid_id = st.integers(min_value=1, max_value=2**53)
|
||||
|
||||
_ods_record_with_id = st.fixed_dictionaries({
|
||||
"id": _valid_id,
|
||||
"payAmount": st.text(alphabet="0123456789.", min_size=1, max_size=8),
|
||||
})
|
||||
|
||||
_missing_pk_record = st.fixed_dictionaries({
|
||||
"id": st.sampled_from([None, ""]),
|
||||
"payAmount": st.just("0.00"),
|
||||
})
|
||||
|
||||
# ODS_PAYMENT 的 endpoint 和 table
|
||||
_PAYMENT_CODE = "ODS_PAYMENT"
|
||||
_PAYMENT_ENDPOINT = "/PayLog/GetPayLogListPage"
|
||||
_PAYMENT_TABLE = "payment_transactions"
|
||||
|
||||
# ODS_ASSISTANT_ACCOUNT 的 endpoint 和 table
|
||||
_ASSISTANT_CODE = "ODS_ASSISTANT_ACCOUNT"
|
||||
_ASSISTANT_TABLE = "assistant_accounts_master"
|
||||
|
||||
# 包含 is_delete 列的列定义(用于 Property 5)
|
||||
_COLUMNS_WITH_IS_DELETE = [
|
||||
("id", "bigint", "int8"),
|
||||
("sitegoodsstockid", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("content_hash", "text", "text"),
|
||||
("source_file", "text", "text"),
|
||||
("source_endpoint", "text", "text"),
|
||||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
("is_delete", "integer", "int4"),
|
||||
]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: ODS 任务提取记录数一致性
|
||||
# Feature: etl-pipeline-debug, Property 1: ODS 任务提取记录数一致性
|
||||
# Validates: Requirements 1.1, 1.2
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(records=st.lists(_ods_record_with_id, min_size=1, max_size=20))
|
||||
def test_property1_ods_record_count_consistency(tmp_path, records):
|
||||
"""对任意非空记录列表,fetched == inserted + updated + skipped。
|
||||
|
||||
**Validates: Requirements 1.1, 1.2**
|
||||
"""
|
||||
spec = _get_task_spec(_PAYMENT_CODE)
|
||||
task_cls = ODS_TASK_CLASSES[_PAYMENT_CODE]
|
||||
|
||||
config = _build_config(tmp_path)
|
||||
db = PkAwareFakeDB()
|
||||
api = FakeAPIClient({spec.endpoint: records})
|
||||
logger = logging.getLogger("prop1")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
result = task.execute()
|
||||
|
||||
counts = result["counts"]
|
||||
fetched = counts["fetched"]
|
||||
inserted = counts["inserted"]
|
||||
updated = counts["updated"]
|
||||
skipped = counts["skipped"]
|
||||
|
||||
# 核心属性:fetched == inserted + updated + skipped
|
||||
assert fetched == inserted + updated + skipped, (
|
||||
f"记录数不闭合: fetched={fetched}, "
|
||||
f"inserted={inserted}, updated={updated}, skipped={skipped}"
|
||||
)
|
||||
# fetched 应等于 API 提供的记录数
|
||||
assert fetched == len(records), (
|
||||
f"fetched({fetched}) != len(records)({len(records)})"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 2: ODS 冲突处理策略正确性
|
||||
# Feature: etl-pipeline-debug, Property 2: ODS 冲突处理策略正确性
|
||||
# Validates: Requirements 1.3
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(conflict_mode=st.sampled_from(["nothing", "backfill", "update"]))
|
||||
def test_property2_ods_conflict_mode_sql(tmp_path, conflict_mode):
|
||||
"""对任意 ods_conflict_mode,生成的 SQL 应包含对应的冲突处理子句。
|
||||
|
||||
**Validates: Requirements 1.3**
|
||||
"""
|
||||
spec = _get_task_spec(_PAYMENT_CODE)
|
||||
task_cls = ODS_TASK_CLASSES[_PAYMENT_CODE]
|
||||
|
||||
config = _build_config(tmp_path, conflict_mode=conflict_mode)
|
||||
db = PkAwareFakeDB()
|
||||
# 使用 id 字段确保记录不被跳过
|
||||
sample = [{"id": 1001, "payAmount": "50.00"}]
|
||||
api = FakeAPIClient({spec.endpoint: sample})
|
||||
logger = logging.getLogger("prop2")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
result = task.execute()
|
||||
|
||||
# INSERT SQL 通过 execute_values → cur.execute 记录在 db.conn.statements 中
|
||||
all_sql = [stmt["sql"] for stmt in db.conn.statements]
|
||||
insert_sqls = [s for s in all_sql if "INSERT" in s.upper() and "VALUES" in s.upper()]
|
||||
assert insert_sqls, (
|
||||
f"应至少有一条 INSERT...VALUES 语句,"
|
||||
f"实际 SQL: {[s[:100] for s in all_sql]}, counts={result['counts']}"
|
||||
)
|
||||
sql_upper = insert_sqls[0].upper()
|
||||
|
||||
if conflict_mode == "nothing":
|
||||
assert "DO NOTHING" in sql_upper, (
|
||||
f"conflict_mode=nothing 时 SQL 应包含 DO NOTHING"
|
||||
)
|
||||
elif conflict_mode == "backfill":
|
||||
assert "COALESCE" in sql_upper, (
|
||||
f"conflict_mode=backfill 时 SQL 应包含 COALESCE"
|
||||
)
|
||||
elif conflict_mode == "update":
|
||||
assert "IS DISTINCT FROM" in sql_upper, (
|
||||
f"conflict_mode=update 时 SQL 应包含 IS DISTINCT FROM"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: ODS 跳过缺失主键记录
|
||||
# Feature: etl-pipeline-debug, Property 3: ODS 跳过缺失主键记录
|
||||
# Validates: Requirements 1.4
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(
|
||||
valid_records=st.lists(_ods_record_with_id, min_size=0, max_size=10),
|
||||
missing_pk_records=st.lists(_missing_pk_record, min_size=1, max_size=10),
|
||||
)
|
||||
def test_property3_ods_skip_missing_pk(tmp_path, valid_records, missing_pk_records):
|
||||
"""对任意包含缺失主键的记录集合,skipped 应 >= 缺失主键的记录数。
|
||||
|
||||
**Validates: Requirements 1.4**
|
||||
"""
|
||||
spec = _get_task_spec(_PAYMENT_CODE)
|
||||
task_cls = ODS_TASK_CLASSES[_PAYMENT_CODE]
|
||||
|
||||
all_records = valid_records + missing_pk_records
|
||||
assume(len(all_records) > 0)
|
||||
|
||||
config = _build_config(tmp_path)
|
||||
db = PkAwareFakeDB()
|
||||
api = FakeAPIClient({spec.endpoint: all_records})
|
||||
logger = logging.getLogger("prop3")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
result = task.execute()
|
||||
|
||||
counts = result["counts"]
|
||||
skipped = counts["skipped"]
|
||||
|
||||
# 缺失主键的记录至少应被跳过
|
||||
assert skipped >= len(missing_pk_records), (
|
||||
f"skipped({skipped}) < missing_pk_count({len(missing_pk_records)})"
|
||||
)
|
||||
|
||||
# 写入的行不应包含缺失主键的记录
|
||||
if db.upserts:
|
||||
for upsert in db.upserts:
|
||||
for row in upsert.get("rows", []):
|
||||
pk_val = row.get("id")
|
||||
assert pk_val is not None and pk_val != "", (
|
||||
f"写入行中出现缺失主键: {row}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 4: ODS content_hash 去重
|
||||
# Feature: etl-pipeline-debug, Property 4: ODS content_hash 去重
|
||||
# Validates: Requirements 1.5
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(record=_ods_record_with_id)
|
||||
def test_property4_ods_content_hash_deterministic(tmp_path, record):
|
||||
"""对任意记录,相同 payload + is_delete 组合的 content_hash 应确定性相同。
|
||||
|
||||
**Validates: Requirements 1.5, 5.1, 5.4**
|
||||
|
||||
新签名:_compute_content_hash(record, payload=..., is_delete=...)
|
||||
hash 基于原始 payload(未展平)+ is_delete 计算,fetched_at 天然不参与。
|
||||
"""
|
||||
merged = BaseOdsTask._merge_record_layers(record)
|
||||
is_delete = 0
|
||||
|
||||
hash1 = BaseOdsTask._compute_content_hash(merged, payload=record, is_delete=is_delete)
|
||||
hash2 = BaseOdsTask._compute_content_hash(merged, payload=record, is_delete=is_delete)
|
||||
assert hash1 == hash2, "同一 payload + is_delete 的 content_hash 应确定性相同"
|
||||
|
||||
# fetched_at 不影响 hash(payload 天然不含 fetched_at)
|
||||
merged_with_ts1 = {**merged, "fetched_at": "2025-01-01T10:00:00+08:00"}
|
||||
merged_with_ts2 = {**merged, "fetched_at": "2025-06-15T23:59:59+08:00"}
|
||||
|
||||
hash3 = BaseOdsTask._compute_content_hash(merged_with_ts1, payload=record, is_delete=is_delete)
|
||||
hash4 = BaseOdsTask._compute_content_hash(merged_with_ts2, payload=record, is_delete=is_delete)
|
||||
assert hash3 == hash4, "merged_rec 中 fetched_at 不同不影响 hash(基于 payload 计算)"
|
||||
assert hash1 == hash3, "hash 仅取决于 payload + is_delete,与 merged_rec 内容无关"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 5: ODS 快照删除标记(INSERT 语义)
|
||||
# Feature: etl-pipeline-debug, Property 5: ODS 快照删除标记
|
||||
# Validates: Requirements 1.7, 7.1, 7.4
|
||||
# ===========================================================================
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(
|
||||
existing_ids=st.lists(
|
||||
st.integers(min_value=1, max_value=10000), min_size=2, max_size=10, unique=True
|
||||
),
|
||||
)
|
||||
def test_property5_ods_snapshot_delete_marking(tmp_path, existing_ids):
|
||||
"""对启用 snapshot_missing_delete 的任务,API 返回真子集时差集应被标记删除。
|
||||
|
||||
**Validates: Requirements 1.7, 7.1, 7.4**
|
||||
|
||||
测试策略:使用 ODS_ASSISTANT_ACCOUNT(snapshot_mode=FULL_TABLE),
|
||||
API 只返回 existing_ids 的前半部分,验证 SQL 中包含 INSERT(而非独立 UPDATE)
|
||||
来写入 is_delete=1 的删除版本行。
|
||||
|
||||
方案 4 改造后,软删除从 UPDATE 改为 INSERT 删除版本行:
|
||||
- _mark_missing_as_deleted 读取缺失 ID 的最新版本行
|
||||
- 构造 is_delete=1 的新版本行并 INSERT
|
||||
- 不应有独立的 UPDATE ... SET is_delete=1 语句
|
||||
"""
|
||||
if _ASSISTANT_CODE not in ODS_TASK_CLASSES:
|
||||
pytest.skip(f"{_ASSISTANT_CODE} 不在 ODS_TASK_CLASSES 中")
|
||||
|
||||
spec = _get_task_spec(_ASSISTANT_CODE)
|
||||
task_cls = ODS_TASK_CLASSES[_ASSISTANT_CODE]
|
||||
|
||||
# API 只返回前半部分 ID(真子集)
|
||||
subset_size = max(1, len(existing_ids) // 2)
|
||||
api_ids = existing_ids[:subset_size]
|
||||
assume(len(api_ids) < len(existing_ids))
|
||||
|
||||
api_records = [
|
||||
{"id": eid, "assistant_no": f"A{eid}", "nickname": f"助教{eid}"}
|
||||
for eid in api_ids
|
||||
]
|
||||
|
||||
config = _build_config(tmp_path, snapshot_missing_delete=True)
|
||||
|
||||
# 配置增强版 FakeDB,使其返回包含 is_delete 列的列信息
|
||||
db = PkAwareFakeDB(
|
||||
pk_map={_ASSISTANT_TABLE: ["id"]},
|
||||
columns_map={_ASSISTANT_TABLE: _COLUMNS_WITH_IS_DELETE},
|
||||
)
|
||||
api = FakeAPIClient({spec.endpoint: api_records})
|
||||
logger = logging.getLogger("prop5")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
|
||||
# 方案 4 后,_mark_missing_as_deleted 通过 INSERT 写入删除版本行
|
||||
# 正常写入路径的 INSERT ... ON CONFLICT DO UPDATE 中会包含 "update" 关键字,
|
||||
# 但不应有独立的 "UPDATE ods.xxx SET is_delete=1" 语句
|
||||
all_sql = [stmt["sql"] for stmt in db.conn.statements]
|
||||
|
||||
# 验证:不应有独立的 UPDATE ... SET is_delete 语句(旧语义)
|
||||
has_standalone_update = any(
|
||||
sql.lower().lstrip().startswith("update ")
|
||||
and "is_delete" in sql.lower()
|
||||
for sql in all_sql
|
||||
)
|
||||
assert not has_standalone_update, (
|
||||
f"方案 4 后不应有独立的 UPDATE ... SET is_delete 语句,"
|
||||
f"实际 SQL: {[s[:120] for s in all_sql if 'update' in s.lower()]}"
|
||||
)
|
||||
|
||||
# 验证:应有 INSERT 语句(正常写入路径 + 可能的软删除 INSERT)
|
||||
has_insert = any(
|
||||
"insert" in sql.lower() and "values" in sql.lower()
|
||||
for sql in all_sql
|
||||
)
|
||||
assert has_insert, (
|
||||
f"应至少有一条 INSERT 语句,"
|
||||
f"实际 SQL: {[s[:120] for s in all_sql]}"
|
||||
)
|
||||
@@ -0,0 +1,266 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""编排层属性测试 — 验证 FlowRunner、TaskExecutor、CLI 的核心正确性属性。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
使用 hypothesis 对 Flow 层解析、无效 Flow 拒绝、工具类任务跳过游标、
|
||||
CLI data_source 解析进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume, HealthCheck
|
||||
from hypothesis import strategies as st
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from orchestration.flow_runner import FlowRunner
|
||||
from orchestration.task_registry import default_registry, TaskRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 公共常量
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_FLOWS = list(FlowRunner.FLOW_LAYERS.keys())
|
||||
|
||||
# 从 default_registry 中收集所有 requires_db_config=False 的任务代码(即工具类任务)
|
||||
UTILITY_TASK_CODES = [
|
||||
code for code in default_registry.get_all_task_codes()
|
||||
if default_registry.is_utility_task(code)
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 9: FlowRunner Flow 层解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(flow_name=st.sampled_from(VALID_FLOWS))
|
||||
def test_property9_pipeline_runner_flow_layer_resolution(flow_name: str):
|
||||
"""Property 9: FlowRunner Flow 层解析
|
||||
|
||||
对任意有效的 Flow 名称(FLOW_LAYERS 的键),FlowRunner 解析出的
|
||||
层列表应与 FLOW_LAYERS 中定义的值完全一致。
|
||||
|
||||
**Validates: Requirements 5.1, 10.2**
|
||||
"""
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
|
||||
|
||||
# 直接验证 FLOW_LAYERS 字典查找——这是 FlowRunner.run() 内部
|
||||
# 解析层列表的唯一路径:`layers = self.FLOW_LAYERS[pipeline]`
|
||||
assert flow_name in FlowRunner.FLOW_LAYERS
|
||||
assert FlowRunner.FLOW_LAYERS[flow_name] == expected_layers
|
||||
|
||||
# 验证层列表非空且元素为字符串
|
||||
assert len(expected_layers) > 0, f"Flow {flow_name} 的层列表不应为空"
|
||||
for layer in expected_layers:
|
||||
assert isinstance(layer, str)
|
||||
assert layer in ("ODS", "DWD", "DWS", "INDEX"), (
|
||||
f"Flow {flow_name} 包含未知层: {layer}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 10: FlowRunner 无效 Flow 拒绝
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 生成不在 FLOW_LAYERS 键集合中的字符串
|
||||
_invalid_flow_strategy = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("L", "N", "P")),
|
||||
min_size=1,
|
||||
max_size=50,
|
||||
).filter(lambda s: s not in FlowRunner.FLOW_LAYERS)
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(invalid_flow=_invalid_flow_strategy)
|
||||
def test_property10_pipeline_runner_rejects_invalid_flow(invalid_flow: str):
|
||||
"""Property 10: FlowRunner 无效 Flow 拒绝
|
||||
|
||||
对任意不在 FLOW_LAYERS 键集合中的字符串,FlowRunner.run()
|
||||
应抛出 ValueError。
|
||||
|
||||
**Validates: Requirements 5.2**
|
||||
|
||||
实现方式:直接复现 FlowRunner.run() 中的校验逻辑,
|
||||
避免构造完整的 FlowRunner 实例(需要真实 DB/API 连接)。
|
||||
"""
|
||||
# FlowRunner.run() 的第一行就是:
|
||||
# if pipeline not in self.FLOW_LAYERS: raise ValueError(...)
|
||||
# 我们直接验证这个守卫条件
|
||||
assert invalid_flow not in FlowRunner.FLOW_LAYERS
|
||||
|
||||
# 构造一个最小化的 FlowRunner 实例来验证 run() 抛出 ValueError
|
||||
mock_config = MagicMock()
|
||||
mock_config.get = MagicMock(return_value="Asia/Shanghai")
|
||||
mock_config.__getitem__ = MagicMock(return_value={})
|
||||
|
||||
runner = FlowRunner(
|
||||
config=mock_config,
|
||||
task_executor=MagicMock(),
|
||||
task_registry=MagicMock(),
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="无效的 Flow 名称"):
|
||||
runner.run(pipeline=invalid_flow)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 11: TaskExecutor 工具类任务跳过游标
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(task_code=st.sampled_from(UTILITY_TASK_CODES))
|
||||
def test_property11_task_executor_utility_task_skips_cursor(task_code: str):
|
||||
"""Property 11: TaskExecutor 工具类任务跳过游标
|
||||
|
||||
对任意被 TaskRegistry 标记为 requires_db_config=False 的任务代码,
|
||||
TaskExecutor 应通过 _run_utility_task 路径执行,不调用 CursorManager
|
||||
和 RunTracker。
|
||||
|
||||
**Validates: Requirements 5.5**
|
||||
|
||||
验证方式:
|
||||
1. 确认 task_code 在 registry 中被标记为工具类任务
|
||||
2. 构造 TaskExecutor 实例,用 Mock 替代 CursorManager 和 RunTracker
|
||||
3. 调用 run_single_task,验证 CursorManager 和 RunTracker 未被调用
|
||||
"""
|
||||
from orchestration.task_executor import TaskExecutor
|
||||
|
||||
# 1) 确认 registry 元数据
|
||||
assert default_registry.is_utility_task(task_code), (
|
||||
f"{task_code} 应被标记为工具类任务"
|
||||
)
|
||||
meta = default_registry.get_metadata(task_code)
|
||||
assert meta is not None
|
||||
assert meta.requires_db_config is False
|
||||
|
||||
# 2) 构造最小化 TaskExecutor
|
||||
mock_config = MagicMock()
|
||||
mock_config.get = MagicMock(return_value="Asia/Shanghai")
|
||||
mock_config.__getitem__ = MagicMock(return_value={
|
||||
"export_root": "/tmp/test_export",
|
||||
"log_root": "/tmp/test_log",
|
||||
"fetch_root": "/tmp/test_fetch",
|
||||
})
|
||||
|
||||
mock_cursor_mgr = MagicMock(spec=["get_or_create", "advance"])
|
||||
mock_run_tracker = MagicMock(spec=["create_run", "update_run"])
|
||||
|
||||
# 创建一个假的任务实例,execute 返回成功
|
||||
mock_task_instance = MagicMock()
|
||||
mock_task_instance.execute = MagicMock(return_value={
|
||||
"status": "SUCCESS",
|
||||
"counts": {"processed": 1},
|
||||
})
|
||||
|
||||
# 用真实的 default_registry,但替换 create_task 以避免真实实例化
|
||||
mock_registry = MagicMock(wraps=default_registry)
|
||||
mock_registry.is_utility_task = default_registry.is_utility_task
|
||||
mock_registry.get_metadata = default_registry.get_metadata
|
||||
mock_registry.create_task = MagicMock(return_value=mock_task_instance)
|
||||
|
||||
executor = TaskExecutor(
|
||||
config=mock_config,
|
||||
db_ops=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
cursor_mgr=mock_cursor_mgr,
|
||||
run_tracker=mock_run_tracker,
|
||||
task_registry=mock_registry,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
# 3) 执行任务
|
||||
result = executor.run_single_task(
|
||||
task_code=task_code,
|
||||
run_uuid="test-uuid-001",
|
||||
store_id=1,
|
||||
)
|
||||
|
||||
# 4) 验证:CursorManager 和 RunTracker 不应被调用
|
||||
mock_cursor_mgr.get_or_create.assert_not_called()
|
||||
mock_cursor_mgr.advance.assert_not_called()
|
||||
mock_run_tracker.create_run.assert_not_called()
|
||||
mock_run_tracker.update_run.assert_not_called()
|
||||
|
||||
# 验证结果状态
|
||||
assert result["status"] == "SUCCESS"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 12: CLI data_source 解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# resolve_data_source 需要一个类似 argparse.Namespace 的对象
|
||||
_DATA_SOURCE_VALUES = ["online", "offline", "hybrid"]
|
||||
_PIPELINE_FLOW_VALUES = ["FULL", "FETCH_ONLY", "INGEST_ONLY"]
|
||||
_FLOW_TO_DATA_SOURCE = {
|
||||
"FULL": "hybrid",
|
||||
"FETCH_ONLY": "online",
|
||||
"INGEST_ONLY": "offline",
|
||||
}
|
||||
|
||||
|
||||
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(
|
||||
data_source=st.one_of(st.none(), st.sampled_from(_DATA_SOURCE_VALUES)),
|
||||
pipeline_flow=st.one_of(st.none(), st.sampled_from(_PIPELINE_FLOW_VALUES)),
|
||||
)
|
||||
def test_property12_cli_data_source_resolution(data_source, pipeline_flow):
|
||||
"""Property 12: CLI data_source 解析
|
||||
|
||||
对任意 --data-source 参数值(online/offline/hybrid)和
|
||||
--pipeline-flow 参数值(FULL/FETCH_ONLY/INGEST_ONLY),
|
||||
resolve_data_source 应返回正确的映射值,且 --data-source 优先于
|
||||
--pipeline-flow。
|
||||
|
||||
**Validates: Requirements 6.3, 6.4**
|
||||
"""
|
||||
from cli.main import resolve_data_source
|
||||
|
||||
# 构造类似 argparse.Namespace 的对象
|
||||
args = types.SimpleNamespace(
|
||||
data_source=data_source,
|
||||
pipeline_flow=pipeline_flow,
|
||||
)
|
||||
|
||||
result = resolve_data_source(args)
|
||||
|
||||
# 规则 1:--data-source 优先
|
||||
if data_source is not None:
|
||||
assert result == data_source, (
|
||||
f"当 --data-source={data_source} 时,结果应为 {data_source},"
|
||||
f"实际为 {result}"
|
||||
)
|
||||
# 规则 2:无 --data-source 时,使用 --pipeline-flow 映射
|
||||
elif pipeline_flow is not None:
|
||||
expected = _FLOW_TO_DATA_SOURCE[pipeline_flow]
|
||||
assert result == expected, (
|
||||
f"当 --pipeline-flow={pipeline_flow} 时,结果应为 {expected},"
|
||||
f"实际为 {result}"
|
||||
)
|
||||
# 规则 3:两者都未指定时,默认 hybrid
|
||||
else:
|
||||
assert result == "hybrid", (
|
||||
f"两者都未指定时,结果应为 hybrid,实际为 {result}"
|
||||
)
|
||||
|
||||
# 通用断言:返回值必须是有效的 data_source 值
|
||||
assert result in _DATA_SOURCE_VALUES, (
|
||||
f"返回值 {result} 不在有效值列表 {_DATA_SOURCE_VALUES} 中"
|
||||
)
|
||||
@@ -0,0 +1,152 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI 参数和 Flow 类型文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 7.1, 7.2**
|
||||
|
||||
Property 6: 对于所有在 cli/main.py 的 parse_args() 中定义的 CLI 参数,
|
||||
README.md 或 base_task_mechanism.md 中应包含该参数的说明。
|
||||
|
||||
Property 7: 对于所有在 FlowRunner.FLOW_LAYERS 中定义的 Flow 类型,
|
||||
README.md 中应包含该 Flow 类型的层组合说明。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 6 & 7
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ── 常量 ──────────────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_CLI_MAIN_PATH = _PROJECT_ROOT / "cli" / "main.py"
|
||||
_README_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "README.md"
|
||||
_BASE_MECHANISM_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "base_task_mechanism.md"
|
||||
|
||||
|
||||
# ── 辅助函数:通过 AST 解析 parse_args() 中的 CLI 参数名 ─────
|
||||
|
||||
def _extract_cli_params_via_ast() -> list[str]:
|
||||
"""从 cli/main.py 的 parse_args() 函数中,通过 AST 提取所有 add_argument 的参数名。
|
||||
|
||||
只提取以 '--' 开头的长参数名(如 '--store-id'),忽略位置参数。
|
||||
当 add_argument 有多个名称时(如 '--api-token', '--token'),取第一个 '--' 开头的名称。
|
||||
"""
|
||||
source = _CLI_MAIN_PATH.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(_CLI_MAIN_PATH))
|
||||
|
||||
params: list[str] = []
|
||||
|
||||
# 找到 parse_args 函数定义
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef) or node.name != "parse_args":
|
||||
continue
|
||||
|
||||
# 遍历函数体中的所有 add_argument 调用
|
||||
for child in ast.walk(node):
|
||||
if not isinstance(child, ast.Call):
|
||||
continue
|
||||
# 匹配 parser.add_argument(...) 或 xxx.add_argument(...)
|
||||
func = child.func
|
||||
if not (isinstance(func, ast.Attribute) and func.attr == "add_argument"):
|
||||
continue
|
||||
|
||||
# 从位置参数中提取 '--xxx' 形式的参数名
|
||||
for arg in child.args:
|
||||
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
||||
val = arg.value
|
||||
if val.startswith("--"):
|
||||
params.append(val)
|
||||
break # 取第一个 '--' 开头的名称即可
|
||||
|
||||
return sorted(set(params))
|
||||
|
||||
|
||||
# ── 辅助函数:提取 FLOW_LAYERS 的键 ──────────────────────
|
||||
|
||||
def _extract_pipeline_types() -> list[str]:
|
||||
"""从 FlowRunner.FLOW_LAYERS 获取所有 Flow 类型名称。
|
||||
|
||||
直接导入 FLOW_LAYERS 字典,避免实例化 FlowRunner。
|
||||
"""
|
||||
# 通过 AST 解析 flow_runner.py 提取 FLOW_LAYERS 的键
|
||||
pr_path = _PROJECT_ROOT / "orchestration" / "flow_runner.py"
|
||||
source = pr_path.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(pr_path))
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.ClassDef) or node.name != "FlowRunner":
|
||||
continue
|
||||
for item in node.body:
|
||||
if not isinstance(item, (ast.Assign, ast.AnnAssign)):
|
||||
continue
|
||||
# 匹配 FLOW_LAYERS = {...} 或 FLOW_LAYERS: ... = {...}
|
||||
targets = (
|
||||
[item.target] if isinstance(item, ast.AnnAssign) else item.targets
|
||||
)
|
||||
for target in targets:
|
||||
if isinstance(target, ast.Name) and target.id == "FLOW_LAYERS":
|
||||
value = item.value
|
||||
if isinstance(value, ast.Dict):
|
||||
keys: list[str] = []
|
||||
for k in value.keys:
|
||||
if isinstance(k, ast.Constant) and isinstance(k.value, str):
|
||||
keys.append(k.value)
|
||||
return sorted(keys)
|
||||
|
||||
raise RuntimeError("未能从 flow_runner.py 中解析出 FLOW_LAYERS")
|
||||
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_CLI_PARAMS: list[str] = _extract_cli_params_via_ast()
|
||||
_PIPELINE_TYPES: list[str] = _extract_pipeline_types()
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def readme_content() -> str:
|
||||
"""读取 README.md 全文。"""
|
||||
assert _README_PATH.exists(), f"文档文件不存在: {_README_PATH}"
|
||||
return _README_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def base_mechanism_content() -> str:
|
||||
"""读取 base_task_mechanism.md 全文。"""
|
||||
assert _BASE_MECHANISM_PATH.exists(), f"文档文件不存在: {_BASE_MECHANISM_PATH}"
|
||||
return _BASE_MECHANISM_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── Property 6: CLI 参数文档覆盖完整性 ────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("param", _CLI_PARAMS, ids=_CLI_PARAMS)
|
||||
def test_cli_param_in_docs(param: str, readme_content: str, base_mechanism_content: str):
|
||||
"""Property 6: 每个 CLI 参数在 README.md 或 base_task_mechanism.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 7.1**
|
||||
"""
|
||||
# 参数名以反引号包裹或直接出现均可
|
||||
combined = readme_content + "\n" + base_mechanism_content
|
||||
assert param in combined, (
|
||||
f"CLI 参数 '{param}' 在 parse_args() 中定义,"
|
||||
f"但未在 README.md 或 base_task_mechanism.md 中找到对应说明"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 7: Flow 类型文档覆盖完整性 ────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("pipeline_type", _PIPELINE_TYPES, ids=_PIPELINE_TYPES)
|
||||
def test_pipeline_type_in_readme(pipeline_type: str, readme_content: str):
|
||||
"""Property 7: 每个 Flow 类型在 README.md 中有对应的层组合说明。
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
assert pipeline_type in readme_content, (
|
||||
f"Flow 类型 '{pipeline_type}' 在 FLOW_LAYERS 中定义,"
|
||||
f"但未在 README.md 中找到对应的层组合说明"
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWD 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 3.1**
|
||||
|
||||
从 task_registry.py 中提取所有 layer="DWD" 的任务代码,
|
||||
验证 docs/etl_tasks/dwd_tasks.md 中包含每个任务代码的说明章节。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 2: DWD 任务文档覆盖完整性
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_DWD_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "dwd_tasks.md"
|
||||
|
||||
# 从注册表动态获取所有 DWD 层任务代码
|
||||
_DWD_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("DWD")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dwd_doc_content() -> str:
|
||||
"""读取 dwd_tasks.md 全文,供所有测试用例共享。"""
|
||||
assert _DWD_DOC_PATH.exists(), f"文档文件不存在: {_DWD_DOC_PATH}"
|
||||
return _DWD_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── 参数化验证:每个 DWD 任务代码必须出现在文档中 ─────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _DWD_TASK_CODES, ids=_DWD_TASK_CODES)
|
||||
def test_dwd_task_code_in_doc(task_code: str, dwd_doc_content: str):
|
||||
"""Property 2: 每个注册的 DWD 任务代码在 dwd_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 3.1**
|
||||
"""
|
||||
assert task_code in dwd_doc_content, (
|
||||
f"DWD 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 dwd_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWS 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 4.1, 4.4**
|
||||
|
||||
从 task_registry.py 中提取所有 layer="DWS" 的任务代码,
|
||||
验证 docs/etl_tasks/dws_tasks.md 中包含每个任务代码的说明章节,
|
||||
并标注其更新策略。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 3: DWS 任务文档覆盖完整性
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_DWS_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "dws_tasks.md"
|
||||
|
||||
# 从注册表动态获取所有 DWS 层任务代码
|
||||
_DWS_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("DWS")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dws_doc_content() -> str:
|
||||
"""读取 dws_tasks.md 全文,供所有测试用例共享。"""
|
||||
assert _DWS_DOC_PATH.exists(), f"文档文件不存在: {_DWS_DOC_PATH}"
|
||||
return _DWS_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── 参数化验证:每个 DWS 任务代码必须出现在文档中 ─────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _DWS_TASK_CODES, ids=_DWS_TASK_CODES)
|
||||
def test_dws_task_code_in_doc(task_code: str, dws_doc_content: str):
|
||||
"""Property 3: 每个注册的 DWS 任务代码在 dws_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 4.1, 4.4**
|
||||
"""
|
||||
assert task_code in dws_doc_content, (
|
||||
f"DWS 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 dws_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""INDEX 和 Utility 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 5.1, 6.1**
|
||||
|
||||
Property 4: 对于所有在 task_registry.py 中注册且 layer="INDEX" 的任务代码,
|
||||
index_tasks.md 中应包含该任务代码的说明章节。
|
||||
|
||||
Property 5: 对于所有在 task_registry.py 中注册且 task_type="utility" 的任务代码,
|
||||
utility_tasks.md 中应包含该任务代码的说明章节。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 4 & 5
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_INDEX_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "index_tasks.md"
|
||||
_UTILITY_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "utility_tasks.md"
|
||||
|
||||
# INDEX 层任务:通过 get_tasks_by_layer 获取
|
||||
_INDEX_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("INDEX")
|
||||
|
||||
# Utility 任务:注册表无 get_tasks_by_type,直接遍历内部字典筛选
|
||||
_UTILITY_TASK_CODES: list[str] = [
|
||||
code
|
||||
for code, meta in default_registry._tasks.items()
|
||||
if meta.task_type == "utility"
|
||||
]
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def index_doc_content() -> str:
|
||||
"""读取 index_tasks.md 全文。"""
|
||||
assert _INDEX_DOC_PATH.exists(), f"文档文件不存在: {_INDEX_DOC_PATH}"
|
||||
return _INDEX_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def utility_doc_content() -> str:
|
||||
"""读取 utility_tasks.md 全文。"""
|
||||
assert _UTILITY_DOC_PATH.exists(), f"文档文件不存在: {_UTILITY_DOC_PATH}"
|
||||
return _UTILITY_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── Property 4: INDEX 任务文档覆盖完整性 ──────────────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _INDEX_TASK_CODES, ids=_INDEX_TASK_CODES)
|
||||
def test_index_task_code_in_doc(task_code: str, index_doc_content: str):
|
||||
"""Property 4: 每个注册的 INDEX 任务代码在 index_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 5.1**
|
||||
"""
|
||||
assert task_code in index_doc_content, (
|
||||
f"INDEX 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 index_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 5: Utility 任务文档覆盖完整性 ────────────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _UTILITY_TASK_CODES, ids=_UTILITY_TASK_CODES)
|
||||
def test_utility_task_code_in_doc(task_code: str, utility_doc_content: str):
|
||||
"""Property 5: 每个注册的 task_type="utility" 任务代码在 utility_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 6.1**
|
||||
"""
|
||||
assert task_code in utility_doc_content, (
|
||||
f"Utility 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 utility_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ODS 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 2.1, 2.4**
|
||||
|
||||
从 task_registry.py 中提取所有 layer="ODS" 的任务代码,
|
||||
验证 docs/etl_tasks/ods_tasks.md 中包含每个任务代码的说明章节。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 1: ODS 任务文档覆盖完整性
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_ODS_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "ods_tasks.md"
|
||||
|
||||
# 从注册表动态获取所有 ODS 层任务代码
|
||||
_ODS_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("ODS")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ods_doc_content() -> str:
|
||||
"""读取 ods_tasks.md 全文,供所有测试用例共享。"""
|
||||
assert _ODS_DOC_PATH.exists(), f"文档文件不存在: {_ODS_DOC_PATH}"
|
||||
return _ODS_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── 参数化验证:每个 ODS 任务代码必须出现在文档中 ─────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _ODS_TASK_CODES, ids=_ODS_TASK_CODES)
|
||||
def test_ods_task_code_in_doc(task_code: str, ods_doc_content: str):
|
||||
"""Property 1: 每个注册的 ODS 任务代码在 ods_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 2.1, 2.4**
|
||||
"""
|
||||
assert task_code in ods_doc_content, (
|
||||
f"ODS 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 ods_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,202 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWD 第一阶段重构 — 单元测试。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tests.unit.task_test_utils import FakeCursor, FakeDBOperations
|
||||
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
|
||||
|
||||
def _make_task() -> DwdLoadTask:
|
||||
"""构造一个最小可用的 DwdLoadTask 实例。"""
|
||||
config = MagicMock()
|
||||
# side_effect 让 config.get(key, default) 正确返回 default
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: default)
|
||||
db = FakeDBOperations()
|
||||
api = MagicMock()
|
||||
logger = logging.getLogger("test_dwd_phase1")
|
||||
return DwdLoadTask(config, db, api, logger)
|
||||
|
||||
|
||||
# ── 5.2 _build_column_mapping: fetched_at 缺失时使用 ods_table / cur 参数 ──
|
||||
|
||||
|
||||
class TestBuildColumnMappingMissingFetchedAt:
|
||||
"""验证 _build_column_mapping() 在 fetched_at 缺失时的行为(需求 4.1)。"""
|
||||
|
||||
def test_returns_zero_counts_when_fetched_at_missing(self):
|
||||
"""ods_cols 不含 fetched_at → 返回全零计数字典。"""
|
||||
task = _make_task()
|
||||
cur = FakeCursor(recorder=[])
|
||||
result = task._build_column_mapping(
|
||||
cur,
|
||||
dwd_table="dwd.dwd_settlement_head",
|
||||
ods_table="ods.settlement_records",
|
||||
pk_cols=["id"],
|
||||
ods_cols=["id", "amount", "create_time"], # 无 fetched_at
|
||||
)
|
||||
assert result == {"processed": 0, "inserted": 0, "updated": 0, "skipped": 0}
|
||||
|
||||
def test_error_log_contains_ods_table_param(self, caplog):
|
||||
"""日志中应包含传入的 ods_table 参数值,证明使用了方法参数。"""
|
||||
task = _make_task()
|
||||
cur = FakeCursor(recorder=[])
|
||||
ods_table = "ods.my_custom_test_table"
|
||||
|
||||
with caplog.at_level(logging.ERROR, logger="test_dwd_phase1"):
|
||||
task._build_column_mapping(
|
||||
cur,
|
||||
dwd_table="dwd.dwd_settlement_head",
|
||||
ods_table=ods_table,
|
||||
pk_cols=["id"],
|
||||
ods_cols=["id", "amount"],
|
||||
)
|
||||
|
||||
assert any(ods_table in record.message for record in caplog.records), (
|
||||
f"日志中未找到 ods_table='{ods_table}',实际日志: {[r.message for r in caplog.records]}"
|
||||
)
|
||||
|
||||
def test_calls_log_missing_fetched_at_with_cur(self):
|
||||
"""应调用 _log_missing_fetched_at(cur, ods_table),验证 cur 参数被正确传递。"""
|
||||
task = _make_task()
|
||||
cur = FakeCursor(recorder=[])
|
||||
ods_table = "ods.settlement_records"
|
||||
|
||||
# 替换为 mock 以验证调用
|
||||
task._log_missing_fetched_at = MagicMock()
|
||||
|
||||
task._build_column_mapping(
|
||||
cur,
|
||||
dwd_table="dwd.dwd_settlement_head",
|
||||
ods_table=ods_table,
|
||||
pk_cols=["id"],
|
||||
ods_cols=["id", "amount"],
|
||||
)
|
||||
|
||||
task._log_missing_fetched_at.assert_not_called()
|
||||
# _log_missing_fetched_at 仅在 fetched_at 存在时调用,
|
||||
# fetched_at 缺失时直接 return,不会走到 _log_missing_fetched_at
|
||||
|
||||
def test_fetched_at_present_calls_log_missing_with_cur(self):
|
||||
"""fetched_at 存在时,验证 _log_missing_fetched_at 被调用且传入了正确的 cur 和 ods_table。"""
|
||||
task = _make_task()
|
||||
cur = FakeCursor(recorder=[])
|
||||
ods_table = "ods.settlement_records"
|
||||
|
||||
task._log_missing_fetched_at = MagicMock()
|
||||
|
||||
task._build_column_mapping(
|
||||
cur,
|
||||
dwd_table="dwd.dwd_settlement_head",
|
||||
ods_table=ods_table,
|
||||
pk_cols=["id"],
|
||||
ods_cols=["id", "amount", "fetched_at"],
|
||||
)
|
||||
|
||||
task._log_missing_fetched_at.assert_called_once_with(cur, ods_table)
|
||||
|
||||
|
||||
# ── 6.1 验证死代码已清理:hasattr 检查所有已删除的方法和常量 ──
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDeadCodeRemoved:
|
||||
"""验证重构后 DwdLoadTask 上不再存在已删除的方法和常量(需求 1.3, 2.1, 2.2, 3.2-3.7)。"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attr, requirement",
|
||||
[
|
||||
("_get_fact_watermark", "1.3"),
|
||||
("_insert_missing_by_pk", "2.1"),
|
||||
("_pick_order_column", "3.2"),
|
||||
("_upsert_scd2_row", "3.4"),
|
||||
("_close_current_dim", "3.5"),
|
||||
("_insert_dim_row", "3.6"),
|
||||
("_merge_dim_type1_upsert", "3.7"),
|
||||
],
|
||||
ids=lambda v: v if not v.startswith("_") else v.lstrip("_"),
|
||||
)
|
||||
def test_method_removed(self, attr: str, requirement: str):
|
||||
"""已删除的方法不应存在于 DwdLoadTask 实例上。"""
|
||||
task = _make_task()
|
||||
assert not hasattr(task, attr), (
|
||||
f"需求 {requirement}:方法 {attr} 应已删除,但仍存在于 DwdLoadTask"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attr, requirement",
|
||||
[
|
||||
("FACT_ORDER_CANDIDATES", "3.3"),
|
||||
("FACT_MISSING_FILL_TABLES", "2.2"),
|
||||
],
|
||||
ids=lambda v: v if not v.startswith("FACT") else v,
|
||||
)
|
||||
def test_constant_removed(self, attr: str, requirement: str):
|
||||
"""已删除的常量不应存在于 DwdLoadTask 类上。"""
|
||||
assert not hasattr(DwdLoadTask, attr), (
|
||||
f"需求 {requirement}:常量 {attr} 应已删除,但仍存在于 DwdLoadTask"
|
||||
)
|
||||
|
||||
|
||||
class TestBaseFileRemoved:
|
||||
"""验证 base_dwd_task.py 已删除(需求 3.1)。"""
|
||||
|
||||
def test_base_dwd_task_import_raises(self):
|
||||
"""导入 BaseDwdTask 应抛出 ImportError。"""
|
||||
with pytest.raises(ImportError):
|
||||
from tasks.dwd.base_dwd_task import BaseDwdTask # noqa: F401
|
||||
|
||||
|
||||
# ── 6.2 验证外部模块导入正常(debug_dwd.py、integrity_checker.py 无 ImportError) ──
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
class TestExternalModuleImports:
|
||||
"""验证外部模块导入正常且包含 _TIME_COLUMN_CANDIDATES(需求 3.9)。"""
|
||||
|
||||
def test_debug_dwd_imports_without_error(self):
|
||||
"""scripts.debug.debug_dwd 应能正常导入,无 ImportError。"""
|
||||
mod = importlib.import_module("scripts.debug.debug_dwd")
|
||||
assert mod is not None
|
||||
|
||||
def test_integrity_checker_imports_without_error(self):
|
||||
"""quality.integrity_checker 应能正常导入,无 ImportError。"""
|
||||
mod = importlib.import_module("quality.integrity_checker")
|
||||
assert mod is not None
|
||||
|
||||
def test_debug_dwd_has_time_column_candidates(self):
|
||||
"""debug_dwd 应包含 _TIME_COLUMN_CANDIDATES 常量。"""
|
||||
mod = importlib.import_module("scripts.debug.debug_dwd")
|
||||
assert hasattr(mod, "_TIME_COLUMN_CANDIDATES"), (
|
||||
"debug_dwd 缺少 _TIME_COLUMN_CANDIDATES 常量"
|
||||
)
|
||||
|
||||
def test_integrity_checker_has_time_column_candidates(self):
|
||||
"""integrity_checker 应包含 _TIME_COLUMN_CANDIDATES 常量。"""
|
||||
mod = importlib.import_module("quality.integrity_checker")
|
||||
assert hasattr(mod, "_TIME_COLUMN_CANDIDATES"), (
|
||||
"integrity_checker 缺少 _TIME_COLUMN_CANDIDATES 常量"
|
||||
)
|
||||
|
||||
def test_debug_dwd_time_column_candidates_content(self):
|
||||
"""debug_dwd._TIME_COLUMN_CANDIDATES 应包含预期的列名。"""
|
||||
mod = importlib.import_module("scripts.debug.debug_dwd")
|
||||
expected = [
|
||||
"pay_time", "create_time", "update_time",
|
||||
"occur_time", "settle_time", "start_use_time", "fetched_at",
|
||||
]
|
||||
assert mod._TIME_COLUMN_CANDIDATES == expected
|
||||
|
||||
def test_integrity_checker_time_column_candidates_content(self):
|
||||
"""integrity_checker._TIME_COLUMN_CANDIDATES 应包含预期的列名。"""
|
||||
mod = importlib.import_module("quality.integrity_checker")
|
||||
expected = [
|
||||
"pay_time", "create_time", "update_time",
|
||||
"occur_time", "settle_time", "start_use_time", "fetched_at",
|
||||
]
|
||||
assert mod._TIME_COLUMN_CANDIDATES == expected
|
||||
243
apps/etl/connectors/feiqiu/tests/unit/test_dws_helpers.py
Normal file
243
apps/etl/connectors/feiqiu/tests/unit/test_dws_helpers.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
dws_helpers 函数等价性 — 属性测试
|
||||
|
||||
Feature: etl-dws-flow-refactor, Property 3: dws_helpers 函数等价性
|
||||
测试位置:apps/etl/connectors/feiqiu/tests/unit/
|
||||
|
||||
使用 hypothesis 验证 dws_helpers 模块中各辅助函数的正确性属性。
|
||||
**Validates: Requirements 2.3**
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume
|
||||
import hypothesis.strategies as st
|
||||
|
||||
# ── 将 ETL 模块加入 sys.path ──
|
||||
_ETL_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_ETL_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_ETL_ROOT))
|
||||
|
||||
from tasks.dws.dws_helpers import (
|
||||
mask_mobile,
|
||||
calc_days_since,
|
||||
parse_id_list,
|
||||
safe_division,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 策略(Strategies)
|
||||
# ============================================================================
|
||||
|
||||
# 11 位数字字符串(标准中国手机号长度)
|
||||
st_mobile_11 = st.from_regex(r"[0-9]{11}", fullmatch=True)
|
||||
|
||||
# 7 位及以上的字符串(mask_mobile 会执行脱敏)
|
||||
st_mobile_long = st.from_regex(r"[0-9]{7,15}", fullmatch=True)
|
||||
|
||||
# 短于 7 位的字符串(mask_mobile 原样返回)
|
||||
st_mobile_short = st.text(
|
||||
alphabet=st.characters(whitelist_categories=("Nd",)),
|
||||
min_size=0,
|
||||
max_size=6,
|
||||
)
|
||||
|
||||
# 日期策略
|
||||
st_date = st.dates(min_value=date(2000, 1, 1), max_value=date(2035, 12, 31))
|
||||
|
||||
# datetime 策略(用于测试 calc_days_since 的 datetime→date 自动转换)
|
||||
st_datetime = st.datetimes(
|
||||
min_value=datetime(2000, 1, 1),
|
||||
max_value=datetime(2035, 12, 31),
|
||||
)
|
||||
|
||||
# 逗号分隔的正整数字符串
|
||||
st_comma_int_str = st.lists(
|
||||
st.integers(min_value=0, max_value=999999),
|
||||
min_size=0,
|
||||
max_size=20,
|
||||
).map(lambda nums: ",".join(str(n) for n in nums))
|
||||
|
||||
# 安全除法的数值策略
|
||||
st_numerator = st.decimals(
|
||||
min_value=Decimal("-1000000"),
|
||||
max_value=Decimal("1000000"),
|
||||
allow_nan=False,
|
||||
allow_infinity=False,
|
||||
)
|
||||
st_denominator_nonzero = st.decimals(
|
||||
min_value=Decimal("0.001"),
|
||||
max_value=Decimal("1000000"),
|
||||
allow_nan=False,
|
||||
allow_infinity=False,
|
||||
).flatmap(
|
||||
lambda d: st.sampled_from([d, -d])
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 3.1: mask_mobile — 11 位字符串脱敏
|
||||
# Feature: etl-dws-flow-refactor, Property 3: dws_helpers 函数等价性
|
||||
# ============================================================================
|
||||
|
||||
class TestMaskMobileProperty:
|
||||
"""mask_mobile 属性测试"""
|
||||
|
||||
@given(mobile=st_mobile_11)
|
||||
@settings(max_examples=100)
|
||||
def test_11_digit_mask_structure(self, mobile: str):
|
||||
"""对任意 11 位数字字符串,mask_mobile 应返回前 3 位 + **** + 后 4 位"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = mask_mobile(mobile)
|
||||
assert result is not None
|
||||
assert len(result) == 11 # 3 + 4 + 4
|
||||
assert result[:3] == mobile[:3]
|
||||
assert result[3:7] == "****"
|
||||
assert result[7:] == mobile[7:]
|
||||
|
||||
@given(mobile=st_mobile_long)
|
||||
@settings(max_examples=100)
|
||||
def test_long_mobile_preserves_head_tail(self, mobile: str):
|
||||
"""对任意 ≥7 位字符串,mask_mobile 保留前 3 位和后 4 位"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = mask_mobile(mobile)
|
||||
assert result is not None
|
||||
assert result[:3] == mobile[:3]
|
||||
assert result[-4:] == mobile[-4:]
|
||||
assert "****" in result
|
||||
|
||||
@given(mobile=st_mobile_short)
|
||||
@settings(max_examples=100)
|
||||
def test_short_mobile_returned_as_is(self, mobile: str):
|
||||
"""短于 7 位的号码原样返回"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = mask_mobile(mobile)
|
||||
assert result == mobile
|
||||
|
||||
def test_none_returns_none(self):
|
||||
"""None 输入返回 None"""
|
||||
assert mask_mobile(None) is None
|
||||
|
||||
def test_empty_returns_empty(self):
|
||||
"""空字符串原样返回(len < 7)"""
|
||||
assert mask_mobile("") == ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 3.2: calc_days_since — 日期差计算
|
||||
# Feature: etl-dws-flow-refactor, Property 3: dws_helpers 函数等价性
|
||||
# ============================================================================
|
||||
|
||||
class TestCalcDaysSinceProperty:
|
||||
"""calc_days_since 属性测试"""
|
||||
|
||||
@given(stat_date=st_date, last_date=st_date)
|
||||
@settings(max_examples=100)
|
||||
def test_date_diff_equals_subtraction(self, stat_date: date, last_date: date):
|
||||
"""对任意两个 date,calc_days_since(stat_date, last_date) == (stat_date - last_date).days"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = calc_days_since(stat_date, last_date)
|
||||
expected = (stat_date - last_date).days
|
||||
assert result == expected
|
||||
|
||||
@given(stat_date=st_date, last_dt=st_datetime)
|
||||
@settings(max_examples=100)
|
||||
def test_datetime_auto_converts_to_date(self, stat_date: date, last_dt: datetime):
|
||||
"""当 last_date 为 datetime 时,自动取 .date() 后计算"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = calc_days_since(stat_date, last_dt)
|
||||
expected = (stat_date - last_dt.date()).days
|
||||
assert result == expected
|
||||
|
||||
@given(stat_date=st_date)
|
||||
@settings(max_examples=100)
|
||||
def test_none_last_date_returns_none(self, stat_date: date):
|
||||
"""last_date 为 None 时返回 None"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = calc_days_since(stat_date, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 3.3: parse_id_list — 逗号分隔整数解析
|
||||
# Feature: etl-dws-flow-refactor, Property 3: dws_helpers 函数等价性
|
||||
# ============================================================================
|
||||
|
||||
class TestParseIdListProperty:
|
||||
"""parse_id_list 属性测试"""
|
||||
|
||||
@given(nums=st.lists(st.integers(min_value=0, max_value=999999), min_size=1, max_size=20))
|
||||
@settings(max_examples=100)
|
||||
def test_comma_string_roundtrip(self, nums: list[int]):
|
||||
"""对任意非负整数列表,逗号拼接后 parse_id_list 应返回对应的 int 集合"""
|
||||
# **Validates: Requirements 2.3**
|
||||
csv_str = ",".join(str(n) for n in nums)
|
||||
result = parse_id_list(csv_str)
|
||||
assert result == set(nums)
|
||||
|
||||
@given(nums=st.lists(st.integers(min_value=0, max_value=999999), min_size=0, max_size=20))
|
||||
@settings(max_examples=100)
|
||||
def test_list_input_returns_same_set(self, nums: list[int]):
|
||||
"""list 输入也应返回对应的 int 集合"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = parse_id_list(nums)
|
||||
assert result == set(nums)
|
||||
|
||||
def test_empty_string_returns_empty_set(self):
|
||||
"""空字符串返回空集合"""
|
||||
assert parse_id_list("") == set()
|
||||
|
||||
def test_none_returns_empty_set(self):
|
||||
"""None 返回空集合"""
|
||||
assert parse_id_list(None) == set()
|
||||
|
||||
def test_non_digit_items_skipped(self):
|
||||
"""非数字项被静默跳过"""
|
||||
assert parse_id_list("1,abc,3") == {1, 3}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 3.4: safe_division — 安全除法
|
||||
# Feature: etl-dws-flow-refactor, Property 3: dws_helpers 函数等价性
|
||||
# ============================================================================
|
||||
|
||||
class TestSafeDivisionProperty:
|
||||
"""safe_division 属性测试"""
|
||||
|
||||
@given(numerator=st_numerator, denominator=st_denominator_nonzero)
|
||||
@settings(max_examples=100)
|
||||
def test_nonzero_denominator_returns_quotient(self, numerator: Decimal, denominator: Decimal):
|
||||
"""分母非零时,safe_division 返回 Decimal(numerator) / Decimal(denominator)"""
|
||||
# **Validates: Requirements 2.3**
|
||||
result = safe_division(numerator, denominator)
|
||||
expected = Decimal(str(numerator)) / Decimal(str(denominator))
|
||||
assert result == expected
|
||||
|
||||
@given(numerator=st_numerator)
|
||||
@settings(max_examples=100)
|
||||
def test_zero_denominator_returns_default(self, numerator: Decimal):
|
||||
"""分母为零时返回默认值"""
|
||||
# **Validates: Requirements 2.3**
|
||||
default = Decimal("99.99")
|
||||
result = safe_division(numerator, 0, default=default)
|
||||
assert result == default
|
||||
|
||||
def test_zero_denominator_default_is_zero(self):
|
||||
"""分母为零且未指定 default 时返回 Decimal('0')"""
|
||||
assert safe_division(100, 0) == Decimal("0")
|
||||
|
||||
def test_none_denominator_returns_default(self):
|
||||
"""分母为 None 时返回默认值"""
|
||||
assert safe_division(100, None) == Decimal("0")
|
||||
|
||||
def test_result_is_decimal(self):
|
||||
"""返回值始终为 Decimal 类型"""
|
||||
result = safe_division(10, 3)
|
||||
assert isinstance(result, Decimal)
|
||||
479
apps/etl/connectors/feiqiu/tests/unit/test_dws_tasks.py
Normal file
479
apps/etl/connectors/feiqiu/tests/unit/test_dws_tasks.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-14 | bugfix: 修复 3 个测试 bug
|
||||
# prompt: "继续。完成后检查所有任务是否全面"
|
||||
# 直接原因: (1) mock_config.get 返回 None 导致 timezone 异常;(2) _build_daily_record 缺少 gift_card 参数;(3) loaded_at naive/aware 不匹配
|
||||
# 变更: mock_config.get 改用 side_effect 返回 default;补充 gift_card 参数;loaded_at 改用 aware datetime
|
||||
# 验证: pytest tests/unit -x(449 passed)
|
||||
"""
|
||||
DWS任务单元测试
|
||||
|
||||
测试内容:
|
||||
- BaseDwsTask基类方法
|
||||
- 时间计算方法
|
||||
- 配置应用方法
|
||||
- 排名计算方法
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from tasks.dws.base_dws_task import (
|
||||
BaseDwsTask,
|
||||
TimeLayer,
|
||||
TimeWindow,
|
||||
CourseType,
|
||||
TimeRange,
|
||||
ConfigCache
|
||||
)
|
||||
from tasks.dws.finance_daily_task import FinanceDailyTask
|
||||
from tasks.dws.assistant_monthly_task import AssistantMonthlyTask
|
||||
|
||||
|
||||
class TestTimeLayerRange:
|
||||
"""测试时间分层范围计算"""
|
||||
|
||||
def test_last_2_days(self):
|
||||
"""测试近2天"""
|
||||
base_date = date(2026, 2, 1)
|
||||
# 创建一个模拟的BaseDwsTask实例
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_layer_range(TimeLayer.LAST_2_DAYS, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 31)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
def test_last_1_month(self):
|
||||
"""测试近1月"""
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_layer_range(TimeLayer.LAST_1_MONTH, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 2)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
def test_last_3_months(self):
|
||||
"""测试近3月"""
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_layer_range(TimeLayer.LAST_3_MONTHS, base_date)
|
||||
|
||||
assert result.start == date(2025, 11, 3)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
|
||||
class TestTimeWindowRange:
|
||||
"""测试时间窗口范围计算"""
|
||||
|
||||
def test_this_week_monday_start(self):
|
||||
"""测试本周(周一起始)"""
|
||||
# 2026-02-01 是周日
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.THIS_WEEK, base_date)
|
||||
|
||||
# 本周一是 2026-01-26
|
||||
assert result.start == date(2026, 1, 26)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
def test_last_week(self):
|
||||
"""测试上周"""
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_WEEK, base_date)
|
||||
|
||||
# 上周一是 2026-01-19,上周日是 2026-01-25
|
||||
assert result.start == date(2026, 1, 19)
|
||||
assert result.end == date(2026, 1, 25)
|
||||
|
||||
def test_this_month(self):
|
||||
"""测试本月"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.THIS_MONTH, base_date)
|
||||
|
||||
assert result.start == date(2026, 2, 1)
|
||||
assert result.end == date(2026, 2, 15)
|
||||
|
||||
def test_last_month(self):
|
||||
"""测试上月"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_MONTH, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 1)
|
||||
assert result.end == date(2026, 1, 31)
|
||||
|
||||
def test_last_3_months_excl_current(self):
|
||||
"""测试前3个月(不含本月)"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_EXCL_CURRENT, base_date)
|
||||
|
||||
assert result.start == date(2025, 11, 1)
|
||||
assert result.end == date(2026, 1, 31)
|
||||
|
||||
def test_last_3_months_incl_current(self):
|
||||
"""测试前3个月(含本月)"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_INCL_CURRENT, base_date)
|
||||
|
||||
assert result.start == date(2025, 12, 1)
|
||||
assert result.end == date(2026, 2, 15)
|
||||
|
||||
def test_this_quarter(self):
|
||||
"""测试本季度"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.THIS_QUARTER, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 1)
|
||||
assert result.end == date(2026, 2, 15)
|
||||
|
||||
def test_last_6_months(self):
|
||||
"""测试最近半年(不含本月)"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_6_MONTHS, base_date)
|
||||
|
||||
# 不含本月,从上月末往前6个月
|
||||
assert result.end == date(2026, 1, 31)
|
||||
assert result.start == date(2025, 8, 1)
|
||||
|
||||
|
||||
class TestComparisonRange:
|
||||
"""测试环比区间计算"""
|
||||
|
||||
def test_comparison_7_days(self):
|
||||
"""测试7天环比"""
|
||||
task = create_mock_task()
|
||||
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 2, 7))
|
||||
|
||||
result = task.get_comparison_range(current)
|
||||
|
||||
# 上一个7天:1月25日-1月31日
|
||||
assert result.start == date(2026, 1, 25)
|
||||
assert result.end == date(2026, 1, 31)
|
||||
|
||||
def test_comparison_30_days(self):
|
||||
"""测试30天环比"""
|
||||
task = create_mock_task()
|
||||
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 3, 2))
|
||||
|
||||
result = task.get_comparison_range(current)
|
||||
|
||||
# 上一个30天区间
|
||||
assert (result.end - result.start).days == (current.end - current.start).days
|
||||
|
||||
|
||||
class TestFinanceDailyRecord:
|
||||
"""测试财务日度记录计算"""
|
||||
|
||||
def test_groupbuy_and_cashflow(self):
|
||||
"""测试团购优惠与现金流口径"""
|
||||
task = create_finance_daily_task()
|
||||
stat_date = date(2026, 2, 1)
|
||||
|
||||
settle = {
|
||||
'gross_amount': Decimal('1000'),
|
||||
'table_fee_amount': Decimal('1000'),
|
||||
'goods_amount': Decimal('0'),
|
||||
'assistant_pd_amount': Decimal('0'),
|
||||
'assistant_cx_amount': Decimal('0'),
|
||||
'cash_pay_amount': Decimal('300'),
|
||||
'card_pay_amount': Decimal('0'),
|
||||
'balance_pay_amount': Decimal('0'),
|
||||
'gift_card_pay_amount': Decimal('0'),
|
||||
'coupon_amount': Decimal('200'),
|
||||
'pl_coupon_sale_amount': Decimal('0'),
|
||||
'adjust_amount': Decimal('50'),
|
||||
'member_discount_amount': Decimal('10'),
|
||||
'rounding_amount': Decimal('0'),
|
||||
'order_count': 1,
|
||||
'member_order_count': 1,
|
||||
'guest_order_count': 0,
|
||||
}
|
||||
groupbuy = {'groupbuy_pay_total': Decimal('80')}
|
||||
recharge = {'recharge_cash': Decimal('20')}
|
||||
expense = {'expense_amount': Decimal('40')}
|
||||
platform = {
|
||||
'settlement_amount': Decimal('60'),
|
||||
'commission_amount': Decimal('5'),
|
||||
'service_fee': Decimal('5'),
|
||||
}
|
||||
big_customer = {'big_customer_amount': Decimal('20')}
|
||||
|
||||
gift_card = {'gift_card_consume': Decimal('0')}
|
||||
record = task._build_daily_record(
|
||||
stat_date, settle, groupbuy, recharge, gift_card, expense, platform, big_customer, 1
|
||||
)
|
||||
|
||||
assert record['discount_groupbuy'] == Decimal('120')
|
||||
assert record['discount_other'] == Decimal('30')
|
||||
assert record['platform_settlement_amount'] == Decimal('60')
|
||||
assert record['platform_fee_amount'] == Decimal('10')
|
||||
assert record['cash_inflow_total'] == Decimal('380')
|
||||
assert record['cash_outflow_total'] == Decimal('50')
|
||||
assert record['cash_balance_change'] == Decimal('330')
|
||||
|
||||
|
||||
class TestNewHireTier:
|
||||
"""测试新入职定档规则"""
|
||||
|
||||
def test_new_hire_tier_hours(self):
|
||||
"""测试日均*30折算"""
|
||||
task = create_assistant_monthly_task()
|
||||
effective_hours = Decimal('15')
|
||||
work_days = 5
|
||||
result = task._calc_new_hire_tier_hours(effective_hours, work_days)
|
||||
assert result == Decimal('90')
|
||||
|
||||
def test_max_tier_level_cap(self):
|
||||
"""测试新入职定档上限"""
|
||||
task = create_mock_task()
|
||||
now = datetime.now(tz=task.tz)
|
||||
task._config_cache = ConfigCache(
|
||||
performance_tiers=[
|
||||
{'tier_id': 1, 'tier_level': 1, 'min_hours': 0, 'max_hours': 100, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
{'tier_id': 2, 'tier_level': 2, 'min_hours': 100, 'max_hours': 200, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
{'tier_id': 3, 'tier_level': 3, 'min_hours': 200, 'max_hours': 300, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
{'tier_id': 4, 'tier_level': 4, 'min_hours': 300, 'max_hours': None, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
],
|
||||
level_prices=[],
|
||||
bonus_rules=[],
|
||||
area_categories={},
|
||||
skill_types={},
|
||||
loaded_at=now
|
||||
)
|
||||
|
||||
tier = task.get_performance_tier(
|
||||
Decimal('350'),
|
||||
is_new_hire=True,
|
||||
effective_date=date(2026, 2, 1),
|
||||
max_tier_level=3
|
||||
)
|
||||
assert tier['tier_level'] == 3
|
||||
|
||||
|
||||
class TestNewHireCheck:
|
||||
"""测试新入职判断"""
|
||||
|
||||
def test_new_hire_in_month(self):
|
||||
"""测试月内入职为新入职"""
|
||||
task = create_mock_task()
|
||||
hire_date = date(2026, 2, 5)
|
||||
stat_month = date(2026, 2, 1)
|
||||
|
||||
assert task.is_new_hire_in_month(hire_date, stat_month) == True
|
||||
|
||||
def test_not_new_hire(self):
|
||||
"""测试月前入职不是新入职"""
|
||||
task = create_mock_task()
|
||||
hire_date = date(2026, 1, 15)
|
||||
stat_month = date(2026, 2, 1)
|
||||
|
||||
assert task.is_new_hire_in_month(hire_date, stat_month) == False
|
||||
|
||||
def test_hire_on_first_day(self):
|
||||
"""测试月1日入职为新入职"""
|
||||
task = create_mock_task()
|
||||
hire_date = date(2026, 2, 1)
|
||||
stat_month = date(2026, 2, 1)
|
||||
|
||||
assert task.is_new_hire_in_month(hire_date, stat_month) == True
|
||||
|
||||
|
||||
class TestRankWithTies:
|
||||
"""测试考虑并列的排名计算"""
|
||||
|
||||
def test_no_ties(self):
|
||||
"""测试无并列情况"""
|
||||
task = create_mock_task()
|
||||
values = [
|
||||
(1, Decimal('100')),
|
||||
(2, Decimal('90')),
|
||||
(3, Decimal('80')),
|
||||
]
|
||||
|
||||
result = task.calculate_rank_with_ties(values)
|
||||
|
||||
assert result[0] == (1, 1, 1) # 第1名
|
||||
assert result[1] == (2, 2, 2) # 第2名
|
||||
assert result[2] == (3, 3, 3) # 第3名
|
||||
|
||||
def test_with_ties(self):
|
||||
"""测试有并列情况"""
|
||||
task = create_mock_task()
|
||||
values = [
|
||||
(1, Decimal('100')),
|
||||
(2, Decimal('100')), # 并列第1
|
||||
(3, Decimal('80')),
|
||||
]
|
||||
|
||||
result = task.calculate_rank_with_ties(values)
|
||||
|
||||
# 两个第1,下一个是第3
|
||||
assert result[0][1] == 1 # 第1名
|
||||
assert result[1][1] == 1 # 并列第1名
|
||||
assert result[2][1] == 3 # 第3名(跳过2)
|
||||
|
||||
def test_all_ties(self):
|
||||
"""测试全部并列"""
|
||||
task = create_mock_task()
|
||||
values = [
|
||||
(1, Decimal('100')),
|
||||
(2, Decimal('100')),
|
||||
(3, Decimal('100')),
|
||||
]
|
||||
|
||||
result = task.calculate_rank_with_ties(values)
|
||||
|
||||
# 全部第1
|
||||
assert all(r[1] == 1 for r in result)
|
||||
|
||||
|
||||
class TestGuestCheck:
|
||||
"""测试散客判断"""
|
||||
|
||||
def test_guest_zero(self):
|
||||
"""测试member_id=0为散客"""
|
||||
task = create_mock_task()
|
||||
assert task.is_guest(0) == True
|
||||
|
||||
def test_guest_none(self):
|
||||
"""测试member_id=None为散客"""
|
||||
task = create_mock_task()
|
||||
assert task.is_guest(None) == True
|
||||
|
||||
def test_not_guest(self):
|
||||
"""测试正常会员不是散客"""
|
||||
task = create_mock_task()
|
||||
assert task.is_guest(12345) == False
|
||||
|
||||
|
||||
class TestUtilityMethods:
|
||||
"""测试工具方法"""
|
||||
|
||||
def test_safe_decimal(self):
|
||||
"""测试安全Decimal转换"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.safe_decimal(100) == Decimal('100')
|
||||
assert task.safe_decimal('123.45') == Decimal('123.45')
|
||||
assert task.safe_decimal(None) == Decimal('0')
|
||||
assert task.safe_decimal('invalid') == Decimal('0')
|
||||
|
||||
def test_safe_int(self):
|
||||
"""测试安全int转换"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.safe_int(100) == 100
|
||||
assert task.safe_int('123') == 123
|
||||
assert task.safe_int(None) == 0
|
||||
assert task.safe_int('invalid') == 0
|
||||
|
||||
def test_seconds_to_hours(self):
|
||||
"""测试秒转小时"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.seconds_to_hours(3600) == Decimal('1')
|
||||
assert task.seconds_to_hours(5400) == Decimal('1.5')
|
||||
assert task.seconds_to_hours(0) == Decimal('0')
|
||||
|
||||
def test_hours_to_seconds(self):
|
||||
"""测试小时转秒"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.hours_to_seconds(Decimal('1')) == 3600
|
||||
assert task.hours_to_seconds(Decimal('1.5')) == 5400
|
||||
|
||||
|
||||
class TestCourseType:
|
||||
"""测试课程类型"""
|
||||
|
||||
def test_base_course(self):
|
||||
"""测试基础课"""
|
||||
assert CourseType.BASE.value == 'BASE'
|
||||
|
||||
def test_bonus_course(self):
|
||||
"""测试附加课"""
|
||||
assert CourseType.BONUS.value == 'BONUS'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
|
||||
def create_mock_task():
|
||||
"""
|
||||
创建一个模拟的BaseDwsTask实例用于测试
|
||||
"""
|
||||
# 创建一个具体的子类用于测试
|
||||
class TestDwsTask(BaseDwsTask):
|
||||
def get_task_code(self):
|
||||
return "TEST_DWS_TASK"
|
||||
|
||||
def get_target_table(self):
|
||||
return "test_table"
|
||||
|
||||
def get_primary_keys(self):
|
||||
return ["id"]
|
||||
|
||||
def extract(self, context):
|
||||
return {}
|
||||
|
||||
def load(self, transformed, context):
|
||||
return {}
|
||||
|
||||
# 创建模拟的依赖
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: default
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
|
||||
task = TestDwsTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
return task
|
||||
|
||||
|
||||
def create_finance_daily_task():
|
||||
"""创建 FinanceDailyTask 实例用于测试"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: 1 if key == "app.tenant_id" else default
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
return FinanceDailyTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
|
||||
|
||||
def create_assistant_monthly_task():
|
||||
"""创建 AssistantMonthlyTask 实例用于测试"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: default
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
return AssistantMonthlyTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
222
apps/etl/connectors/feiqiu/tests/unit/test_e2e_flow.py
Normal file
222
apps/etl/connectors/feiqiu/tests/unit/test_e2e_flow.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""端到端流程集成测试
|
||||
|
||||
验证 CLI → FlowRunner → TaskExecutor 完整调用链。
|
||||
使用 mock 依赖,不需要真实数据库。
|
||||
|
||||
需求: 9.4
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
import pytest
|
||||
|
||||
from orchestration.task_executor import TaskExecutor, DataSource
|
||||
from orchestration.flow_runner import FlowRunner
|
||||
from orchestration.task_registry import TaskRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:构造最小可用的 mock config
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_config(**overrides):
|
||||
"""构造一个行为类似 AppConfig 的 MagicMock。"""
|
||||
store = {
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
"app.store_id": 1,
|
||||
"io.fetch_root": "/tmp/fetch",
|
||||
"io.ingest_source_dir": "",
|
||||
"io.write_pretty_json": False,
|
||||
"io.export_root": "/tmp/export",
|
||||
"io.log_root": "/tmp/logs",
|
||||
"pipeline.fetch_root": None,
|
||||
"pipeline.ingest_source_dir": None,
|
||||
"run.ods_tasks": [],
|
||||
"run.dws_tasks": [],
|
||||
"run.index_tasks": [],
|
||||
"run.data_source": "hybrid",
|
||||
"verification.ods_use_local_json": False,
|
||||
"verification.skip_ods_when_fetch_before_verify": True,
|
||||
}
|
||||
store.update(overrides)
|
||||
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda k, d=None: store.get(k, d))
|
||||
config.__getitem__ = MagicMock(side_effect=lambda k: {
|
||||
"io": {"export_root": "/tmp/export", "log_root": "/tmp/logs"},
|
||||
}[k])
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:构造一个可被 TaskRegistry 注册的假任务类
|
||||
# ---------------------------------------------------------------------------
|
||||
class _FakeTask:
|
||||
"""最小假任务,execute() 返回固定结果。"""
|
||||
def __init__(self, config, db_ops, api_client, logger):
|
||||
pass
|
||||
|
||||
def execute(self, cursor_data):
|
||||
return {"status": "SUCCESS", "counts": {"fetched": 5, "inserted": 3}}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 测试 1:传统模式 — TaskExecutor.run_tasks 端到端
|
||||
# ===========================================================================
|
||||
class TestTraditionalModeE2E:
|
||||
"""传统模式:TaskExecutor.run_tasks 端到端"""
|
||||
|
||||
def test_run_tasks_executes_utility_task_and_returns_results(self):
|
||||
"""工具类任务走 _run_utility_task 路径,跳过游标和运行记录。"""
|
||||
config = _make_config()
|
||||
registry = TaskRegistry()
|
||||
registry.register(
|
||||
"FAKE_UTIL", _FakeTask,
|
||||
requires_db_config=False, task_type="utility",
|
||||
)
|
||||
|
||||
cursor_mgr = MagicMock()
|
||||
run_tracker = MagicMock()
|
||||
|
||||
executor = TaskExecutor(
|
||||
config=config,
|
||||
db_ops=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
cursor_mgr=cursor_mgr,
|
||||
run_tracker=run_tracker,
|
||||
task_registry=registry,
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
results = executor.run_tasks(["FAKE_UTIL"], data_source="hybrid")
|
||||
|
||||
assert len(results) == 1
|
||||
# 工具类任务成功时 run_tasks 包装为 "成功"
|
||||
assert results[0]["status"] in ("成功", "完成", "SUCCESS")
|
||||
# 工具类任务不应触发游标或运行记录
|
||||
cursor_mgr.get_or_create.assert_not_called()
|
||||
run_tracker.create_run.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 测试 2:Flow 模式 — FlowRunner → TaskExecutor 端到端
|
||||
# ===========================================================================
|
||||
class TestFlowModeE2E:
|
||||
"""Flow 模式:FlowRunner.run → TaskExecutor.run_tasks 端到端"""
|
||||
|
||||
def test_flow_delegates_to_executor_and_returns_structure(self):
|
||||
"""FlowRunner 解析层→任务后委托 TaskExecutor 执行。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = [
|
||||
{"task_code": "FAKE_ODS", "status": "成功", "counts": {"fetched": 10, "inserted": 8}},
|
||||
]
|
||||
|
||||
registry = TaskRegistry()
|
||||
registry.register("FAKE_ODS", _FakeTask, layer="ODS")
|
||||
|
||||
config = _make_config()
|
||||
|
||||
runner = FlowRunner(
|
||||
config=config,
|
||||
task_executor=executor,
|
||||
task_registry=registry,
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
result = runner.run(
|
||||
pipeline="api_ods",
|
||||
processing_mode="increment_only",
|
||||
data_source="hybrid",
|
||||
)
|
||||
|
||||
# 结构验证
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["pipeline"] == "api_ods"
|
||||
assert result["layers"] == ["ODS"]
|
||||
assert isinstance(result["results"], list)
|
||||
# TaskExecutor 被调用
|
||||
executor.run_tasks.assert_called_once()
|
||||
call_args = executor.run_tasks.call_args
|
||||
assert call_args[1]["data_source"] == "hybrid"
|
||||
|
||||
def test_flow_verify_only_skips_increment(self):
|
||||
"""verify_only 模式跳过增量 ETL,仅执行校验。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
|
||||
registry = TaskRegistry()
|
||||
config = _make_config()
|
||||
|
||||
runner = FlowRunner(
|
||||
config=config,
|
||||
task_executor=executor,
|
||||
task_registry=registry,
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
# 校验框架可能未安装,mock 掉 _run_verification
|
||||
with patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}):
|
||||
result = runner.run(
|
||||
pipeline="api_ods",
|
||||
processing_mode="verify_only",
|
||||
data_source="hybrid",
|
||||
)
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
# verify_only 且 fetch_before_verify=False 时不调用 run_tasks
|
||||
executor.run_tasks.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 测试 3:ETLScheduler 薄包装层委托验证
|
||||
# ===========================================================================
|
||||
class TestSchedulerThinWrapper:
|
||||
"""ETLScheduler 薄包装层正确委托 TaskExecutor / FlowRunner。"""
|
||||
|
||||
def test_scheduler_delegates_run_tasks(self):
|
||||
"""run_tasks() 委托给内部 task_executor。"""
|
||||
from orchestration.scheduler import ETLScheduler
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.__getitem__ = MagicMock(side_effect=lambda k: {
|
||||
"db": {
|
||||
"dsn": "postgresql://fake:5432/test",
|
||||
"session": {"timezone": "Asia/Shanghai"},
|
||||
"connect_timeout_sec": 5,
|
||||
},
|
||||
"api": {
|
||||
"base_url": "https://fake.api",
|
||||
"token": "fake-token",
|
||||
"timeout_sec": 30,
|
||||
"retries": {"max_attempts": 3},
|
||||
},
|
||||
}[k])
|
||||
mock_config.get = MagicMock(side_effect=lambda k, d=None: {
|
||||
"run.data_source": "hybrid",
|
||||
"run.tasks": ["FAKE"],
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
}.get(k, d))
|
||||
|
||||
# mock 掉资源创建,避免真实连接
|
||||
with patch("orchestration.scheduler.DatabaseConnection"), \
|
||||
patch("orchestration.scheduler.DatabaseOperations"), \
|
||||
patch("orchestration.scheduler.APIClient"), \
|
||||
patch("orchestration.scheduler.CursorManager"), \
|
||||
patch("orchestration.scheduler.RunTracker"), \
|
||||
patch("orchestration.scheduler.TaskExecutor") as MockTE, \
|
||||
patch("orchestration.scheduler.FlowRunner") as MockPR:
|
||||
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
scheduler = ETLScheduler(mock_config, MagicMock())
|
||||
|
||||
# run_tasks 委托
|
||||
scheduler.run_tasks(["TEST_TASK"])
|
||||
scheduler.task_executor.run_tasks.assert_called_once()
|
||||
|
||||
# run_flow_with_verification 委托
|
||||
scheduler.run_flow_with_verification(pipeline="api_ods")
|
||||
scheduler.flow_runner.run.assert_called_once()
|
||||
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for recent/former endpoint routing."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from api.endpoint_routing import plan_calls, recent_boundary
|
||||
|
||||
|
||||
TZ = ZoneInfo("Asia/Shanghai")
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime(2025, 12, 18, 10, 0, 0, tzinfo=TZ)
|
||||
|
||||
|
||||
def test_recent_boundary_month_start():
|
||||
b = recent_boundary(_now())
|
||||
assert b.isoformat() == "2025-09-01T00:00:00+08:00"
|
||||
|
||||
|
||||
def test_paylog_routes_to_former_when_old_window():
|
||||
params = {"siteId": 1, "StartPayTime": "2025-08-01 00:00:00", "EndPayTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/PayLog/GetPayLogListPage", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/PayLog/GetFormerPayLogListPage"]
|
||||
|
||||
|
||||
def test_coupon_usage_stays_same_path_even_when_old():
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/Promotion/GetOfflineCouponConsumePageList", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/Promotion/GetOfflineCouponConsumePageList"]
|
||||
|
||||
|
||||
def test_goods_outbound_routes_to_queryformer_when_old():
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/GoodsStockManage/QueryGoodsOutboundReceipt", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/GoodsStockManage/QueryFormerGoodsOutboundReceipt"]
|
||||
|
||||
|
||||
def test_settlement_records_split_when_crossing_boundary():
|
||||
params = {"siteId": 1, "rangeStartTime": "2025-08-15 00:00:00", "rangeEndTime": "2025-09-10 00:00:00"}
|
||||
calls = plan_calls("/Site/GetAllOrderSettleList", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/Site/GetFormerOrderSettleList", "/Site/GetAllOrderSettleList"]
|
||||
assert calls[0].params["rangeEndTime"] == "2025-09-01 00:00:00"
|
||||
assert calls[1].params["rangeStartTime"] == "2025-09-01 00:00:00"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/PayLog/GetFormerPayLogListPage",
|
||||
"/Site/GetFormerOrderSettleList",
|
||||
"/GoodsStockManage/QueryFormerGoodsOutboundReceipt",
|
||||
],
|
||||
)
|
||||
def test_explicit_former_endpoint_not_rerouted(endpoint):
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls(endpoint, params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == [endpoint]
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""filter_verify_tables 单元测试"""
|
||||
|
||||
import pytest
|
||||
from tasks.verification.models import filter_verify_tables
|
||||
|
||||
|
||||
class TestFilterVerifyTables:
|
||||
"""按层过滤校验表名"""
|
||||
|
||||
def test_none_input_returns_none(self):
|
||||
assert filter_verify_tables("DWD", None) is None
|
||||
|
||||
def test_empty_list_returns_none(self):
|
||||
assert filter_verify_tables("DWD", []) is None
|
||||
|
||||
def test_dwd_layer_filters_correctly(self):
|
||||
tables = ["dwd_order", "dim_member", "fact_payment", "ods_raw", "dws_daily"]
|
||||
result = filter_verify_tables("DWD", tables)
|
||||
assert result == ["dwd_order", "dim_member", "fact_payment"]
|
||||
|
||||
def test_dws_layer_filters_correctly(self):
|
||||
tables = ["dws_daily", "dwd_order", "dws_summary"]
|
||||
result = filter_verify_tables("DWS", tables)
|
||||
assert result == ["dws_daily", "dws_summary"]
|
||||
|
||||
def test_index_layer_filters_correctly(self):
|
||||
tables = ["v_score", "wbi_index", "dws_daily", "v_rank"]
|
||||
result = filter_verify_tables("INDEX", tables)
|
||||
assert result == ["v_score", "wbi_index", "v_rank"]
|
||||
|
||||
def test_ods_layer_filters_correctly(self):
|
||||
tables = ["ods_order", "dwd_order", "ods_member"]
|
||||
result = filter_verify_tables("ODS", tables)
|
||||
assert result == ["ods_order", "ods_member"]
|
||||
|
||||
def test_unknown_layer_returns_normalized(self):
|
||||
tables = [" SomeTable ", "Another"]
|
||||
result = filter_verify_tables("UNKNOWN", tables)
|
||||
assert result == ["sometable", "another"]
|
||||
|
||||
def test_layer_case_insensitive(self):
|
||||
tables = ["dwd_order", "ods_raw"]
|
||||
assert filter_verify_tables("dwd", tables) == ["dwd_order"]
|
||||
assert filter_verify_tables("Dwd", tables) == ["dwd_order"]
|
||||
|
||||
def test_whitespace_and_empty_entries_stripped(self):
|
||||
tables = [" dwd_order ", "", " ", None, "dim_member"]
|
||||
result = filter_verify_tables("DWD", tables)
|
||||
assert result == ["dwd_order", "dim_member"]
|
||||
@@ -0,0 +1,903 @@
|
||||
"""审计一览表生成脚本 — 解析模块单元测试
|
||||
|
||||
覆盖:AuditEntry、parse_audit_file、classify_module、scan_audit_dir
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.gen_audit_dashboard import (
|
||||
AuditEntry,
|
||||
MODULE_MAP,
|
||||
VALID_MODULES,
|
||||
classify_module,
|
||||
parse_audit_file,
|
||||
scan_audit_dir,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify_module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClassifyModule:
|
||||
"""classify_module 应将文件路径映射到正确的功能模块"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path, expected",
|
||||
[
|
||||
("api/recording_client.py", "API 层"),
|
||||
("tasks/ods/ods_task.py", "ODS 层"),
|
||||
("tasks/dwd/dwd_load_task.py", "DWD 层"),
|
||||
("tasks/dws/base_dws_task.py", "DWS 层"),
|
||||
("tasks/index/wbi.py", "指数算法"),
|
||||
("loaders/fact_loader.py", "数据装载"),
|
||||
("database/migrations/001.sql", "数据库"),
|
||||
("orchestration/task_registry.py", "调度"),
|
||||
("config/defaults.py", "配置"),
|
||||
("cli/main.py", "CLI"),
|
||||
("models/parser.py", "模型"),
|
||||
("scd/scd2.py", "SCD2"),
|
||||
("docs/README.md", "文档"),
|
||||
("scripts/gen_audit_dashboard.py", "脚本工具"),
|
||||
("tests/unit/test_foo.py", "测试"),
|
||||
("quality/checker.py", "质量校验"),
|
||||
("gui/main.py", "GUI"),
|
||||
("utils/logging_utils.py", "工具库"),
|
||||
],
|
||||
)
|
||||
def test_known_prefixes(self, path, expected):
|
||||
assert classify_module(path) == expected
|
||||
|
||||
def test_unknown_path_returns_other(self):
|
||||
assert classify_module("README.md") == "其他"
|
||||
assert classify_module(".kiro/steering/foo.md") == "其他"
|
||||
|
||||
def test_normalizes_backslash(self):
|
||||
"""Windows 反斜杠路径也能正确分类"""
|
||||
assert classify_module("tasks\\dws\\base.py") == "DWS 层"
|
||||
|
||||
def test_strips_leading_dot_slash(self):
|
||||
assert classify_module("./api/foo.py") == "API 层"
|
||||
|
||||
def test_result_always_in_valid_modules(self):
|
||||
"""任何输入的返回值都应在 VALID_MODULES 内"""
|
||||
for path in ["", "x", "api/", "unknown/deep/path.py"]:
|
||||
assert classify_module(path) in VALID_MODULES
|
||||
|
||||
def test_longest_prefix_wins(self):
|
||||
"""tasks/ods 应优先匹配 ODS 层,而非泛化的 tasks/ 前缀"""
|
||||
# MODULE_MAP 中没有 "tasks/" 泛前缀,但 tasks/ods 应匹配 ODS 层
|
||||
assert classify_module("tasks/ods/foo.py") == "ODS 层"
|
||||
assert classify_module("tasks/dwd/bar.py") == "DWD 层"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_audit_file — 使用临时文件
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 标准审计文件内容模板
|
||||
_STANDARD_AUDIT = textwrap.dedent("""\
|
||||
# 审计记录:测试变更
|
||||
|
||||
- 日期:2026-03-01(Asia/Shanghai)
|
||||
|
||||
## 直接原因
|
||||
测试用例
|
||||
|
||||
## 修改文件清单
|
||||
|
||||
| 文件 | 变更类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `api/client.py` | 修改 | 测试 |
|
||||
| `tasks/dws/foo.py` | 新增 | 测试 |
|
||||
|
||||
## 风险与回滚
|
||||
- 风险:极低。纯测试变更
|
||||
""")
|
||||
|
||||
|
||||
class TestParseAuditFile:
|
||||
|
||||
def test_standard_file(self, tmp_path):
|
||||
f = tmp_path / "2026-03-01__test-change.md"
|
||||
f.write_text(_STANDARD_AUDIT, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.date == "2026-03-01"
|
||||
assert entry.slug == "test-change"
|
||||
assert entry.title == "审计记录:测试变更"
|
||||
assert entry.filename == "2026-03-01__test-change.md"
|
||||
assert "api/client.py" in entry.changed_files
|
||||
assert "tasks/dws/foo.py" in entry.changed_files
|
||||
assert "API 层" in entry.modules
|
||||
assert "DWS 层" in entry.modules
|
||||
assert entry.risk_level == "极低"
|
||||
|
||||
def test_missing_title_uses_slug(self, tmp_path):
|
||||
"""缺少一级标题时用 slug 兜底"""
|
||||
content = "没有标题的文件\n\n一些内容\n"
|
||||
f = tmp_path / "2026-01-01__no-title.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.title == "no-title"
|
||||
|
||||
def test_missing_file_list_section(self, tmp_path):
|
||||
"""缺少文件清单章节 → 空列表,模块为 {"其他"}"""
|
||||
content = "# 标题\n\n没有文件清单\n"
|
||||
f = tmp_path / "2026-01-01__no-files.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.changed_files == []
|
||||
assert entry.modules == {"其他"}
|
||||
|
||||
def test_invalid_filename_returns_none(self, tmp_path):
|
||||
"""文件名不符合格式 → 返回 None"""
|
||||
f = tmp_path / "invalid-name.md"
|
||||
f.write_text("# Title\n", encoding="utf-8")
|
||||
assert parse_audit_file(f) is None
|
||||
|
||||
def test_non_md_gitkeep_returns_none(self, tmp_path):
|
||||
f = tmp_path / ".gitkeep"
|
||||
f.write_text("", encoding="utf-8")
|
||||
assert parse_audit_file(f) is None
|
||||
|
||||
def test_list_format_file_section(self, tmp_path):
|
||||
"""列表格式的文件清单也能正确解析"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
|
||||
## 文件清单(Files changed)
|
||||
- docs/api-reference/summary/foo.md
|
||||
- scripts/gen_api_docs.py
|
||||
- tasks/base_task.py(补 AI_CHANGELOG)
|
||||
""")
|
||||
f = tmp_path / "2026-02-01__list-format.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert len(entry.changed_files) == 3
|
||||
|
||||
def test_risk_from_metadata_header(self, tmp_path):
|
||||
"""从头部元数据行提取风险等级"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
- 日期:2026-01-01
|
||||
- 风险等级:低(纯文档重组)
|
||||
|
||||
## 直接原因
|
||||
测试
|
||||
""")
|
||||
f = tmp_path / "2026-01-01__meta-risk.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.risk_level == "低"
|
||||
|
||||
def test_change_type_bugfix(self, tmp_path):
|
||||
content = "# bugfix 修复\n\n修复了一个 bug\n"
|
||||
f = tmp_path / "2026-01-01__fix.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "bugfix"
|
||||
|
||||
def test_change_type_doc(self, tmp_path):
|
||||
content = "# 纯文档变更\n\n无逻辑改动\n"
|
||||
f = tmp_path / "2026-01-01__doc.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "文档"
|
||||
|
||||
def test_arrow_path_extraction(self, tmp_path):
|
||||
"""含 → 的移动行应提取源和目标路径"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
|
||||
## 变更摘要
|
||||
|
||||
### 文件移动
|
||||
- `docs/index/algo.md` → `docs/database/DWS/algo.md`
|
||||
""")
|
||||
f = tmp_path / "2026-01-01__arrow.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert "docs/index/algo.md" in entry.changed_files
|
||||
assert "docs/database/DWS/algo.md" in entry.changed_files
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_audit_dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScanAuditDir:
|
||||
|
||||
def test_empty_dir(self, tmp_path):
|
||||
assert scan_audit_dir(tmp_path) == []
|
||||
|
||||
def test_nonexistent_dir(self):
|
||||
assert scan_audit_dir("nonexistent_dir_xyz") == []
|
||||
|
||||
def test_sorts_by_date_descending(self, tmp_path):
|
||||
for name in [
|
||||
"2026-01-01__first.md",
|
||||
"2026-03-01__third.md",
|
||||
"2026-02-01__second.md",
|
||||
]:
|
||||
(tmp_path / name).write_text("# Title\n", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
dates = [e.date for e in entries]
|
||||
assert dates == ["2026-03-01", "2026-02-01", "2026-01-01"]
|
||||
|
||||
def test_skips_non_md_files(self, tmp_path):
|
||||
(tmp_path / "2026-01-01__valid.md").write_text("# OK\n", encoding="utf-8")
|
||||
(tmp_path / ".gitkeep").write_text("", encoding="utf-8")
|
||||
(tmp_path / "notes.txt").write_text("text", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert len(entries) == 1
|
||||
|
||||
def test_skips_invalid_filenames(self, tmp_path):
|
||||
(tmp_path / "2026-01-01__valid.md").write_text("# OK\n", encoding="utf-8")
|
||||
(tmp_path / "bad-name.md").write_text("# Bad\n", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].slug == "valid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 真实审计文件集成测试(仅在项目目录中运行时有效)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REAL_AUDIT_DIR = Path("docs/audit/changes")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _REAL_AUDIT_DIR.is_dir(),
|
||||
reason="真实审计目录不存在(非项目根目录运行)",
|
||||
)
|
||||
class TestRealAuditFiles:
|
||||
|
||||
def test_parses_all_real_files(self):
|
||||
entries = scan_audit_dir(_REAL_AUDIT_DIR)
|
||||
assert len(entries) > 0, "应至少解析出一条审计记录"
|
||||
|
||||
def test_all_modules_valid(self):
|
||||
entries = scan_audit_dir(_REAL_AUDIT_DIR)
|
||||
for e in entries:
|
||||
for m in e.modules:
|
||||
assert m in VALID_MODULES, (
|
||||
f"模块 {m!r} 不在 VALID_MODULES 中 (文件: {e.filename})"
|
||||
)
|
||||
|
||||
def test_dates_descending(self):
|
||||
entries = scan_audit_dir(_REAL_AUDIT_DIR)
|
||||
dates = [e.date for e in entries]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 1: 审计记录解析-渲染完整性
|
||||
# Feature: docs-optimization, Property 1: 审计记录解析-渲染完整性
|
||||
# Validates: Requirements 2.1, 2.2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# --- 生成策略 ---
|
||||
|
||||
# 合法日期策略:YYYY-MM-DD 格式
|
||||
_date_st = st.dates(
|
||||
min_value=__import__("datetime").date(2020, 1, 1),
|
||||
max_value=__import__("datetime").date(2030, 12, 31),
|
||||
).map(lambda d: d.isoformat())
|
||||
|
||||
# slug 策略:小写字母+数字+连字符,长度 2~30
|
||||
_slug_st = st.from_regex(r"[a-z][a-z0-9\-]{1,29}", fullmatch=True)
|
||||
|
||||
# 标题策略:非空中文/英文混合文本
|
||||
_title_st = st.text(
|
||||
alphabet=st.sampled_from(
|
||||
list("审计记录变更修复新增重构清理测试文档优化迁移合并")
|
||||
+ list("abcdefghijklmnopqrstuvwxyz ")
|
||||
),
|
||||
min_size=2,
|
||||
max_size=40,
|
||||
).map(lambda s: s.strip() or "默认标题")
|
||||
|
||||
# 文件路径策略:从已知前缀中选取,确保模块分类有意义
|
||||
_KNOWN_PREFIXES = [
|
||||
"api/", "tasks/ods/", "tasks/dwd/", "tasks/dws/", "tasks/index/",
|
||||
"loaders/", "database/migrations/", "orchestration/", "config/",
|
||||
"cli/", "models/", "scd/", "docs/", "scripts/", "tests/unit/",
|
||||
"quality/", "gui/", "utils/",
|
||||
]
|
||||
|
||||
_file_path_st = st.tuples(
|
||||
st.sampled_from(_KNOWN_PREFIXES),
|
||||
st.from_regex(r"[a-z_]{1,15}\.(py|sql|md)", fullmatch=True),
|
||||
).map(lambda t: t[0] + t[1])
|
||||
|
||||
# 风险等级策略
|
||||
_risk_st = st.sampled_from(["极低", "低", "中", "高"])
|
||||
|
||||
# 变更类型关键词策略(用于在内容中嵌入,让 _infer_change_type 能推断)
|
||||
_change_kw_st = st.sampled_from(["bugfix", "修复", "重构", "清理", "纯文档", "功能新增"])
|
||||
|
||||
|
||||
def _build_audit_md(title: str, date: str, files: list[str], risk: str, change_kw: str) -> str:
|
||||
"""根据参数构造一份格式合规的审计 Markdown 内容。"""
|
||||
lines = [
|
||||
f"# {title}",
|
||||
"",
|
||||
f"- 日期:{date}(Asia/Shanghai)",
|
||||
f"- 风险等级:{risk}",
|
||||
"",
|
||||
"## 直接原因",
|
||||
f"本次变更为 {change_kw} 类型操作",
|
||||
"",
|
||||
"## 修改文件清单",
|
||||
"",
|
||||
"| 文件 | 变更类型 | 说明 |",
|
||||
"|------|----------|------|",
|
||||
]
|
||||
for fp in files:
|
||||
lines.append(f"| `{fp}` | 修改 | 自动生成 |")
|
||||
lines.append("")
|
||||
lines.append("## 风险与回滚")
|
||||
lines.append(f"- 风险:{risk}。自动生成的测试内容")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TestProperty1AuditParseCompleteness:
|
||||
"""Property 1: 审计记录解析-渲染完整性
|
||||
|
||||
对于任何格式合规的审计 Markdown 文件,parse_audit_file 解析后
|
||||
应返回包含所有必要字段(date、title、change_type、modules、filename)
|
||||
且各字段非空的 AuditEntry。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2**
|
||||
"""
|
||||
|
||||
@given(
|
||||
date=_date_st,
|
||||
slug=_slug_st,
|
||||
title=_title_st,
|
||||
files=st.lists(_file_path_st, min_size=1, max_size=8),
|
||||
risk=_risk_st,
|
||||
change_kw=_change_kw_st,
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_parsed_entry_has_all_required_fields(
|
||||
self, tmp_path_factory, date, slug, title, files, risk, change_kw
|
||||
):
|
||||
"""解析格式合规的审计文件后,AuditEntry 的所有必要字段均非空。"""
|
||||
# 构造临时文件
|
||||
tmp_dir = tmp_path_factory.mktemp("audit")
|
||||
filename = f"{date}__{slug}.md"
|
||||
md_content = _build_audit_md(title, date, files, risk, change_kw)
|
||||
filepath = tmp_dir / filename
|
||||
filepath.write_text(md_content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(filepath)
|
||||
|
||||
# 核心断言:解析成功且所有必要字段非空
|
||||
assert entry is not None, f"格式合规的文件应能成功解析:{filename}"
|
||||
assert entry.date, "date 字段不应为空"
|
||||
assert entry.title, "title 字段不应为空"
|
||||
assert entry.filename, "filename 字段不应为空"
|
||||
assert entry.change_type, "change_type 字段不应为空"
|
||||
assert len(entry.modules) > 0, "modules 集合不应为空"
|
||||
|
||||
# 日期应与文件名中的日期一致
|
||||
assert entry.date == date
|
||||
|
||||
# filename 应与实际文件名一致
|
||||
assert entry.filename == filename
|
||||
|
||||
# modules 中的每个值都应在 VALID_MODULES 内
|
||||
for mod in entry.modules:
|
||||
assert mod in VALID_MODULES, f"模块 {mod!r} 不在 VALID_MODULES 中"
|
||||
|
||||
@given(
|
||||
date=_date_st,
|
||||
slug=_slug_st,
|
||||
title=_title_st,
|
||||
files=st.lists(_file_path_st, min_size=1, max_size=5),
|
||||
risk=_risk_st,
|
||||
change_kw=_change_kw_st,
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_parsed_files_match_input(
|
||||
self, tmp_path_factory, date, slug, title, files, risk, change_kw
|
||||
):
|
||||
"""解析后的 changed_files 应包含输入的所有文件路径。"""
|
||||
tmp_dir = tmp_path_factory.mktemp("audit")
|
||||
filename = f"{date}__{slug}.md"
|
||||
md_content = _build_audit_md(title, date, files, risk, change_kw)
|
||||
filepath = tmp_dir / filename
|
||||
filepath.write_text(md_content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(filepath)
|
||||
assert entry is not None
|
||||
|
||||
# 每个输入文件路径都应出现在解析结果中
|
||||
for fp in files:
|
||||
assert fp in entry.changed_files, (
|
||||
f"文件 {fp!r} 应出现在 changed_files 中,"
|
||||
f"实际结果:{entry.changed_files}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 2: 文件路径模块分类正确性
|
||||
# Feature: docs-optimization, Property 2: 文件路径模块分类正确性
|
||||
# Validates: Requirements 2.3
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# --- 文件路径生成策略 ---
|
||||
|
||||
# 已知前缀路径:确保覆盖 MODULE_MAP 中所有前缀
|
||||
_known_prefix_path_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.from_regex(r"[a-z_]{1,20}\.(py|sql|md|txt|json)", fullmatch=True),
|
||||
).map(lambda t: t[0] + t[1])
|
||||
|
||||
# 完全随机路径:任意 Unicode 字符串(含空串、特殊字符等)
|
||||
_random_path_st = st.text(min_size=0, max_size=200)
|
||||
|
||||
# 带反斜杠的 Windows 风格路径
|
||||
_backslash_path_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.from_regex(r"[a-z_]{1,15}\.(py|md)", fullmatch=True),
|
||||
).map(lambda t: t[0].replace("/", "\\") + t[1])
|
||||
|
||||
# 带前导 ./ 的相对路径
|
||||
_dotslash_path_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.from_regex(r"[a-z_]{1,15}\.(py|md)", fullmatch=True),
|
||||
).map(lambda t: "./" + t[0] + t[1])
|
||||
|
||||
# 深层嵌套路径
|
||||
_deep_nested_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.lists(
|
||||
st.from_regex(r"[a-z]{1,8}", fullmatch=True),
|
||||
min_size=1, max_size=5,
|
||||
),
|
||||
st.from_regex(r"[a-z_]{1,10}\.(py|md|sql)", fullmatch=True),
|
||||
).map(lambda t: t[0] + "/".join(t[1]) + "/" + t[2])
|
||||
|
||||
# 未知前缀路径(不以任何已知前缀开头)
|
||||
_unknown_prefix_st = st.tuples(
|
||||
st.sampled_from(["README.md", ".kiro/foo.md", "setup.py", "Makefile",
|
||||
"pyproject.toml", ".env", "unknown/deep/path.py",
|
||||
"random_dir/file.txt", ".github/workflows/ci.yml"]),
|
||||
).map(lambda t: t[0])
|
||||
|
||||
# 混合策略:从以上所有策略中随机选取
|
||||
_any_filepath_st = st.one_of(
|
||||
_known_prefix_path_st,
|
||||
_random_path_st,
|
||||
_backslash_path_st,
|
||||
_dotslash_path_st,
|
||||
_deep_nested_st,
|
||||
_unknown_prefix_st,
|
||||
)
|
||||
|
||||
|
||||
class TestProperty2ModuleClassification:
|
||||
"""Property 2: 文件路径模块分类正确性
|
||||
|
||||
对于任意文件路径字符串,classify_module 的返回值
|
||||
应始终属于预定义的 VALID_MODULES 集合。
|
||||
|
||||
**Validates: Requirements 2.3**
|
||||
"""
|
||||
|
||||
@given(filepath=_any_filepath_st)
|
||||
@settings(max_examples=200)
|
||||
def test_classify_always_returns_valid_module(self, filepath: str):
|
||||
"""任意文件路径的分类结果都在 VALID_MODULES 内。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES, (
|
||||
f"classify_module({filepath!r}) 返回 {result!r},"
|
||||
f"不在 VALID_MODULES 中"
|
||||
)
|
||||
|
||||
@given(filepath=_known_prefix_path_st)
|
||||
@settings(max_examples=150)
|
||||
def test_known_prefix_never_returns_other(self, filepath: str):
|
||||
"""以已知前缀开头的路径不应返回 '其他'。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES, (
|
||||
f"classify_module({filepath!r}) 返回 {result!r},"
|
||||
f"不在 VALID_MODULES 中"
|
||||
)
|
||||
assert result != "其他", (
|
||||
f"已知前缀路径 {filepath!r} 不应分类为 '其他',"
|
||||
f"实际返回 {result!r}"
|
||||
)
|
||||
|
||||
@given(filepath=_unknown_prefix_st)
|
||||
@settings(max_examples=50)
|
||||
def test_unknown_prefix_returns_other(self, filepath: str):
|
||||
"""不匹配任何已知前缀的路径应返回 '其他'。"""
|
||||
result = classify_module(filepath)
|
||||
assert result == "其他", (
|
||||
f"未知前缀路径 {filepath!r} 应分类为 '其他',"
|
||||
f"实际返回 {result!r}"
|
||||
)
|
||||
|
||||
@given(filepath=_backslash_path_st)
|
||||
@settings(max_examples=100)
|
||||
def test_backslash_paths_classified_correctly(self, filepath: str):
|
||||
"""Windows 反斜杠路径应正确归类(非 '其他')。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES
|
||||
assert result != "其他", (
|
||||
f"反斜杠路径 {filepath!r} 应正确分类,"
|
||||
f"实际返回 '其他'"
|
||||
)
|
||||
|
||||
@given(filepath=_dotslash_path_st)
|
||||
@settings(max_examples=100)
|
||||
def test_dotslash_paths_classified_correctly(self, filepath: str):
|
||||
"""前导 ./ 的路径应正确归类(非 '其他')。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES
|
||||
assert result != "其他", (
|
||||
f"前导 ./ 路径 {filepath!r} 应正确分类,"
|
||||
f"实际返回 '其他'"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 渲染函数测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from scripts.gen_audit_dashboard import (
|
||||
render_timeline_table,
|
||||
render_module_index,
|
||||
render_dashboard,
|
||||
)
|
||||
|
||||
|
||||
def _make_entry(**overrides) -> AuditEntry:
|
||||
"""构造测试用 AuditEntry,支持字段覆盖。"""
|
||||
defaults = dict(
|
||||
date="2026-01-15",
|
||||
slug="test-change",
|
||||
title="测试变更",
|
||||
filename="2026-01-15__test-change.md",
|
||||
changed_files=["api/client.py"],
|
||||
modules={"API 层"},
|
||||
risk_level="低",
|
||||
change_type="功能",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return AuditEntry(**defaults)
|
||||
|
||||
|
||||
class TestRenderTimelineTable:
|
||||
"""render_timeline_table 单元测试"""
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""空列表返回暂无审计记录提示"""
|
||||
result = render_timeline_table([])
|
||||
assert "暂无审计记录" in result
|
||||
|
||||
def test_single_entry(self):
|
||||
"""单条记录生成正确的表格行"""
|
||||
entry = _make_entry()
|
||||
result = render_timeline_table([entry])
|
||||
assert "| 日期 |" in result
|
||||
assert "2026-01-15" in result
|
||||
assert "测试变更" in result
|
||||
assert "API 层" in result
|
||||
assert "低" in result
|
||||
assert "[链接](changes/2026-01-15__test-change.md)" in result
|
||||
|
||||
def test_multiple_entries_order_preserved(self):
|
||||
"""多条记录保持输入顺序(调用方负责排序)"""
|
||||
e1 = _make_entry(date="2026-02-01", title="后")
|
||||
e2 = _make_entry(date="2026-01-01", title="前")
|
||||
result = render_timeline_table([e1, e2])
|
||||
pos_e1 = result.index("2026-02-01")
|
||||
pos_e2 = result.index("2026-01-01")
|
||||
assert pos_e1 < pos_e2
|
||||
|
||||
def test_multiple_modules_joined(self):
|
||||
"""多个模块用逗号分隔并排序"""
|
||||
entry = _make_entry(modules={"文档", "API 层"})
|
||||
result = render_timeline_table([entry])
|
||||
assert "API 层, 文档" in result
|
||||
|
||||
def test_table_header_present(self):
|
||||
"""表格包含表头和分隔行"""
|
||||
result = render_timeline_table([_make_entry()])
|
||||
assert "|------|" in result
|
||||
assert "需求摘要" in result
|
||||
assert "变更类型" in result
|
||||
|
||||
|
||||
class TestRenderModuleIndex:
|
||||
"""render_module_index 单元测试"""
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""空列表返回暂无审计记录提示"""
|
||||
result = render_module_index([])
|
||||
assert "暂无审计记录" in result
|
||||
|
||||
def test_single_module(self):
|
||||
"""单模块生成一个三级标题章节"""
|
||||
entry = _make_entry(modules={"API 层"})
|
||||
result = render_module_index([entry])
|
||||
assert "### API 层" in result
|
||||
assert "2026-01-15" in result
|
||||
# 模块索引表格不含"影响模块"列
|
||||
lines = result.strip().splitlines()
|
||||
header = [l for l in lines if l.startswith("| 日期")]
|
||||
assert header
|
||||
assert "影响模块" not in header[0]
|
||||
|
||||
def test_multiple_modules_sorted(self):
|
||||
"""多模块按字母序排列"""
|
||||
e1 = _make_entry(modules={"文档"})
|
||||
e2 = _make_entry(modules={"API 层"}, date="2026-02-01")
|
||||
result = render_module_index([e1, e2])
|
||||
pos_api = result.index("### API 层")
|
||||
pos_doc = result.index("### 文档")
|
||||
assert pos_api < pos_doc
|
||||
|
||||
def test_entry_appears_in_multiple_modules(self):
|
||||
"""一条记录影响多个模块时,在每个模块章节中都出现"""
|
||||
entry = _make_entry(modules={"API 层", "文档"})
|
||||
result = render_module_index([entry])
|
||||
assert "### API 层" in result
|
||||
assert "### 文档" in result
|
||||
# 两个章节都包含该记录的链接
|
||||
assert result.count("[链接](changes/2026-01-15__test-change.md)") == 2
|
||||
|
||||
def test_link_format(self):
|
||||
"""详情列链接格式正确"""
|
||||
entry = _make_entry(filename="2026-03-01__my-slug.md")
|
||||
result = render_module_index([entry])
|
||||
assert "[链接](changes/2026-03-01__my-slug.md)" in result
|
||||
|
||||
|
||||
class TestRenderDashboard:
|
||||
"""render_dashboard 单元测试"""
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""空列表生成包含提示的完整文档"""
|
||||
result = render_dashboard([])
|
||||
assert "# 审计一览表" in result
|
||||
assert "自动生成于" in result
|
||||
assert "暂无审计记录" in result
|
||||
|
||||
def test_contains_both_views(self):
|
||||
"""完整文档包含时间线和模块索引两个章节"""
|
||||
entry = _make_entry()
|
||||
result = render_dashboard([entry])
|
||||
assert "## 时间线视图" in result
|
||||
assert "## 模块索引" in result
|
||||
|
||||
def test_contains_header_and_timestamp(self):
|
||||
"""文档包含标题和生成时间戳"""
|
||||
result = render_dashboard([_make_entry()])
|
||||
assert "# 审计一览表" in result
|
||||
assert "自动生成于" in result
|
||||
assert "请勿手动编辑" in result
|
||||
|
||||
def test_timeline_before_module_index(self):
|
||||
"""时间线视图在模块索引之前"""
|
||||
result = render_dashboard([_make_entry()])
|
||||
pos_timeline = result.index("## 时间线视图")
|
||||
pos_module = result.index("## 模块索引")
|
||||
assert pos_timeline < pos_module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 3: 审计条目时间倒序排列
|
||||
# Feature: docs-optimization, Property 3: 审计条目时间倒序排列
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProperty3AuditEntriesDescendingOrder:
|
||||
"""属性测试:审计条目经排序后日期严格非递增。
|
||||
|
||||
**Validates: Requirements 2.4**
|
||||
"""
|
||||
|
||||
@given(
|
||||
dates=st.lists(
|
||||
st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
min_size=0,
|
||||
max_size=30,
|
||||
)
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_sorted_entries_dates_non_increasing(self, dates):
|
||||
"""任意日期列表构造的 AuditEntry,经 scan_audit_dir 同款排序后,
|
||||
日期序列应为非递增(即每个条目的日期 >= 其后续条目的日期)。"""
|
||||
# 构造 AuditEntry 列表,日期格式与 scan_audit_dir 一致(YYYY-MM-DD 字符串)
|
||||
entries = [
|
||||
AuditEntry(
|
||||
date=d.isoformat(),
|
||||
slug=f"entry-{i}",
|
||||
title=f"条目 {i}",
|
||||
filename=f"{d.isoformat()}__entry-{i}.md",
|
||||
)
|
||||
for i, d in enumerate(dates)
|
||||
]
|
||||
|
||||
# 使用与 scan_audit_dir 完全相同的排序逻辑
|
||||
entries.sort(key=lambda e: e.date, reverse=True)
|
||||
|
||||
# 验证非递增序
|
||||
for i in range(len(entries) - 1):
|
||||
assert entries[i].date >= entries[i + 1].date, (
|
||||
f"位置 {i} 的日期 {entries[i].date} "
|
||||
f"不应小于位置 {i+1} 的日期 {entries[i+1].date}"
|
||||
)
|
||||
|
||||
@given(
|
||||
dates=st.lists(
|
||||
st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
min_size=2,
|
||||
max_size=30,
|
||||
)
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_first_entry_has_latest_date(self, dates):
|
||||
"""排序后第一个条目的日期应等于输入中的最大日期。"""
|
||||
entries = [
|
||||
AuditEntry(
|
||||
date=d.isoformat(),
|
||||
slug=f"entry-{i}",
|
||||
title=f"条目 {i}",
|
||||
filename=f"{d.isoformat()}__entry-{i}.md",
|
||||
)
|
||||
for i, d in enumerate(dates)
|
||||
]
|
||||
|
||||
entries.sort(key=lambda e: e.date, reverse=True)
|
||||
|
||||
expected_max = max(d.isoformat() for d in dates)
|
||||
assert entries[0].date == expected_max
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 补充边界情况测试 — Task 4.6
|
||||
# 覆盖:空内容、缺少风险章节、变更类型推断分支、同日期排序
|
||||
# Requirements: 2.1, 2.3
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseAuditFileEdgeCases:
|
||||
"""parse_audit_file 的补充边界情况"""
|
||||
|
||||
def test_empty_content(self, tmp_path):
|
||||
"""完全空文件 → 标题用 slug 兜底,文件清单为空,模块为 {"其他"}"""
|
||||
f = tmp_path / "2026-01-01__empty.md"
|
||||
f.write_text("", encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.title == "empty"
|
||||
assert entry.changed_files == []
|
||||
assert entry.modules == {"其他"}
|
||||
|
||||
def test_whitespace_only_content(self, tmp_path):
|
||||
"""仅含空白字符 → 同空文件处理"""
|
||||
f = tmp_path / "2026-01-01__blank.md"
|
||||
f.write_text(" \n\n \n", encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.title == "blank"
|
||||
assert entry.changed_files == []
|
||||
|
||||
def test_missing_risk_section_returns_unknown(self, tmp_path):
|
||||
"""无任何风险相关内容 → 风险等级为"未知" """
|
||||
content = "# 简单标题\n\n一些普通内容,没有提到任何等级信息。\n"
|
||||
f = tmp_path / "2026-01-01__no-risk.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.risk_level == "未知"
|
||||
|
||||
def test_change_type_refactor(self, tmp_path):
|
||||
"""含"重构"关键词 → 变更类型为"重构" """
|
||||
content = "# 代码重构\n\n对模块进行了重构优化\n"
|
||||
f = tmp_path / "2026-01-01__refactor.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "重构"
|
||||
|
||||
def test_change_type_cleanup(self, tmp_path):
|
||||
"""含"清理"关键词 → 变更类型为"清理" """
|
||||
content = "# 遗留代码清理\n\n清理了废弃文件\n"
|
||||
f = tmp_path / "2026-01-01__cleanup.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "清理"
|
||||
|
||||
def test_change_type_default_function(self, tmp_path):
|
||||
"""无任何变更类型关键词 → 默认为"功能" """
|
||||
content = "# 新增能力\n\n增加了一个全新的处理流程\n"
|
||||
f = tmp_path / "2026-01-01__feature.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "功能"
|
||||
|
||||
def test_file_section_with_empty_table(self, tmp_path):
|
||||
"""文件清单章节存在但表格无数据行 → 空列表"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
|
||||
## 修改文件清单
|
||||
|
||||
| 文件 | 变更类型 | 说明 |
|
||||
|------|----------|------|
|
||||
|
||||
## 风险与回滚
|
||||
无
|
||||
""")
|
||||
f = tmp_path / "2026-01-01__empty-table.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.changed_files == []
|
||||
assert entry.modules == {"其他"}
|
||||
|
||||
|
||||
class TestScanAuditDirEdgeCases:
|
||||
"""scan_audit_dir 的补充边界情况"""
|
||||
|
||||
def test_dir_with_only_invalid_files(self, tmp_path):
|
||||
"""目录中全是无效文件 → 返回空列表"""
|
||||
(tmp_path / "README.md").write_text("# 说明\n", encoding="utf-8")
|
||||
(tmp_path / ".gitkeep").write_text("", encoding="utf-8")
|
||||
(tmp_path / "notes.txt").write_text("备注", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert entries == []
|
||||
|
||||
def test_same_date_multiple_files(self, tmp_path):
|
||||
"""同日期多个文件 → 全部解析,日期相同"""
|
||||
for slug in ["alpha", "beta", "gamma"]:
|
||||
f = tmp_path / f"2026-03-01__{slug}.md"
|
||||
f.write_text(f"# {slug} 变更\n", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert len(entries) == 3
|
||||
assert all(e.date == "2026-03-01" for e in entries)
|
||||
295
apps/etl/connectors/feiqiu/tests/unit/test_layers_cli.py
Normal file
295
apps/etl/connectors/feiqiu/tests/unit/test_layers_cli.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""--layers CLI 参数单元测试 + 属性测试
|
||||
|
||||
验证 parse_layers() 解析、--layers 与 --flow/--pipeline 互斥、
|
||||
以及 --layers 参数在 argparse 中的注册。
|
||||
|
||||
需求: 6.1, 6.2, 6.3, 6.4
|
||||
"""
|
||||
import pytest
|
||||
import string
|
||||
from unittest.mock import patch
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from cli.main import parse_layers, parse_args, VALID_LAYERS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. parse_layers() 解析
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestParseLayers:
|
||||
"""parse_layers() 函数测试"""
|
||||
|
||||
def test_single_layer(self):
|
||||
assert parse_layers("ODS") == ["ODS"]
|
||||
|
||||
def test_two_layers(self):
|
||||
assert parse_layers("ODS,DWD") == ["ODS", "DWD"]
|
||||
|
||||
def test_all_layers(self):
|
||||
result = parse_layers("ODS,DWD,DWS,INDEX")
|
||||
assert result == ["ODS", "DWD", "DWS", "INDEX"]
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""小写输入自动转大写"""
|
||||
assert parse_layers("ods,dwd") == ["ODS", "DWD"]
|
||||
|
||||
def test_mixed_case(self):
|
||||
assert parse_layers("Ods,dWd,DWS") == ["ODS", "DWD", "DWS"]
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
"""逗号周围的空格被去除"""
|
||||
assert parse_layers(" ODS , DWD ") == ["ODS", "DWD"]
|
||||
|
||||
def test_invalid_layer_raises(self):
|
||||
with pytest.raises(ValueError, match="无效的层名"):
|
||||
parse_layers("ODS,INVALID")
|
||||
|
||||
def test_all_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="无效的层名"):
|
||||
parse_layers("FOO,BAR")
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="不能为空"):
|
||||
parse_layers("")
|
||||
|
||||
def test_only_commas_raises(self):
|
||||
with pytest.raises(ValueError, match="不能为空"):
|
||||
parse_layers(",,,")
|
||||
|
||||
def test_valid_layers_constant(self):
|
||||
"""VALID_LAYERS 包含且仅包含四个合法层名"""
|
||||
assert VALID_LAYERS == {"ODS", "DWD", "DWS", "INDEX"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. --layers argparse 注册
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestLayersArgParsing:
|
||||
"""--layers 参数在 argparse 中的行为"""
|
||||
|
||||
def test_layers_parsed(self):
|
||||
with patch("sys.argv", ["cli", "--layers", "ODS,DWD"]):
|
||||
args = parse_args()
|
||||
assert args.layers == "ODS,DWD"
|
||||
|
||||
def test_layers_default_is_none(self):
|
||||
with patch("sys.argv", ["cli"]):
|
||||
args = parse_args()
|
||||
assert args.layers is None
|
||||
|
||||
def test_pipeline_still_works(self):
|
||||
"""--pipeline 参数保留可用(需求 6.3),存入 pipeline_deprecated(弃用别名)"""
|
||||
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. --layers 与 --pipeline 互斥(需求 6.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestLayersPipelineMutualExclusion:
|
||||
"""--layers 和 --flow/--pipeline 互斥校验
|
||||
|
||||
互斥校验在 main() 中实现(非 argparse 层),
|
||||
此处验证两个参数可以同时被解析(互斥由 main 层处理)。
|
||||
"""
|
||||
|
||||
def test_both_args_can_be_parsed(self):
|
||||
"""argparse 层允许同时传入 --layers 和 --pipeline(弃用别名),互斥由 main() 校验"""
|
||||
with patch("sys.argv", [
|
||||
"cli", "--layers", "ODS,DWD", "--pipeline", "api_full",
|
||||
]):
|
||||
args = parse_args()
|
||||
assert args.layers == "ODS,DWD"
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. 属性测试 — Property 5: --layers 解析正确性
|
||||
# Feature: etl-dws-flow-refactor, Property 5: --layers 解析正确性
|
||||
# Validates: Requirements 6.1, 6.2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 合法层名集合
|
||||
_ALL_LAYERS = sorted(VALID_LAYERS) # ["DWD", "DWS", "INDEX", "ODS"]
|
||||
|
||||
# 策略:生成 {ODS, DWD, DWS, INDEX} 的非空子集
|
||||
_layer_subsets = st.lists(
|
||||
st.sampled_from(_ALL_LAYERS),
|
||||
min_size=1,
|
||||
max_size=4,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
# 策略:生成不在合法层名中的随机字符串
|
||||
_invalid_layer = st.text(
|
||||
alphabet=string.ascii_letters,
|
||||
min_size=1,
|
||||
max_size=10,
|
||||
).filter(lambda s: s.upper() not in VALID_LAYERS)
|
||||
|
||||
|
||||
class TestParseLayersProperties:
|
||||
"""parse_layers() 属性测试
|
||||
|
||||
# Feature: etl-dws-flow-refactor, Property 5: --layers 解析正确性
|
||||
"""
|
||||
|
||||
@given(subset=_layer_subsets)
|
||||
@settings(max_examples=100)
|
||||
def test_valid_subset_roundtrip(self, subset: list[str]):
|
||||
"""任意合法层子集经逗号拼接后,parse_layers 返回恰好该子集(大写)。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2**
|
||||
"""
|
||||
raw = ",".join(subset)
|
||||
result = parse_layers(raw)
|
||||
assert set(result) == set(subset)
|
||||
# 所有元素均为大写
|
||||
assert all(l == l.upper() for l in result)
|
||||
# 元素数量与输入一致
|
||||
assert len(result) == len(subset)
|
||||
|
||||
@given(subset=_layer_subsets)
|
||||
@settings(max_examples=100)
|
||||
def test_case_insensitive_roundtrip(self, subset: list[str]):
|
||||
"""任意大小写变体经 parse_layers 后均返回大写结果。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2**
|
||||
"""
|
||||
# 随机变换大小写:全小写
|
||||
raw = ",".join(l.lower() for l in subset)
|
||||
result = parse_layers(raw)
|
||||
assert set(result) == set(subset)
|
||||
|
||||
@given(subset=_layer_subsets)
|
||||
@settings(max_examples=100)
|
||||
def test_whitespace_tolerance(self, subset: list[str]):
|
||||
"""逗号周围添加空格不影响解析结果。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2**
|
||||
"""
|
||||
raw = " , ".join(f" {l} " for l in subset)
|
||||
result = parse_layers(raw)
|
||||
assert set(result) == set(subset)
|
||||
|
||||
@given(
|
||||
valid=_layer_subsets,
|
||||
bad=_invalid_layer,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_invalid_layer_raises_valueerror(self, valid: list[str], bad: str):
|
||||
"""包含无效层名时必须抛出 ValueError。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2**
|
||||
"""
|
||||
# 确保 bad 确实不在合法层名中(filter 已保证,双重保险)
|
||||
assume(bad.upper() not in VALID_LAYERS)
|
||||
raw = ",".join(valid + [bad])
|
||||
with pytest.raises(ValueError, match="无效的层名"):
|
||||
parse_layers(raw)
|
||||
|
||||
@given(bad=_invalid_layer)
|
||||
@settings(max_examples=100)
|
||||
def test_all_invalid_raises_valueerror(self, bad: str):
|
||||
"""纯无效层名字符串必须抛出 ValueError。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2**
|
||||
"""
|
||||
assume(bad.upper() not in VALID_LAYERS)
|
||||
with pytest.raises(ValueError, match="无效的层名"):
|
||||
parse_layers(bad)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. --flow / --pipeline 弃用别名测试(需求 9.3, 9.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFlowParameter:
|
||||
"""--flow 作为主参数、--pipeline 作为弃用别名"""
|
||||
|
||||
def test_flow_parsed(self):
|
||||
"""--flow 作为主参数可正常解析"""
|
||||
with patch("sys.argv", ["cli", "--flow", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.flow == "api_full"
|
||||
|
||||
def test_flow_default_is_none(self):
|
||||
"""未指定 --flow 时默认为 None"""
|
||||
with patch("sys.argv", ["cli"]):
|
||||
args = parse_args()
|
||||
assert args.flow is None
|
||||
|
||||
def test_pipeline_deprecated_parsed(self):
|
||||
"""--pipeline 仍可解析,存入 pipeline_deprecated"""
|
||||
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
assert args.flow is None # --flow 未指定
|
||||
|
||||
def test_pipeline_emits_deprecation_warning(self):
|
||||
"""使用 --pipeline 时应发出 DeprecationWarning(需求 9.4)
|
||||
|
||||
直接模拟 main() 中的弃用逻辑,避免进入数据库连接。
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
|
||||
# 模拟 main() 中的弃用处理逻辑
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
if args.pipeline_deprecated:
|
||||
warnings.warn(
|
||||
"--pipeline 参数已弃用,请使用 --flow",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not args.flow:
|
||||
args.flow = args.pipeline_deprecated
|
||||
|
||||
dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
|
||||
assert len(dep_warnings) == 1
|
||||
assert "--pipeline 参数已弃用" in str(dep_warnings[0].message)
|
||||
# 验证值已合并到 args.flow
|
||||
assert args.flow == "api_full"
|
||||
|
||||
def test_flow_and_pipeline_mutually_exclusive(self):
|
||||
"""--flow 和 --pipeline 不能同时指定(需求 9.3)
|
||||
|
||||
argparse 层允许同时传入,互斥由 main() 中的逻辑处理。
|
||||
"""
|
||||
with patch("sys.argv", ["cli", "--flow", "api_full", "--pipeline", "api_ods"]):
|
||||
args = parse_args()
|
||||
# 两者同时存在时,main() 应 sys.exit(2)
|
||||
assert args.flow == "api_full"
|
||||
assert args.pipeline_deprecated == "api_ods"
|
||||
|
||||
def test_layers_and_flow_mutually_exclusive(self):
|
||||
"""--layers 和 --flow 互斥(argparse 层可同时解析,main() 校验)"""
|
||||
with patch("sys.argv", ["cli", "--layers", "ODS,DWD", "--flow", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.layers == "ODS,DWD"
|
||||
assert args.flow == "api_full"
|
||||
|
||||
def test_layers_and_pipeline_deprecated_mutually_exclusive(self):
|
||||
"""--layers 和 --pipeline(弃用别名)也互斥
|
||||
|
||||
--pipeline 先合并到 args.flow,然后 --layers vs --flow 互斥生效。
|
||||
"""
|
||||
with patch("sys.argv", ["cli", "--layers", "ODS,DWD", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.layers == "ODS,DWD"
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
|
||||
def test_pipeline_value_merges_to_flow(self):
|
||||
"""--pipeline 的值在弃用处理后应合并到 args.flow"""
|
||||
with patch("sys.argv", ["cli", "--pipeline", "dwd_dws"]):
|
||||
args = parse_args()
|
||||
# 模拟 main() 中的合并逻辑
|
||||
if args.pipeline_deprecated and not args.flow:
|
||||
args.flow = args.pipeline_deprecated
|
||||
assert args.flow == "dwd_dws"
|
||||
323
apps/etl/connectors/feiqiu/tests/unit/test_maintenance_task.py
Normal file
323
apps/etl/connectors/feiqiu/tests/unit/test_maintenance_task.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
DwsMaintenanceTask — 属性测试 + 单元测试
|
||||
|
||||
Feature: etl-dws-flow-refactor, Property 4: DwsMaintenanceTask 配置控制
|
||||
测试位置:apps/etl/connectors/feiqiu/tests/unit/
|
||||
|
||||
属性测试:验证任意 mv_enabled / retention_enabled 布尔组合下,
|
||||
load() 的行为和返回值结构始终正确。
|
||||
|
||||
单元测试:执行顺序(先刷新后清理)、TaskRegistry 注册项替换。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
import hypothesis.strategies as st
|
||||
|
||||
# ── 将 ETL 模块加入 sys.path ──
|
||||
_ETL_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_ETL_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_ETL_ROOT))
|
||||
|
||||
from tasks.base_task import TaskContext
|
||||
from tasks.dws.maintenance_task import DwsMaintenanceTask
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助工具
|
||||
# ============================================================================
|
||||
|
||||
def _build_context(store_id: int = 1) -> TaskContext:
|
||||
"""构造最小合法 TaskContext。"""
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
return TaskContext(
|
||||
store_id=store_id,
|
||||
window_start=datetime(2025, 1, 1, tzinfo=tz),
|
||||
window_end=datetime(2025, 1, 2, tzinfo=tz),
|
||||
window_minutes=1440,
|
||||
)
|
||||
|
||||
|
||||
def _make_task(mv_enabled: bool = False, retention_enabled: bool = False) -> DwsMaintenanceTask:
|
||||
"""构造 DwsMaintenanceTask 实例,通过 config mock 控制开关。"""
|
||||
config_map = {
|
||||
"dws.mv.enabled": mv_enabled,
|
||||
"dws.retention.enabled": retention_enabled,
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
}
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: config_map.get(key, default)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
|
||||
task = DwsMaintenanceTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
return task
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 4: DwsMaintenanceTask 配置控制
|
||||
# Feature: etl-dws-flow-refactor, Property 4: DwsMaintenanceTask 配置控制
|
||||
# **Validates: Requirements 4.3, 4.4**
|
||||
# ============================================================================
|
||||
|
||||
class TestMaintenanceConfigProperty:
|
||||
"""属性测试:任意 mv_enabled / retention_enabled 布尔组合下,
|
||||
load() 返回值始终包含 refreshed 和 cleaned 键,
|
||||
且仅在对应开关启用时执行对应操作。
|
||||
"""
|
||||
|
||||
@given(
|
||||
mv_enabled=st.booleans(),
|
||||
retention_enabled=st.booleans(),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_load_always_returns_refreshed_and_cleaned(
|
||||
self, mv_enabled: bool, retention_enabled: bool
|
||||
):
|
||||
"""返回的统计字典始终包含 refreshed 和 cleaned 键。"""
|
||||
task = _make_task(mv_enabled=mv_enabled, retention_enabled=retention_enabled)
|
||||
context = _build_context()
|
||||
|
||||
# mock 内部方法,避免真实 DB 调用
|
||||
task._refresh_all_views = MagicMock(return_value=3)
|
||||
task._cleanup_all_tables = MagicMock(return_value=42)
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
|
||||
# 始终包含 counts 键
|
||||
assert "counts" in result
|
||||
counts = result["counts"]
|
||||
# 始终包含 refreshed 和 cleaned
|
||||
assert "refreshed" in counts
|
||||
assert "cleaned" in counts
|
||||
|
||||
@given(
|
||||
mv_enabled=st.booleans(),
|
||||
retention_enabled=st.booleans(),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_mv_refresh_only_when_enabled(
|
||||
self, mv_enabled: bool, retention_enabled: bool
|
||||
):
|
||||
"""仅当 mv_enabled=True 时执行物化视图刷新。"""
|
||||
task = _make_task(mv_enabled=mv_enabled, retention_enabled=retention_enabled)
|
||||
context = _build_context()
|
||||
|
||||
task._refresh_all_views = MagicMock(return_value=5)
|
||||
task._cleanup_all_tables = MagicMock(return_value=10)
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
counts = result["counts"]
|
||||
|
||||
if mv_enabled:
|
||||
task._refresh_all_views.assert_called_once()
|
||||
assert counts["refreshed"] == 5
|
||||
else:
|
||||
task._refresh_all_views.assert_not_called()
|
||||
assert counts["refreshed"] == 0
|
||||
|
||||
@given(
|
||||
mv_enabled=st.booleans(),
|
||||
retention_enabled=st.booleans(),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_retention_cleanup_only_when_enabled(
|
||||
self, mv_enabled: bool, retention_enabled: bool
|
||||
):
|
||||
"""仅当 retention_enabled=True 时执行数据清理。"""
|
||||
task = _make_task(mv_enabled=mv_enabled, retention_enabled=retention_enabled)
|
||||
context = _build_context()
|
||||
|
||||
task._refresh_all_views = MagicMock(return_value=2)
|
||||
task._cleanup_all_tables = MagicMock(return_value=99)
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
counts = result["counts"]
|
||||
|
||||
if retention_enabled:
|
||||
task._cleanup_all_tables.assert_called_once_with(context)
|
||||
assert counts["cleaned"] == 99
|
||||
else:
|
||||
task._cleanup_all_tables.assert_not_called()
|
||||
assert counts["cleaned"] == 0
|
||||
|
||||
@given(
|
||||
mv_enabled=st.booleans(),
|
||||
retention_enabled=st.booleans(),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_counts_are_non_negative_integers(
|
||||
self, mv_enabled: bool, retention_enabled: bool
|
||||
):
|
||||
"""refreshed 和 cleaned 值始终为非负整数。"""
|
||||
task = _make_task(mv_enabled=mv_enabled, retention_enabled=retention_enabled)
|
||||
context = _build_context()
|
||||
|
||||
task._refresh_all_views = MagicMock(return_value=7)
|
||||
task._cleanup_all_tables = MagicMock(return_value=33)
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
counts = result["counts"]
|
||||
|
||||
assert isinstance(counts["refreshed"], int)
|
||||
assert isinstance(counts["cleaned"], int)
|
||||
assert counts["refreshed"] >= 0
|
||||
assert counts["cleaned"] >= 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 单元测试:执行顺序(先刷新后清理)
|
||||
# ============================================================================
|
||||
|
||||
class TestMaintenanceExecutionOrder:
|
||||
"""验证 load() 内部先执行 MV 刷新,再执行数据清理。"""
|
||||
|
||||
def test_refresh_before_cleanup(self):
|
||||
"""当两个开关都启用时,MV 刷新在数据清理之前执行。"""
|
||||
task = _make_task(mv_enabled=True, retention_enabled=True)
|
||||
context = _build_context()
|
||||
|
||||
call_order = []
|
||||
|
||||
def mock_refresh():
|
||||
call_order.append("refresh")
|
||||
return 2
|
||||
|
||||
def mock_cleanup(ctx):
|
||||
call_order.append("cleanup")
|
||||
return 10
|
||||
|
||||
task._refresh_all_views = mock_refresh
|
||||
task._cleanup_all_tables = mock_cleanup
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
|
||||
assert call_order == ["refresh", "cleanup"], (
|
||||
f"期望先 refresh 后 cleanup,实际顺序: {call_order}"
|
||||
)
|
||||
assert result["counts"]["refreshed"] == 2
|
||||
assert result["counts"]["cleaned"] == 10
|
||||
|
||||
def test_only_refresh_when_retention_disabled(self):
|
||||
"""仅启用 MV 刷新时,不执行清理。"""
|
||||
task = _make_task(mv_enabled=True, retention_enabled=False)
|
||||
context = _build_context()
|
||||
|
||||
task._refresh_all_views = MagicMock(return_value=4)
|
||||
task._cleanup_all_tables = MagicMock(return_value=0)
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
|
||||
task._refresh_all_views.assert_called_once()
|
||||
task._cleanup_all_tables.assert_not_called()
|
||||
assert result["counts"] == {"refreshed": 4, "cleaned": 0}
|
||||
|
||||
def test_only_cleanup_when_mv_disabled(self):
|
||||
"""仅启用数据清理时,不执行 MV 刷新。"""
|
||||
task = _make_task(mv_enabled=False, retention_enabled=True)
|
||||
context = _build_context()
|
||||
|
||||
task._refresh_all_views = MagicMock(return_value=0)
|
||||
task._cleanup_all_tables = MagicMock(return_value=50)
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
|
||||
task._refresh_all_views.assert_not_called()
|
||||
task._cleanup_all_tables.assert_called_once()
|
||||
assert result["counts"] == {"refreshed": 0, "cleaned": 50}
|
||||
|
||||
def test_both_disabled_returns_zeros(self):
|
||||
"""两个开关都禁用时,返回全零统计。"""
|
||||
task = _make_task(mv_enabled=False, retention_enabled=False)
|
||||
context = _build_context()
|
||||
|
||||
task._refresh_all_views = MagicMock()
|
||||
task._cleanup_all_tables = MagicMock()
|
||||
|
||||
result = task.load({"site_id": 1}, context)
|
||||
|
||||
task._refresh_all_views.assert_not_called()
|
||||
task._cleanup_all_tables.assert_not_called()
|
||||
assert result["counts"] == {"refreshed": 0, "cleaned": 0}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 单元测试:TaskRegistry 注册项替换
|
||||
# ============================================================================
|
||||
|
||||
class TestRegistryReplacement:
|
||||
"""验证 TaskRegistry 中旧的 3 个任务已移除,DWS_MAINTENANCE 已注册。"""
|
||||
|
||||
def test_dws_maintenance_registered(self):
|
||||
"""DWS_MAINTENANCE 已注册到 default_registry。"""
|
||||
from orchestration.task_registry import default_registry
|
||||
meta = default_registry.get_metadata("DWS_MAINTENANCE")
|
||||
assert meta is not None, "DWS_MAINTENANCE 未在 default_registry 中注册"
|
||||
assert meta.task_class is DwsMaintenanceTask
|
||||
assert meta.layer == "DWS"
|
||||
|
||||
def test_old_mv_refresh_tasks_removed(self):
|
||||
"""原 DWS_MV_REFRESH_FINANCE_DAILY 和 DWS_MV_REFRESH_ASSISTANT_DAILY 已移除。"""
|
||||
from orchestration.task_registry import default_registry
|
||||
all_codes = default_registry.get_all_task_codes()
|
||||
assert "DWS_MV_REFRESH_FINANCE_DAILY" not in all_codes, (
|
||||
"DWS_MV_REFRESH_FINANCE_DAILY 应已被移除"
|
||||
)
|
||||
assert "DWS_MV_REFRESH_ASSISTANT_DAILY" not in all_codes, (
|
||||
"DWS_MV_REFRESH_ASSISTANT_DAILY 应已被移除"
|
||||
)
|
||||
|
||||
def test_old_retention_cleanup_removed(self):
|
||||
"""原 DWS_RETENTION_CLEANUP 已移除。"""
|
||||
from orchestration.task_registry import default_registry
|
||||
all_codes = default_registry.get_all_task_codes()
|
||||
assert "DWS_RETENTION_CLEANUP" not in all_codes, (
|
||||
"DWS_RETENTION_CLEANUP 应已被移除"
|
||||
)
|
||||
|
||||
def test_maintenance_in_dws_layer(self):
|
||||
"""DWS_MAINTENANCE 出现在 DWS 层任务列表中。"""
|
||||
from orchestration.task_registry import default_registry
|
||||
dws_tasks = default_registry.get_tasks_by_layer("DWS")
|
||||
assert "DWS_MAINTENANCE" in dws_tasks
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 单元测试:基本属性
|
||||
# ============================================================================
|
||||
|
||||
class TestMaintenanceBasicAttributes:
|
||||
"""验证 DwsMaintenanceTask 的基本属性。"""
|
||||
|
||||
def test_task_code(self):
|
||||
task = _make_task()
|
||||
assert task.get_task_code() == "DWS_MAINTENANCE"
|
||||
|
||||
def test_target_table(self):
|
||||
task = _make_task()
|
||||
assert task.get_target_table() == "dws_maintenance"
|
||||
|
||||
def test_extract_returns_site_id(self):
|
||||
"""extract() 返回包含 site_id 的字典。"""
|
||||
task = _make_task()
|
||||
context = _build_context(store_id=42)
|
||||
result = task.extract(context)
|
||||
assert result == {"site_id": 42}
|
||||
|
||||
def test_transform_passthrough(self):
|
||||
"""transform() 直接透传输入。"""
|
||||
task = _make_task()
|
||||
context = _build_context()
|
||||
data = {"site_id": 1, "extra": "value"}
|
||||
result = task.transform(data, context)
|
||||
assert result == data
|
||||
1033
apps/etl/connectors/feiqiu/tests/unit/test_ods_dedup_properties.py
Normal file
1033
apps/etl/connectors/feiqiu/tests/unit/test_ods_dedup_properties.py
Normal file
File diff suppressed because it is too large
Load Diff
161
apps/etl/connectors/feiqiu/tests/unit/test_ods_tasks.py
Normal file
161
apps/etl/connectors/feiqiu/tests/unit/test_ods_tasks.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the new ODS ingestion tasks."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 确保在独立运行测试时能正确解析项目根目录
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.ods.ods_tasks import ODS_TASK_CLASSES
|
||||
from .task_test_utils import create_test_config, get_db_operations, FakeAPIClient
|
||||
|
||||
|
||||
def _build_config(tmp_path):
|
||||
archive_dir = tmp_path / "archive"
|
||||
temp_dir = tmp_path / "temp"
|
||||
return create_test_config("ONLINE", archive_dir, temp_dir)
|
||||
|
||||
|
||||
def test_assistant_accounts_masters_ingest(tmp_path):
|
||||
"""Ensure ODS_ASSISTANT_ACCOUNT stores raw payload with record_index dedup keys."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [
|
||||
{
|
||||
"id": 5001,
|
||||
"assistant_no": "A01",
|
||||
"nickname": "小张",
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/PersonnelManagement/SearchAssistantInfo": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_ASSISTANT_ACCOUNT"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_assistant_accounts_masters"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["counts"]["fetched"] == 1
|
||||
assert db_ops.commits == 1
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["id"] == 5001
|
||||
assert row["record_index"] == 0
|
||||
assert row["source_file"] is None or row["source_file"] == ""
|
||||
assert '"id": 5001' in row["payload"]
|
||||
|
||||
|
||||
def test_goods_stock_movements_ingest(tmp_path):
|
||||
"""Ensure ODS_INVENTORY_CHANGE stores raw payload with record_index dedup keys."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [
|
||||
{
|
||||
"siteGoodsStockId": 123456,
|
||||
"stockType": 1,
|
||||
"goodsName": "测试商品",
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/GoodsStockManage/QueryGoodsOutboundReceipt": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_INVENTORY_CHANGE"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_goods_stock_movements"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["counts"]["fetched"] == 1
|
||||
assert db_ops.commits == 1
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["sitegoodsstockid"] == 123456
|
||||
assert row["record_index"] == 0
|
||||
assert '"siteGoodsStockId": 123456' in row["payload"]
|
||||
|
||||
|
||||
def test_member_profiless_ingest(tmp_path):
|
||||
"""Ensure ODS_MEMBER task stores tenantMemberInfos raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"tenantMemberInfos": [{"id": 101, "mobile": "13800000000"}]}]
|
||||
api = FakeAPIClient({"/MemberProfile/GetTenantMemberList": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_MEMBER"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_member"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["record_index"] == 0
|
||||
assert '"id": 101' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_payment_ingest(tmp_path):
|
||||
"""Ensure ODS_PAYMENT task stores payment_transactions raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"payId": 901, "payAmount": "100.00"}]
|
||||
api = FakeAPIClient({"/PayLog/GetPayLogListPage": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_PAYMENT"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_payment"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["record_index"] == 0
|
||||
assert '"payId": 901' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_settlement_records_ingest(tmp_path):
|
||||
"""Ensure ODS_SETTLEMENT_RECORDS stores settleList raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"id": 701, "orderTradeNo": 8001}]
|
||||
api = FakeAPIClient({"/Site/GetAllOrderSettleList": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_RECORDS"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_settlement_records"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["record_index"] == 0
|
||||
assert '"orderTradeNo": 8001' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
|
||||
"""Ensure settlement tickets are fetched per payment relate_id and skip existing ones."""
|
||||
config = _build_config(tmp_path)
|
||||
ticket_payload = {"data": {"data": {"orderSettleId": 9001, "orderSettleNumber": "T001"}}}
|
||||
api = FakeAPIClient({"/Order/GetOrderSettleTicketNew": [ticket_payload]})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
# 第一次查询:已有的小票ID;第二次查询:支付关联ID
|
||||
db_ops.query_results = [
|
||||
[{"order_settle_id": 9002}],
|
||||
[
|
||||
{"order_settle_id": 9001},
|
||||
{"order_settle_id": 9002},
|
||||
{"order_settle_id": None},
|
||||
],
|
||||
]
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_settlement_ticket"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
counts = result["counts"]
|
||||
assert counts["fetched"] == 1
|
||||
assert counts["inserted"] == 1
|
||||
assert counts["updated"] == 0
|
||||
assert counts["skipped"] == 0
|
||||
assert '"orderSettleId": 9001' in db_ops.upserts[0]["rows"][0]["payload"]
|
||||
assert any(
|
||||
call["endpoint"] == "/Order/GetOrderSettleTicketNew"
|
||||
and call.get("params", {}).get("orderSettleId") == 9001
|
||||
for call in api.calls
|
||||
)
|
||||
|
||||
39
apps/etl/connectors/feiqiu/tests/unit/test_parsers.py
Normal file
39
apps/etl/connectors/feiqiu/tests/unit/test_parsers.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""解析器测试"""
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
from models.parsers import TypeParser
|
||||
|
||||
def test_parse_decimal():
|
||||
"""测试金额解析"""
|
||||
assert TypeParser.parse_decimal("100.555", 2) == Decimal("100.56")
|
||||
assert TypeParser.parse_decimal(None) is None
|
||||
assert TypeParser.parse_decimal("invalid") is None
|
||||
|
||||
def test_parse_int():
|
||||
"""测试整数解析"""
|
||||
assert TypeParser.parse_int("123") == 123
|
||||
assert TypeParser.parse_int(456) == 456
|
||||
assert TypeParser.parse_int(None) is None
|
||||
assert TypeParser.parse_int("abc") is None
|
||||
|
||||
def test_parse_timestamp():
|
||||
"""测试时间戳解析"""
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
dt = TypeParser.parse_timestamp("2025-01-15 10:30:00", tz)
|
||||
assert dt is not None
|
||||
assert dt.year == 2025
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
|
||||
|
||||
def test_parse_timestamp_zero_epoch():
|
||||
"""0 不应被当成空值;应解析为 Unix epoch。"""
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
dt = TypeParser.parse_timestamp(0, tz)
|
||||
assert dt is not None
|
||||
assert dt.year == 1970
|
||||
assert dt.month == 1
|
||||
assert dt.day == 1
|
||||
@@ -0,0 +1,449 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""FlowRunner 属性测试 - hypothesis 验证 Flow 编排器的通用正确性属性。"""
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.flow_runner import FlowRunner
|
||||
|
||||
# run() 内部延迟导入 TaskLogger,需要 mock 源模块路径
|
||||
_TASK_LOGGER_PATH = "utils.task_logger.TaskLogger"
|
||||
|
||||
FILE_VERSION = "v1_shell"
|
||||
|
||||
# ── 策略定义 ──────────────────────────────────────────────────────
|
||||
|
||||
flow_name_st = st.sampled_from(list(FlowRunner.FLOW_LAYERS.keys()))
|
||||
|
||||
processing_mode_st = st.sampled_from(["increment_only", "verify_only", "increment_verify"])
|
||||
|
||||
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
|
||||
|
||||
_TASK_PREFIXES = ["ODS_", "DWD_", "DWS_", "INDEX_"]
|
||||
task_code_st = st.builds(
|
||||
lambda prefix, suffix: prefix + suffix,
|
||||
prefix=st.sampled_from(_TASK_PREFIXES),
|
||||
suffix=st.text(
|
||||
alphabet=string.ascii_uppercase + string.digits + "_",
|
||||
min_size=1, max_size=12,
|
||||
),
|
||||
)
|
||||
|
||||
# 单任务结果生成器
|
||||
task_result_st = st.fixed_dictionaries({
|
||||
"task_code": task_code_st,
|
||||
"status": st.sampled_from(["SUCCESS", "FAIL", "SKIP"]),
|
||||
"counts": st.fixed_dictionaries({
|
||||
"fetched": st.integers(min_value=0, max_value=10000),
|
||||
"inserted": st.integers(min_value=0, max_value=10000),
|
||||
"updated": st.integers(min_value=0, max_value=10000),
|
||||
"skipped": st.integers(min_value=0, max_value=10000),
|
||||
"errors": st.integers(min_value=0, max_value=100),
|
||||
}),
|
||||
"dump_dir": st.none(),
|
||||
})
|
||||
|
||||
task_results_st = st.lists(task_result_st, min_size=0, max_size=10)
|
||||
|
||||
|
||||
# ── 辅助函数 ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_config():
|
||||
"""创建 mock 配置对象。"""
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: {
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
"verification.ods_use_local_json": False,
|
||||
"verification.skip_ods_when_fetch_before_verify": True,
|
||||
"run.ods_tasks": [],
|
||||
"run.dws_tasks": [],
|
||||
"run.index_tasks": [],
|
||||
}.get(key, default))
|
||||
return config
|
||||
|
||||
|
||||
def _make_runner(task_executor=None, task_registry=None):
|
||||
"""创建 FlowRunner 实例,注入 mock 依赖。"""
|
||||
if task_executor is None:
|
||||
task_executor = MagicMock()
|
||||
task_executor.run_tasks.return_value = []
|
||||
if task_registry is None:
|
||||
task_registry = MagicMock()
|
||||
task_registry.get_tasks_by_layer.return_value = ["FAKE_TASK"]
|
||||
return FlowRunner(
|
||||
config=_make_config(),
|
||||
task_executor=task_executor,
|
||||
task_registry=task_registry,
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
# ── Property 5: Flow 名称→层列表映射 ──────────────────────────────
|
||||
# Feature: scheduler-refactor, Property 5: Flow 名称→层列表映射
|
||||
# **Validates: Requirements 2.1**
|
||||
|
||||
|
||||
class TestProperty5FlowNameToLayers:
|
||||
"""对于任意有效的 Flow 名称,FlowRunner 解析出的层列表应与
|
||||
FLOW_LAYERS 字典中的定义完全一致。"""
|
||||
|
||||
@given(pipeline=flow_name_st)
|
||||
@settings(max_examples=100)
|
||||
def test_layers_match_flow_definition(self, pipeline):
|
||||
"""run() 返回的 layers 字段与 FLOW_LAYERS[pipeline] 完全一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[pipeline]
|
||||
assert result["layers"] == expected_layers
|
||||
|
||||
@given(pipeline=flow_name_st)
|
||||
@settings(max_examples=100)
|
||||
def test_resolve_tasks_called_with_correct_layers(self, pipeline):
|
||||
"""_resolve_tasks 接收的层列表与 FLOW_LAYERS 定义一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with (
|
||||
patch(_TASK_LOGGER_PATH),
|
||||
patch.object(runner, "_resolve_tasks", wraps=runner._resolve_tasks) as spy,
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[pipeline]
|
||||
spy.assert_called_once_with(expected_layers)
|
||||
|
||||
|
||||
# ── Property 6: processing_mode 控制执行流程 ─────────────────────
|
||||
# Feature: scheduler-refactor, Property 6: processing_mode 控制执行流程
|
||||
# **Validates: Requirements 2.3, 2.4**
|
||||
|
||||
|
||||
class TestProperty6ProcessingModeControlsFlow:
|
||||
"""对于任意 processing_mode,增量 ETL 执行当且仅当模式包含 increment,
|
||||
校验流程执行当且仅当模式包含 verify。"""
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
mode=processing_mode_st,
|
||||
data_source=data_source_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_increment_executes_iff_mode_contains_increment(self, pipeline, mode, data_source):
|
||||
"""增量 ETL(task_executor.run_tasks)执行当且仅当 mode 包含 'increment'。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with (
|
||||
patch(_TASK_LOGGER_PATH),
|
||||
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}),
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode=mode,
|
||||
data_source=data_source,
|
||||
)
|
||||
|
||||
should_increment = "increment" in mode
|
||||
if should_increment:
|
||||
assert executor.run_tasks.called, (
|
||||
f"mode={mode} 包含 'increment',但 run_tasks 未被调用"
|
||||
)
|
||||
else:
|
||||
# verify_only 且 fetch_before_verify=False(默认),run_tasks 不应被调用
|
||||
assert not executor.run_tasks.called, (
|
||||
f"mode={mode} 不包含 'increment',但 run_tasks 被调用了"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
mode=processing_mode_st,
|
||||
data_source=data_source_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_verification_executes_iff_mode_contains_verify(self, pipeline, mode, data_source):
|
||||
"""校验流程(_run_verification)执行当且仅当 mode 包含 'verify'。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with (
|
||||
patch(_TASK_LOGGER_PATH),
|
||||
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}) as mock_verify,
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode=mode,
|
||||
data_source=data_source,
|
||||
)
|
||||
|
||||
should_verify = "verify" in mode
|
||||
if should_verify:
|
||||
assert mock_verify.called, (
|
||||
f"mode={mode} 包含 'verify',但 _run_verification 未被调用"
|
||||
)
|
||||
else:
|
||||
assert not mock_verify.called, (
|
||||
f"mode={mode} 不包含 'verify',但 _run_verification 被调用了"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 7: Flow 结果汇总完整性 ──────────────────────────────
|
||||
# Feature: scheduler-refactor, Property 7: Flow 结果汇总完整性
|
||||
# **Validates: Requirements 2.6**
|
||||
|
||||
|
||||
class TestProperty7FlowSummaryCompleteness:
|
||||
"""对于任意一组任务执行结果,FlowRunner 返回的汇总字典应包含
|
||||
status/pipeline/layers/results 字段,且 results 长度等于实际执行的任务数。
|
||||
(返回字典中 pipeline 键名保留以兼容下游消费方)"""
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_summary_has_required_fields(self, pipeline, task_results):
|
||||
"""返回字典必须包含 status、pipeline、layers、results、verification_summary。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
required_keys = {"status", "pipeline", "layers", "results", "verification_summary"}
|
||||
assert required_keys.issubset(result.keys()), (
|
||||
f"缺少必要字段: {required_keys - result.keys()}"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_results_length_equals_executed_tasks(self, pipeline, task_results):
|
||||
"""results 列表长度等于 task_executor.run_tasks 返回的任务数。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert len(result["results"]) == len(task_results), (
|
||||
f"results 长度 {len(result['results'])} != 实际任务数 {len(task_results)}"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_and_layers_match_input(self, pipeline, task_results):
|
||||
"""返回的 flow 标识和 layers 字段与输入一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert result["pipeline"] == pipeline
|
||||
assert result["layers"] == FlowRunner.FLOW_LAYERS[pipeline]
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_increment_only_has_no_verification(self, pipeline, task_results):
|
||||
"""increment_only 模式下 verification_summary 应为 None。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert result["verification_summary"] is None
|
||||
|
||||
|
||||
# ── DWS/INDEX 层轻量级校验(需求 6.5)──────────────────────────────
|
||||
# Feature: etl-dws-flow-refactor, Task 5.5: DWS/INDEX 层跳过完整性校验
|
||||
# **Validates: Requirements 6.5**
|
||||
|
||||
|
||||
class TestDwsIndexLightweightVerification:
|
||||
"""DWS/INDEX 层在 _run_verification 中应跳过完整性校验。"""
|
||||
|
||||
def _make_verification_runner(self):
|
||||
"""创建用于校验测试的 runner,mock 掉 verification 模块。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
registry = MagicMock()
|
||||
registry.get_tasks_by_layer.return_value = ["FAKE_TASK"]
|
||||
runner = _make_runner(task_executor=executor, task_registry=registry)
|
||||
return runner
|
||||
|
||||
@patch("orchestration.flow_runner.filter_verify_tables", return_value=None)
|
||||
def test_dws_layer_skipped_in_verification(self, _mock_filter):
|
||||
"""DWS 层校验应返回 SKIPPED + lightweight_dws_index 原因。"""
|
||||
runner = self._make_verification_runner()
|
||||
|
||||
with patch(
|
||||
"tasks.verification.get_verifier_for_layer"
|
||||
) as mock_verifier, patch(
|
||||
"tasks.verification.build_window_segments", return_value=[]
|
||||
):
|
||||
result = runner._run_verification(
|
||||
layers=["DWS"],
|
||||
window_start=datetime(2025, 1, 1),
|
||||
window_end=datetime(2025, 1, 2),
|
||||
)
|
||||
|
||||
# get_verifier_for_layer 不应被调用(DWS 被跳过)
|
||||
mock_verifier.assert_not_called()
|
||||
assert result["layers"]["DWS"]["status"] == "SKIPPED"
|
||||
assert result["layers"]["DWS"]["reason"] == "lightweight_dws_index"
|
||||
|
||||
@patch("orchestration.flow_runner.filter_verify_tables", return_value=None)
|
||||
def test_index_layer_skipped_in_verification(self, _mock_filter):
|
||||
"""INDEX 层校验应返回 SKIPPED + lightweight_dws_index 原因。"""
|
||||
runner = self._make_verification_runner()
|
||||
|
||||
with patch(
|
||||
"tasks.verification.get_verifier_for_layer"
|
||||
) as mock_verifier, patch(
|
||||
"tasks.verification.build_window_segments", return_value=[]
|
||||
):
|
||||
result = runner._run_verification(
|
||||
layers=["INDEX"],
|
||||
window_start=datetime(2025, 1, 1),
|
||||
window_end=datetime(2025, 1, 2),
|
||||
)
|
||||
|
||||
mock_verifier.assert_not_called()
|
||||
assert result["layers"]["INDEX"]["status"] == "SKIPPED"
|
||||
assert result["layers"]["INDEX"]["reason"] == "lightweight_dws_index"
|
||||
|
||||
@patch("orchestration.flow_runner.filter_verify_tables", return_value=None)
|
||||
def test_ods_dwd_not_skipped(self, _mock_filter):
|
||||
"""ODS/DWD 层不应被轻量级校验逻辑跳过。"""
|
||||
runner = self._make_verification_runner()
|
||||
|
||||
mock_verifier_instance = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.to_dict.return_value = {"status": "OK"}
|
||||
mock_summary.total_tables = 1
|
||||
mock_summary.consistent_tables = 1
|
||||
mock_summary.total_backfilled = 0
|
||||
mock_summary.error_tables = 0
|
||||
mock_verifier_instance.verify_and_backfill.return_value = mock_summary
|
||||
|
||||
with patch(
|
||||
"tasks.verification.get_verifier_for_layer",
|
||||
return_value=mock_verifier_instance,
|
||||
), patch(
|
||||
"tasks.verification.build_window_segments", return_value=[]
|
||||
):
|
||||
result = runner._run_verification(
|
||||
layers=["DWD"],
|
||||
window_start=datetime(2025, 1, 1),
|
||||
window_end=datetime(2025, 1, 2),
|
||||
)
|
||||
|
||||
# DWD 层应正常执行校验,不被跳过
|
||||
assert result["layers"]["DWD"]["status"] != "SKIPPED"
|
||||
|
||||
@patch("orchestration.flow_runner.filter_verify_tables", return_value=None)
|
||||
def test_mixed_layers_dws_index_skipped_others_verified(self, _mock_filter):
|
||||
"""混合层列表中,DWS/INDEX 被跳过,其他层正常校验。"""
|
||||
runner = self._make_verification_runner()
|
||||
|
||||
mock_verifier_instance = MagicMock()
|
||||
mock_summary = MagicMock()
|
||||
mock_summary.to_dict.return_value = {"status": "OK"}
|
||||
mock_summary.total_tables = 1
|
||||
mock_summary.consistent_tables = 1
|
||||
mock_summary.total_backfilled = 0
|
||||
mock_summary.error_tables = 0
|
||||
mock_verifier_instance.verify_and_backfill.return_value = mock_summary
|
||||
|
||||
with patch(
|
||||
"tasks.verification.get_verifier_for_layer",
|
||||
return_value=mock_verifier_instance,
|
||||
), patch(
|
||||
"tasks.verification.build_window_segments", return_value=[]
|
||||
):
|
||||
result = runner._run_verification(
|
||||
layers=["DWD", "DWS", "INDEX"],
|
||||
window_start=datetime(2025, 1, 1),
|
||||
window_end=datetime(2025, 1, 2),
|
||||
)
|
||||
|
||||
# DWS 和 INDEX 被跳过
|
||||
assert result["layers"]["DWS"]["status"] == "SKIPPED"
|
||||
assert result["layers"]["DWS"]["reason"] == "lightweight_dws_index"
|
||||
assert result["layers"]["INDEX"]["status"] == "SKIPPED"
|
||||
assert result["layers"]["INDEX"]["reason"] == "lightweight_dws_index"
|
||||
# DWD 正常校验
|
||||
assert result["layers"]["DWD"]["status"] != "SKIPPED"
|
||||
|
||||
@patch("orchestration.flow_runner.filter_verify_tables", return_value=None)
|
||||
def test_dws_index_skip_logs_message(self, _mock_filter):
|
||||
"""DWS/INDEX 跳过时应记录日志。"""
|
||||
runner = self._make_verification_runner()
|
||||
|
||||
with patch(
|
||||
"tasks.verification.get_verifier_for_layer"
|
||||
), patch(
|
||||
"tasks.verification.build_window_segments", return_value=[]
|
||||
):
|
||||
runner._run_verification(
|
||||
layers=["DWS"],
|
||||
window_start=datetime(2025, 1, 1),
|
||||
window_end=datetime(2025, 1, 2),
|
||||
)
|
||||
|
||||
# 验证日志被调用,包含轻量级校验信息
|
||||
log_calls = [str(c) for c in runner.logger.info.call_args_list]
|
||||
assert any("轻量级校验" in c for c in log_calls), (
|
||||
f"未找到轻量级校验日志,实际日志: {log_calls}"
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""关系指数基础能力单测。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tasks.dws.index.base_index_task import BaseIndexTask
|
||||
from tasks.dws.index.ml_manual_import_task import MlManualImportTask
|
||||
|
||||
|
||||
class _DummyConfig:
|
||||
"""最小配置桩对象。"""
|
||||
|
||||
def __init__(self, values: Optional[Dict[str, Any]] = None):
|
||||
self._values = values or {}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
class _DummyDB:
|
||||
"""最小数据库桩对象。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.query_calls: List[tuple] = []
|
||||
|
||||
def query(self, sql: str, params=None):
|
||||
self.query_calls.append((sql, params))
|
||||
index_type = (params or [None])[0]
|
||||
if index_type == "RS":
|
||||
return [{"param_name": "lookback_days", "param_value": 60}]
|
||||
if index_type == "MS":
|
||||
return [{"param_name": "lookback_days", "param_value": 30}]
|
||||
return []
|
||||
|
||||
|
||||
class _DummyIndexTask(BaseIndexTask):
|
||||
"""用于测试 BaseIndexTask 的最小实现。"""
|
||||
|
||||
INDEX_TYPE = "RS"
|
||||
|
||||
def get_task_code(self) -> str: # pragma: no cover - 测试桩
|
||||
return "DUMMY_INDEX"
|
||||
|
||||
def get_target_table(self) -> str: # pragma: no cover - 测试桩
|
||||
return "dummy_table"
|
||||
|
||||
def get_primary_keys(self) -> List[str]: # pragma: no cover - 测试桩
|
||||
return ["id"]
|
||||
|
||||
def get_index_type(self) -> str:
|
||||
return self.INDEX_TYPE
|
||||
|
||||
def extract(self, context): # pragma: no cover - 测试桩
|
||||
return []
|
||||
|
||||
def load(self, transformed, context): # pragma: no cover - 测试桩
|
||||
return {}
|
||||
|
||||
|
||||
def test_load_index_parameters_cache_isolated_by_index_type():
|
||||
"""参数缓存应按 index_type 隔离,避免单任务串参。"""
|
||||
task = _DummyIndexTask(
|
||||
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
|
||||
_DummyDB(),
|
||||
None,
|
||||
logging.getLogger("test_index_cache"),
|
||||
)
|
||||
|
||||
rs_first = task.load_index_parameters(index_type="RS")
|
||||
ms_first = task.load_index_parameters(index_type="MS")
|
||||
rs_second = task.load_index_parameters(index_type="RS")
|
||||
|
||||
assert rs_first["lookback_days"] == 60.0
|
||||
assert ms_first["lookback_days"] == 30.0
|
||||
assert rs_second["lookback_days"] == 60.0
|
||||
# 只应查询两次:RS 一次 + MS 一次,第二次 RS 命中缓存
|
||||
assert len(task.db.query_calls) == 2
|
||||
|
||||
|
||||
def test_batch_normalize_passes_index_type_to_smoothing_chain():
|
||||
"""batch_normalize_to_display 应把 index_type 传入平滑链路。"""
|
||||
task = _DummyIndexTask(
|
||||
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
|
||||
_DummyDB(),
|
||||
None,
|
||||
logging.getLogger("test_index_smoothing"),
|
||||
)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def _fake_apply(site_id, current_p5, current_p95, alpha=None, index_type=None):
|
||||
captured["index_type"] = index_type
|
||||
return current_p5, current_p95
|
||||
|
||||
task._apply_ewma_smoothing = _fake_apply # type: ignore[method-assign]
|
||||
|
||||
result = task.batch_normalize_to_display(
|
||||
raw_scores=[("a", 1.0), ("b", 2.0), ("c", 3.0)],
|
||||
use_smoothing=True,
|
||||
site_id=1,
|
||||
index_type="ML",
|
||||
)
|
||||
|
||||
assert result
|
||||
assert captured["index_type"] == "ML"
|
||||
|
||||
|
||||
def test_ml_manual_import_scope_day_and_p30_boundary():
|
||||
"""30天边界内按天覆盖,超过30天进入固定纪元30天桶。"""
|
||||
today = date(2026, 2, 8)
|
||||
|
||||
day_scope = MlManualImportTask.resolve_scope(
|
||||
site_id=1,
|
||||
biz_date=date(2026, 1, 9), # 距 today 30 天
|
||||
today=today,
|
||||
)
|
||||
assert day_scope.scope_type == "DAY"
|
||||
assert day_scope.start_date == date(2026, 1, 9)
|
||||
assert day_scope.end_date == date(2026, 1, 9)
|
||||
|
||||
p30_scope = MlManualImportTask.resolve_scope(
|
||||
site_id=1,
|
||||
biz_date=date(2026, 1, 8), # 距 today 31 天
|
||||
today=today,
|
||||
)
|
||||
assert p30_scope.scope_type == "P30"
|
||||
# 固定纪元 2026-01-01,第一桶应为 2026-01-01 ~ 2026-01-30
|
||||
assert p30_scope.start_date == date(2026, 1, 1)
|
||||
assert p30_scope.end_date == date(2026, 1, 30)
|
||||
22
apps/etl/connectors/feiqiu/tests/unit/test_reporting.py
Normal file
22
apps/etl/connectors/feiqiu/tests/unit/test_reporting.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""汇总与报告工具的单测。"""
|
||||
from utils.reporting import summarize_counts, format_report
|
||||
|
||||
|
||||
def test_summarize_counts_and_format():
|
||||
task_results = [
|
||||
{"task_code": "ORDERS", "counts": {"fetched": 2, "inserted": 2, "updated": 0, "skipped": 0, "errors": 0}},
|
||||
{"task_code": "PAYMENTS", "counts": {"fetched": 3, "inserted": 2, "updated": 1, "skipped": 0, "errors": 0}},
|
||||
]
|
||||
|
||||
summary = summarize_counts(task_results)
|
||||
assert summary["total"]["fetched"] == 5
|
||||
assert summary["total"]["inserted"] == 4
|
||||
assert summary["total"]["updated"] == 1
|
||||
assert summary["total"]["errors"] == 0
|
||||
assert len(summary["details"]) == 2
|
||||
|
||||
report = format_report(summary)
|
||||
assert "TOTAL fetched=5" in report
|
||||
assert "ORDERS:" in report
|
||||
assert "PAYMENTS:" in report
|
||||
189
apps/etl/connectors/feiqiu/tests/unit/test_resolve_tasks.py
Normal file
189
apps/etl/connectors/feiqiu/tests/unit/test_resolve_tasks.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Property 6 + 单元测试:_resolve_tasks() 配置优先级
|
||||
|
||||
# Feature: etl-dws-flow-refactor, Property 6: 配置优先级
|
||||
# 验证:配置值优先于 Registry;配置为空时回退到 Registry;两者皆空返回空列表。
|
||||
"""
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.flow_runner import FlowRunner
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:构造最小可用的 FlowRunner 实例
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LAYERS = ["ODS", "DWD", "DWS", "INDEX"]
|
||||
|
||||
# 配置键映射(与 _resolve_tasks 内部一致)
|
||||
_LAYER_CONFIG_KEY = {
|
||||
"ODS": "run.ods_tasks",
|
||||
"DWD": "run.dwd_tasks",
|
||||
"DWS": "run.dws_tasks",
|
||||
"INDEX": "run.index_tasks",
|
||||
}
|
||||
|
||||
|
||||
class _DictConfig:
|
||||
"""模拟 AppConfig.get(),支持点号路径查找。"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self._data = data
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
keys = key.split(".")
|
||||
node = self._data
|
||||
for k in keys:
|
||||
if isinstance(node, dict) and k in node:
|
||||
node = node[k]
|
||||
else:
|
||||
return default
|
||||
return node
|
||||
|
||||
|
||||
def _make_runner(config_data: dict, registry_by_layer: dict[str, list[str]]) -> FlowRunner:
|
||||
"""构造一个最小 FlowRunner,仅满足 _resolve_tasks 所需依赖。"""
|
||||
config = _DictConfig(config_data)
|
||||
|
||||
registry = MagicMock()
|
||||
registry.get_tasks_by_layer = MagicMock(
|
||||
side_effect=lambda layer: registry_by_layer.get(layer.upper(), [])
|
||||
)
|
||||
|
||||
runner = object.__new__(FlowRunner)
|
||||
runner.config = config
|
||||
runner.task_registry = registry
|
||||
runner.logger = logging.getLogger("test_resolve_tasks")
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hypothesis 策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 任务代码:大写蛇形,1~3 段
|
||||
_task_code = st.from_regex(r"[A-Z][A-Z0-9_]{2,20}", fullmatch=True)
|
||||
|
||||
_task_list = st.lists(_task_code, min_size=1, max_size=8, unique=True)
|
||||
|
||||
_layer = st.sampled_from(_LAYERS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6: 配置优先级——配置值优先于 Registry
|
||||
# Feature: etl-dws-flow-refactor, Property 6: 配置优先级
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveTasksProperty:
|
||||
"""Property 6: 配置优先级——配置值优先于 Registry"""
|
||||
|
||||
@given(layer=_layer, config_tasks=_task_list, registry_tasks=_task_list)
|
||||
@settings(max_examples=100)
|
||||
def test_config_overrides_registry(
|
||||
self, layer: str, config_tasks: list[str], registry_tasks: list[str]
|
||||
):
|
||||
"""当配置非空时,_resolve_tasks 返回配置中的任务,忽略 Registry。
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
# 确保配置和 Registry 不同,以区分来源
|
||||
assume(set(config_tasks) != set(registry_tasks))
|
||||
|
||||
config_key = _LAYER_CONFIG_KEY[layer]
|
||||
# 构造嵌套 dict:如 "run.ods_tasks" → {"run": {"ods_tasks": [...]}}
|
||||
parts = config_key.split(".")
|
||||
config_data = {parts[0]: {parts[1]: config_tasks}}
|
||||
|
||||
runner = _make_runner(config_data, {layer: registry_tasks})
|
||||
result = runner._resolve_tasks([layer])
|
||||
|
||||
assert result == config_tasks
|
||||
|
||||
@given(layer=_layer, registry_tasks=_task_list)
|
||||
@settings(max_examples=100)
|
||||
def test_empty_config_falls_back_to_registry(
|
||||
self, layer: str, registry_tasks: list[str]
|
||||
):
|
||||
"""当配置为空时,_resolve_tasks 回退到 Registry 结果。
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
runner = _make_runner({}, {layer: registry_tasks})
|
||||
result = runner._resolve_tasks([layer])
|
||||
|
||||
assert result == registry_tasks
|
||||
|
||||
@given(layer=_layer)
|
||||
@settings(max_examples=100)
|
||||
def test_both_empty_returns_empty(self, layer: str):
|
||||
"""当配置和 Registry 均为空时,返回空列表。
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
runner = _make_runner({}, {layer: []})
|
||||
result = runner._resolve_tasks([layer])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 单元测试:具体示例和边界条件
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveTasksUnit:
|
||||
"""_resolve_tasks 配置优先级——具体示例"""
|
||||
|
||||
def test_ods_config_overrides_registry(self):
|
||||
"""ODS 层:配置指定任务时忽略 Registry。"""
|
||||
runner = _make_runner(
|
||||
{"run": {"ods_tasks": ["TASK_A", "TASK_B"]}},
|
||||
{"ODS": ["TASK_X", "TASK_Y"]},
|
||||
)
|
||||
assert runner._resolve_tasks(["ODS"]) == ["TASK_A", "TASK_B"]
|
||||
|
||||
def test_dwd_config_overrides_registry(self):
|
||||
"""DWD 层:配置指定任务时忽略 Registry。"""
|
||||
runner = _make_runner(
|
||||
{"run": {"dwd_tasks": ["DWD_CUSTOM"]}},
|
||||
{"DWD": ["DWD_LOAD_FROM_ODS"]},
|
||||
)
|
||||
assert runner._resolve_tasks(["DWD"]) == ["DWD_CUSTOM"]
|
||||
|
||||
def test_dws_falls_back_to_registry(self):
|
||||
"""DWS 层:无配置时回退到 Registry。"""
|
||||
registry_tasks = ["DWS_ASSISTANT_DAILY", "DWS_FINANCE_DAILY"]
|
||||
runner = _make_runner({}, {"DWS": registry_tasks})
|
||||
assert runner._resolve_tasks(["DWS"]) == registry_tasks
|
||||
|
||||
def test_multiple_layers_mixed(self):
|
||||
"""多层混合:部分层有配置,部分回退 Registry。"""
|
||||
runner = _make_runner(
|
||||
{"run": {"ods_tasks": ["ODS_CUSTOM"]}},
|
||||
{"ODS": ["ODS_REG"], "DWD": ["DWD_LOAD_FROM_ODS"]},
|
||||
)
|
||||
result = runner._resolve_tasks(["ODS", "DWD"])
|
||||
# ODS 走配置,DWD 走 Registry
|
||||
assert result == ["ODS_CUSTOM", "DWD_LOAD_FROM_ODS"]
|
||||
|
||||
def test_empty_config_list_falls_back(self):
|
||||
"""配置值为空列表时视为无配置,回退 Registry。"""
|
||||
runner = _make_runner(
|
||||
{"run": {"ods_tasks": []}},
|
||||
{"ODS": ["ODS_FROM_REGISTRY"]},
|
||||
)
|
||||
assert runner._resolve_tasks(["ODS"]) == ["ODS_FROM_REGISTRY"]
|
||||
|
||||
def test_both_empty_returns_empty_list(self):
|
||||
"""配置和 Registry 均为空时返回空列表。"""
|
||||
runner = _make_runner({}, {"INDEX": []})
|
||||
assert runner._resolve_tasks(["INDEX"]) == []
|
||||
|
||||
def test_unknown_layer_returns_empty(self):
|
||||
"""未知层名(无配置键映射)回退 Registry,Registry 也无则返回空。"""
|
||||
runner = _make_runner({}, {})
|
||||
result = runner._resolve_tasks(["UNKNOWN"])
|
||||
assert result == []
|
||||
@@ -0,0 +1,207 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskExecutor 属性测试 - hypothesis 验证执行器的通用正确性属性。"""
|
||||
import re
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.task_executor import TaskExecutor, DataSource
|
||||
from orchestration.task_registry import TaskRegistry
|
||||
|
||||
FILE_VERSION = "v4_shell"
|
||||
|
||||
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
|
||||
|
||||
_NON_ODS_PREFIXES = ["DWD_", "DWS_", "TASK_", "ETL_", "TEST_"]
|
||||
task_code_st = st.builds(
|
||||
lambda prefix, suffix: prefix + suffix,
|
||||
prefix=st.sampled_from(_NON_ODS_PREFIXES),
|
||||
suffix=st.text(
|
||||
alphabet=string.ascii_uppercase + string.digits + "_",
|
||||
min_size=1, max_size=15,
|
||||
),
|
||||
)
|
||||
|
||||
window_start_st = st.datetimes(min_value=datetime(2020, 1, 1), max_value=datetime(2030, 12, 31))
|
||||
|
||||
|
||||
def _make_fake_class(name="FakeTask"):
|
||||
return type(name, (), {"__init__": lambda self, *a, **kw: None})
|
||||
|
||||
|
||||
def _make_config():
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: {
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
"io.fetch_root": "/tmp/fetch",
|
||||
"io.ingest_source_dir": "/tmp/ingest",
|
||||
"io.write_pretty_json": False,
|
||||
"pipeline.fetch_root": None,
|
||||
"pipeline.ingest_source_dir": None,
|
||||
"integrity.auto_check": False,
|
||||
"run.overlap_seconds": 600,
|
||||
}.get(key, default))
|
||||
config.__getitem__ = MagicMock(side_effect=lambda k: {
|
||||
"io": {"export_root": "/tmp/export", "log_root": "/tmp/log"},
|
||||
}[k])
|
||||
return config
|
||||
|
||||
|
||||
def _make_executor(registry):
|
||||
return TaskExecutor(
|
||||
config=_make_config(), db_ops=MagicMock(), api_client=MagicMock(),
|
||||
cursor_mgr=MagicMock(), run_tracker=MagicMock(),
|
||||
task_registry=registry, logger=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 1: data_source 参数决定执行路径
|
||||
# **Validates: Requirements 1.2**
|
||||
|
||||
class TestProperty1DataSourceDeterminesPath:
|
||||
|
||||
@given(ds=data_source_st)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_includes_fetch(self, ds):
|
||||
result = TaskExecutor._flow_includes_fetch(ds)
|
||||
assert result == (ds in {"online", "hybrid"})
|
||||
|
||||
@given(ds=data_source_st)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_includes_ingest(self, ds):
|
||||
result = TaskExecutor._flow_includes_ingest(ds)
|
||||
assert result == (ds in {"offline", "hybrid"})
|
||||
|
||||
@given(ds=data_source_st)
|
||||
@settings(max_examples=100)
|
||||
def test_fetch_and_ingest_consistency(self, ds):
|
||||
fetch = TaskExecutor._flow_includes_fetch(ds)
|
||||
ingest = TaskExecutor._flow_includes_ingest(ds)
|
||||
if ds == "hybrid":
|
||||
assert fetch and ingest
|
||||
elif ds == "online":
|
||||
assert fetch and not ingest
|
||||
elif ds == "offline":
|
||||
assert not fetch and ingest
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 2: 成功任务推进游标
|
||||
# **Validates: Requirements 1.3**
|
||||
|
||||
class TestProperty2SuccessAdvancesCursor:
|
||||
|
||||
@given(
|
||||
task_code=task_code_st,
|
||||
window_start=window_start_st,
|
||||
window_minutes=st.integers(min_value=1, max_value=1440),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_success_with_window_advances_cursor(self, task_code, window_start, window_minutes):
|
||||
window_end = window_start + timedelta(minutes=window_minutes)
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
|
||||
task_result = {
|
||||
"status": "SUCCESS",
|
||||
"counts": {"fetched": 10, "inserted": 5},
|
||||
"window": {"start": window_start, "end": window_end, "minutes": window_minutes},
|
||||
}
|
||||
executor = _make_executor(registry)
|
||||
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
|
||||
executor.run_tracker.create_run.return_value = 100
|
||||
|
||||
with (
|
||||
patch.object(TaskExecutor, "_load_task_config", return_value={
|
||||
"task_id": 42, "task_code": task_code, "store_id": 1, "enabled": True}),
|
||||
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
|
||||
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
|
||||
patch.object(TaskExecutor, "_maybe_run_integrity_check"),
|
||||
):
|
||||
executor.run_single_task(task_code, "test-uuid", 1, "offline")
|
||||
|
||||
executor.cursor_mgr.advance.assert_called_once()
|
||||
kw = executor.cursor_mgr.advance.call_args.kwargs
|
||||
assert kw["window_start"] == window_start
|
||||
assert kw["window_end"] == window_end
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 3: 失败任务标记 FAIL 并重新抛出
|
||||
# **Validates: Requirements 1.4**
|
||||
|
||||
class TestProperty3FailureMarksFailAndReraises:
|
||||
|
||||
@given(
|
||||
task_code=task_code_st,
|
||||
error_msg=st.text(
|
||||
alphabet=string.ascii_letters + string.digits + " _-",
|
||||
min_size=1, max_size=80,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_exception_marks_fail_and_reraises(self, task_code, error_msg):
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
|
||||
executor = _make_executor(registry)
|
||||
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
|
||||
executor.run_tracker.create_run.return_value = 200
|
||||
|
||||
with (
|
||||
patch.object(TaskExecutor, "_load_task_config", return_value={
|
||||
"task_id": 99, "task_code": task_code, "store_id": 1, "enabled": True}),
|
||||
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
|
||||
patch.object(TaskExecutor, "_execute_ingest", side_effect=RuntimeError(error_msg)),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match=re.escape(error_msg)):
|
||||
executor.run_single_task(task_code, "fail-uuid", 1, "offline")
|
||||
|
||||
executor.run_tracker.update_run.assert_called()
|
||||
kw = executor.run_tracker.update_run.call_args.kwargs
|
||||
assert kw["status"] == "FAIL"
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 4: 工具类任务由元数据决定
|
||||
# **Validates: Requirements 1.6, 4.2**
|
||||
|
||||
class TestProperty4UtilityTaskDeterminedByMetadata:
|
||||
|
||||
@given(task_code=task_code_st)
|
||||
@settings(max_examples=100)
|
||||
def test_utility_task_skips_cursor_and_run_tracker(self, task_code):
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=False, task_type="utility")
|
||||
executor = _make_executor(registry)
|
||||
mock_task = MagicMock()
|
||||
mock_task.execute.return_value = {"status": "SUCCESS", "counts": {}}
|
||||
registry.create_task = MagicMock(return_value=mock_task)
|
||||
|
||||
result = executor.run_single_task(task_code, "util-uuid", 1, "hybrid")
|
||||
|
||||
executor.cursor_mgr.get_or_create.assert_not_called()
|
||||
executor.cursor_mgr.advance.assert_not_called()
|
||||
executor.run_tracker.create_run.assert_not_called()
|
||||
assert result.get("status") == "SUCCESS"
|
||||
|
||||
@given(task_code=task_code_st)
|
||||
@settings(max_examples=100)
|
||||
def test_non_utility_task_uses_cursor_and_run_tracker(self, task_code):
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWS")
|
||||
task_result = {"status": "SUCCESS", "counts": {"fetched": 1}}
|
||||
executor = _make_executor(registry)
|
||||
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
|
||||
executor.run_tracker.create_run.return_value = 300
|
||||
|
||||
with (
|
||||
patch.object(TaskExecutor, "_load_task_config", return_value={
|
||||
"task_id": 77, "task_code": task_code, "store_id": 1, "enabled": True}),
|
||||
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
|
||||
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
|
||||
):
|
||||
executor.run_single_task(task_code, "non-util-uuid", 1, "offline")
|
||||
|
||||
executor.cursor_mgr.get_or_create.assert_called_once()
|
||||
executor.run_tracker.create_run.assert_called_once()
|
||||
139
apps/etl/connectors/feiqiu/tests/unit/test_task_registry.py
Normal file
139
apps/etl/connectors/feiqiu/tests/unit/test_task_registry.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskRegistry 单元测试 — 验证 TaskMeta 元数据注册与查询"""
|
||||
import pytest
|
||||
from orchestration.task_registry import TaskRegistry, TaskMeta
|
||||
|
||||
|
||||
# ── 辅助:用作注册的假任务类 ──────────────────────────────────
|
||||
|
||||
class _FakeTask:
|
||||
"""占位任务类,用于测试注册"""
|
||||
def __init__(self, config, db_connection, api_client, logger):
|
||||
self.config = config
|
||||
|
||||
|
||||
class _AnotherFakeTask:
|
||||
def __init__(self, config, db_connection, api_client, logger):
|
||||
pass
|
||||
|
||||
|
||||
# ── fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
return TaskRegistry()
|
||||
|
||||
|
||||
# ── register + get_metadata ───────────────────────────────────
|
||||
|
||||
class TestRegisterAndMetadata:
|
||||
"""注册与元数据查询"""
|
||||
|
||||
def test_register_with_defaults(self, registry):
|
||||
"""仅传 task_code + task_class 时,元数据使用默认值(向后兼容)"""
|
||||
registry.register("MY_TASK", _FakeTask)
|
||||
meta = registry.get_metadata("MY_TASK")
|
||||
assert meta is not None
|
||||
assert meta.task_class is _FakeTask
|
||||
assert meta.requires_db_config is True
|
||||
assert meta.layer is None
|
||||
assert meta.task_type == "etl"
|
||||
|
||||
def test_register_with_full_metadata(self, registry):
|
||||
"""传入完整元数据"""
|
||||
registry.register(
|
||||
"ODS_ORDERS", _FakeTask,
|
||||
requires_db_config=True, layer="ODS", task_type="etl",
|
||||
)
|
||||
meta = registry.get_metadata("ODS_ORDERS")
|
||||
assert meta.layer == "ODS"
|
||||
assert meta.task_type == "etl"
|
||||
|
||||
def test_register_utility_task(self, registry):
|
||||
"""工具类任务:requires_db_config=False"""
|
||||
registry.register(
|
||||
"INIT_SCHEMA", _FakeTask,
|
||||
requires_db_config=False, task_type="utility",
|
||||
)
|
||||
meta = registry.get_metadata("INIT_SCHEMA")
|
||||
assert meta.requires_db_config is False
|
||||
assert meta.task_type == "utility"
|
||||
|
||||
def test_case_insensitive_lookup(self, registry):
|
||||
"""task_code 大小写不敏感"""
|
||||
registry.register("my_task", _FakeTask)
|
||||
assert registry.get_metadata("MY_TASK") is not None
|
||||
assert registry.get_metadata("my_task") is not None
|
||||
|
||||
def test_get_metadata_unknown_returns_none(self, registry):
|
||||
"""查询未注册的任务返回 None"""
|
||||
assert registry.get_metadata("NONEXISTENT") is None
|
||||
|
||||
|
||||
# ── create_task(接口不变)────────────────────────────────────
|
||||
|
||||
class TestCreateTask:
|
||||
|
||||
def test_create_task_returns_instance(self, registry):
|
||||
registry.register("MY_TASK", _FakeTask)
|
||||
task = registry.create_task("MY_TASK", {"k": "v"}, None, None, None)
|
||||
assert isinstance(task, _FakeTask)
|
||||
assert task.config == {"k": "v"}
|
||||
|
||||
def test_create_task_unknown_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="未知的任务类型"):
|
||||
registry.create_task("NOPE", None, None, None, None)
|
||||
|
||||
|
||||
# ── get_tasks_by_layer ────────────────────────────────────────
|
||||
|
||||
class TestGetTasksByLayer:
|
||||
|
||||
def test_returns_matching_tasks(self, registry):
|
||||
registry.register("A", _FakeTask, layer="ODS")
|
||||
registry.register("B", _AnotherFakeTask, layer="ODS")
|
||||
registry.register("C", _FakeTask, layer="DWD")
|
||||
result = registry.get_tasks_by_layer("ODS")
|
||||
assert set(result) == {"A", "B"}
|
||||
|
||||
def test_case_insensitive_layer(self, registry):
|
||||
registry.register("X", _FakeTask, layer="dws")
|
||||
assert registry.get_tasks_by_layer("DWS") == ["X"]
|
||||
|
||||
def test_no_match_returns_empty(self, registry):
|
||||
registry.register("A", _FakeTask, layer="ODS")
|
||||
assert registry.get_tasks_by_layer("INDEX") == []
|
||||
|
||||
def test_none_layer_excluded(self, registry):
|
||||
"""layer=None 的任务不会被任何层查询返回"""
|
||||
registry.register("UTIL", _FakeTask) # layer 默认 None
|
||||
assert registry.get_tasks_by_layer("ODS") == []
|
||||
|
||||
|
||||
# ── is_utility_task ───────────────────────────────────────────
|
||||
|
||||
class TestIsUtilityTask:
|
||||
|
||||
def test_utility_task(self, registry):
|
||||
registry.register("INIT", _FakeTask, requires_db_config=False)
|
||||
assert registry.is_utility_task("INIT") is True
|
||||
|
||||
def test_normal_task(self, registry):
|
||||
registry.register("ETL", _FakeTask, requires_db_config=True)
|
||||
assert registry.is_utility_task("ETL") is False
|
||||
|
||||
def test_unknown_task(self, registry):
|
||||
assert registry.is_utility_task("NOPE") is False
|
||||
|
||||
|
||||
# ── get_all_task_codes(接口不变)──────────────────────────────
|
||||
|
||||
class TestGetAllTaskCodes:
|
||||
|
||||
def test_returns_all_codes(self, registry):
|
||||
registry.register("A", _FakeTask)
|
||||
registry.register("B", _AnotherFakeTask)
|
||||
assert set(registry.get_all_task_codes()) == {"A", "B"}
|
||||
|
||||
def test_empty_registry(self, registry):
|
||||
assert registry.get_all_task_codes() == []
|
||||
@@ -0,0 +1,165 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskRegistry 属性测试 — 使用 hypothesis 验证注册表的通用正确性属性。"""
|
||||
import string
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.task_registry import TaskRegistry, TaskMeta
|
||||
|
||||
|
||||
# ── 辅助:动态生成假任务类 ────────────────────────────────────
|
||||
|
||||
def _make_fake_class(name: str = "FakeTask") -> type:
|
||||
"""创建一个最小化的假任务类,用于注册测试。"""
|
||||
return type(name, (), {"__init__": lambda self, *a, **kw: None})
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
# 合法任务代码:大写字母 + 数字 + 下划线,长度 1~30
|
||||
task_code_st = st.text(
|
||||
alphabet=string.ascii_uppercase + string.digits + "_",
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
requires_db_config_st = st.booleans()
|
||||
|
||||
layer_st = st.sampled_from([None, "ODS", "DWD", "DWS", "INDEX"])
|
||||
|
||||
task_type_st = st.sampled_from(["etl", "utility", "verification"])
|
||||
|
||||
|
||||
# ── Property 8: TaskRegistry 元数据 round-trip ────────────────
|
||||
# Feature: scheduler-refactor, Property 8: TaskRegistry 元数据 round-trip
|
||||
# **Validates: Requirements 4.1**
|
||||
#
|
||||
# 对于任意任务代码、任务类和元数据组合(requires_db_config、layer、task_type),
|
||||
# 注册后通过 get_metadata 查询应返回相同的元数据值。
|
||||
|
||||
|
||||
class TestProperty8MetadataRoundTrip:
|
||||
"""Property 8: 注册元数据后查询应返回完全相同的值。"""
|
||||
|
||||
@given(
|
||||
task_code=task_code_st,
|
||||
requires_db=requires_db_config_st,
|
||||
layer=layer_st,
|
||||
task_type=task_type_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_metadata_round_trip(
|
||||
self, task_code: str, requires_db: bool, layer: str | None, task_type: str
|
||||
):
|
||||
"""注册任意元数据组合后,get_metadata 应返回相同的值。"""
|
||||
# Arrange — 每次迭代使用全新的注册表,避免状态泄漏
|
||||
registry = TaskRegistry()
|
||||
fake_cls = _make_fake_class()
|
||||
|
||||
# Act — 注册并查询
|
||||
registry.register(
|
||||
task_code,
|
||||
fake_cls,
|
||||
requires_db_config=requires_db,
|
||||
layer=layer,
|
||||
task_type=task_type,
|
||||
)
|
||||
meta = registry.get_metadata(task_code)
|
||||
|
||||
# Assert — 元数据 round-trip 一致
|
||||
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
|
||||
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
|
||||
assert meta.requires_db_config is requires_db, (
|
||||
f"requires_db_config 应为 {requires_db},实际为 {meta.requires_db_config}"
|
||||
)
|
||||
assert meta.layer == layer, f"layer 应为 {layer!r},实际为 {meta.layer!r}"
|
||||
assert meta.task_type == task_type, (
|
||||
f"task_type 应为 {task_type!r},实际为 {meta.task_type!r}"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 9: TaskRegistry 向后兼容默认值 ───────────────────
|
||||
# Feature: scheduler-refactor, Property 9: TaskRegistry 向后兼容默认值
|
||||
# **Validates: Requirements 4.4**
|
||||
#
|
||||
# 对于任意使用旧接口(仅 task_code 和 task_class)注册的任务,
|
||||
# 查询元数据应返回 requires_db_config=True、layer=None、task_type="etl"。
|
||||
|
||||
|
||||
class TestProperty9BackwardCompatibleDefaults:
|
||||
"""Property 9: 仅传 task_code + task_class 时,元数据应使用默认值。"""
|
||||
|
||||
@given(task_code=task_code_st)
|
||||
@settings(max_examples=100)
|
||||
def test_legacy_register_uses_defaults(self, task_code: str):
|
||||
"""使用旧接口(仅 task_code 和 task_class)注册后,元数据应为默认值。"""
|
||||
# Arrange
|
||||
registry = TaskRegistry()
|
||||
fake_cls = _make_fake_class()
|
||||
|
||||
# Act — 仅传 task_code 和 task_class,不传任何元数据参数
|
||||
registry.register(task_code, fake_cls)
|
||||
meta = registry.get_metadata(task_code)
|
||||
|
||||
# Assert — 默认值契约
|
||||
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
|
||||
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
|
||||
assert meta.requires_db_config is True, (
|
||||
f"默认 requires_db_config 应为 True,实际为 {meta.requires_db_config}"
|
||||
)
|
||||
assert meta.layer is None, (
|
||||
f"默认 layer 应为 None,实际为 {meta.layer!r}"
|
||||
)
|
||||
assert meta.task_type == "etl", (
|
||||
f"默认 task_type 应为 'etl',实际为 {meta.task_type!r}"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 10: 按层查询任务 ────────────────────────────────
|
||||
# Feature: scheduler-refactor, Property 10: 按层查询任务
|
||||
# **Validates: Requirements 4.3**
|
||||
#
|
||||
# 对于任意注册了 layer 元数据的任务集合,get_tasks_by_layer(layer)
|
||||
# 返回的任务代码集合应等于所有 layer 匹配的已注册任务代码集合。
|
||||
|
||||
# 非 None 的层值策略,用于查询验证
|
||||
non_none_layer_st = st.sampled_from(["ODS", "DWD", "DWS", "INDEX"])
|
||||
|
||||
|
||||
class TestProperty10GetTasksByLayer:
|
||||
"""Property 10: get_tasks_by_layer 返回的集合应与手动过滤一致。"""
|
||||
|
||||
@given(
|
||||
entries=st.lists(
|
||||
st.tuples(task_code_st, layer_st),
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_get_tasks_by_layer_matches_manual_filter(
|
||||
self, entries: list[tuple[str, str | None]],
|
||||
):
|
||||
"""注册一组任务后,按层查询结果应与手动过滤完全一致。"""
|
||||
# Arrange
|
||||
registry = TaskRegistry()
|
||||
# 去重:同一 task_code 只保留最后一次注册(与 register 覆盖语义一致)
|
||||
unique_entries: dict[str, str | None] = {}
|
||||
for code, layer in entries:
|
||||
fake_cls = _make_fake_class(f"Fake_{code}")
|
||||
registry.register(code, fake_cls, layer=layer)
|
||||
unique_entries[code.upper()] = layer # register 内部会 upper()
|
||||
|
||||
# Act & Assert — 对每个非 None 的层值进行验证
|
||||
for query_layer in ["ODS", "DWD", "DWS", "INDEX"]:
|
||||
actual = set(registry.get_tasks_by_layer(query_layer))
|
||||
expected = {
|
||||
code for code, layer in unique_entries.items()
|
||||
if layer is not None and layer.upper() == query_layer.upper()
|
||||
}
|
||||
assert actual == expected, (
|
||||
f"查询 layer={query_layer!r} 时,"
|
||||
f"期望 {expected},实际 {actual}"
|
||||
)
|
||||
142
apps/etl/connectors/feiqiu/tests/unit/test_topological_sort.py
Normal file
142
apps/etl/connectors/feiqiu/tests/unit/test_topological_sort.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""拓扑排序单元测试
|
||||
|
||||
覆盖正常依赖、循环依赖、缺失依赖、空列表等场景。
|
||||
属性测试(Property 7, 8)位于 Monorepo 级 tests/test_etl_refactor_properties.py。
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.topological_sort import topological_sort
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:轻量 TaskMeta 替身 + 简易 Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class _FakeMeta:
|
||||
depends_on: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class _FakeRegistry:
|
||||
"""最小 Registry 替身,仅提供 get_metadata()。"""
|
||||
|
||||
def __init__(self, deps: dict[str, list[str]]):
|
||||
self._deps = deps
|
||||
|
||||
def get_metadata(self, code: str):
|
||||
if code in self._deps:
|
||||
return _FakeMeta(depends_on=self._deps[code])
|
||||
return _FakeMeta()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTopologicalSort:
|
||||
"""拓扑排序——具体示例和边界条件"""
|
||||
|
||||
def test_empty_list(self):
|
||||
"""空任务列表返回空列表。"""
|
||||
reg = _FakeRegistry({})
|
||||
assert topological_sort([], reg) == []
|
||||
|
||||
def test_single_task_no_deps(self):
|
||||
"""单个无依赖任务原样返回。"""
|
||||
reg = _FakeRegistry({"A": []})
|
||||
assert topological_sort(["A"], reg) == ["A"]
|
||||
|
||||
def test_linear_chain(self):
|
||||
"""线性依赖链:A → B → C,排序后 A 在 B 前,B 在 C 前。"""
|
||||
reg = _FakeRegistry({"A": [], "B": ["A"], "C": ["B"]})
|
||||
result = topological_sort(["A", "B", "C"], reg)
|
||||
assert result.index("A") < result.index("B") < result.index("C")
|
||||
|
||||
def test_diamond_dependency(self):
|
||||
"""菱形依赖:A → B, A → C, B → D, C → D。"""
|
||||
reg = _FakeRegistry({"A": [], "B": ["A"], "C": ["A"], "D": ["B", "C"]})
|
||||
result = topological_sort(["A", "B", "C", "D"], reg)
|
||||
assert result.index("A") < result.index("B")
|
||||
assert result.index("A") < result.index("C")
|
||||
assert result.index("B") < result.index("D")
|
||||
assert result.index("C") < result.index("D")
|
||||
|
||||
def test_no_dependencies_preserves_input_order(self):
|
||||
"""无依赖关系时,保持输入顺序(Kahn's 的 FIFO 特性)。"""
|
||||
reg = _FakeRegistry({"X": [], "Y": [], "Z": []})
|
||||
result = topological_sort(["X", "Y", "Z"], reg)
|
||||
assert result == ["X", "Y", "Z"]
|
||||
|
||||
def test_circular_dependency_raises(self):
|
||||
"""循环依赖抛出 ValueError,错误信息包含涉及的任务。"""
|
||||
reg = _FakeRegistry({"A": ["B"], "B": ["A"]})
|
||||
with pytest.raises(ValueError, match="循环依赖"):
|
||||
topological_sort(["A", "B"], reg)
|
||||
|
||||
def test_circular_three_tasks(self):
|
||||
"""三任务环:A → B → C → A。"""
|
||||
reg = _FakeRegistry({"A": ["C"], "B": ["A"], "C": ["B"]})
|
||||
with pytest.raises(ValueError, match="循环依赖"):
|
||||
topological_sort(["A", "B", "C"], reg)
|
||||
|
||||
def test_partial_cycle_reports_cycle_tasks(self):
|
||||
"""部分任务成环时,错误信息仅包含环中的任务。"""
|
||||
# D 无依赖,A ↔ B 成环
|
||||
reg = _FakeRegistry({"D": [], "A": ["B"], "B": ["A"]})
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
topological_sort(["D", "A", "B"], reg)
|
||||
msg = str(exc_info.value)
|
||||
assert "A" in msg and "B" in msg
|
||||
|
||||
def test_missing_dependency_logs_warning(self, caplog):
|
||||
"""依赖任务不在执行列表中时记录警告但继续执行。"""
|
||||
reg = _FakeRegistry({"A": ["MISSING_TASK"], "B": []})
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = topological_sort(["A", "B"], reg)
|
||||
# 排序应正常完成
|
||||
assert set(result) == {"A", "B"}
|
||||
assert "MISSING_TASK" in caplog.text
|
||||
assert "不在当前执行列表中" in caplog.text
|
||||
|
||||
def test_missing_dependency_does_not_affect_order(self):
|
||||
"""缺失依赖不影响列表内任务的排序。"""
|
||||
# B 依赖 A(在列表内)和 EXTERNAL(不在列表内)
|
||||
reg = _FakeRegistry({"A": [], "B": ["A", "EXTERNAL"]})
|
||||
result = topological_sort(["A", "B"], reg)
|
||||
assert result.index("A") < result.index("B")
|
||||
|
||||
def test_metadata_none_treated_as_no_deps(self):
|
||||
"""get_metadata 返回 None 时视为无依赖。"""
|
||||
reg = MagicMock()
|
||||
reg.get_metadata.return_value = None
|
||||
result = topological_sort(["A", "B"], reg)
|
||||
assert set(result) == {"A", "B"}
|
||||
|
||||
def test_real_dws_dependency_scenario(self):
|
||||
"""模拟真实 DWS 依赖场景:MONTHLY 依赖 DAILY,MAINTENANCE 依赖所有。"""
|
||||
reg = _FakeRegistry({
|
||||
"DWS_ASSISTANT_DAILY": [],
|
||||
"DWS_ASSISTANT_MONTHLY": ["DWS_ASSISTANT_DAILY"],
|
||||
"DWS_FINANCE_DAILY": [],
|
||||
"DWS_MAINTENANCE": [
|
||||
"DWS_ASSISTANT_DAILY",
|
||||
"DWS_ASSISTANT_MONTHLY",
|
||||
"DWS_FINANCE_DAILY",
|
||||
],
|
||||
})
|
||||
tasks = [
|
||||
"DWS_ASSISTANT_DAILY",
|
||||
"DWS_ASSISTANT_MONTHLY",
|
||||
"DWS_FINANCE_DAILY",
|
||||
"DWS_MAINTENANCE",
|
||||
]
|
||||
result = topological_sort(tasks, reg)
|
||||
# DAILY 在 MONTHLY 前
|
||||
assert result.index("DWS_ASSISTANT_DAILY") < result.index("DWS_ASSISTANT_MONTHLY")
|
||||
# 所有任务在 MAINTENANCE 前
|
||||
assert result.index("DWS_MAINTENANCE") == len(result) - 1
|
||||
358
apps/etl/connectors/feiqiu/tests/unit/test_validate_bd_manual.py
Normal file
358
apps/etl/connectors/feiqiu/tests/unit/test_validate_bd_manual.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
scripts/validate_bd_manual.py 的单元测试。
|
||||
|
||||
不需要数据库连接,通过构造临时文件系统结构来验证各 check 函数。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# 被测模块
|
||||
from scripts.validate_bd_manual import (
|
||||
check_directory_structure,
|
||||
check_ods_doc_coverage,
|
||||
check_ods_doc_format,
|
||||
check_ods_doc_naming,
|
||||
check_mapping_doc_coverage,
|
||||
check_mapping_doc_content,
|
||||
check_mapping_doc_naming,
|
||||
check_ods_dictionary_coverage,
|
||||
BD_MANUAL_ROOT,
|
||||
ODS_MAIN_DIR,
|
||||
ODS_MAPPINGS_DIR,
|
||||
ODS_DICT_PATH,
|
||||
DATA_LAYERS,
|
||||
)
|
||||
import scripts.validate_bd_manual as mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:临时目录结构搭建
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _setup_layer_dirs(tmp_path: Path) -> None:
|
||||
"""在 tmp_path 下创建完整的四层目录结构。"""
|
||||
for layer in DATA_LAYERS:
|
||||
(tmp_path / "docs" / "bd_manual" / layer / "main").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / "docs" / "bd_manual" / layer / "changes").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
SAMPLE_ODS_DOC = textwrap.dedent("""\
|
||||
# test_table 测试表
|
||||
|
||||
> 生成时间:2026-01-01
|
||||
|
||||
## 表信息
|
||||
|
||||
| 属性 | 值 |
|
||||
|------|-----|
|
||||
| Schema | billiards_ods |
|
||||
| 表名 | test_table |
|
||||
| 主键 | id, content_hash |
|
||||
| 数据来源 | TestEndpoint |
|
||||
| 说明 | 测试表 |
|
||||
|
||||
## 字段说明
|
||||
|
||||
| 序号 | 字段名 | 类型 | 可空 | 说明 |
|
||||
|------|--------|------|------|------|
|
||||
| 1 | id | BIGINT | NO | 主键 |
|
||||
|
||||
## 使用说明
|
||||
|
||||
```sql
|
||||
SELECT * FROM billiards_ods.test_table LIMIT 10;
|
||||
```
|
||||
|
||||
## ETL 元数据字段
|
||||
|
||||
| 字段名 | 类型 | 说明 |
|
||||
|--------|------|------|
|
||||
| content_hash | TEXT | SHA256 |
|
||||
| source_file | TEXT | 文件名 |
|
||||
| source_endpoint | TEXT | 端点 |
|
||||
| fetched_at | TIMESTAMPTZ | 时间戳 |
|
||||
| payload | JSONB | 原始 JSON |
|
||||
|
||||
## 可回溯性
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| 可回溯 | ✅ 完全可回溯 |
|
||||
| 数据来源 | TestEndpoint |
|
||||
""")
|
||||
|
||||
|
||||
SAMPLE_MAPPING_DOC = textwrap.dedent("""\
|
||||
# TestEndpoint → test_table 字段映射
|
||||
|
||||
> 生成时间:2026-01-01
|
||||
|
||||
## 端点信息
|
||||
|
||||
| 属性 | 值 |
|
||||
|------|-----|
|
||||
| 接口路径 | `Test/TestEndpoint` |
|
||||
| 请求方法 | POST |
|
||||
| ODS 对应表 | `billiards_ods.test_table` |
|
||||
| JSON 数据路径 | `data.items` |
|
||||
|
||||
## 字段映射
|
||||
|
||||
| JSON 字段 | ODS 列名 | 类型转换 | 说明 |
|
||||
|-----------|----------|----------|------|
|
||||
| id | id | int→BIGINT | 主键 |
|
||||
|
||||
## ETL 补充字段
|
||||
|
||||
| ODS 列名 | 生成逻辑 |
|
||||
|-----------|----------|
|
||||
| content_hash | SHA256 |
|
||||
| source_file | 固定值 |
|
||||
| source_endpoint | 端点路径 |
|
||||
| fetched_at | 入库时间戳 |
|
||||
| payload | 原始 JSON |
|
||||
""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 1: 目录结构一致性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckDirectoryStructure:
|
||||
"""Property 1: 数据层目录结构一致性。"""
|
||||
|
||||
def test_pass_when_all_dirs_exist(self, tmp_path, monkeypatch):
|
||||
_setup_layer_dirs(tmp_path)
|
||||
monkeypatch.setattr(mod, "BD_MANUAL_ROOT", tmp_path / "docs" / "bd_manual")
|
||||
result = check_directory_structure()
|
||||
assert result.passed is True
|
||||
assert result.property_id == "Property 1"
|
||||
|
||||
def test_fail_when_missing_subdir(self, tmp_path, monkeypatch):
|
||||
_setup_layer_dirs(tmp_path)
|
||||
# 删除 ETL_Admin/changes/
|
||||
import shutil
|
||||
shutil.rmtree(tmp_path / "docs" / "bd_manual" / "ETL_Admin" / "changes")
|
||||
monkeypatch.setattr(mod, "BD_MANUAL_ROOT", tmp_path / "docs" / "bd_manual")
|
||||
result = check_directory_structure()
|
||||
assert result.passed is False
|
||||
assert any("ETL_Admin" in d and "changes" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4: ODS 文档覆盖率
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDocCoverage:
|
||||
"""Property 4: ODS 表级文档覆盖率。"""
|
||||
|
||||
def test_pass_when_all_docs_exist(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
tables = ["member_profiles", "payment_transactions"]
|
||||
for t in tables:
|
||||
(ods_main / f"BD_manual_{t}.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_coverage(tables)
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_doc_missing(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_coverage(["member_profiles", "missing_table"])
|
||||
assert result.passed is False
|
||||
assert any("missing_table" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 5: ODS 文档格式完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDocFormat:
|
||||
"""Property 5: ODS 表级文档格式完整性。"""
|
||||
|
||||
def test_pass_with_complete_doc(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_test_table.md").write_text(SAMPLE_ODS_DOC, encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_format()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_missing_section(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
# 去掉"可回溯性"章节
|
||||
incomplete = SAMPLE_ODS_DOC.replace("## 可回溯性", "## 其他")
|
||||
(ods_main / "BD_manual_test_table.md").write_text(incomplete, encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_format()
|
||||
assert result.passed is False
|
||||
assert any("可回溯性" in d for d in result.details)
|
||||
|
||||
def test_fail_when_missing_etl_field(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
# 去掉 content_hash
|
||||
incomplete = SAMPLE_ODS_DOC.replace("content_hash", "xxx_hash")
|
||||
(ods_main / "BD_manual_test_table.md").write_text(incomplete, encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_format()
|
||||
assert result.passed is False
|
||||
assert any("content_hash" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6: ODS 文档命名规范
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDocNaming:
|
||||
"""Property 6: ODS 表级文档命名规范。"""
|
||||
|
||||
def test_pass_with_valid_names(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
|
||||
(ods_main / "BD_manual_payment_transactions.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_naming()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_with_invalid_name(self, tmp_path, monkeypatch):
|
||||
"""ODS/main/ 下所有 .md 文件都应符合命名规范,BadName.md 应被检出。"""
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
|
||||
(ods_main / "BadName.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_naming()
|
||||
assert result.passed is False
|
||||
assert any("BadName.md" in d for d in result.details)
|
||||
|
||||
def test_fail_with_uppercase_table_name(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
# BD_manual_ 后面跟大写字母,不符合 snake_case
|
||||
(ods_main / "BD_manual_MemberProfiles.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_naming()
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 7: 映射文档覆盖率
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckMappingDocCoverage:
|
||||
"""Property 7: 映射文档覆盖率。"""
|
||||
|
||||
def test_pass_when_all_mappings_exist(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_GetTest_test_table.md").write_text("# map", encoding="utf-8")
|
||||
(mappings / "mapping_GetOther_other_table.md").write_text("# map", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_coverage(["test_table", "other_table"])
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_mapping_missing(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_GetTest_test_table.md").write_text("# map", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_coverage(["test_table", "missing_table"])
|
||||
assert result.passed is False
|
||||
assert any("missing_table" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 8: 映射文档内容完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckMappingDocContent:
|
||||
"""Property 8: 映射文档内容完整性。"""
|
||||
|
||||
def test_pass_with_complete_mapping(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_TestEndpoint_test_table.md").write_text(
|
||||
SAMPLE_MAPPING_DOC, encoding="utf-8"
|
||||
)
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_content()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_missing_section(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
incomplete = SAMPLE_MAPPING_DOC.replace("## 字段映射", "## 其他映射")
|
||||
(mappings / "mapping_TestEndpoint_test_table.md").write_text(
|
||||
incomplete, encoding="utf-8"
|
||||
)
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_content()
|
||||
assert result.passed is False
|
||||
assert any("字段映射" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 9: 映射文档命名规范
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckMappingDocNaming:
|
||||
"""Property 9: 映射文档命名规范。"""
|
||||
|
||||
def test_pass_with_valid_names(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_GetTenantMemberList_member_profiles.md").write_text("# m", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_naming()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_with_lowercase_endpoint(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
# 端点名以小写开头,不符合 PascalCase
|
||||
(mappings / "mapping_getTenantMemberList_member_profiles.md").write_text("# m", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_naming()
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 10: ODS 数据字典覆盖率
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDictionaryCoverage:
|
||||
"""Property 10: ODS 数据字典覆盖率。"""
|
||||
|
||||
def test_pass_when_all_tables_in_dict(self, tmp_path, monkeypatch):
|
||||
dict_file = tmp_path / "ods_tables_dictionary.md"
|
||||
dict_file.write_text(
|
||||
"| `member_profiles` | 会员 |\n| `payment_transactions` | 支付 |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(mod, "ODS_DICT_PATH", dict_file)
|
||||
result = check_ods_dictionary_coverage(["member_profiles", "payment_transactions"])
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_table_missing_from_dict(self, tmp_path, monkeypatch):
|
||||
dict_file = tmp_path / "ods_tables_dictionary.md"
|
||||
dict_file.write_text("| `member_profiles` | 会员 |", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_DICT_PATH", dict_file)
|
||||
result = check_ods_dictionary_coverage(["member_profiles", "missing_table"])
|
||||
assert result.passed is False
|
||||
assert any("missing_table" in d for d in result.details)
|
||||
|
||||
def test_fail_when_dict_file_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(mod, "ODS_DICT_PATH", tmp_path / "nonexistent.md")
|
||||
result = check_ods_dictionary_coverage(["some_table"])
|
||||
assert result.passed is False
|
||||
Reference in New Issue
Block a user