init: 项目初始提交 - NeoZQYY Monorepo 完整代码
This commit is contained in:
59
apps/etl/pipelines/feiqiu/tests/README.md
Normal file
59
apps/etl/pipelines/feiqiu/tests/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# tests/ — 测试套件
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── unit/ # 单元测试(FakeDB/FakeAPI,无需真实数据库)
|
||||
│ ├── task_test_utils.py # 测试工具:FakeDBOperations、FakeAPIClient、OfflineAPIClient、TaskSpec
|
||||
│ ├── test_ods_tasks.py # ODS 任务在线/离线模式测试
|
||||
│ ├── test_cli_args.py # CLI 参数解析测试
|
||||
│ ├── test_config.py # 配置管理测试
|
||||
│ ├── test_e2e_flow.py # 端到端流程测试(CLI → PipelineRunner → TaskExecutor)
|
||||
│ ├── test_task_registry.py # 任务注册表测试
|
||||
│ ├── test_*_properties.py # 属性测试(hypothesis)
|
||||
│ └── test_audit_*.py # 仓库审计相关测试
|
||||
└── integration/ # 集成测试(需要真实数据库)
|
||||
├── test_database.py # 数据库连接与操作测试
|
||||
└── test_index_tasks.py # 指数任务集成测试
|
||||
```
|
||||
|
||||
## 运行测试
|
||||
|
||||
```bash
|
||||
# 安装测试依赖
|
||||
pip install pytest hypothesis
|
||||
|
||||
# 全部单元测试
|
||||
pytest tests/unit
|
||||
|
||||
# 指定测试文件
|
||||
pytest tests/unit/test_ods_tasks.py
|
||||
|
||||
# 按关键字过滤
|
||||
pytest tests/unit -k "online"
|
||||
|
||||
# 集成测试(需要设置 TEST_DB_DSN)
|
||||
TEST_DB_DSN="postgresql://user:pass@host:5432/db" pytest tests/integration
|
||||
|
||||
# 查看详细输出
|
||||
pytest tests/unit -v --tb=short
|
||||
```
|
||||
|
||||
## 测试工具(task_test_utils.py)
|
||||
|
||||
单元测试通过 `tests/unit/task_test_utils.py` 提供的桩对象避免依赖真实数据库和 API:
|
||||
|
||||
- `FakeDBOperations` — 拦截并记录 upsert/execute/commit/rollback,不触碰真实数据库
|
||||
- `FakeAPIClient` — 在线模式桩,直接返回预置的内存数据
|
||||
- `OfflineAPIClient` — 离线模式桩,从归档 JSON 文件回放数据
|
||||
- `TaskSpec` — 描述任务测试元数据(任务代码、端点、数据路径、样例记录)
|
||||
- `create_test_config()` — 构建测试用 `AppConfig`
|
||||
- `dump_offline_payload()` — 将样例数据写入归档目录供离线测试使用
|
||||
|
||||
## 编写新测试
|
||||
|
||||
- 单元测试放在 `tests/unit/`,文件名 `test_*.py`
|
||||
- 使用 `FakeDBOperations` 和 `FakeAPIClient` 避免外部依赖
|
||||
- 属性测试使用 `hypothesis`,文件名以 `_properties.py` 结尾
|
||||
- 集成测试放在 `tests/integration/`,通过 `TEST_DB_DSN` 环境变量控制是否执行
|
||||
0
apps/etl/pipelines/feiqiu/tests/__init__.py
Normal file
0
apps/etl/pipelines/feiqiu/tests/__init__.py
Normal file
33
apps/etl/pipelines/feiqiu/tests/integration/test_database.py
Normal file
33
apps/etl/pipelines/feiqiu/tests/integration/test_database.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库集成测试"""
|
||||
import pytest
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
|
||||
# 注意:这些测试需要实际的数据库连接
|
||||
# 在CI/CD环境中应使用测试数据库
|
||||
|
||||
@pytest.fixture
|
||||
def db_connection():
|
||||
"""数据库连接fixture"""
|
||||
# 从环境变量获取测试数据库DSN
|
||||
import os
|
||||
dsn = os.environ.get("TEST_DB_DSN")
|
||||
if not dsn:
|
||||
pytest.skip("未配置测试数据库")
|
||||
|
||||
conn = DatabaseConnection(dsn)
|
||||
yield conn
|
||||
conn.close()
|
||||
|
||||
def test_database_query(db_connection):
|
||||
"""测试数据库查询"""
|
||||
result = db_connection.query("SELECT 1 AS test")
|
||||
assert len(result) == 1
|
||||
assert result[0]["test"] == 1
|
||||
|
||||
def test_database_operations(db_connection):
|
||||
"""测试数据库操作"""
|
||||
ops = DatabaseOperations(db_connection)
|
||||
# 添加实际的测试用例
|
||||
pass
|
||||
238
apps/etl/pipelines/feiqiu/tests/integration/test_index_tasks.py
Normal file
238
apps/etl/pipelines/feiqiu/tests/integration/test_index_tasks.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG [2026-02-13] 移除 dws_member_assistant_intimacy 表存在性检查
|
||||
"""Smoke test scripts for WBI/NCI index tasks."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if ROOT not in sys.path:
|
||||
sys.path.insert(0, ROOT)
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations
|
||||
from tasks.dws.index import NewconvIndexTask, WinbackIndexTask
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("test_index_tasks")
|
||||
|
||||
|
||||
def _make_db() -> tuple[AppConfig, DatabaseConnection, DatabaseOperations]:
|
||||
config = AppConfig.load()
|
||||
db_conn = DatabaseConnection(config.config["db"]["dsn"])
|
||||
db = DatabaseOperations(db_conn)
|
||||
return config, db_conn, db
|
||||
|
||||
|
||||
def _dict_rows(rows) -> List[Dict]:
|
||||
return [dict(r) for r in (rows or [])]
|
||||
|
||||
|
||||
def _fmt(value, digits: int = 2) -> str:
|
||||
if value is None:
|
||||
return "-"
|
||||
if isinstance(value, (int, float)):
|
||||
return f"{value:.{digits}f}"
|
||||
return str(value)
|
||||
|
||||
|
||||
def _check_required_tables() -> None:
|
||||
_, db_conn, db = _make_db()
|
||||
try:
|
||||
sql = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'billiards_dws'
|
||||
AND table_name IN (
|
||||
'cfg_index_parameters',
|
||||
'dws_member_winback_index',
|
||||
'dws_member_newconv_index'
|
||||
)
|
||||
"""
|
||||
rows = _dict_rows(db.query(sql))
|
||||
existing = {r["table_name"] for r in rows}
|
||||
required = {
|
||||
"cfg_index_parameters",
|
||||
"dws_member_winback_index",
|
||||
"dws_member_newconv_index",
|
||||
}
|
||||
missing = sorted(required - existing)
|
||||
if missing:
|
||||
raise RuntimeError(f"Missing required tables: {', '.join(missing)}")
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
def test_winback_index() -> Dict:
|
||||
logger.info("=" * 80)
|
||||
logger.info("Run WBI task")
|
||||
logger.info("=" * 80)
|
||||
|
||||
config, db_conn, db = _make_db()
|
||||
try:
|
||||
task = WinbackIndexTask(config, db, None, logger)
|
||||
result = task.execute(None)
|
||||
logger.info("WBI result: %s", result)
|
||||
|
||||
if result.get("status") == "success":
|
||||
stats_sql = """
|
||||
SELECT
|
||||
COUNT(*) AS total_count,
|
||||
ROUND(AVG(display_score)::numeric, 2) AS avg_display,
|
||||
ROUND(MIN(display_score)::numeric, 2) AS min_display,
|
||||
ROUND(MAX(display_score)::numeric, 2) AS max_display,
|
||||
ROUND(AVG(raw_score)::numeric, 4) AS avg_raw,
|
||||
ROUND(AVG(overdue_old)::numeric, 4) AS avg_overdue,
|
||||
ROUND(AVG(drop_old)::numeric, 4) AS avg_drop,
|
||||
ROUND(AVG(recharge_old)::numeric, 4) AS avg_recharge,
|
||||
ROUND(AVG(value_old)::numeric, 4) AS avg_value,
|
||||
ROUND(AVG(t_v)::numeric, 2) AS avg_t_v
|
||||
FROM billiards_dws.dws_member_winback_index
|
||||
"""
|
||||
stats_rows = _dict_rows(db.query(stats_sql))
|
||||
if stats_rows:
|
||||
s = stats_rows[0]
|
||||
logger.info(
|
||||
"WBI stats | total=%s, display(avg/min/max)=%s/%s/%s, raw_avg=%s, overdue=%s, drop=%s, recharge=%s, value=%s, t_v=%s",
|
||||
s.get("total_count"),
|
||||
_fmt(s.get("avg_display")),
|
||||
_fmt(s.get("min_display")),
|
||||
_fmt(s.get("max_display")),
|
||||
_fmt(s.get("avg_raw"), 4),
|
||||
_fmt(s.get("avg_overdue"), 4),
|
||||
_fmt(s.get("avg_drop"), 4),
|
||||
_fmt(s.get("avg_recharge"), 4),
|
||||
_fmt(s.get("avg_value"), 4),
|
||||
_fmt(s.get("avg_t_v"), 2),
|
||||
)
|
||||
|
||||
top_sql = """
|
||||
SELECT member_id, display_score, raw_score, t_v, visits_14d, sv_balance
|
||||
FROM billiards_dws.dws_member_winback_index
|
||||
ORDER BY display_score DESC NULLS LAST
|
||||
LIMIT 5
|
||||
"""
|
||||
for i, r in enumerate(_dict_rows(db.query(top_sql)), 1):
|
||||
logger.info(
|
||||
"WBI TOP%d | member=%s, display=%s, raw=%s, t_v=%s, visits_14d=%s, sv_balance=%s",
|
||||
i,
|
||||
r.get("member_id"),
|
||||
_fmt(r.get("display_score")),
|
||||
_fmt(r.get("raw_score"), 4),
|
||||
_fmt(r.get("t_v"), 2),
|
||||
_fmt(r.get("visits_14d"), 0),
|
||||
_fmt(r.get("sv_balance"), 2),
|
||||
)
|
||||
|
||||
return result
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
def test_newconv_index() -> Dict:
|
||||
logger.info("=" * 80)
|
||||
logger.info("Run NCI task")
|
||||
logger.info("=" * 80)
|
||||
|
||||
config, db_conn, db = _make_db()
|
||||
try:
|
||||
task = NewconvIndexTask(config, db, None, logger)
|
||||
result = task.execute(None)
|
||||
logger.info("NCI result: %s", result)
|
||||
|
||||
if result.get("status") == "success":
|
||||
stats_sql = """
|
||||
SELECT
|
||||
COUNT(*) AS total_count,
|
||||
ROUND(AVG(display_score)::numeric, 2) AS avg_display,
|
||||
ROUND(MIN(display_score)::numeric, 2) AS min_display,
|
||||
ROUND(MAX(display_score)::numeric, 2) AS max_display,
|
||||
ROUND(AVG(display_score_welcome)::numeric, 2) AS avg_display_welcome,
|
||||
ROUND(AVG(display_score_convert)::numeric, 2) AS avg_display_convert,
|
||||
ROUND(AVG(raw_score)::numeric, 4) AS avg_raw,
|
||||
ROUND(AVG(raw_score_welcome)::numeric, 4) AS avg_raw_welcome,
|
||||
ROUND(AVG(raw_score_convert)::numeric, 4) AS avg_raw_convert,
|
||||
ROUND(AVG(need_new)::numeric, 4) AS avg_need,
|
||||
ROUND(AVG(salvage_new)::numeric, 4) AS avg_salvage,
|
||||
ROUND(AVG(recharge_new)::numeric, 4) AS avg_recharge,
|
||||
ROUND(AVG(value_new)::numeric, 4) AS avg_value,
|
||||
ROUND(AVG(welcome_new)::numeric, 4) AS avg_welcome,
|
||||
ROUND(AVG(t_v)::numeric, 2) AS avg_t_v
|
||||
FROM billiards_dws.dws_member_newconv_index
|
||||
"""
|
||||
stats_rows = _dict_rows(db.query(stats_sql))
|
||||
if stats_rows:
|
||||
s = stats_rows[0]
|
||||
logger.info(
|
||||
"NCI stats | total=%s, display(avg/min/max)=%s/%s/%s, display_welcome=%s, display_convert=%s, raw_avg=%s, raw_welcome=%s, raw_convert=%s",
|
||||
s.get("total_count"),
|
||||
_fmt(s.get("avg_display")),
|
||||
_fmt(s.get("min_display")),
|
||||
_fmt(s.get("max_display")),
|
||||
_fmt(s.get("avg_display_welcome")),
|
||||
_fmt(s.get("avg_display_convert")),
|
||||
_fmt(s.get("avg_raw"), 4),
|
||||
_fmt(s.get("avg_raw_welcome"), 4),
|
||||
_fmt(s.get("avg_raw_convert"), 4),
|
||||
)
|
||||
logger.info(
|
||||
"NCI components | need=%s, salvage=%s, recharge=%s, value=%s, welcome=%s, t_v=%s",
|
||||
_fmt(s.get("avg_need"), 4),
|
||||
_fmt(s.get("avg_salvage"), 4),
|
||||
_fmt(s.get("avg_recharge"), 4),
|
||||
_fmt(s.get("avg_value"), 4),
|
||||
_fmt(s.get("avg_welcome"), 4),
|
||||
_fmt(s.get("avg_t_v"), 2),
|
||||
)
|
||||
|
||||
top_sql = """
|
||||
SELECT member_id, display_score, display_score_welcome, display_score_convert,
|
||||
raw_score, raw_score_welcome, raw_score_convert, t_v, visits_14d
|
||||
FROM billiards_dws.dws_member_newconv_index
|
||||
ORDER BY display_score DESC NULLS LAST
|
||||
LIMIT 5
|
||||
"""
|
||||
for i, r in enumerate(_dict_rows(db.query(top_sql)), 1):
|
||||
logger.info(
|
||||
"NCI TOP%d | member=%s, nci=%s (welcome=%s, convert=%s), raw=%s (w=%s,c=%s), t_v=%s, visits_14d=%s",
|
||||
i,
|
||||
r.get("member_id"),
|
||||
_fmt(r.get("display_score")),
|
||||
_fmt(r.get("display_score_welcome")),
|
||||
_fmt(r.get("display_score_convert")),
|
||||
_fmt(r.get("raw_score"), 4),
|
||||
_fmt(r.get("raw_score_welcome"), 4),
|
||||
_fmt(r.get("raw_score_convert"), 4),
|
||||
_fmt(r.get("t_v"), 2),
|
||||
_fmt(r.get("visits_14d"), 0),
|
||||
)
|
||||
|
||||
return result
|
||||
finally:
|
||||
db_conn.close()
|
||||
|
||||
|
||||
|
||||
|
||||
def main() -> None:
|
||||
_check_required_tables()
|
||||
|
||||
results = {
|
||||
"WBI": test_winback_index(),
|
||||
"NCI": test_newconv_index(),
|
||||
}
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("Test complete")
|
||||
logger.info("WBI=%s, NCI=%s", results["WBI"].get("status"), results["NCI"].get("status"))
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
apps/etl/pipelines/feiqiu/tests/unit/__init__.py
Normal file
0
apps/etl/pipelines/feiqiu/tests/unit/__init__.py
Normal file
392
apps/etl/pipelines/feiqiu/tests/unit/task_test_utils.py
Normal file
392
apps/etl/pipelines/feiqiu/tests/unit/task_test_utils.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-14 | 删除废弃的 14 个独立 ODS 任务 TaskSpec 定义和对应 import;修复语法错误(TASK_SPECS=[] 后残留孤立代码块)
|
||||
# 直接原因: 之前清理只把 TASK_SPECS 赋值为空列表,但未删除后续 ~370 行废弃 TaskSpec 定义,导致 IndentationError
|
||||
# 验证: `python -c "import ast; ast.parse(open('tests/unit/task_test_utils.py','utf-8').read()); print('OK')"`
|
||||
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Tuple, Type
|
||||
|
||||
from config.settings import AppConfig
|
||||
from database.connection import DatabaseConnection
|
||||
from database.operations import DatabaseOperations as PgDBOperations
|
||||
from utils.json_store import endpoint_to_filename
|
||||
|
||||
DEFAULT_STORE_ID = 2790685415443269
|
||||
BASE_TS = "2025-01-01 10:00:00"
|
||||
END_TS = "2025-01-01 12:00:00"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskSpec:
|
||||
"""描述单个任务在测试中如何被驱动的元数据,包含任务代码、API 路径、数据路径与样例记录。"""
|
||||
|
||||
code: str
|
||||
task_cls: Type
|
||||
endpoint: str
|
||||
data_path: Tuple[str, ...]
|
||||
sample_records: List[Dict]
|
||||
|
||||
@property
|
||||
def archive_filename(self) -> str:
|
||||
return endpoint_to_filename(self.endpoint)
|
||||
|
||||
|
||||
def wrap_records(records: List[Dict], data_path: Sequence[str]):
|
||||
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
|
||||
payload = records
|
||||
for key in reversed(data_path):
|
||||
payload = {key: payload}
|
||||
return payload
|
||||
|
||||
|
||||
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
|
||||
"""构建一份适合测试的 AppConfig,自动填充存储、日志、归档目录等参数并保证目录存在。"""
|
||||
archive_dir = Path(archive_dir)
|
||||
temp_dir = Path(temp_dir)
|
||||
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
|
||||
overrides = {
|
||||
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Shanghai"},
|
||||
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
|
||||
"api": {
|
||||
"base_url": "https://api.example.com",
|
||||
"token": "test-token",
|
||||
"timeout_sec": 3,
|
||||
"page_size": 50,
|
||||
},
|
||||
"pipeline": {
|
||||
"flow": flow,
|
||||
"fetch_root": str(temp_dir / "json_fetch"),
|
||||
"ingest_source_dir": str(archive_dir),
|
||||
},
|
||||
"io": {
|
||||
"export_root": str(temp_dir / "export"),
|
||||
"log_root": str(temp_dir / "logs"),
|
||||
},
|
||||
}
|
||||
return AppConfig.load(overrides)
|
||||
|
||||
|
||||
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
|
||||
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
|
||||
archive_dir = Path(archive_dir)
|
||||
payload = wrap_records(spec.sample_records, spec.data_path)
|
||||
file_path = archive_dir / spec.archive_filename
|
||||
with file_path.open("w", encoding="utf-8") as fp:
|
||||
json.dump(payload, fp, ensure_ascii=False)
|
||||
return file_path
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
|
||||
|
||||
def __init__(self, recorder: List[Dict], db_ops=None):
|
||||
self.recorder = recorder
|
||||
self._db_ops = db_ops
|
||||
self._pending_rows: List[Tuple] = []
|
||||
self._fetchall_rows: List[Tuple] = []
|
||||
self.rowcount = 0
|
||||
self.connection = SimpleNamespace(encoding="UTF8")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def execute(self, sql: str, params=None):
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
# 处理 information_schema 查询,用于结构感知写入。
|
||||
lowered = sql_text.lower()
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
self._fetchall_rows = []
|
||||
return
|
||||
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
def fetchone(self):
|
||||
return None
|
||||
|
||||
def fetchall(self):
|
||||
return list(self._fetchall_rows)
|
||||
|
||||
def mogrify(self, template, args):
|
||||
self._pending_rows.append(tuple(args))
|
||||
return b"(?)"
|
||||
|
||||
def _record_upserts(self, sql_text: str):
|
||||
if not self._db_ops:
|
||||
return
|
||||
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
|
||||
if not match:
|
||||
return
|
||||
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
|
||||
rows = []
|
||||
for idx, row in enumerate(self._pending_rows):
|
||||
if len(row) != len(columns):
|
||||
continue
|
||||
row_dict = {}
|
||||
for col, val in zip(columns, row):
|
||||
if col == "record_index" and val in (None, ""):
|
||||
row_dict[col] = idx
|
||||
continue
|
||||
if hasattr(val, "adapted"):
|
||||
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
|
||||
else:
|
||||
row_dict[col] = val
|
||||
rows.append(row_dict)
|
||||
if rows:
|
||||
self._db_ops.upserts.append(
|
||||
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
|
||||
return [
|
||||
("id", "bigint", "int8"),
|
||||
("sitegoodsstockid", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("content_hash", "text", "text"),
|
||||
("source_file", "text", "text"),
|
||||
("source_endpoint", "text", "text"),
|
||||
("fetched_at", "timestamp with time zone", "timestamptz"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
|
||||
|
||||
def __init__(self, db_ops):
|
||||
self.statements: List[Dict] = []
|
||||
self._db_ops = db_ops
|
||||
|
||||
def cursor(self):
|
||||
return FakeCursor(self.statements, self._db_ops)
|
||||
|
||||
|
||||
class FakeDBOperations:
|
||||
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
|
||||
|
||||
def __init__(self):
|
||||
self.upserts: List[Dict] = []
|
||||
self.executes: List[Dict] = []
|
||||
self.commits = 0
|
||||
self.rollbacks = 0
|
||||
self.conn = FakeConnection(self)
|
||||
# 预设查询结果(FIFO),用于测试中控制数据库返回的行
|
||||
self.query_results: List[List[Dict]] = []
|
||||
|
||||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
self.upserts.append(
|
||||
{
|
||||
"sql": sql.strip(),
|
||||
"count": len(rows),
|
||||
"page_size": page_size,
|
||||
"rows": [dict(row) for row in rows],
|
||||
}
|
||||
)
|
||||
return len(rows), 0
|
||||
|
||||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
self.executes.append(
|
||||
{
|
||||
"sql": sql.strip(),
|
||||
"count": len(rows),
|
||||
"page_size": page_size,
|
||||
"rows": [dict(row) for row in rows],
|
||||
}
|
||||
)
|
||||
|
||||
def execute(self, sql: str, params=None):
|
||||
self.executes.append({"sql": sql.strip(), "params": params})
|
||||
|
||||
def query(self, sql: str, params=None):
|
||||
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
|
||||
if self.query_results:
|
||||
return self.query_results.pop(0)
|
||||
return []
|
||||
|
||||
def cursor(self):
|
||||
return self.conn.cursor()
|
||||
|
||||
def commit(self):
|
||||
self.commits += 1
|
||||
|
||||
def rollback(self):
|
||||
self.rollbacks += 1
|
||||
|
||||
|
||||
class FakeAPIClient:
|
||||
"""在线模式使用的伪 API Client,直接返回预置的内存数据并记录调用,以确保任务参数正确传递。"""
|
||||
|
||||
def __init__(self, data_map: Dict[str, List[Dict]]):
|
||||
self.data_map = data_map
|
||||
self.calls: List[Dict] = []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def iter_paginated(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
page_size: int = 200,
|
||||
page_field: str = "page",
|
||||
size_field: str = "limit",
|
||||
data_path: Tuple[str, ...] = (),
|
||||
list_key: str | None = None,
|
||||
):
|
||||
self.calls.append({"endpoint": endpoint, "params": params})
|
||||
if endpoint not in self.data_map:
|
||||
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
|
||||
|
||||
records = list(self.data_map[endpoint])
|
||||
yield 1, records, dict(params or {}), {"data": records}
|
||||
|
||||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||||
records = []
|
||||
pages = []
|
||||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||||
records.extend(page_records)
|
||||
pages.append({"page": page_no, "request": req, "response": resp})
|
||||
return records, pages
|
||||
|
||||
def get_source_hint(self, endpoint: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class OfflineAPIClient:
|
||||
"""离线模式专用 API Client,根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
|
||||
|
||||
def __init__(self, file_map: Dict[str, Path]):
|
||||
self.file_map = {k: Path(v) for k, v in file_map.items()}
|
||||
self.calls: List[Dict] = []
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def iter_paginated(
|
||||
self,
|
||||
endpoint: str,
|
||||
params=None,
|
||||
page_size: int = 200,
|
||||
page_field: str = "page",
|
||||
size_field: str = "limit",
|
||||
data_path: Tuple[str, ...] = (),
|
||||
list_key: str | None = None,
|
||||
):
|
||||
self.calls.append({"endpoint": endpoint, "params": params})
|
||||
if endpoint not in self.file_map:
|
||||
raise AssertionError(f"Missing archive for endpoint {endpoint}")
|
||||
|
||||
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
|
||||
payload = json.load(fp)
|
||||
|
||||
data = payload
|
||||
for key in data_path:
|
||||
if isinstance(data, dict):
|
||||
data = data.get(key, [])
|
||||
|
||||
if list_key and isinstance(data, dict):
|
||||
data = data.get(list_key, [])
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = []
|
||||
|
||||
total = len(data)
|
||||
start = 0
|
||||
page = 1
|
||||
while start < total or (start == 0 and total == 0):
|
||||
chunk = data[start : start + page_size]
|
||||
if not chunk and total != 0:
|
||||
break
|
||||
yield page, list(chunk), dict(params or {}), payload
|
||||
if len(chunk) < page_size:
|
||||
break
|
||||
start += page_size
|
||||
page += 1
|
||||
|
||||
def get_paginated(self, endpoint: str, params=None, **kwargs):
|
||||
records = []
|
||||
pages = []
|
||||
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
|
||||
records.extend(page_records)
|
||||
pages.append({"page": page_no, "request": req, "response": resp})
|
||||
return records, pages
|
||||
|
||||
def get_source_hint(self, endpoint: str) -> str | None:
|
||||
if endpoint not in self.file_map:
|
||||
return None
|
||||
return str(self.file_map[endpoint])
|
||||
|
||||
|
||||
class RealDBOperationsAdapter:
|
||||
|
||||
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
|
||||
|
||||
def __init__(self, dsn: str):
|
||||
self._conn = DatabaseConnection(dsn)
|
||||
self._ops = PgDBOperations(self._conn)
|
||||
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
|
||||
self.conn = self._conn.conn
|
||||
|
||||
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
|
||||
|
||||
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
|
||||
return self._ops.batch_execute(sql, rows, page_size=page_size)
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def rollback(self):
|
||||
self._conn.rollback()
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_operations():
|
||||
"""
|
||||
测试专用的 DB 操作上下文:
|
||||
- 若设置 TEST_DB_DSN,则连接真实 PostgreSQL;
|
||||
- 否则回退到 FakeDBOperations(内存桩)。
|
||||
"""
|
||||
dsn = os.environ.get("TEST_DB_DSN")
|
||||
if dsn:
|
||||
adapter = RealDBOperationsAdapter(dsn)
|
||||
try:
|
||||
yield adapter
|
||||
finally:
|
||||
adapter.close()
|
||||
else:
|
||||
fake = FakeDBOperations()
|
||||
yield fake
|
||||
|
||||
|
||||
# 14 个独立 ODS 任务已废弃删除(写入不存在的 billiards.* schema,已被通用 ODS 任务替代)
|
||||
TASK_SPECS: List[TaskSpec] = []
|
||||
696
apps/etl/pipelines/feiqiu/tests/unit/test_audit_doc_alignment.py
Normal file
696
apps/etl/pipelines/feiqiu/tests/unit/test_audit_doc_alignment.py
Normal file
@@ -0,0 +1,696 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试 — 文档对齐分析器 (doc_alignment_analyzer.py)
|
||||
|
||||
覆盖:
|
||||
- scan_docs 文档来源识别
|
||||
- extract_code_references 代码引用提取
|
||||
- check_reference_validity 引用有效性检查
|
||||
- find_undocumented_modules 缺失文档检测
|
||||
- check_ddl_vs_dictionary DDL 与数据字典比对
|
||||
- check_api_samples_vs_parsers API 样本与解析器比对
|
||||
- render_alignment_report 报告渲染
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.audit import AlignmentIssue, DocMapping
|
||||
from scripts.audit.doc_alignment_analyzer import (
|
||||
_parse_ddl_tables,
|
||||
_parse_dictionary_tables,
|
||||
build_mappings,
|
||||
check_api_samples_vs_parsers,
|
||||
check_ddl_vs_dictionary,
|
||||
check_reference_validity,
|
||||
extract_code_references,
|
||||
find_undocumented_modules,
|
||||
render_alignment_report,
|
||||
scan_docs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_docs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScanDocs:
|
||||
"""文档来源识别测试。"""
|
||||
|
||||
def test_finds_docs_dir_md(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "docs").mkdir()
|
||||
(tmp_path / "docs" / "guide.md").write_text("# Guide", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "docs/guide.md" in result
|
||||
|
||||
def test_finds_root_readme(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "README.md").write_text("# Readme", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "README.md" in result
|
||||
|
||||
def test_finds_docs_subdir_requirements(self, tmp_path: Path) -> None:
|
||||
"""docs/requirements/ 下的文件应被扫描到。"""
|
||||
req_dir = tmp_path / "docs" / "requirements"
|
||||
req_dir.mkdir(parents=True)
|
||||
(req_dir / "需求.md").write_text("需求", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "docs/requirements/需求.md" in result
|
||||
|
||||
def test_finds_module_readme(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "gui").mkdir()
|
||||
(tmp_path / "gui" / "README.md").write_text("# GUI", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "gui/README.md" in result
|
||||
|
||||
def test_finds_steering_files(self, tmp_path: Path) -> None:
|
||||
steering = tmp_path / ".kiro" / "steering"
|
||||
steering.mkdir(parents=True)
|
||||
(steering / "tech.md").write_text("# Tech", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert ".kiro/steering/tech.md" in result
|
||||
|
||||
def test_finds_json_samples(self, tmp_path: Path) -> None:
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "member.json").write_text("[]", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert "docs/test-json-doc/member.json" in result
|
||||
|
||||
def test_empty_repo_returns_empty(self, tmp_path: Path) -> None:
|
||||
result = scan_docs(tmp_path)
|
||||
assert result == []
|
||||
|
||||
def test_results_sorted(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "docs").mkdir()
|
||||
(tmp_path / "docs" / "z.md").write_text("z", encoding="utf-8")
|
||||
(tmp_path / "docs" / "a.md").write_text("a", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("r", encoding="utf-8")
|
||||
result = scan_docs(tmp_path)
|
||||
assert result == sorted(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_code_references
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractCodeReferences:
|
||||
"""代码引用提取测试。"""
|
||||
|
||||
def test_extracts_backtick_paths(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("使用 `tasks/base_task.py` 作为基类", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert "tasks/base_task.py" in refs
|
||||
|
||||
def test_extracts_class_names(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("继承 `BaseTask` 类", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert "BaseTask" in refs
|
||||
|
||||
def test_skips_single_char(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("变量 `x` 和 `y`", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert refs == []
|
||||
|
||||
def test_skips_pure_numbers(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("版本 `2.0.0` 和 ID `12345`", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert refs == []
|
||||
|
||||
def test_deduplicates(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("`foo.py` 和 `foo.py` 重复", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert refs.count("foo.py") == 1
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
|
||||
refs = extract_code_references(tmp_path / "nonexistent.md")
|
||||
assert refs == []
|
||||
|
||||
def test_normalizes_backslash(self, tmp_path: Path) -> None:
|
||||
doc = tmp_path / "doc.md"
|
||||
doc.write_text("路径 `tasks\\base_task.py`", encoding="utf-8")
|
||||
refs = extract_code_references(doc)
|
||||
assert "tasks/base_task.py" in refs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_reference_validity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckReferenceValidity:
|
||||
"""引用有效性检查测试。"""
|
||||
|
||||
def test_valid_file_path(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "tasks").mkdir()
|
||||
(tmp_path / "tasks" / "base.py").write_text("", encoding="utf-8")
|
||||
assert check_reference_validity("tasks/base.py", tmp_path) is True
|
||||
|
||||
def test_invalid_file_path(self, tmp_path: Path) -> None:
|
||||
assert check_reference_validity("nonexistent/file.py", tmp_path) is False
|
||||
|
||||
def test_strips_legacy_prefix(self, tmp_path: Path) -> None:
|
||||
"""兼容旧包名前缀(etl_billiards/)和当前根目录前缀(FQ-ETL/)"""
|
||||
(tmp_path / "tasks").mkdir()
|
||||
(tmp_path / "tasks" / "x.py").write_text("", encoding="utf-8")
|
||||
assert check_reference_validity("etl_billiards/tasks/x.py", tmp_path) is True
|
||||
assert check_reference_validity("FQ-ETL/tasks/x.py", tmp_path) is True
|
||||
|
||||
def test_directory_path(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "loaders").mkdir()
|
||||
assert check_reference_validity("loaders", tmp_path) is True
|
||||
|
||||
def test_dotted_module_path(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "config").mkdir()
|
||||
(tmp_path / "config" / "settings.py").write_text("", encoding="utf-8")
|
||||
assert check_reference_validity("config.settings", tmp_path) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_undocumented_modules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindUndocumentedModules:
|
||||
"""缺失文档检测测试。"""
|
||||
|
||||
def test_finds_undocumented(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert "tasks/ods_task.py" in result
|
||||
|
||||
def test_excludes_init(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert all("__init__" not in r for r in result)
|
||||
|
||||
def test_documented_module_excluded(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, {"tasks/ods_task.py"})
|
||||
assert "tasks/ods_task.py" not in result
|
||||
|
||||
def test_non_core_dirs_ignored(self, tmp_path: Path) -> None:
|
||||
"""gui/ 不在核心代码目录列表中,不应被检测。"""
|
||||
gui_dir = tmp_path / "gui"
|
||||
gui_dir.mkdir()
|
||||
(gui_dir / "main.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert all("gui/" not in r for r in result)
|
||||
|
||||
def test_results_sorted(self, tmp_path: Path) -> None:
|
||||
tasks_dir = tmp_path / "tasks"
|
||||
tasks_dir.mkdir()
|
||||
(tasks_dir / "z_task.py").write_text("", encoding="utf-8")
|
||||
(tasks_dir / "a_task.py").write_text("", encoding="utf-8")
|
||||
result = find_undocumented_modules(tmp_path, set())
|
||||
assert result == sorted(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_ddl_tables / _parse_dictionary_tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseDdlTables:
|
||||
"""DDL 解析测试。"""
|
||||
|
||||
def test_extracts_table_and_columns(self) -> None:
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS dim_member (
|
||||
member_id BIGINT,
|
||||
nickname TEXT,
|
||||
mobile TEXT,
|
||||
PRIMARY KEY (member_id)
|
||||
);
|
||||
"""
|
||||
result = _parse_ddl_tables(sql)
|
||||
assert "dim_member" in result
|
||||
assert "member_id" in result["dim_member"]
|
||||
assert "nickname" in result["dim_member"]
|
||||
assert "mobile" in result["dim_member"]
|
||||
|
||||
def test_handles_schema_prefix(self) -> None:
|
||||
sql = "CREATE TABLE billiards_dwd.dim_site (\n site_id BIGINT\n);"
|
||||
result = _parse_ddl_tables(sql)
|
||||
assert "dim_site" in result
|
||||
|
||||
def test_excludes_sql_keywords(self) -> None:
|
||||
sql = """
|
||||
CREATE TABLE test_tbl (
|
||||
id INTEGER,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
"""
|
||||
result = _parse_ddl_tables(sql)
|
||||
assert "primary" not in result.get("test_tbl", set())
|
||||
|
||||
|
||||
class TestParseDictionaryTables:
|
||||
"""数据字典解析测试。"""
|
||||
|
||||
def test_extracts_table_and_fields(self) -> None:
|
||||
md = """## dim_member
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| member_id | BIGINT | 会员ID |
|
||||
| nickname | TEXT | 昵称 |
|
||||
"""
|
||||
result = _parse_dictionary_tables(md)
|
||||
assert "dim_member" in result
|
||||
assert "member_id" in result["dim_member"]
|
||||
assert "nickname" in result["dim_member"]
|
||||
|
||||
def test_skips_header_row(self) -> None:
|
||||
md = """## dim_test
|
||||
|
||||
| 字段 | 类型 |
|
||||
|------|------|
|
||||
| col_a | INT |
|
||||
"""
|
||||
result = _parse_dictionary_tables(md)
|
||||
assert "字段" not in result.get("dim_test", set())
|
||||
|
||||
def test_handles_backtick_table_name(self) -> None:
|
||||
md = "## `dim_goods`\n\n| 字段 |\n| goods_id |"
|
||||
result = _parse_dictionary_tables(md)
|
||||
assert "dim_goods" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_ddl_vs_dictionary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckDdlVsDictionary:
|
||||
"""DDL 与数据字典比对测试。"""
|
||||
|
||||
def test_detects_missing_table_in_dictionary(self, tmp_path: Path) -> None:
|
||||
# DDL 有表,字典没有
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_test.sql").write_text(
|
||||
"CREATE TABLE dim_orphan (\n id BIGINT\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
|
||||
"## dim_other\n\n| 字段 |\n| id |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_ddl_vs_dictionary(tmp_path)
|
||||
missing = [i for i in issues if i.issue_type == "missing"]
|
||||
assert any("dim_orphan" in i.description for i in missing)
|
||||
|
||||
def test_detects_column_mismatch(self, tmp_path: Path) -> None:
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_test.sql").write_text(
|
||||
"CREATE TABLE dim_x (\n id BIGINT,\n extra_col TEXT\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
|
||||
"## dim_x\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_ddl_vs_dictionary(tmp_path)
|
||||
conflict = [i for i in issues if i.issue_type == "conflict"]
|
||||
assert any("extra_col" in i.description for i in conflict)
|
||||
|
||||
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_test.sql").write_text(
|
||||
"CREATE TABLE dim_ok (\n id BIGINT\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
|
||||
"## dim_ok\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_ddl_vs_dictionary(tmp_path)
|
||||
assert len(issues) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_api_samples_vs_parsers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckApiSamplesVsParsers:
|
||||
"""API 样本与解析器比对测试。"""
|
||||
|
||||
def test_detects_json_field_not_in_ods(self, tmp_path: Path) -> None:
|
||||
# JSON 样本有 extra_field,ODS 没有
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "test_entity.json").write_text(
|
||||
json.dumps([{"id": 1, "name": "a", "extra_field": "x"}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_ODS_doc.sql").write_text(
|
||||
"CREATE TABLE billiards_ods.test_entity (\n"
|
||||
" id BIGINT,\n name TEXT,\n"
|
||||
" content_hash TEXT,\n payload JSONB\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_api_samples_vs_parsers(tmp_path)
|
||||
assert any("extra_field" in i.description for i in issues)
|
||||
|
||||
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "aligned_entity.json").write_text(
|
||||
json.dumps([{"id": 1, "name": "a"}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_ODS_doc.sql").write_text(
|
||||
"CREATE TABLE billiards_ods.aligned_entity (\n"
|
||||
" id BIGINT,\n name TEXT,\n"
|
||||
" content_hash TEXT,\n payload JSONB\n);",
|
||||
encoding="utf-8",
|
||||
)
|
||||
issues = check_api_samples_vs_parsers(tmp_path)
|
||||
assert len(issues) == 0
|
||||
|
||||
def test_skips_when_no_ods_table(self, tmp_path: Path) -> None:
|
||||
sample_dir = tmp_path / "docs" / "test-json-doc"
|
||||
sample_dir.mkdir(parents=True)
|
||||
(sample_dir / "unknown.json").write_text(
|
||||
json.dumps([{"a": 1}]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
db_dir = tmp_path / "database"
|
||||
db_dir.mkdir()
|
||||
(db_dir / "schema_ODS_doc.sql").write_text("-- empty", encoding="utf-8")
|
||||
issues = check_api_samples_vs_parsers(tmp_path)
|
||||
assert len(issues) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_alignment_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRenderAlignmentReport:
|
||||
"""报告渲染测试。"""
|
||||
|
||||
def test_contains_all_sections(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
assert "## 映射关系" in report
|
||||
assert "## 过期点" in report
|
||||
assert "## 冲突点" in report
|
||||
assert "## 缺失点" in report
|
||||
assert "## 统计摘要" in report
|
||||
|
||||
def test_contains_header_metadata(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
assert "生成时间" in report
|
||||
assert "`/repo`" in report
|
||||
|
||||
def test_contains_iso_timestamp(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
# ISO 格式时间戳包含 T 和 Z
|
||||
import re
|
||||
assert re.search(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", report)
|
||||
|
||||
def test_mapping_table_rendered(self) -> None:
|
||||
mappings = [
|
||||
DocMapping(
|
||||
doc_path="docs/guide.md",
|
||||
doc_topic="项目文档",
|
||||
related_code=["tasks/base.py"],
|
||||
status="aligned",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report(mappings, [], "/repo")
|
||||
assert "`docs/guide.md`" in report
|
||||
assert "`tasks/base.py`" in report
|
||||
assert "aligned" in report
|
||||
|
||||
def test_stale_issues_rendered(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue(
|
||||
doc_path="docs/old.md",
|
||||
issue_type="stale",
|
||||
description="引用了已删除的文件",
|
||||
related_code="tasks/deleted.py",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report([], issues, "/repo")
|
||||
assert "引用了已删除的文件" in report
|
||||
assert "## 过期点" in report
|
||||
|
||||
def test_conflict_issues_rendered(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue(
|
||||
doc_path="docs/dict.md",
|
||||
issue_type="conflict",
|
||||
description="字段不一致",
|
||||
related_code="database/schema.sql",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report([], issues, "/repo")
|
||||
assert "字段不一致" in report
|
||||
|
||||
def test_missing_issues_rendered(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue(
|
||||
doc_path="docs/dict.md",
|
||||
issue_type="missing",
|
||||
description="缺少表定义",
|
||||
related_code="database/schema.sql",
|
||||
)
|
||||
]
|
||||
report = render_alignment_report([], issues, "/repo")
|
||||
assert "缺少表定义" in report
|
||||
|
||||
def test_summary_counts(self) -> None:
|
||||
issues = [
|
||||
AlignmentIssue("a", "stale", "d1", "c1"),
|
||||
AlignmentIssue("b", "stale", "d2", "c2"),
|
||||
AlignmentIssue("c", "conflict", "d3", "c3"),
|
||||
AlignmentIssue("d", "missing", "d4", "c4"),
|
||||
]
|
||||
mappings = [DocMapping("x", "t", [], "aligned")]
|
||||
report = render_alignment_report(mappings, issues, "/repo")
|
||||
assert "过期点数量:2" in report
|
||||
assert "冲突点数量:1" in report
|
||||
assert "缺失点数量:1" in report
|
||||
assert "文档总数:1" in report
|
||||
|
||||
def test_empty_report(self) -> None:
|
||||
report = render_alignment_report([], [], "/repo")
|
||||
assert "未发现过期点" in report
|
||||
assert "未发现冲突点" in report
|
||||
assert "未发现缺失点" in report
|
||||
assert "过期点数量:0" in report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 11 / 12 / 16 (hypothesis)
|
||||
# hypothesis 与 pytest 的 function-scoped fixture (tmp_path) 不兼容,
|
||||
# 因此在测试内部使用 tempfile.mkdtemp 自行管理临时目录。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit.doc_alignment_analyzer import _CORE_CODE_DIRS
|
||||
|
||||
|
||||
class TestPropertyStaleReferenceDetection:
|
||||
"""Feature: repo-audit, Property 11: 过期引用检测
|
||||
|
||||
*对于任意* 文档中提取的代码引用,若该引用指向的文件路径在仓库中不存在,
|
||||
则 check_reference_validity 应返回 False。
|
||||
|
||||
Validates: Requirements 3.3
|
||||
"""
|
||||
|
||||
_safe_name = st.from_regex(r"[a-z][a-z0-9_]{1,12}", fullmatch=True)
|
||||
|
||||
@given(
|
||||
existing_names=st.lists(
|
||||
_safe_name, min_size=1, max_size=5, unique=True,
|
||||
),
|
||||
missing_names=st.lists(
|
||||
_safe_name, min_size=1, max_size=5, unique=True,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_nonexistent_path_returns_false(
|
||||
self,
|
||||
existing_names: list[str],
|
||||
missing_names: list[str],
|
||||
) -> None:
|
||||
"""不存在的文件路径引用应返回 False。"""
|
||||
tmp = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
for name in existing_names:
|
||||
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
|
||||
|
||||
existing_set = set(existing_names)
|
||||
# 只检查确实不存在的名称
|
||||
truly_missing = [n for n in missing_names if n not in existing_set]
|
||||
for name in truly_missing:
|
||||
ref = f"nonexistent_dir/{name}.py"
|
||||
result = check_reference_validity(ref, tmp)
|
||||
assert result is False, (
|
||||
f"引用 '{ref}' 指向不存在的文件,但返回了 True"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
@given(
|
||||
existing_names=st.lists(
|
||||
_safe_name, min_size=1, max_size=5, unique=True,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_existing_path_returns_true(
|
||||
self,
|
||||
existing_names: list[str],
|
||||
) -> None:
|
||||
"""存在的文件路径引用应返回 True。"""
|
||||
tmp = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
for name in existing_names:
|
||||
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
|
||||
|
||||
for name in existing_names:
|
||||
ref = f"{name}.py"
|
||||
result = check_reference_validity(ref, tmp)
|
||||
assert result is True, (
|
||||
f"引用 '{ref}' 指向存在的文件,但返回了 False"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
|
||||
class TestPropertyMissingDocDetection:
|
||||
"""Feature: repo-audit, Property 12: 缺失文档检测
|
||||
|
||||
*对于任意* 核心代码模块集合和已文档化模块集合,
|
||||
find_undocumented_modules 返回的缺失列表应恰好等于核心模块集合与已文档化集合的差集。
|
||||
|
||||
Validates: Requirements 3.5
|
||||
"""
|
||||
|
||||
_core_dir = st.sampled_from(list(_CORE_CODE_DIRS))
|
||||
_module_name = st.from_regex(r"[a-z][a-z0-9_]{1,10}", fullmatch=True)
|
||||
|
||||
@given(
|
||||
core_dir=_core_dir,
|
||||
module_names=st.lists(
|
||||
_module_name, min_size=2, max_size=6, unique=True,
|
||||
),
|
||||
doc_fraction=st.floats(min_value=0.0, max_value=1.0),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_undocumented_equals_difference(
|
||||
self,
|
||||
core_dir: str,
|
||||
module_names: list[str],
|
||||
doc_fraction: float,
|
||||
) -> None:
|
||||
"""返回的缺失列表应恰好等于核心模块与已文档化集合的差集。"""
|
||||
tmp = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
code_dir = tmp / core_dir
|
||||
code_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
all_modules: set[str] = set()
|
||||
for name in module_names:
|
||||
(code_dir / f"{name}.py").write_text("# module", encoding="utf-8")
|
||||
all_modules.add(f"{core_dir}/{name}.py")
|
||||
|
||||
split_idx = int(len(module_names) * doc_fraction)
|
||||
documented = {
|
||||
f"{core_dir}/{n}.py" for n in module_names[:split_idx]
|
||||
}
|
||||
|
||||
result = find_undocumented_modules(tmp, documented)
|
||||
expected = sorted(all_modules - documented)
|
||||
|
||||
assert result == expected, (
|
||||
f"期望缺失列表 {expected},实际得到 {result}"
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
|
||||
class TestPropertyAlignmentReportSections:
|
||||
"""Feature: repo-audit, Property 16: 文档对齐报告分区完整性
|
||||
|
||||
*对于任意* render_alignment_report 的输出,Markdown 文本应包含
|
||||
"映射关系"、"过期点"、"冲突点"、"缺失点"四个分区标题。
|
||||
|
||||
Validates: Requirements 3.8
|
||||
"""
|
||||
|
||||
_issue_type = st.sampled_from(["stale", "conflict", "missing"])
|
||||
_text = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P"),
|
||||
blacklist_characters="\x00",
|
||||
),
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
_mapping_st = st.builds(
|
||||
DocMapping,
|
||||
doc_path=_text,
|
||||
doc_topic=_text,
|
||||
related_code=st.lists(_text, max_size=3),
|
||||
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
|
||||
)
|
||||
|
||||
_issue_st = st.builds(
|
||||
AlignmentIssue,
|
||||
doc_path=_text,
|
||||
issue_type=_issue_type,
|
||||
description=_text,
|
||||
related_code=_text,
|
||||
)
|
||||
|
||||
@given(
|
||||
mappings=st.lists(_mapping_st, max_size=5),
|
||||
issues=st.lists(_issue_st, max_size=8),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_report_contains_four_sections(
|
||||
self,
|
||||
mappings: list[DocMapping],
|
||||
issues: list[AlignmentIssue],
|
||||
) -> None:
|
||||
"""报告应包含四个分区标题。"""
|
||||
report = render_alignment_report(mappings, issues, "/test/repo")
|
||||
|
||||
required_sections = ["## 映射关系", "## 过期点", "## 冲突点", "## 缺失点"]
|
||||
for section in required_sections:
|
||||
assert section in report, (
|
||||
f"报告中缺少分区标题 '{section}'"
|
||||
)
|
||||
667
apps/etl/pipelines/feiqiu/tests/unit/test_audit_flow.py
Normal file
667
apps/etl/pipelines/feiqiu/tests/unit/test_audit_flow.py
Normal file
@@ -0,0 +1,667 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试 — 流程树分析器 (flow_analyzer.py)
|
||||
|
||||
覆盖:
|
||||
- parse_imports: import 语句解析、标准库/第三方排除、语法错误容错
|
||||
- build_flow_tree: 递归构建、循环导入处理
|
||||
- find_orphan_modules: 孤立模块检测
|
||||
- render_flow_report: Markdown 渲染、Mermaid 图、统计摘要
|
||||
- discover_entry_points: 入口点识别
|
||||
- classify_task_type / classify_loader_type: 类型区分
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.audit import FileEntry, FlowNode
|
||||
from scripts.audit.flow_analyzer import (
|
||||
build_flow_tree,
|
||||
classify_loader_type,
|
||||
classify_task_type,
|
||||
discover_entry_points,
|
||||
find_orphan_modules,
|
||||
parse_imports,
|
||||
render_flow_report,
|
||||
_path_to_module_name,
|
||||
_parse_bat_python_target,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_imports 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseImports:
|
||||
"""import 语句解析测试。"""
|
||||
|
||||
def test_absolute_import(self, tmp_path: Path) -> None:
|
||||
"""绝对导入项目内部模块应被识别。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import cli.main\nimport config.settings\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert "cli.main" in result
|
||||
assert "config.settings" in result
|
||||
|
||||
def test_from_import(self, tmp_path: Path) -> None:
|
||||
"""from ... import 语句应被识别。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("from tasks.base_task import BaseTask\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert "tasks.base_task" in result
|
||||
|
||||
def test_stdlib_excluded(self, tmp_path: Path) -> None:
|
||||
"""标准库模块应被排除。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import os\nimport sys\nimport json\nfrom pathlib import Path\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
def test_third_party_excluded(self, tmp_path: Path) -> None:
|
||||
"""第三方包应被排除。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import requests\nfrom psycopg2 import sql\nimport flask\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
def test_mixed_imports(self, tmp_path: Path) -> None:
|
||||
"""混合导入应只保留项目内部模块。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text(
|
||||
"import os\nimport cli.main\nimport requests\nfrom loaders.base_loader import BaseLoader\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = parse_imports(f)
|
||||
assert "cli.main" in result
|
||||
assert "loaders.base_loader" in result
|
||||
assert "os" not in result
|
||||
assert "requests" not in result
|
||||
|
||||
def test_syntax_error_returns_empty(self, tmp_path: Path) -> None:
|
||||
"""语法错误的文件应返回空列表。"""
|
||||
f = tmp_path / "bad.py"
|
||||
f.write_text("def broken(\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
|
||||
"""不存在的文件应返回空列表。"""
|
||||
result = parse_imports(tmp_path / "nonexistent.py")
|
||||
assert result == []
|
||||
|
||||
def test_deduplication(self, tmp_path: Path) -> None:
|
||||
"""重复导入应去重。"""
|
||||
f = tmp_path / "test.py"
|
||||
f.write_text("import cli.main\nimport cli.main\nfrom cli.main import main\n", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result.count("cli.main") == 1
|
||||
|
||||
def test_empty_file(self, tmp_path: Path) -> None:
|
||||
"""空文件应返回空列表。"""
|
||||
f = tmp_path / "empty.py"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = parse_imports(f)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_flow_tree 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildFlowTree:
|
||||
"""流程树构建测试。"""
|
||||
|
||||
def test_single_file_no_imports(self, tmp_path: Path) -> None:
|
||||
"""无导入的单文件应生成叶节点。"""
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
tree = build_flow_tree(tmp_path, "cli/main.py")
|
||||
assert tree.name == "cli.main"
|
||||
assert tree.source_file == "cli/main.py"
|
||||
assert tree.children == []
|
||||
|
||||
def test_simple_import_chain(self, tmp_path: Path) -> None:
|
||||
"""简单导入链应正确构建子节点。"""
|
||||
# cli/main.py → config/settings.py
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text(
|
||||
"from config.settings import AppConfig\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(config_dir / "settings.py").write_text("class AppConfig: pass\n", encoding="utf-8")
|
||||
|
||||
tree = build_flow_tree(tmp_path, "cli/main.py")
|
||||
assert tree.name == "cli.main"
|
||||
assert len(tree.children) == 1
|
||||
assert tree.children[0].name == "config.settings"
|
||||
|
||||
def test_circular_import_no_infinite_loop(self, tmp_path: Path) -> None:
|
||||
"""循环导入不应导致无限递归。"""
|
||||
pkg = tmp_path / "utils"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
# a → b → a(循环)
|
||||
(pkg / "a.py").write_text("from utils.b import func_b\n", encoding="utf-8")
|
||||
(pkg / "b.py").write_text("from utils.a import func_a\n", encoding="utf-8")
|
||||
|
||||
# 不应抛出 RecursionError
|
||||
tree = build_flow_tree(tmp_path, "utils/a.py")
|
||||
assert tree.name == "utils.a"
|
||||
|
||||
def test_entry_node_type(self, tmp_path: Path) -> None:
|
||||
"""CLI 入口文件应标记为 entry 类型。"""
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
tree = build_flow_tree(tmp_path, "cli/main.py")
|
||||
assert tree.node_type == "entry"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_orphan_modules 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindOrphanModules:
|
||||
"""孤立模块检测测试。"""
|
||||
|
||||
def test_all_reachable(self, tmp_path: Path) -> None:
|
||||
"""所有模块都可达时应返回空列表。"""
|
||||
entries = [
|
||||
FileEntry("cli/main.py", False, 100, ".py", False),
|
||||
FileEntry("config/settings.py", False, 200, ".py", False),
|
||||
]
|
||||
reachable = {"cli/main.py", "config/settings.py"}
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_orphan_detected(self, tmp_path: Path) -> None:
|
||||
"""不可达的模块应被标记为孤立。"""
|
||||
entries = [
|
||||
FileEntry("cli/main.py", False, 100, ".py", False),
|
||||
FileEntry("utils/orphan.py", False, 50, ".py", False),
|
||||
]
|
||||
reachable = {"cli/main.py"}
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert "utils/orphan.py" in orphans
|
||||
|
||||
def test_init_files_excluded(self, tmp_path: Path) -> None:
|
||||
"""__init__.py 不应被视为孤立模块。"""
|
||||
entries = [
|
||||
FileEntry("cli/__init__.py", False, 0, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert "cli/__init__.py" not in orphans
|
||||
|
||||
def test_test_files_excluded(self, tmp_path: Path) -> None:
|
||||
"""测试文件不应被视为孤立模块。"""
|
||||
entries = [
|
||||
FileEntry("tests/unit/test_something.py", False, 100, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_audit_scripts_excluded(self, tmp_path: Path) -> None:
|
||||
"""审计脚本自身不应被视为孤立模块。"""
|
||||
entries = [
|
||||
FileEntry("scripts/audit/scanner.py", False, 100, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_directories_excluded(self, tmp_path: Path) -> None:
|
||||
"""目录条目不应出现在孤立列表中。"""
|
||||
entries = [
|
||||
FileEntry("cli", True, 0, "", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == []
|
||||
|
||||
def test_sorted_output(self, tmp_path: Path) -> None:
|
||||
"""孤立模块列表应按路径排序。"""
|
||||
entries = [
|
||||
FileEntry("utils/z.py", False, 50, ".py", False),
|
||||
FileEntry("utils/a.py", False, 50, ".py", False),
|
||||
FileEntry("cli/orphan.py", False, 50, ".py", False),
|
||||
]
|
||||
reachable: set[str] = set()
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
assert orphans == sorted(orphans)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_flow_report 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRenderFlowReport:
|
||||
"""流程树报告渲染测试。"""
|
||||
|
||||
def test_header_contains_timestamp_and_path(self) -> None:
|
||||
"""报告头部应包含时间戳和仓库路径。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, [], "/repo")
|
||||
assert "生成时间:" in report
|
||||
assert "`/repo`" in report
|
||||
|
||||
def test_contains_mermaid_block(self) -> None:
|
||||
"""报告应包含 Mermaid 代码块。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, [], "/repo")
|
||||
assert "```mermaid" in report
|
||||
assert "graph TD" in report
|
||||
|
||||
def test_contains_indented_text(self) -> None:
|
||||
"""报告应包含缩进文本形式的流程树。"""
|
||||
child = FlowNode("config.settings", "config/settings.py", "module", [])
|
||||
root = FlowNode("cli.main", "cli/main.py", "entry", [child])
|
||||
report = render_flow_report([root], [], "/repo")
|
||||
assert "`cli.main`" in report
|
||||
assert "`config.settings`" in report
|
||||
|
||||
def test_orphan_section(self) -> None:
|
||||
"""报告应包含孤立模块列表。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
orphans = ["utils/orphan.py", "models/unused.py"]
|
||||
report = render_flow_report(trees, orphans, "/repo")
|
||||
assert "孤立模块" in report
|
||||
assert "`utils/orphan.py`" in report
|
||||
assert "`models/unused.py`" in report
|
||||
|
||||
def test_no_orphans_message(self) -> None:
|
||||
"""无孤立模块时应显示提示信息。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, [], "/repo")
|
||||
assert "未发现孤立模块" in report
|
||||
|
||||
def test_statistics_summary(self) -> None:
|
||||
"""报告应包含统计摘要。"""
|
||||
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
|
||||
report = render_flow_report(trees, ["a.py"], "/repo")
|
||||
assert "统计摘要" in report
|
||||
assert "入口点" in report
|
||||
assert "任务" in report
|
||||
assert "加载器" in report
|
||||
assert "孤立模块" in report
|
||||
|
||||
def test_task_type_annotation(self) -> None:
|
||||
"""任务模块应带有类型标注。"""
|
||||
task_node = FlowNode("tasks.ods_member", "tasks/ods_member.py", "module", [])
|
||||
root = FlowNode("cli.main", "cli/main.py", "entry", [task_node])
|
||||
report = render_flow_report([root], [], "/repo")
|
||||
assert "ODS" in report
|
||||
|
||||
def test_loader_type_annotation(self) -> None:
|
||||
"""加载器模块应带有类型标注。"""
|
||||
loader_node = FlowNode(
|
||||
"loaders.dimensions.member", "loaders/dimensions/member.py", "module", []
|
||||
)
|
||||
root = FlowNode("cli.main", "cli/main.py", "entry", [loader_node])
|
||||
report = render_flow_report([root], [], "/repo")
|
||||
assert "维度" in report or "SCD2" in report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# discover_entry_points 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDiscoverEntryPoints:
|
||||
"""入口点识别测试。"""
|
||||
|
||||
def test_cli_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别 CLI 入口。"""
|
||||
cli_dir = tmp_path / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
cli_entries = [e for e in entries if e["type"] == "CLI"]
|
||||
assert len(cli_entries) == 1
|
||||
assert cli_entries[0]["file"] == "cli/main.py"
|
||||
|
||||
def test_gui_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别 GUI 入口。"""
|
||||
gui_dir = tmp_path / "gui"
|
||||
gui_dir.mkdir()
|
||||
(gui_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
gui_entries = [e for e in entries if e["type"] == "GUI"]
|
||||
assert len(gui_entries) == 1
|
||||
|
||||
def test_bat_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别批处理文件入口。"""
|
||||
(tmp_path / "run_etl.bat").write_text(
|
||||
"@echo off\npython -m cli.main %*\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
bat_entries = [e for e in entries if e["type"] == "批处理"]
|
||||
assert len(bat_entries) == 1
|
||||
assert "cli.main" in bat_entries[0]["description"]
|
||||
|
||||
def test_script_entry(self, tmp_path: Path) -> None:
|
||||
"""应识别运维脚本入口。"""
|
||||
scripts_dir = tmp_path / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(scripts_dir / "rebuild_db.py").write_text(
|
||||
'if __name__ == "__main__": pass\n', encoding="utf-8"
|
||||
)
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
script_entries = [e for e in entries if e["type"] == "运维脚本"]
|
||||
assert len(script_entries) == 1
|
||||
assert script_entries[0]["file"] == "scripts/rebuild_db.py"
|
||||
|
||||
def test_init_py_excluded_from_scripts(self, tmp_path: Path) -> None:
|
||||
"""scripts/__init__.py 不应被识别为入口。"""
|
||||
scripts_dir = tmp_path / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
entries = discover_entry_points(tmp_path)
|
||||
script_entries = [e for e in entries if e["type"] == "运维脚本"]
|
||||
assert all(e["file"] != "scripts/__init__.py" for e in script_entries)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify_task_type / classify_loader_type 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClassifyTypes:
|
||||
"""任务类型和加载器类型区分测试。"""
|
||||
|
||||
def test_ods_task(self) -> None:
|
||||
assert "ODS" in classify_task_type("tasks/ods_member.py")
|
||||
|
||||
def test_dwd_task(self) -> None:
|
||||
assert "DWD" in classify_task_type("tasks/dwd_load.py")
|
||||
|
||||
def test_dws_task(self) -> None:
|
||||
assert "DWS" in classify_task_type("tasks/dws/assistant_daily.py")
|
||||
|
||||
def test_verification_task(self) -> None:
|
||||
assert "校验" in classify_task_type("tasks/verification/balance_check.py")
|
||||
|
||||
def test_schema_init_task(self) -> None:
|
||||
assert "Schema" in classify_task_type("tasks/init_ods_schema.py")
|
||||
|
||||
def test_dimension_loader(self) -> None:
|
||||
result = classify_loader_type("loaders/dimensions/member.py")
|
||||
assert "维度" in result or "SCD2" in result
|
||||
|
||||
def test_fact_loader(self) -> None:
|
||||
assert "事实" in classify_loader_type("loaders/facts/order.py")
|
||||
|
||||
def test_ods_loader(self) -> None:
|
||||
assert "ODS" in classify_loader_type("loaders/ods/generic.py")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _path_to_module_name 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPathToModuleName:
|
||||
"""路径到模块名转换测试。"""
|
||||
|
||||
def test_simple_file(self) -> None:
|
||||
assert _path_to_module_name("cli/main.py") == "cli.main"
|
||||
|
||||
def test_init_file(self) -> None:
|
||||
assert _path_to_module_name("cli/__init__.py") == "cli"
|
||||
|
||||
def test_nested_path(self) -> None:
|
||||
assert _path_to_module_name("tasks/dws/assistant.py") == "tasks.dws.assistant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_bat_python_target 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseBatPythonTarget:
|
||||
"""批处理文件 Python 命令解析测试。"""
|
||||
|
||||
def test_module_invocation(self, tmp_path: Path) -> None:
|
||||
bat = tmp_path / "run.bat"
|
||||
bat.write_text("@echo off\npython -m cli.main %*\n", encoding="utf-8")
|
||||
assert _parse_bat_python_target(bat) == "cli.main"
|
||||
|
||||
def test_no_python_command(self, tmp_path: Path) -> None:
|
||||
bat = tmp_path / "run.bat"
|
||||
bat.write_text("@echo off\necho hello\n", encoding="utf-8")
|
||||
assert _parse_bat_python_target(bat) is None
|
||||
|
||||
def test_nonexistent_file(self, tmp_path: Path) -> None:
|
||||
assert _parse_bat_python_target(tmp_path / "missing.bat") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 9 & 10(hypothesis)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import string
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:项目包名列表(与 flow_analyzer 中 _PROJECT_PACKAGES 一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROJECT_PACKAGES_LIST = [
|
||||
"cli", "config", "api", "database", "tasks", "loaders",
|
||||
"scd", "orchestration", "quality", "models", "utils",
|
||||
"gui", "scripts",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 9: 流程树节点 source_file 有效性
|
||||
# Feature: repo-audit, Property 9: 流程树节点 source_file 有效性
|
||||
# Validates: Requirements 2.7
|
||||
#
|
||||
# 策略:在临时目录中随机生成 1~5 个项目内部模块文件,
|
||||
# 其中一个作为入口,其他文件通过 import 语句相互引用。
|
||||
# 构建流程树后,遍历所有节点验证 source_file 非空且文件存在。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_all_nodes(node: FlowNode) -> list[FlowNode]:
|
||||
"""递归收集流程树中所有节点。"""
|
||||
result = [node]
|
||||
for child in node.children:
|
||||
result.extend(_collect_all_nodes(child))
|
||||
return result
|
||||
|
||||
|
||||
# 生成合法的 Python 标识符作为模块文件名
|
||||
_module_name_st = st.from_regex(r"[a-z][a-z0-9_]{0,8}", fullmatch=True).filter(
|
||||
lambda s: s not in {"__init__", ""}
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def project_layout(draw):
|
||||
"""生成一个随机的项目布局:包名、模块文件名列表、以及模块间的 import 关系。
|
||||
|
||||
返回 (package, module_names, imports_map)
|
||||
- package: 项目包名(如 "cli")
|
||||
- module_names: 模块文件名列表(不含 .py 后缀),第一个为入口
|
||||
- imports_map: dict[str, list[str]],每个模块导入的其他模块列表
|
||||
"""
|
||||
package = draw(st.sampled_from(_PROJECT_PACKAGES_LIST))
|
||||
n_modules = draw(st.integers(min_value=1, max_value=5))
|
||||
module_names = draw(
|
||||
st.lists(
|
||||
_module_name_st,
|
||||
min_size=n_modules,
|
||||
max_size=n_modules,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
# 确保至少有一个模块
|
||||
assume(len(module_names) >= 1)
|
||||
|
||||
# 为每个模块随机选择要导入的其他模块(子集)
|
||||
imports_map: dict[str, list[str]] = {}
|
||||
for i, mod in enumerate(module_names):
|
||||
# 只能导入列表中的其他模块
|
||||
others = [m for m in module_names if m != mod]
|
||||
if others:
|
||||
imported = draw(
|
||||
st.lists(st.sampled_from(others), max_size=len(others), unique=True)
|
||||
)
|
||||
else:
|
||||
imported = []
|
||||
imports_map[mod] = imported
|
||||
|
||||
return package, module_names, imports_map
|
||||
|
||||
|
||||
@given(layout=project_layout())
|
||||
@settings(max_examples=100)
|
||||
def test_property9_flow_tree_source_file_validity(layout, tmp_path_factory):
|
||||
"""Property 9: 流程树中每个节点的 source_file 非空且对应文件在仓库中实际存在。
|
||||
|
||||
**Feature: repo-audit, Property 9: 流程树节点 source_file 有效性**
|
||||
**Validates: Requirements 2.7**
|
||||
"""
|
||||
package, module_names, imports_map = layout
|
||||
tmp_path = tmp_path_factory.mktemp("prop9")
|
||||
|
||||
# 创建包目录和 __init__.py
|
||||
pkg_dir = tmp_path / package
|
||||
pkg_dir.mkdir(parents=True, exist_ok=True)
|
||||
(pkg_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# 创建每个模块文件,写入 import 语句
|
||||
for mod in module_names:
|
||||
lines = []
|
||||
for imp in imports_map[mod]:
|
||||
lines.append(f"from {package}.{imp} import *")
|
||||
lines.append("") # 确保文件非空
|
||||
(pkg_dir / f"{mod}.py").write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
# 以第一个模块为入口构建流程树
|
||||
entry_rel = f"{package}/{module_names[0]}.py"
|
||||
tree = build_flow_tree(tmp_path, entry_rel)
|
||||
|
||||
# 遍历所有节点,验证 source_file 有效性
|
||||
all_nodes = _collect_all_nodes(tree)
|
||||
for node in all_nodes:
|
||||
# source_file 应为非空字符串
|
||||
assert isinstance(node.source_file, str), (
|
||||
f"source_file 应为字符串,实际为 {type(node.source_file)}"
|
||||
)
|
||||
assert node.source_file != "", "source_file 不应为空字符串"
|
||||
|
||||
# 对应文件应在仓库中实际存在
|
||||
full_path = tmp_path / node.source_file
|
||||
assert full_path.exists(), (
|
||||
f"source_file '{node.source_file}' 对应的文件不存在: {full_path}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 10: 孤立模块检测正确性
|
||||
# Feature: repo-audit, Property 10: 孤立模块检测正确性
|
||||
# Validates: Requirements 2.8
|
||||
#
|
||||
# 策略:生成随机的 FileEntry 列表(模拟项目中的 .py 文件),
|
||||
# 生成随机的 reachable 集合(是 FileEntry 路径的子集),
|
||||
# 调用 find_orphan_modules 验证:
|
||||
# 1. 返回的每个孤立模块都不在 reachable 集合中
|
||||
# 2. reachable 集合中的每个模块都不在孤立列表中
|
||||
#
|
||||
# 注意:find_orphan_modules 会排除 __init__.py、tests/、scripts/audit/ 下的文件,
|
||||
# 以及不属于 _PROJECT_PACKAGES 的子目录文件。生成器需要考虑这些排除规则。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 生成属于项目包的 .py 文件路径(排除被 find_orphan_modules 忽略的路径)
|
||||
_eligible_packages = [
|
||||
p for p in _PROJECT_PACKAGES_LIST
|
||||
if p not in ("scripts",) # scripts 下只有 scripts/audit/ 会被排除,但为简化直接排除
|
||||
]
|
||||
|
||||
|
||||
@st.composite
|
||||
def orphan_test_data(draw):
|
||||
"""生成 (file_entries, reachable_set) 用于测试 find_orphan_modules。
|
||||
|
||||
只生成"合格"的文件条目(属于项目包、非 __init__.py、非 tests/、非 scripts/audit/),
|
||||
这样可以精确验证 reachable 与 orphan 的互斥关系。
|
||||
"""
|
||||
# 生成 1~10 个合格的 .py 文件路径
|
||||
n_files = draw(st.integers(min_value=1, max_value=10))
|
||||
paths: list[str] = []
|
||||
for _ in range(n_files):
|
||||
pkg = draw(st.sampled_from(_eligible_packages))
|
||||
fname = draw(_module_name_st)
|
||||
path = f"{pkg}/{fname}.py"
|
||||
paths.append(path)
|
||||
|
||||
# 去重
|
||||
paths = list(dict.fromkeys(paths))
|
||||
assume(len(paths) >= 1)
|
||||
|
||||
# 构建 FileEntry 列表
|
||||
entries = [
|
||||
FileEntry(rel_path=p, is_dir=False, size_bytes=100, extension=".py", is_empty_dir=False)
|
||||
for p in paths
|
||||
]
|
||||
|
||||
# 随机选择一个子集作为 reachable
|
||||
reachable = set(draw(
|
||||
st.lists(st.sampled_from(paths), max_size=len(paths), unique=True)
|
||||
))
|
||||
|
||||
return entries, reachable
|
||||
|
||||
|
||||
@given(data=orphan_test_data())
|
||||
@settings(max_examples=100)
|
||||
def test_property10_orphan_module_detection(data, tmp_path_factory):
|
||||
"""Property 10: 孤立模块与可达模块互斥——孤立列表中的模块不在 reachable 中,
|
||||
reachable 中的模块不在孤立列表中。
|
||||
|
||||
**Feature: repo-audit, Property 10: 孤立模块检测正确性**
|
||||
**Validates: Requirements 2.8**
|
||||
"""
|
||||
entries, reachable = data
|
||||
tmp_path = tmp_path_factory.mktemp("prop10")
|
||||
|
||||
orphans = find_orphan_modules(tmp_path, entries, reachable)
|
||||
|
||||
orphan_set = set(orphans)
|
||||
|
||||
# 验证 1: 孤立模块不应出现在 reachable 集合中
|
||||
overlap = orphan_set & reachable
|
||||
assert overlap == set(), (
|
||||
f"孤立模块与可达集合存在交集: {overlap}"
|
||||
)
|
||||
|
||||
# 验证 2: reachable 中的模块不应出现在孤立列表中
|
||||
for r in reachable:
|
||||
assert r not in orphan_set, (
|
||||
f"可达模块 '{r}' 不应出现在孤立列表中"
|
||||
)
|
||||
|
||||
# 验证 3: 孤立列表应已排序
|
||||
assert orphans == sorted(orphans), "孤立模块列表应按路径排序"
|
||||
309
apps/etl/pipelines/feiqiu/tests/unit/test_audit_inventory.py
Normal file
309
apps/etl/pipelines/feiqiu/tests/unit/test_audit_inventory.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
属性测试 — classify 完整性
|
||||
|
||||
Feature: repo-audit, Property 1: classify 完整性
|
||||
Validates: Requirements 1.2, 1.3
|
||||
|
||||
对于任意 FileEntry,classify 函数返回的 InventoryItem 的 category 字段
|
||||
应属于 Category 枚举,disposition 字段应属于 Disposition 枚举,
|
||||
且 description 字段为非空字符串。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
|
||||
from scripts.audit.inventory_analyzer import classify
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生成器策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 常见文件扩展名(含空扩展名表示无扩展名的情况)
|
||||
_EXTENSIONS = st.sampled_from([
|
||||
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
|
||||
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
|
||||
".toml", ".yaml", ".yml", ".html", ".css", ".js",
|
||||
])
|
||||
|
||||
# 路径片段:字母数字加常见特殊字符
|
||||
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
|
||||
|
||||
_path_segment = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
)
|
||||
|
||||
# 生成 1~4 层目录深度的相对路径
|
||||
_rel_path = st.lists(
|
||||
_path_segment,
|
||||
min_size=1,
|
||||
max_size=4,
|
||||
).map(lambda parts: "/".join(parts))
|
||||
|
||||
|
||||
def _file_entry_strategy() -> st.SearchStrategy[FileEntry]:
|
||||
"""生成随机 FileEntry 的 hypothesis 策略。
|
||||
|
||||
覆盖各种扩展名、目录层级、大小和布尔标志组合。
|
||||
"""
|
||||
return st.builds(
|
||||
FileEntry,
|
||||
rel_path=_rel_path,
|
||||
is_dir=st.booleans(),
|
||||
size_bytes=st.integers(min_value=0, max_value=10_000_000),
|
||||
extension=_EXTENSIONS,
|
||||
is_empty_dir=st.booleans(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 1: classify 完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(entry=_file_entry_strategy())
|
||||
@settings(max_examples=100)
|
||||
def test_classify_completeness(entry: FileEntry) -> None:
|
||||
"""Property 1: classify 完整性
|
||||
|
||||
Feature: repo-audit, Property 1: classify 完整性
|
||||
Validates: Requirements 1.2, 1.3
|
||||
|
||||
对于任意 FileEntry,classify 返回的 InventoryItem 应满足:
|
||||
- category 属于 Category 枚举
|
||||
- disposition 属于 Disposition 枚举
|
||||
- description 为非空字符串
|
||||
"""
|
||||
result = classify(entry)
|
||||
|
||||
# 返回类型正确
|
||||
assert isinstance(result, InventoryItem), (
|
||||
f"classify 应返回 InventoryItem,实际返回 {type(result)}"
|
||||
)
|
||||
|
||||
# category 属于 Category 枚举
|
||||
assert isinstance(result.category, Category), (
|
||||
f"category 应为 Category 枚举成员,实际为 {result.category!r}"
|
||||
)
|
||||
|
||||
# disposition 属于 Disposition 枚举
|
||||
assert isinstance(result.disposition, Disposition), (
|
||||
f"disposition 应为 Disposition 枚举成员,实际为 {result.disposition!r}"
|
||||
)
|
||||
|
||||
# description 为非空字符串
|
||||
assert isinstance(result.description, str) and len(result.description) > 0, (
|
||||
f"description 应为非空字符串,实际为 {result.description!r}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:高优先级目录前缀(用于在低优先级属性测试中排除)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HIGH_PRIORITY_PREFIXES = ("tmp/", "logs/", "export/")
|
||||
|
||||
# 安全的顶层目录名(不会触发高优先级规则)
|
||||
_SAFE_TOP_DIRS = st.sampled_from([
|
||||
"src", "lib", "data", "misc", "vendor", "tools", "archive",
|
||||
"assets", "resources", "contrib", "extras",
|
||||
])
|
||||
|
||||
# 非 .lnk/.rar 的扩展名
|
||||
_SAFE_EXTENSIONS = st.sampled_from([
|
||||
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
|
||||
".bat", ".sh", ".ps1", ".log", ".ini", ".cfg",
|
||||
".toml", ".yaml", ".yml", ".html", ".css", ".js",
|
||||
])
|
||||
|
||||
|
||||
def _safe_rel_path() -> st.SearchStrategy[str]:
|
||||
"""生成不以高优先级目录开头的相对路径。"""
|
||||
return st.builds(
|
||||
lambda top, rest: f"{top}/{rest}" if rest else top,
|
||||
top=_SAFE_TOP_DIRS,
|
||||
rest=st.lists(_path_segment, min_size=0, max_size=3).map(
|
||||
lambda parts: "/".join(parts) if parts else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 3: 空目录标记为候选删除
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_empty_dir_candidate_delete(data: st.DataObject) -> None:
|
||||
"""Property 3: 空目录标记为候选删除
|
||||
|
||||
Feature: repo-audit, Property 3: 空目录标记为候选删除
|
||||
Validates: Requirements 1.5
|
||||
|
||||
对于任意 is_empty_dir=True 的 FileEntry(排除 tmp/、logs/、reports/、
|
||||
export/ 开头和 .lnk/.rar 扩展名),classify 返回的 disposition
|
||||
应为 Disposition.CANDIDATE_DELETE。
|
||||
"""
|
||||
rel_path = data.draw(_safe_rel_path())
|
||||
ext = data.draw(_SAFE_EXTENSIONS)
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=True,
|
||||
size_bytes=0,
|
||||
extension=ext,
|
||||
is_empty_dir=True,
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
assert result.disposition == Disposition.CANDIDATE_DELETE, (
|
||||
f"空目录 '{entry.rel_path}' 应标记为候选删除,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4: .lnk/.rar 文件标记为候选删除
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_lnk_rar_candidate_delete(data: st.DataObject) -> None:
|
||||
"""Property 4: .lnk/.rar 文件标记为候选删除
|
||||
|
||||
Feature: repo-audit, Property 4: .lnk/.rar 文件标记为候选删除
|
||||
Validates: Requirements 1.6
|
||||
|
||||
对于任意扩展名为 .lnk 或 .rar 的 FileEntry(排除 tmp/、logs/、
|
||||
reports/、export/ 开头,且 is_empty_dir=False),classify 返回的
|
||||
disposition 应为 Disposition.CANDIDATE_DELETE。
|
||||
"""
|
||||
rel_path = data.draw(_safe_rel_path())
|
||||
ext = data.draw(st.sampled_from([".lnk", ".rar"]))
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=False,
|
||||
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
|
||||
extension=ext,
|
||||
is_empty_dir=False,
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
assert result.disposition == Disposition.CANDIDATE_DELETE, (
|
||||
f"文件 '{entry.rel_path}' (ext={ext}) 应标记为候选删除,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 5: tmp/ 下文件处置范围
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TMP_EXTENSIONS = st.sampled_from([
|
||||
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
|
||||
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
|
||||
".toml", ".yaml", ".yml", ".html", ".css", ".js", ".tmp", ".bak",
|
||||
])
|
||||
|
||||
|
||||
def _tmp_rel_path() -> st.SearchStrategy[str]:
|
||||
"""生成以 tmp/ 开头的相对路径。"""
|
||||
return st.builds(
|
||||
lambda rest: f"tmp/{rest}",
|
||||
rest=st.lists(_path_segment, min_size=1, max_size=3).map(
|
||||
lambda parts: "/".join(parts)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_tmp_disposition_range(data: st.DataObject) -> None:
|
||||
"""Property 5: tmp/ 下文件处置范围
|
||||
|
||||
Feature: repo-audit, Property 5: tmp/ 下文件处置范围
|
||||
Validates: Requirements 1.7
|
||||
|
||||
对于任意 rel_path 以 tmp/ 开头的 FileEntry,classify 返回的
|
||||
disposition 应为 CANDIDATE_DELETE 或 CANDIDATE_ARCHIVE 之一。
|
||||
"""
|
||||
rel_path = data.draw(_tmp_rel_path())
|
||||
ext = data.draw(_TMP_EXTENSIONS)
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=data.draw(st.booleans()),
|
||||
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
|
||||
extension=ext,
|
||||
is_empty_dir=data.draw(st.booleans()),
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
allowed = {Disposition.CANDIDATE_DELETE, Disposition.CANDIDATE_ARCHIVE}
|
||||
assert result.disposition in allowed, (
|
||||
f"tmp/ 下文件 '{entry.rel_path}' 的处置应为候选删除或候选归档,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6: 运行时产出目录标记为候选归档
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RUNTIME_DIRS = st.sampled_from(["logs", "export"])
|
||||
|
||||
# 排除 __init__.py 的文件名
|
||||
_NON_INIT_BASENAME = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
).filter(lambda s: s != "__init__.py")
|
||||
|
||||
|
||||
def _runtime_output_rel_path() -> st.SearchStrategy[str]:
|
||||
"""生成以 logs/、reports/ 或 export/ 开头的相对路径,basename 不是 __init__.py。"""
|
||||
return st.builds(
|
||||
lambda top, mid, name: (
|
||||
f"{top}/{'/'.join(mid)}/{name}" if mid else f"{top}/{name}"
|
||||
),
|
||||
top=_RUNTIME_DIRS,
|
||||
mid=st.lists(_path_segment, min_size=0, max_size=2),
|
||||
name=_NON_INIT_BASENAME,
|
||||
)
|
||||
|
||||
|
||||
@given(data=st.data())
|
||||
@settings(max_examples=100)
|
||||
def test_runtime_output_candidate_archive(data: st.DataObject) -> None:
|
||||
"""Property 6: 运行时产出目录标记为候选归档
|
||||
|
||||
Feature: repo-audit, Property 6: 运行时产出目录标记为候选归档
|
||||
Validates: Requirements 1.8
|
||||
|
||||
对于任意 rel_path 以 logs/ 或 export/ 开头且非 __init__.py
|
||||
的 FileEntry,classify 返回的 disposition 应为 CANDIDATE_ARCHIVE。
|
||||
需求 1.8 仅覆盖 logs/ 和 export/ 目录(不含 reports/)。
|
||||
"""
|
||||
rel_path = data.draw(_runtime_output_rel_path())
|
||||
ext = data.draw(_EXTENSIONS)
|
||||
entry = FileEntry(
|
||||
rel_path=rel_path,
|
||||
is_dir=data.draw(st.booleans()),
|
||||
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
|
||||
extension=ext,
|
||||
is_empty_dir=data.draw(st.booleans()),
|
||||
)
|
||||
|
||||
result = classify(entry)
|
||||
|
||||
assert result.disposition == Disposition.CANDIDATE_ARCHIVE, (
|
||||
f"运行时产出 '{entry.rel_path}' 应标记为候选归档,"
|
||||
f"实际为 {result.disposition.value}"
|
||||
)
|
||||
@@ -0,0 +1,165 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
属性测试 — 清单渲染完整性与分类分组
|
||||
|
||||
Feature: repo-audit
|
||||
- Property 2: 清单渲染完整性
|
||||
- Property 8: 清单按分类分组
|
||||
|
||||
Validates: Requirements 1.4, 1.10
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit import Category, Disposition, InventoryItem
|
||||
from scripts.audit.inventory_analyzer import render_inventory_report
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生成器策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
|
||||
|
||||
_path_segment = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=15,
|
||||
)
|
||||
|
||||
# 随机相对路径(1~3 层)
|
||||
_rel_path = st.lists(
|
||||
_path_segment,
|
||||
min_size=1,
|
||||
max_size=3,
|
||||
).map(lambda parts: "/".join(parts))
|
||||
|
||||
# 随机非空描述(不含管道符和换行符,避免破坏 Markdown 表格解析)
|
||||
_description = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S", "Z"),
|
||||
blacklist_characters="|\n\r",
|
||||
),
|
||||
min_size=1,
|
||||
max_size=40,
|
||||
)
|
||||
|
||||
|
||||
def _inventory_item_strategy() -> st.SearchStrategy[InventoryItem]:
|
||||
"""生成随机 InventoryItem 的 hypothesis 策略。"""
|
||||
return st.builds(
|
||||
InventoryItem,
|
||||
rel_path=_rel_path,
|
||||
category=st.sampled_from(list(Category)),
|
||||
disposition=st.sampled_from(list(Disposition)),
|
||||
description=_description,
|
||||
)
|
||||
|
||||
|
||||
# 生成 0~20 个 InventoryItem 的列表
|
||||
_inventory_list = st.lists(
|
||||
_inventory_item_strategy(),
|
||||
min_size=0,
|
||||
max_size=20,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 2: 清单渲染完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_render_inventory_completeness(items: list[InventoryItem]) -> None:
|
||||
"""Property 2: 清单渲染完整性
|
||||
|
||||
Feature: repo-audit, Property 2: 清单渲染完整性
|
||||
Validates: Requirements 1.4
|
||||
|
||||
对于任意 InventoryItem 列表,render_inventory_report 生成的 Markdown 中,
|
||||
每个条目的 rel_path、category.value、disposition.value 和 description
|
||||
四个字段都应出现在输出文本中。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/test-repo")
|
||||
|
||||
for item in items:
|
||||
# rel_path 出现在表格行中
|
||||
assert item.rel_path in report, (
|
||||
f"rel_path '{item.rel_path}' 未出现在报告中"
|
||||
)
|
||||
# category.value 出现在分组标题中
|
||||
assert item.category.value in report, (
|
||||
f"category '{item.category.value}' 未出现在报告中"
|
||||
)
|
||||
# disposition.value 出现在表格行中
|
||||
assert item.disposition.value in report, (
|
||||
f"disposition '{item.disposition.value}' 未出现在报告中"
|
||||
)
|
||||
# description 出现在表格行中
|
||||
assert item.description in report, (
|
||||
f"description '{item.description}' 未出现在报告中"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 8: 清单按分类分组
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_render_inventory_grouped_by_category(items: list[InventoryItem]) -> None:
|
||||
"""Property 8: 清单按分类分组
|
||||
|
||||
Feature: repo-audit, Property 8: 清单按分类分组
|
||||
Validates: Requirements 1.10
|
||||
|
||||
对于任意 InventoryItem 列表,render_inventory_report 生成的 Markdown 中,
|
||||
同一 Category 的条目应连续出现(不应被其他 Category 的条目打断)。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/test-repo")
|
||||
|
||||
if not items:
|
||||
return # 空列表无需验证
|
||||
|
||||
# 从报告中按行提取条目对应的 category 顺序
|
||||
# 表格行格式: | `{rel_path}` | {disposition} | {description} |
|
||||
# 分组标题格式: ## {category.value}
|
||||
lines = report.split("\n")
|
||||
|
||||
# 收集每个分组标题下的条目,按出现顺序记录 category
|
||||
categories_in_order: list[Category] = []
|
||||
current_category: Category | None = None
|
||||
|
||||
# 建立 category.value -> Category 的映射
|
||||
value_to_cat = {c.value: c for c in Category}
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
# 检测分组标题 "## {category.value}"
|
||||
if stripped.startswith("## ") and stripped[3:] in value_to_cat:
|
||||
current_category = value_to_cat[stripped[3:]]
|
||||
continue
|
||||
# 检测表格数据行(跳过表头和分隔行)
|
||||
if (
|
||||
current_category is not None
|
||||
and stripped.startswith("| `")
|
||||
and not stripped.startswith("| 相对路径")
|
||||
and not stripped.startswith("|---")
|
||||
):
|
||||
categories_in_order.append(current_category)
|
||||
|
||||
# 验证同一 Category 的条目连续出现
|
||||
seen: set[Category] = set()
|
||||
prev: Category | None = None
|
||||
for cat in categories_in_order:
|
||||
if cat != prev:
|
||||
assert cat not in seen, (
|
||||
f"Category '{cat.value}' 的条目不连续——"
|
||||
f"在其他分类条目之后再次出现"
|
||||
)
|
||||
seen.add(cat)
|
||||
prev = cat
|
||||
@@ -0,0 +1,485 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
属性测试 — 报告输出属性
|
||||
|
||||
Feature: repo-audit
|
||||
- Property 13: 统计摘要一致性
|
||||
- Property 14: 报告头部元信息
|
||||
- Property 15: 写操作仅限 docs/audit/
|
||||
|
||||
Validates: Requirements 4.2, 4.5, 4.6, 4.7, 5.2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from scripts.audit import (
|
||||
AlignmentIssue,
|
||||
Category,
|
||||
Disposition,
|
||||
DocMapping,
|
||||
FlowNode,
|
||||
InventoryItem,
|
||||
)
|
||||
from scripts.audit.inventory_analyzer import render_inventory_report
|
||||
from scripts.audit.flow_analyzer import render_flow_report
|
||||
from scripts.audit.doc_alignment_analyzer import render_alignment_report
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 共享生成器策略
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
|
||||
|
||||
_path_segment = st.text(
|
||||
alphabet=_PATH_CHARS,
|
||||
min_size=1,
|
||||
max_size=12,
|
||||
)
|
||||
|
||||
_rel_path = st.lists(
|
||||
_path_segment,
|
||||
min_size=1,
|
||||
max_size=3,
|
||||
).map(lambda parts: "/".join(parts))
|
||||
|
||||
_safe_text = st.text(
|
||||
alphabet=st.characters(
|
||||
whitelist_categories=("L", "N", "P", "S", "Z"),
|
||||
blacklist_characters="|\n\r",
|
||||
),
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
_repo_root_str = st.text(
|
||||
alphabet=string.ascii_letters + string.digits + "/_-.",
|
||||
min_size=3,
|
||||
max_size=40,
|
||||
).map(lambda s: "/" + s.lstrip("/"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InventoryItem 生成器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _inventory_item_st() -> st.SearchStrategy[InventoryItem]:
|
||||
return st.builds(
|
||||
InventoryItem,
|
||||
rel_path=_rel_path,
|
||||
category=st.sampled_from(list(Category)),
|
||||
disposition=st.sampled_from(list(Disposition)),
|
||||
description=_safe_text,
|
||||
)
|
||||
|
||||
|
||||
_inventory_list = st.lists(_inventory_item_st(), min_size=0, max_size=20)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FlowNode 生成器(限制深度和宽度)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _flow_node_st(max_depth: int = 2) -> st.SearchStrategy[FlowNode]:
|
||||
"""生成随机 FlowNode 树,限制深度避免爆炸。"""
|
||||
if max_depth <= 0:
|
||||
return st.builds(
|
||||
FlowNode,
|
||||
name=_path_segment,
|
||||
source_file=_rel_path,
|
||||
node_type=st.sampled_from(["entry", "module", "class", "function"]),
|
||||
children=st.just([]),
|
||||
)
|
||||
return st.builds(
|
||||
FlowNode,
|
||||
name=_path_segment,
|
||||
source_file=_rel_path,
|
||||
node_type=st.sampled_from(["entry", "module", "class", "function"]),
|
||||
children=st.lists(
|
||||
_flow_node_st(max_depth - 1),
|
||||
min_size=0,
|
||||
max_size=3,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_flow_tree_list = st.lists(_flow_node_st(), min_size=0, max_size=5)
|
||||
_orphan_list = st.lists(_rel_path, min_size=0, max_size=10)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocMapping / AlignmentIssue 生成器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_issue_type_st = st.sampled_from(["stale", "conflict", "missing"])
|
||||
|
||||
|
||||
def _alignment_issue_st() -> st.SearchStrategy[AlignmentIssue]:
|
||||
return st.builds(
|
||||
AlignmentIssue,
|
||||
doc_path=_rel_path,
|
||||
issue_type=_issue_type_st,
|
||||
description=_safe_text,
|
||||
related_code=_rel_path,
|
||||
)
|
||||
|
||||
|
||||
def _doc_mapping_st() -> st.SearchStrategy[DocMapping]:
|
||||
return st.builds(
|
||||
DocMapping,
|
||||
doc_path=_rel_path,
|
||||
doc_topic=_safe_text,
|
||||
related_code=st.lists(_rel_path, min_size=0, max_size=5),
|
||||
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
|
||||
)
|
||||
|
||||
|
||||
_mapping_list = st.lists(_doc_mapping_st(), min_size=0, max_size=15)
|
||||
_issue_list = st.lists(_alignment_issue_st(), min_size=0, max_size=15)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 13: 统计摘要一致性
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty13SummaryConsistency:
|
||||
"""Property 13: 统计摘要一致性
|
||||
|
||||
Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.5, 4.6, 4.7
|
||||
|
||||
对于任意报告的统计摘要,各分类/标签的计数之和应等于对应条目列表的总长度。
|
||||
"""
|
||||
|
||||
# --- 13a: render_inventory_report 的分类计数之和 = 列表长度 ---
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_inventory_category_counts_sum(
|
||||
self, items: list[InventoryItem]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.5
|
||||
|
||||
render_inventory_report 统计摘要中各用途分类的计数之和应等于条目总数。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/repo")
|
||||
|
||||
# 定位"按用途分类"表格,提取各行数字并求和
|
||||
cat_sum = _extract_summary_total(report, "按用途分类")
|
||||
assert cat_sum == len(items), (
|
||||
f"分类计数之和 {cat_sum} != 条目总数 {len(items)}"
|
||||
)
|
||||
|
||||
# --- 13b: render_inventory_report 的处置标签计数之和 = 列表长度 ---
|
||||
|
||||
@given(items=_inventory_list)
|
||||
@settings(max_examples=100)
|
||||
def test_inventory_disposition_counts_sum(
|
||||
self, items: list[InventoryItem]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.5
|
||||
|
||||
render_inventory_report 统计摘要中各处置标签的计数之和应等于条目总数。
|
||||
"""
|
||||
report = render_inventory_report(items, "/tmp/repo")
|
||||
|
||||
disp_sum = _extract_summary_total(report, "按处置标签")
|
||||
assert disp_sum == len(items), (
|
||||
f"处置标签计数之和 {disp_sum} != 条目总数 {len(items)}"
|
||||
)
|
||||
|
||||
# --- 13c: render_flow_report 的孤立模块数量 = orphans 列表长度 ---
|
||||
|
||||
@given(trees=_flow_tree_list, orphans=_orphan_list)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_orphan_count_matches(
|
||||
self, trees: list[FlowNode], orphans: list[str]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.6
|
||||
|
||||
render_flow_report 统计摘要中的孤立模块数量应等于 orphans 列表长度。
|
||||
"""
|
||||
report = render_flow_report(trees, orphans, "/tmp/repo")
|
||||
|
||||
# 从统计摘要表格中提取"孤立模块"行的数字
|
||||
orphan_count = _extract_flow_stat(report, "孤立模块")
|
||||
assert orphan_count == len(orphans), (
|
||||
f"报告中孤立模块数 {orphan_count} != orphans 列表长度 {len(orphans)}"
|
||||
)
|
||||
|
||||
# --- 13d: render_alignment_report 的 issue 类型计数一致 ---
|
||||
|
||||
@given(mappings=_mapping_list, issues=_issue_list)
|
||||
@settings(max_examples=100)
|
||||
def test_alignment_issue_counts_match(
|
||||
self, mappings: list[DocMapping], issues: list[AlignmentIssue]
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 13: 统计摘要一致性
|
||||
Validates: Requirements 4.7
|
||||
|
||||
render_alignment_report 统计摘要中过期/冲突/缺失点计数应与
|
||||
issues 列表中对应类型的实际数量一致。
|
||||
"""
|
||||
report = render_alignment_report(mappings, issues, "/tmp/repo")
|
||||
|
||||
expected_stale = sum(1 for i in issues if i.issue_type == "stale")
|
||||
expected_conflict = sum(1 for i in issues if i.issue_type == "conflict")
|
||||
expected_missing = sum(1 for i in issues if i.issue_type == "missing")
|
||||
|
||||
actual_stale = _extract_alignment_stat(report, "过期点数量")
|
||||
actual_conflict = _extract_alignment_stat(report, "冲突点数量")
|
||||
actual_missing = _extract_alignment_stat(report, "缺失点数量")
|
||||
|
||||
assert actual_stale == expected_stale, (
|
||||
f"过期点: 报告 {actual_stale} != 实际 {expected_stale}"
|
||||
)
|
||||
assert actual_conflict == expected_conflict, (
|
||||
f"冲突点: 报告 {actual_conflict} != 实际 {expected_conflict}"
|
||||
)
|
||||
assert actual_missing == expected_missing, (
|
||||
f"缺失点: 报告 {actual_missing} != 实际 {expected_missing}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 14: 报告头部元信息
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty14ReportHeader:
|
||||
"""Property 14: 报告头部元信息
|
||||
|
||||
Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
对于任意报告输出,头部应包含一个符合 ISO 格式的时间戳字符串和仓库根目录路径字符串。
|
||||
"""
|
||||
|
||||
_ISO_TS_RE = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
|
||||
|
||||
@given(items=_inventory_list, repo_root=_repo_root_str)
|
||||
@settings(max_examples=100)
|
||||
def test_inventory_report_header(
|
||||
self, items: list[InventoryItem], repo_root: str
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
render_inventory_report 头部应包含 ISO 时间戳和仓库路径。
|
||||
"""
|
||||
report = render_inventory_report(items, repo_root)
|
||||
header = report[:500]
|
||||
|
||||
assert self._ISO_TS_RE.search(header), (
|
||||
"inventory 报告头部缺少 ISO 格式时间戳"
|
||||
)
|
||||
assert repo_root in header, (
|
||||
f"inventory 报告头部缺少仓库路径 '{repo_root}'"
|
||||
)
|
||||
|
||||
@given(trees=_flow_tree_list, orphans=_orphan_list, repo_root=_repo_root_str)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_report_header(
|
||||
self, trees: list[FlowNode], orphans: list[str], repo_root: str
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
render_flow_report 头部应包含 ISO 时间戳和仓库路径。
|
||||
"""
|
||||
report = render_flow_report(trees, orphans, repo_root)
|
||||
header = report[:500]
|
||||
|
||||
assert self._ISO_TS_RE.search(header), (
|
||||
"flow 报告头部缺少 ISO 格式时间戳"
|
||||
)
|
||||
assert repo_root in header, (
|
||||
f"flow 报告头部缺少仓库路径 '{repo_root}'"
|
||||
)
|
||||
|
||||
@given(mappings=_mapping_list, issues=_issue_list, repo_root=_repo_root_str)
|
||||
@settings(max_examples=100)
|
||||
def test_alignment_report_header(
|
||||
self, mappings: list[DocMapping], issues: list[AlignmentIssue], repo_root: str
|
||||
) -> None:
|
||||
"""Feature: repo-audit, Property 14: 报告头部元信息
|
||||
Validates: Requirements 4.2
|
||||
|
||||
render_alignment_report 头部应包含 ISO 时间戳和仓库路径。
|
||||
"""
|
||||
report = render_alignment_report(mappings, issues, repo_root)
|
||||
header = report[:500]
|
||||
|
||||
assert self._ISO_TS_RE.search(header), (
|
||||
"alignment 报告头部缺少 ISO 格式时间戳"
|
||||
)
|
||||
assert repo_root in header, (
|
||||
f"alignment 报告头部缺少仓库路径 '{repo_root}'"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 15: 写操作仅限 docs/audit/
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty15WritesOnlyDocsAudit:
|
||||
"""Property 15: 写操作仅限 docs/audit/
|
||||
|
||||
Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
|
||||
Validates: Requirements 5.2
|
||||
|
||||
对于任意审计执行过程,所有文件写操作的目标路径应以 docs/audit/ 为前缀。
|
||||
由于需要实际文件系统,使用较少迭代。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_minimal_repo(base: Path, variant: int) -> Path:
|
||||
"""构造最小仓库结构,variant 控制变体以增加多样性。"""
|
||||
repo = base / f"repo_{variant}"
|
||||
repo.mkdir()
|
||||
|
||||
# 必需的 cli 入口
|
||||
cli_dir = repo / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text(
|
||||
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# config 目录
|
||||
config_dir = repo / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# docs 目录
|
||||
docs_dir = repo / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
# 根据 variant 添加不同的额外文件
|
||||
if variant % 3 == 0:
|
||||
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
|
||||
if variant % 3 == 1:
|
||||
scripts_dir = repo / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
if variant % 3 == 2:
|
||||
(docs_dir / "notes.md").write_text("# 笔记\n", encoding="utf-8")
|
||||
|
||||
return repo
|
||||
|
||||
@staticmethod
|
||||
def _snapshot_files(repo: Path) -> dict[str, float]:
|
||||
"""记录仓库中所有文件的 mtime 快照(排除 docs/audit/)。"""
|
||||
snap: dict[str, float] = {}
|
||||
for p in repo.rglob("*"):
|
||||
if p.is_file():
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if not rel.startswith("docs/audit"):
|
||||
snap[rel] = p.stat().st_mtime
|
||||
return snap
|
||||
|
||||
@given(variant=st.integers(min_value=0, max_value=9))
|
||||
@settings(max_examples=10)
|
||||
def test_writes_only_under_docs_audit(self, variant: int) -> None:
|
||||
"""Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
|
||||
Validates: Requirements 5.2
|
||||
|
||||
运行 run_audit 后,docs/audit/ 外不应有新文件被创建。
|
||||
docs/audit/ 下应有报告文件。
|
||||
"""
|
||||
import tempfile
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
repo = self._make_minimal_repo(tmp_path, variant)
|
||||
before_snap = self._snapshot_files(repo)
|
||||
|
||||
run_audit(repo)
|
||||
|
||||
# 验证 docs/audit/ 下有新文件
|
||||
audit_dir = repo / "docs" / "audit"
|
||||
assert audit_dir.is_dir(), "docs/audit/ 目录未创建"
|
||||
audit_files = list(audit_dir.iterdir())
|
||||
assert len(audit_files) > 0, "docs/audit/ 下无报告文件"
|
||||
|
||||
# 验证 docs/audit/ 外无新文件
|
||||
for p in repo.rglob("*"):
|
||||
if p.is_file():
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if rel.startswith("docs/audit"):
|
||||
continue
|
||||
assert rel in before_snap, (
|
||||
f"docs/audit/ 外出现了新文件: {rel}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 辅助函数 — 从报告文本中提取统计数字
|
||||
# ===========================================================================
|
||||
|
||||
def _extract_summary_total(report: str, section_name: str) -> int:
|
||||
"""从 inventory 报告的统计摘要中提取指定分区的数字之和。
|
||||
|
||||
查找 "### {section_name}" 下的 Markdown 表格,
|
||||
累加每行最后一列的数字(排除合计行)。
|
||||
"""
|
||||
lines = report.split("\n")
|
||||
in_section = False
|
||||
total = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped == f"### {section_name}":
|
||||
in_section = True
|
||||
continue
|
||||
if in_section and stripped.startswith("###"):
|
||||
# 进入下一个子节
|
||||
break
|
||||
if in_section and stripped.startswith("|") and "**合计**" not in stripped:
|
||||
# 跳过表头和分隔行
|
||||
if stripped.startswith("| 用途分类") or stripped.startswith("| 处置标签"):
|
||||
continue
|
||||
if stripped.startswith("|---"):
|
||||
continue
|
||||
# 提取最后一列的数字
|
||||
cells = [c.strip() for c in stripped.split("|") if c.strip()]
|
||||
if cells:
|
||||
try:
|
||||
total += int(cells[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def _extract_flow_stat(report: str, label: str) -> int:
|
||||
"""从 flow 报告统计摘要表格中提取指定指标的数字。"""
|
||||
# 匹配 "| 孤立模块 | 5 |" 格式
|
||||
pattern = re.compile(rf"\|\s*{re.escape(label)}\s*\|\s*(\d+)\s*\|")
|
||||
m = pattern.search(report)
|
||||
return int(m.group(1)) if m else -1
|
||||
|
||||
|
||||
def _extract_alignment_stat(report: str, label: str) -> int:
|
||||
"""从 alignment 报告统计摘要中提取指定指标的数字。
|
||||
|
||||
匹配 "- 过期点数量:3" 格式。
|
||||
"""
|
||||
# 兼容全角/半角冒号
|
||||
pattern = re.compile(rf"{re.escape(label)}[::]\s*(\d+)")
|
||||
m = pattern.search(report)
|
||||
return int(m.group(1)) if m else -1
|
||||
177
apps/etl/pipelines/feiqiu/tests/unit/test_audit_run.py
Normal file
177
apps/etl/pipelines/feiqiu/tests/unit/test_audit_run.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
run_audit 主入口的单元测试。
|
||||
|
||||
验证:
|
||||
- docs/audit/ 目录自动创建
|
||||
- 三份报告文件正确生成
|
||||
- 报告头部包含时间戳和仓库路径
|
||||
- 目录创建失败时抛出 RuntimeError
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEnsureReportDir:
|
||||
"""测试 _ensure_report_dir 目录创建逻辑。"""
|
||||
|
||||
def test_creates_dir_when_missing(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import _ensure_report_dir
|
||||
|
||||
result = _ensure_report_dir(tmp_path)
|
||||
expected = tmp_path / "docs" / "audit" / "repo"
|
||||
assert result == expected
|
||||
assert expected.is_dir()
|
||||
|
||||
def test_returns_existing_dir(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import _ensure_report_dir
|
||||
|
||||
audit_dir = tmp_path / "docs" / "audit" / "repo"
|
||||
audit_dir.mkdir(parents=True)
|
||||
result = _ensure_report_dir(tmp_path)
|
||||
assert result == audit_dir
|
||||
|
||||
def test_raises_on_creation_failure(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import _ensure_report_dir
|
||||
|
||||
# 在 docs/audit 位置放一个文件,使 mkdir 失败
|
||||
docs = tmp_path / "docs"
|
||||
docs.mkdir()
|
||||
(docs / "audit").write_text("block", encoding="utf-8")
|
||||
|
||||
with pytest.raises(RuntimeError, match="无法创建报告输出目录"):
|
||||
_ensure_report_dir(tmp_path)
|
||||
|
||||
|
||||
class TestInjectHeader:
|
||||
"""测试 _inject_header 兜底注入逻辑。"""
|
||||
|
||||
def test_skips_when_header_present(self):
|
||||
from scripts.audit.run_audit import _inject_header
|
||||
|
||||
report = "# 标题\n\n- 生成时间: 2025-01-01T00:00:00Z\n- 仓库路径: `/repo`\n"
|
||||
result = _inject_header(report, "2025-06-01T00:00:00Z", "/other")
|
||||
# 不应修改已有头部
|
||||
assert result == report
|
||||
|
||||
def test_injects_when_header_missing(self):
|
||||
from scripts.audit.run_audit import _inject_header
|
||||
|
||||
report = "# 无头部报告\n\n内容..."
|
||||
result = _inject_header(report, "2025-06-01T00:00:00Z", "/repo")
|
||||
assert "生成时间: 2025-06-01T00:00:00Z" in result
|
||||
assert "仓库路径: `/repo`" in result
|
||||
|
||||
|
||||
class TestRunAudit:
|
||||
"""测试 run_audit 完整流程(使用最小仓库结构)。"""
|
||||
|
||||
def _make_minimal_repo(self, tmp_path: Path) -> Path:
|
||||
"""构造一个最小仓库结构,足以让 run_audit 跑通。"""
|
||||
repo = tmp_path / "repo"
|
||||
repo.mkdir()
|
||||
|
||||
# 核心代码目录
|
||||
cli_dir = repo / "cli"
|
||||
cli_dir.mkdir()
|
||||
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(cli_dir / "main.py").write_text(
|
||||
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# config 目录
|
||||
config_dir = repo / "config"
|
||||
config_dir.mkdir()
|
||||
(config_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(config_dir / "defaults.py").write_text("DEFAULTS = {}\n", encoding="utf-8")
|
||||
|
||||
# docs 目录
|
||||
docs_dir = repo / "docs"
|
||||
docs_dir.mkdir()
|
||||
(docs_dir / "README.md").write_text("# 文档\n", encoding="utf-8")
|
||||
|
||||
# 根目录文件
|
||||
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
|
||||
|
||||
return repo
|
||||
|
||||
def test_creates_three_reports(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
run_audit(repo)
|
||||
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
assert (audit_dir / "file_inventory.md").is_file()
|
||||
assert (audit_dir / "flow_tree.md").is_file()
|
||||
assert (audit_dir / "doc_alignment.md").is_file()
|
||||
|
||||
def test_reports_contain_timestamp(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
run_audit(repo)
|
||||
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
# ISO 时间戳格式
|
||||
ts_pattern = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
|
||||
|
||||
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
|
||||
content = (audit_dir / name).read_text(encoding="utf-8")
|
||||
assert ts_pattern.search(content), f"{name} 缺少时间戳"
|
||||
|
||||
def test_reports_contain_repo_path(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
run_audit(repo)
|
||||
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
repo_str = str(repo.resolve())
|
||||
|
||||
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
|
||||
content = (audit_dir / name).read_text(encoding="utf-8")
|
||||
assert repo_str in content, f"{name} 缺少仓库路径"
|
||||
|
||||
def test_writes_only_to_docs_audit(self, tmp_path: Path):
|
||||
"""验证所有写操作仅限 docs/audit/ 目录(Property 15)。"""
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
|
||||
# 记录运行前的文件快照(排除 docs/audit/)
|
||||
before = set()
|
||||
for p in repo.rglob("*"):
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if not rel.startswith("docs/audit"):
|
||||
before.add((rel, p.stat().st_mtime if p.is_file() else None))
|
||||
|
||||
run_audit(repo)
|
||||
|
||||
# 运行后检查:docs/audit/ 外的文件不应被修改
|
||||
for p in repo.rglob("*"):
|
||||
rel = p.relative_to(repo).as_posix()
|
||||
if rel.startswith("docs/audit"):
|
||||
continue
|
||||
if p.is_file():
|
||||
# 文件应在之前的快照中
|
||||
found = any(r == rel for r, _ in before)
|
||||
assert found, f"意外创建了 docs/audit/ 外的文件: {rel}"
|
||||
|
||||
def test_auto_creates_docs_audit_dir(self, tmp_path: Path):
|
||||
from scripts.audit.run_audit import run_audit
|
||||
|
||||
repo = self._make_minimal_repo(tmp_path)
|
||||
# 确保 docs/audit/ 不存在
|
||||
audit_dir = repo / "docs" / "audit" / "repo"
|
||||
assert not audit_dir.exists()
|
||||
|
||||
run_audit(repo)
|
||||
assert audit_dir.is_dir()
|
||||
428
apps/etl/pipelines/feiqiu/tests/unit/test_audit_scanner.py
Normal file
428
apps/etl/pipelines/feiqiu/tests/unit/test_audit_scanner.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单元测试 — 仓库扫描器 (scanner.py)
|
||||
|
||||
覆盖:
|
||||
- 排除模式匹配逻辑
|
||||
- 递归遍历与 FileEntry 构建
|
||||
- 空目录检测
|
||||
- 权限错误容错
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.audit import FileEntry
|
||||
from scripts.audit.scanner import EXCLUDED_PATTERNS, _is_excluded, scan_repo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_excluded 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsExcluded:
|
||||
"""排除模式匹配逻辑测试。"""
|
||||
|
||||
def test_exact_match_git(self) -> None:
|
||||
assert _is_excluded(".git", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_exact_match_pycache(self) -> None:
|
||||
assert _is_excluded("__pycache__", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_exact_match_pytest_cache(self) -> None:
|
||||
assert _is_excluded(".pytest_cache", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_exact_match_kiro(self) -> None:
|
||||
assert _is_excluded(".kiro", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_wildcard_pyc(self) -> None:
|
||||
assert _is_excluded("module.pyc", EXCLUDED_PATTERNS) is True
|
||||
|
||||
def test_normal_py_not_excluded(self) -> None:
|
||||
assert _is_excluded("main.py", EXCLUDED_PATTERNS) is False
|
||||
|
||||
def test_normal_dir_not_excluded(self) -> None:
|
||||
assert _is_excluded("src", EXCLUDED_PATTERNS) is False
|
||||
|
||||
def test_empty_patterns(self) -> None:
|
||||
assert _is_excluded(".git", []) is False
|
||||
|
||||
def test_custom_pattern(self) -> None:
|
||||
assert _is_excluded("data.csv", ["*.csv"]) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_repo 单元测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScanRepo:
|
||||
"""scan_repo 递归遍历测试。"""
|
||||
|
||||
def test_basic_structure(self, tmp_path: Path) -> None:
|
||||
"""基本文件和目录应被正确扫描。"""
|
||||
(tmp_path / "a.py").write_text("# code", encoding="utf-8")
|
||||
sub = tmp_path / "sub"
|
||||
sub.mkdir()
|
||||
(sub / "b.txt").write_text("hello", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "a.py" in paths
|
||||
assert "sub" in paths
|
||||
assert "sub/b.txt" in paths
|
||||
|
||||
def test_file_entry_fields(self, tmp_path: Path) -> None:
|
||||
"""FileEntry 各字段应正确填充。"""
|
||||
(tmp_path / "hello.md").write_text("# hi", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
md = next(e for e in entries if e.rel_path == "hello.md")
|
||||
|
||||
assert md.is_dir is False
|
||||
assert md.size_bytes > 0
|
||||
assert md.extension == ".md"
|
||||
assert md.is_empty_dir is False
|
||||
|
||||
def test_directory_entry_fields(self, tmp_path: Path) -> None:
|
||||
"""目录条目的字段应正确设置。"""
|
||||
sub = tmp_path / "mydir"
|
||||
sub.mkdir()
|
||||
(sub / "file.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
d = next(e for e in entries if e.rel_path == "mydir")
|
||||
|
||||
assert d.is_dir is True
|
||||
assert d.size_bytes == 0
|
||||
assert d.extension == ""
|
||||
assert d.is_empty_dir is False
|
||||
|
||||
def test_excluded_git_dir(self, tmp_path: Path) -> None:
|
||||
""".git 目录及其内容应被排除。"""
|
||||
git_dir = tmp_path / ".git"
|
||||
git_dir.mkdir()
|
||||
(git_dir / "config").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert ".git" not in paths
|
||||
assert ".git/config" not in paths
|
||||
|
||||
def test_excluded_pycache(self, tmp_path: Path) -> None:
|
||||
"""__pycache__ 目录应被排除。"""
|
||||
cache = tmp_path / "pkg" / "__pycache__"
|
||||
cache.mkdir(parents=True)
|
||||
(cache / "mod.cpython-310.pyc").write_bytes(b"\x00")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert not any("__pycache__" in p for p in paths)
|
||||
|
||||
def test_excluded_pyc_files(self, tmp_path: Path) -> None:
|
||||
"""*.pyc 文件应被排除。"""
|
||||
(tmp_path / "mod.pyc").write_bytes(b"\x00")
|
||||
(tmp_path / "mod.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "mod.pyc" not in paths
|
||||
assert "mod.py" in paths
|
||||
|
||||
def test_empty_directory_detection(self, tmp_path: Path) -> None:
|
||||
"""空目录应被标记为 is_empty_dir=True。"""
|
||||
(tmp_path / "empty").mkdir()
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
d = next(e for e in entries if e.rel_path == "empty")
|
||||
|
||||
assert d.is_dir is True
|
||||
assert d.is_empty_dir is True
|
||||
|
||||
def test_dir_with_only_excluded_children(self, tmp_path: Path) -> None:
|
||||
"""仅含被排除子项的目录应视为空目录。"""
|
||||
sub = tmp_path / "pkg"
|
||||
sub.mkdir()
|
||||
cache = sub / "__pycache__"
|
||||
cache.mkdir()
|
||||
(cache / "x.pyc").write_bytes(b"\x00")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
d = next(e for e in entries if e.rel_path == "pkg")
|
||||
|
||||
assert d.is_empty_dir is True
|
||||
|
||||
def test_custom_exclude_patterns(self, tmp_path: Path) -> None:
|
||||
"""自定义排除模式应生效。"""
|
||||
(tmp_path / "keep.py").write_text("pass", encoding="utf-8")
|
||||
(tmp_path / "skip.log").write_text("log", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path, exclude=["*.log"])
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "keep.py" in paths
|
||||
assert "skip.log" not in paths
|
||||
|
||||
def test_empty_repo(self, tmp_path: Path) -> None:
|
||||
"""空仓库应返回空列表。"""
|
||||
entries = scan_repo(tmp_path)
|
||||
assert entries == []
|
||||
|
||||
def test_results_sorted(self, tmp_path: Path) -> None:
|
||||
"""返回结果应按 rel_path 排序。"""
|
||||
(tmp_path / "z.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "a.py").write_text("", encoding="utf-8")
|
||||
sub = tmp_path / "m"
|
||||
sub.mkdir()
|
||||
(sub / "b.py").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = [e.rel_path for e in entries]
|
||||
|
||||
assert paths == sorted(paths)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.name == "nt",
|
||||
reason="Windows 上 chmod 行为不同,跳过权限测试",
|
||||
)
|
||||
def test_permission_error_skipped(self, tmp_path: Path) -> None:
|
||||
"""权限不足的目录应被跳过,不中断扫描。"""
|
||||
ok_file = tmp_path / "ok.py"
|
||||
ok_file.write_text("pass", encoding="utf-8")
|
||||
|
||||
no_access = tmp_path / "secret"
|
||||
no_access.mkdir()
|
||||
(no_access / "data.txt").write_text("x", encoding="utf-8")
|
||||
no_access.chmod(0o000)
|
||||
|
||||
try:
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
# ok.py 应正常扫描到
|
||||
assert "ok.py" in paths
|
||||
# secret 目录本身会被记录(在 _walk 中先记录目录再尝试 iterdir)
|
||||
# 但其子文件不应出现
|
||||
assert "secret/data.txt" not in paths
|
||||
finally:
|
||||
no_access.chmod(0o755)
|
||||
|
||||
def test_nested_directories(self, tmp_path: Path) -> None:
|
||||
"""多层嵌套目录应被正确遍历。"""
|
||||
deep = tmp_path / "a" / "b" / "c"
|
||||
deep.mkdir(parents=True)
|
||||
(deep / "leaf.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "a" in paths
|
||||
assert "a/b" in paths
|
||||
assert "a/b/c" in paths
|
||||
assert "a/b/c/leaf.py" in paths
|
||||
|
||||
def test_extension_lowercase(self, tmp_path: Path) -> None:
|
||||
"""扩展名应统一为小写。"""
|
||||
(tmp_path / "README.MD").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
md = next(e for e in entries if "README" in e.rel_path)
|
||||
|
||||
assert md.extension == ".md"
|
||||
|
||||
def test_no_extension(self, tmp_path: Path) -> None:
|
||||
"""无扩展名的文件 extension 应为空字符串。"""
|
||||
(tmp_path / "Makefile").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
f = next(e for e in entries if e.rel_path == "Makefile")
|
||||
|
||||
assert f.extension == ""
|
||||
|
||||
def test_root_not_in_entries(self, tmp_path: Path) -> None:
|
||||
"""根目录自身不应出现在结果中。"""
|
||||
(tmp_path / "a.py").write_text("", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp_path)
|
||||
paths = {e.rel_path for e in entries}
|
||||
|
||||
assert "." not in paths
|
||||
assert "" not in paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 7: 扫描器排除规则
|
||||
# Feature: repo-audit, Property 7: 扫描器排除规则
|
||||
# Validates: Requirements 1.1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import fnmatch
|
||||
import string
|
||||
import tempfile
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# --- 生成器策略 ---
|
||||
|
||||
# 合法的文件/目录名字符(排除路径分隔符和特殊字符)
|
||||
_SAFE_CHARS = string.ascii_lowercase + string.digits + "_-"
|
||||
|
||||
# 安全的文件名策略(不与排除模式冲突的普通名称)
|
||||
_safe_name = st.text(_SAFE_CHARS, min_size=1, max_size=8)
|
||||
|
||||
# 排除模式中的目录名
|
||||
_EXCLUDED_DIR_NAMES = [".git", "__pycache__", ".pytest_cache", ".kiro"]
|
||||
|
||||
# 排除模式中的文件扩展名
|
||||
_EXCLUDED_FILE_EXT = ".pyc"
|
||||
|
||||
# 随机选择一个被排除的目录名
|
||||
_excluded_dir_name = st.sampled_from(_EXCLUDED_DIR_NAMES)
|
||||
|
||||
|
||||
def _build_tree(tmp: Path, normal_names: list[str], excluded_dirs: list[str],
|
||||
include_pyc: bool) -> None:
|
||||
"""在临时目录中构建包含正常文件和被排除条目的文件树。"""
|
||||
# 创建正常文件
|
||||
for name in normal_names:
|
||||
safe = name or "f"
|
||||
filepath = tmp / f"{safe}.txt"
|
||||
if not filepath.exists():
|
||||
filepath.write_text("ok", encoding="utf-8")
|
||||
|
||||
# 创建被排除的目录(含子文件)
|
||||
for dirname in excluded_dirs:
|
||||
d = tmp / dirname
|
||||
d.mkdir(exist_ok=True)
|
||||
(d / "inner.txt").write_text("hidden", encoding="utf-8")
|
||||
|
||||
# 可选:创建 .pyc 文件
|
||||
if include_pyc:
|
||||
(tmp / "module.pyc").write_bytes(b"\x00")
|
||||
|
||||
|
||||
class TestProperty7ScannerExclusionRules:
|
||||
"""
|
||||
Property 7: 扫描器排除规则
|
||||
|
||||
对于任意文件树,scan_repo 返回的 FileEntry 列表中不应包含
|
||||
rel_path 匹配排除模式(.git、__pycache__、.pytest_cache 等)的条目。
|
||||
|
||||
Feature: repo-audit, Property 7: 扫描器排除规则
|
||||
Validates: Requirements 1.1
|
||||
"""
|
||||
|
||||
@given(
|
||||
normal_names=st.lists(_safe_name, min_size=0, max_size=5),
|
||||
excluded_dirs=st.lists(_excluded_dir_name, min_size=1, max_size=3),
|
||||
include_pyc=st.booleans(),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_excluded_entries_never_in_results(
|
||||
self,
|
||||
normal_names: list[str],
|
||||
excluded_dirs: list[str],
|
||||
include_pyc: bool,
|
||||
) -> None:
|
||||
"""扫描结果中不应包含任何匹配排除模式的条目。"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
_build_tree(tmp, normal_names, excluded_dirs, include_pyc)
|
||||
|
||||
entries = scan_repo(tmp)
|
||||
|
||||
for entry in entries:
|
||||
# 检查 rel_path 的每一段是否匹配排除模式
|
||||
parts = entry.rel_path.split("/")
|
||||
for part in parts:
|
||||
for pat in EXCLUDED_PATTERNS:
|
||||
assert not fnmatch.fnmatch(part, pat), (
|
||||
f"排除模式 '{pat}' 不应出现在结果中,"
|
||||
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
|
||||
)
|
||||
|
||||
@given(
|
||||
excluded_dir=_excluded_dir_name,
|
||||
depth=st.integers(min_value=1, max_value=3),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_excluded_dirs_at_any_depth(
|
||||
self,
|
||||
excluded_dir: str,
|
||||
depth: int,
|
||||
) -> None:
|
||||
"""被排除目录无论在哪一层嵌套深度,都不应出现在结果中。"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
|
||||
# 构建嵌套路径:normal/normal/.../excluded_dir/file.txt
|
||||
current = tmp
|
||||
for i in range(depth):
|
||||
current = current / f"level{i}"
|
||||
current.mkdir(exist_ok=True)
|
||||
# 放一个正常文件保证父目录非空
|
||||
(current / "keep.txt").write_text("ok", encoding="utf-8")
|
||||
|
||||
# 在最深层放置被排除目录
|
||||
excluded = current / excluded_dir
|
||||
excluded.mkdir(exist_ok=True)
|
||||
(excluded / "secret.txt").write_text("hidden", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp)
|
||||
|
||||
for entry in entries:
|
||||
parts = entry.rel_path.split("/")
|
||||
assert excluded_dir not in parts, (
|
||||
f"被排除目录 '{excluded_dir}' 不应出现在结果中,"
|
||||
f"但发现 rel_path='{entry.rel_path}'"
|
||||
)
|
||||
|
||||
@given(
|
||||
custom_patterns=st.lists(
|
||||
st.sampled_from(["*.log", "*.tmp", "*.bak", "node_modules", ".venv"]),
|
||||
min_size=1,
|
||||
max_size=3,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_custom_exclude_patterns_respected(
|
||||
self,
|
||||
custom_patterns: list[str],
|
||||
) -> None:
|
||||
"""自定义排除模式同样应被 scan_repo 正确排除。"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp = Path(tmpdir)
|
||||
|
||||
# 创建一个正常文件
|
||||
(tmp / "main.py").write_text("pass", encoding="utf-8")
|
||||
|
||||
# 为每个自定义模式创建匹配的文件或目录
|
||||
for pat in custom_patterns:
|
||||
if pat.startswith("*."):
|
||||
# 通配符模式 → 创建匹配的文件
|
||||
ext = pat[1:] # e.g. ".log"
|
||||
(tmp / f"data{ext}").write_text("x", encoding="utf-8")
|
||||
else:
|
||||
# 精确匹配 → 创建目录
|
||||
d = tmp / pat
|
||||
d.mkdir(exist_ok=True)
|
||||
(d / "inner.txt").write_text("x", encoding="utf-8")
|
||||
|
||||
entries = scan_repo(tmp, exclude=custom_patterns)
|
||||
|
||||
for entry in entries:
|
||||
parts = entry.rel_path.split("/")
|
||||
for part in parts:
|
||||
for pat in custom_patterns:
|
||||
assert not fnmatch.fnmatch(part, pat), (
|
||||
f"自定义排除模式 '{pat}' 不应出现在结果中,"
|
||||
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
|
||||
)
|
||||
137
apps/etl/pipelines/feiqiu/tests/unit/test_cli_args.py
Normal file
137
apps/etl/pipelines/feiqiu/tests/unit/test_cli_args.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI 参数解析单元测试
|
||||
|
||||
验证 --data-source 新参数、--pipeline-flow 弃用映射、
|
||||
--pipeline + --tasks 同时使用、以及 build_cli_overrides 集成行为。
|
||||
|
||||
需求: 3.1, 3.3, 3.5
|
||||
"""
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cli.main import parse_args, resolve_data_source, build_cli_overrides
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. --data-source 新参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDataSourceArg:
|
||||
"""--data-source 新参数测试"""
|
||||
|
||||
@pytest.mark.parametrize("value", ["online", "offline", "hybrid"])
|
||||
def test_data_source_valid_values(self, value):
|
||||
with patch("sys.argv", ["cli", "--data-source", value]):
|
||||
args = parse_args()
|
||||
assert args.data_source == value
|
||||
|
||||
def test_data_source_default_is_none(self):
|
||||
with patch("sys.argv", ["cli"]):
|
||||
args = parse_args()
|
||||
assert args.data_source is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. resolve_data_source() 弃用映射
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestResolveDataSource:
|
||||
"""resolve_data_source() 弃用映射测试"""
|
||||
|
||||
def test_explicit_data_source_returns_directly(self):
|
||||
args = Namespace(data_source="online", pipeline_flow=None)
|
||||
assert resolve_data_source(args) == "online"
|
||||
|
||||
def test_data_source_takes_priority_over_pipeline_flow(self):
|
||||
"""--data-source 优先于 --pipeline-flow"""
|
||||
args = Namespace(data_source="online", pipeline_flow="FULL")
|
||||
assert resolve_data_source(args) == "online"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"flow, expected",
|
||||
[
|
||||
("FULL", "hybrid"),
|
||||
("FETCH_ONLY", "online"),
|
||||
("INGEST_ONLY", "offline"),
|
||||
],
|
||||
)
|
||||
def test_pipeline_flow_maps_with_deprecation_warning(self, flow, expected):
|
||||
"""旧参数 --pipeline-flow 映射到正确的 data_source 并发出弃用警告"""
|
||||
args = Namespace(data_source=None, pipeline_flow=flow)
|
||||
with pytest.warns(DeprecationWarning, match="--pipeline-flow 已弃用"):
|
||||
result = resolve_data_source(args)
|
||||
assert result == expected
|
||||
|
||||
def test_neither_arg_defaults_to_hybrid(self):
|
||||
"""两个参数都未指定时,默认返回 hybrid"""
|
||||
args = Namespace(data_source=None, pipeline_flow=None)
|
||||
assert resolve_data_source(args) == "hybrid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. build_cli_overrides() 集成
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestBuildCliOverrides:
|
||||
"""build_cli_overrides() 集成测试"""
|
||||
|
||||
def _make_args(self, **kwargs):
|
||||
"""构造最小 Namespace,未指定的参数设为 None/False"""
|
||||
defaults = dict(
|
||||
store_id=None, tasks=None, dry_run=False,
|
||||
pipeline=None, processing_mode="increment_only",
|
||||
fetch_before_verify=False, verify_tables=None,
|
||||
window_split="none", lookback_hours=24, overlap_seconds=3600,
|
||||
pg_dsn=None, pg_host=None, pg_port=None, pg_name=None,
|
||||
pg_user=None, pg_password=None,
|
||||
api_base=None, api_token=None, api_timeout=None,
|
||||
api_page_size=None, api_retry_max=None,
|
||||
window_start=None, window_end=None,
|
||||
force_window_override=False,
|
||||
window_split_unit=None, window_split_days=None,
|
||||
window_compensation_hours=None,
|
||||
export_root=None, log_root=None,
|
||||
data_source=None, pipeline_flow=None,
|
||||
fetch_root=None, ingest_source=None, write_pretty_json=False,
|
||||
idle_start=None, idle_end=None, allow_empty_advance=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return Namespace(**defaults)
|
||||
|
||||
def test_data_source_online_sets_run_key(self):
|
||||
args = self._make_args(data_source="online")
|
||||
overrides = build_cli_overrides(args)
|
||||
assert overrides["run"]["data_source"] == "online"
|
||||
|
||||
def test_pipeline_flow_sets_both_keys(self):
|
||||
"""旧参数同时写入 pipeline.flow 和 run.data_source"""
|
||||
args = self._make_args(pipeline_flow="FULL")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
overrides = build_cli_overrides(args)
|
||||
assert overrides["pipeline"]["flow"] == "FULL"
|
||||
assert overrides["run"]["data_source"] == "hybrid"
|
||||
|
||||
def test_default_data_source_is_hybrid(self):
|
||||
"""无 --data-source 也无 --pipeline-flow 时,run.data_source 默认 hybrid"""
|
||||
args = self._make_args()
|
||||
overrides = build_cli_overrides(args)
|
||||
assert overrides["run"]["data_source"] == "hybrid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. --pipeline + --tasks 同时使用
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPipelineAndTasks:
|
||||
"""--pipeline + --tasks 同时使用时的行为"""
|
||||
|
||||
def test_pipeline_and_tasks_both_parsed(self):
|
||||
with patch("sys.argv", [
|
||||
"cli",
|
||||
"--pipeline", "api_full",
|
||||
"--tasks", "ODS_MEMBER,ODS_ORDER",
|
||||
]):
|
||||
args = parse_args()
|
||||
assert args.pipeline == "api_full"
|
||||
assert args.tasks == "ODS_MEMBER,ODS_ORDER"
|
||||
426
apps/etl/pipelines/feiqiu/tests/unit/test_compare_ddl.py
Normal file
426
apps/etl/pipelines/feiqiu/tests/unit/test_compare_ddl.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""DDL 解析器和对比逻辑的单元测试。
|
||||
|
||||
测试范围:
|
||||
- DDL 解析器正确提取表名、字段、类型、约束
|
||||
- 类型标准化逻辑
|
||||
- 差异检测逻辑识别各类差异
|
||||
- 边界情况:空文件、COMMENT 含特殊字符
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.compare_ddl_db import (
|
||||
ColumnDef,
|
||||
DiffKind,
|
||||
SchemaDiff,
|
||||
TableDef,
|
||||
compare_tables,
|
||||
normalize_type,
|
||||
parse_ddl,
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# normalize_type 测试
|
||||
# =========================================================================
|
||||
|
||||
class TestNormalizeType:
|
||||
"""类型标准化测试。"""
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("BIGINT", "bigint"),
|
||||
("INT8", "bigint"),
|
||||
("INTEGER", "integer"),
|
||||
("INT", "integer"),
|
||||
("INT4", "integer"),
|
||||
("SMALLINT", "smallint"),
|
||||
("INT2", "smallint"),
|
||||
("BOOLEAN", "boolean"),
|
||||
("BOOL", "boolean"),
|
||||
("TEXT", "text"),
|
||||
("JSONB", "jsonb"),
|
||||
("JSON", "json"),
|
||||
("DATE", "date"),
|
||||
("BYTEA", "bytea"),
|
||||
("UUID", "uuid"),
|
||||
])
|
||||
def test_simple_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("NUMERIC(18,2)", "numeric(18,2)"),
|
||||
("NUMERIC(10,6)", "numeric(10,6)"),
|
||||
("DECIMAL(5,2)", "numeric(5,2)"),
|
||||
("NUMERIC(10)", "numeric(10)"),
|
||||
("NUMERIC", "numeric"),
|
||||
])
|
||||
def test_numeric_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("VARCHAR(50)", "varchar(50)"),
|
||||
("CHARACTER VARYING(100)", "varchar(100)"),
|
||||
("VARCHAR", "varchar"),
|
||||
("CHAR(1)", "char(1)"),
|
||||
("CHARACTER(10)", "char(10)"),
|
||||
])
|
||||
def test_string_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("TIMESTAMP", "timestamp"),
|
||||
("TIMESTAMP WITHOUT TIME ZONE", "timestamp"),
|
||||
("TIMESTAMPTZ", "timestamptz"),
|
||||
("TIMESTAMP WITH TIME ZONE", "timestamptz"),
|
||||
])
|
||||
def test_timestamp_types(self, raw, expected):
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("BIGSERIAL", "bigint"),
|
||||
("SERIAL", "integer"),
|
||||
("SMALLSERIAL", "smallint"),
|
||||
])
|
||||
def test_serial_types(self, raw, expected):
|
||||
"""serial 家族应映射到底层整数类型。"""
|
||||
assert normalize_type(raw) == expected
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert normalize_type("bigint") == normalize_type("BIGINT")
|
||||
assert normalize_type("Numeric(18,2)") == normalize_type("NUMERIC(18,2)")
|
||||
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# parse_ddl 测试
|
||||
# =========================================================================
|
||||
|
||||
class TestParseDdl:
|
||||
"""DDL 解析器测试。"""
|
||||
|
||||
def test_basic_create_table(self):
|
||||
"""基本 CREATE TABLE 解析。"""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS myschema.users (
|
||||
id BIGINT NOT NULL,
|
||||
name TEXT,
|
||||
age INTEGER,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="myschema")
|
||||
assert "users" in tables
|
||||
t = tables["users"]
|
||||
assert len(t.columns) == 3
|
||||
assert t.pk_columns == ["id"]
|
||||
assert t.columns["id"].data_type == "bigint"
|
||||
assert t.columns["id"].nullable is False
|
||||
assert t.columns["name"].data_type == "text"
|
||||
assert t.columns["name"].nullable is True
|
||||
assert t.columns["age"].data_type == "integer"
|
||||
|
||||
def test_inline_primary_key(self):
|
||||
"""内联 PRIMARY KEY 约束。"""
|
||||
sql = """
|
||||
CREATE TABLE test_schema.items (
|
||||
item_id BIGSERIAL PRIMARY KEY,
|
||||
label TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="test_schema")
|
||||
t = tables["items"]
|
||||
assert t.columns["item_id"].is_pk is True
|
||||
# BIGSERIAL → bigint
|
||||
assert t.columns["item_id"].data_type == "bigint"
|
||||
assert t.columns["item_id"].nullable is False
|
||||
assert t.columns["label"].nullable is False
|
||||
|
||||
def test_composite_primary_key(self):
|
||||
"""复合主键。"""
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS billiards_ods.member_profiles (
|
||||
id BIGINT,
|
||||
content_hash TEXT NOT NULL,
|
||||
name TEXT,
|
||||
PRIMARY KEY (id, content_hash)
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||||
t = tables["member_profiles"]
|
||||
assert t.pk_columns == ["id", "content_hash"]
|
||||
assert t.columns["id"].is_pk is True
|
||||
assert t.columns["content_hash"].is_pk is True
|
||||
# PK 隐含 NOT NULL
|
||||
assert t.columns["id"].nullable is False
|
||||
|
||||
def test_various_data_types(self):
|
||||
"""各种 PostgreSQL 数据类型。"""
|
||||
sql = """
|
||||
CREATE TABLE s.t (
|
||||
a BIGINT,
|
||||
b VARCHAR(50),
|
||||
c NUMERIC(18,2),
|
||||
d TIMESTAMP,
|
||||
e TIMESTAMPTZ DEFAULT now(),
|
||||
f BOOLEAN DEFAULT TRUE,
|
||||
g JSONB NOT NULL,
|
||||
h TEXT,
|
||||
i INTEGER
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="s")
|
||||
t = tables["t"]
|
||||
assert t.columns["a"].data_type == "bigint"
|
||||
assert t.columns["b"].data_type == "varchar(50)"
|
||||
assert t.columns["c"].data_type == "numeric(18,2)"
|
||||
assert t.columns["d"].data_type == "timestamp"
|
||||
assert t.columns["e"].data_type == "timestamptz"
|
||||
assert t.columns["f"].data_type == "boolean"
|
||||
assert t.columns["g"].data_type == "jsonb"
|
||||
assert t.columns["g"].nullable is False
|
||||
assert t.columns["h"].data_type == "text"
|
||||
assert t.columns["i"].data_type == "integer"
|
||||
|
||||
def test_without_schema_prefix(self):
|
||||
"""无 schema 前缀的 CREATE TABLE(如 DWD DDL 中 SET search_path 后)。"""
|
||||
sql = """
|
||||
SET search_path TO billiards_dwd;
|
||||
CREATE TABLE IF NOT EXISTS dim_site (
|
||||
site_id BIGINT,
|
||||
shop_name TEXT,
|
||||
PRIMARY KEY (site_id)
|
||||
);
|
||||
"""
|
||||
# target_schema 指定时,无前缀的表也应被接受
|
||||
tables = parse_ddl(sql, target_schema="billiards_dwd")
|
||||
assert "dim_site" in tables
|
||||
|
||||
def test_schema_filter(self):
|
||||
"""schema 过滤:只保留目标 schema 的表。"""
|
||||
sql = """
|
||||
CREATE TABLE schema_a.t1 (id BIGINT);
|
||||
CREATE TABLE schema_b.t2 (id BIGINT);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="schema_a")
|
||||
assert "t1" in tables
|
||||
assert "t2" not in tables
|
||||
|
||||
def test_empty_ddl(self):
|
||||
"""空 DDL 文件应返回空字典。"""
|
||||
tables = parse_ddl("", target_schema="any")
|
||||
assert tables == {}
|
||||
|
||||
def test_comments_ignored(self):
|
||||
"""SQL 注释不影响解析。"""
|
||||
sql = """
|
||||
-- 这是注释
|
||||
/* 块注释 */
|
||||
CREATE TABLE s.t (
|
||||
id BIGINT, -- 行内注释
|
||||
name TEXT
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="s")
|
||||
assert "t" in tables
|
||||
assert len(tables["t"].columns) == 2
|
||||
|
||||
def test_comment_on_statements_ignored(self):
|
||||
"""COMMENT ON 语句不影响表解析。"""
|
||||
sql = """
|
||||
CREATE TABLE billiards_ods.test_table (
|
||||
id BIGINT NOT NULL,
|
||||
name TEXT,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
COMMENT ON TABLE billiards_ods.test_table IS '测试表:含特殊字符 ''引号'' 和 (括号)';
|
||||
COMMENT ON COLUMN billiards_ods.test_table.id IS '【说明】主键 ID。【示例】12345。';
|
||||
COMMENT ON COLUMN billiards_ods.test_table.name IS '【说明】名称,含 ''单引号'' 和 "双引号"。';
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||||
assert "test_table" in tables
|
||||
assert len(tables["test_table"].columns) == 2
|
||||
|
||||
def test_drop_then_create(self):
|
||||
"""DROP TABLE 后 CREATE TABLE 应正常解析。"""
|
||||
sql = """
|
||||
DROP TABLE IF EXISTS billiards_dws.cfg_test CASCADE;
|
||||
CREATE TABLE billiards_dws.cfg_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
value TEXT
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="billiards_dws")
|
||||
assert "cfg_test" in tables
|
||||
assert tables["cfg_test"].columns["id"].data_type == "integer"
|
||||
|
||||
def test_default_values_parsed(self):
|
||||
"""DEFAULT 值不影响类型和约束解析。"""
|
||||
sql = """
|
||||
CREATE TABLE s.t (
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
count INTEGER DEFAULT 0 NOT NULL,
|
||||
label VARCHAR(20) NOT NULL
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="s")
|
||||
t = tables["t"]
|
||||
assert t.columns["enabled"].data_type == "boolean"
|
||||
assert t.columns["enabled"].nullable is True
|
||||
assert t.columns["created_at"].data_type == "timestamptz"
|
||||
assert t.columns["count"].data_type == "integer"
|
||||
assert t.columns["count"].nullable is False
|
||||
assert t.columns["label"].data_type == "varchar(20)"
|
||||
assert t.columns["label"].nullable is False
|
||||
|
||||
def test_constraint_lines_skipped(self):
|
||||
"""表级约束行(CONSTRAINT、UNIQUE、FOREIGN KEY)应被跳过。"""
|
||||
sql = """
|
||||
CREATE TABLE etl_admin.etl_task (
|
||||
task_id BIGSERIAL PRIMARY KEY,
|
||||
task_code TEXT NOT NULL,
|
||||
store_id BIGINT NOT NULL,
|
||||
UNIQUE (task_code, store_id)
|
||||
);
|
||||
"""
|
||||
tables = parse_ddl(sql, target_schema="etl_admin")
|
||||
t = tables["etl_task"]
|
||||
assert len(t.columns) == 3
|
||||
assert "task_id" in t.columns
|
||||
assert "task_code" in t.columns
|
||||
assert "store_id" in t.columns
|
||||
|
||||
def test_real_ods_ddl_parseable(self):
|
||||
"""验证实际 ODS DDL 文件可被解析。"""
|
||||
from pathlib import Path
|
||||
ddl_path = Path("database/schema_ODS_doc.sql")
|
||||
if not ddl_path.exists():
|
||||
pytest.skip("DDL 文件不存在")
|
||||
sql = ddl_path.read_text(encoding="utf-8")
|
||||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||||
# 至少应有 20+ 张表
|
||||
assert len(tables) >= 20
|
||||
# 每张表都应有字段
|
||||
for tbl in tables.values():
|
||||
assert len(tbl.columns) > 0
|
||||
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# compare_tables 测试
|
||||
# =========================================================================
|
||||
|
||||
class TestCompareTables:
|
||||
"""差异检测逻辑测试。"""
|
||||
|
||||
def _make_table(self, name: str, columns: dict[str, tuple[str, bool]],
|
||||
pk: list[str] | None = None) -> TableDef:
|
||||
"""辅助方法:快速构建 TableDef。
|
||||
|
||||
columns: {col_name: (data_type, nullable)}
|
||||
"""
|
||||
cols = {}
|
||||
for col_name, (dtype, nullable) in columns.items():
|
||||
cols[col_name] = ColumnDef(
|
||||
name=col_name,
|
||||
data_type=dtype,
|
||||
nullable=nullable,
|
||||
is_pk=col_name in (pk or []),
|
||||
)
|
||||
return TableDef(name=name, columns=cols, pk_columns=pk or [])
|
||||
|
||||
def test_no_diff(self):
|
||||
"""完全一致时应返回空列表。"""
|
||||
t = self._make_table("t", {"id": ("bigint", False), "name": ("text", True)}, pk=["id"])
|
||||
diffs = compare_tables({"t": t}, {"t": t})
|
||||
assert diffs == []
|
||||
|
||||
def test_missing_table(self):
|
||||
"""数据库有但 DDL 没有 → MISSING_TABLE。"""
|
||||
ddl = {}
|
||||
db = {"extra": self._make_table("extra", {"id": ("bigint", False)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.MISSING_TABLE
|
||||
assert diffs[0].table == "extra"
|
||||
|
||||
def test_extra_table(self):
|
||||
"""DDL 有但数据库没有 → EXTRA_TABLE。"""
|
||||
ddl = {"orphan": self._make_table("orphan", {"id": ("bigint", False)})}
|
||||
db = {}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.EXTRA_TABLE
|
||||
assert diffs[0].table == "orphan"
|
||||
|
||||
def test_missing_column(self):
|
||||
"""数据库有但 DDL 没有的字段 → MISSING_COLUMN。"""
|
||||
ddl = {"t": self._make_table("t", {"id": ("bigint", False)})}
|
||||
db = {"t": self._make_table("t", {
|
||||
"id": ("bigint", False),
|
||||
"new_col": ("text", True),
|
||||
})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.MISSING_COLUMN
|
||||
assert diffs[0].column == "new_col"
|
||||
|
||||
def test_extra_column(self):
|
||||
"""DDL 有但数据库没有的字段 → EXTRA_COLUMN。"""
|
||||
ddl = {"t": self._make_table("t", {
|
||||
"id": ("bigint", False),
|
||||
"old_col": ("text", True),
|
||||
})}
|
||||
db = {"t": self._make_table("t", {"id": ("bigint", False)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.EXTRA_COLUMN
|
||||
assert diffs[0].column == "old_col"
|
||||
|
||||
def test_type_mismatch(self):
|
||||
"""字段类型不一致 → TYPE_MISMATCH。"""
|
||||
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
|
||||
db = {"t": self._make_table("t", {"val": ("varchar(100)", True)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.TYPE_MISMATCH
|
||||
assert diffs[0].ddl_value == "text"
|
||||
assert diffs[0].db_value == "varchar(100)"
|
||||
|
||||
def test_nullable_mismatch(self):
|
||||
"""可空约束不一致 → NULLABLE_MISMATCH。"""
|
||||
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
|
||||
db = {"t": self._make_table("t", {"val": ("text", False)})}
|
||||
diffs = compare_tables(ddl, db)
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].kind == DiffKind.NULLABLE_MISMATCH
|
||||
assert diffs[0].ddl_value == "NULL"
|
||||
assert diffs[0].db_value == "NOT NULL"
|
||||
|
||||
def test_multiple_diffs(self):
|
||||
"""多种差异同时存在。"""
|
||||
ddl = {
|
||||
"t1": self._make_table("t1", {
|
||||
"id": ("bigint", False),
|
||||
"extra": ("text", True),
|
||||
}),
|
||||
"ddl_only": self._make_table("ddl_only", {"x": ("integer", True)}),
|
||||
}
|
||||
db = {
|
||||
"t1": self._make_table("t1", {
|
||||
"id": ("integer", False), # TYPE_MISMATCH
|
||||
"missing": ("text", True), # MISSING_COLUMN
|
||||
}),
|
||||
"db_only": self._make_table("db_only", {"y": ("text", True)}),
|
||||
}
|
||||
diffs = compare_tables(ddl, db)
|
||||
kinds = {d.kind for d in diffs}
|
||||
assert DiffKind.MISSING_TABLE in kinds # db_only
|
||||
assert DiffKind.EXTRA_TABLE in kinds # ddl_only
|
||||
assert DiffKind.MISSING_COLUMN in kinds # t1.missing
|
||||
assert DiffKind.EXTRA_COLUMN in kinds # t1.extra
|
||||
assert DiffKind.TYPE_MISMATCH in kinds # t1.id
|
||||
|
||||
def test_empty_both(self):
|
||||
"""两边都为空时应返回空列表。"""
|
||||
assert compare_tables({}, {}) == []
|
||||
545
apps/etl/pipelines/feiqiu/tests/unit/test_compare_ddl_pbt.py
Normal file
545
apps/etl/pipelines/feiqiu/tests/unit/test_compare_ddl_pbt.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""DDL 对比脚本差异检测完整性 — 属性测试。
|
||||
|
||||
# AI_CHANGELOG [2026-02-13] max_examples 从 100 降至 30 以加速测试运行
|
||||
|
||||
**Property 2: DDL 对比脚本差异检测完整性**
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
|
||||
使用 hypothesis 生成随机的 DDL 表定义和数据库表定义,
|
||||
注入已知差异,验证 compare_tables 能检测到所有差异。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given, settings, assume
|
||||
|
||||
from scripts.compare_ddl_db import (
|
||||
ColumnDef,
|
||||
DiffKind,
|
||||
TableDef,
|
||||
compare_tables,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 自定义 Strategy:生成随机的 ColumnDef / TableDef
|
||||
# =========================================================================
|
||||
|
||||
# 可用的标准化类型池(与 normalize_type 输出一致)
|
||||
_NORMALIZED_TYPES = [
|
||||
"bigint", "integer", "smallint", "boolean", "text",
|
||||
"jsonb", "json", "date", "bytea", "uuid",
|
||||
"timestamp", "timestamptz", "double precision",
|
||||
"varchar(50)", "varchar(100)", "varchar(255)",
|
||||
"char(1)", "char(10)",
|
||||
"numeric(18,2)", "numeric(10,6)", "numeric(5,0)",
|
||||
]
|
||||
|
||||
# 合法的标识符字符集(小写字母 + 下划线 + 数字,首字符为字母)
|
||||
_IDENT_ALPHABET = string.ascii_lowercase + "_"
|
||||
_IDENT_FULL = _IDENT_ALPHABET + string.digits
|
||||
|
||||
|
||||
def st_identifier() -> st.SearchStrategy[str]:
|
||||
"""生成合法的 PostgreSQL 标识符(小写,3~20 字符)。"""
|
||||
return st.from_regex(r"[a-z][a-z0-9_]{2,19}", fullmatch=True)
|
||||
|
||||
|
||||
def st_data_type() -> st.SearchStrategy[str]:
|
||||
"""从标准化类型池中随机选取一个类型。"""
|
||||
return st.sampled_from(_NORMALIZED_TYPES)
|
||||
|
||||
|
||||
def st_column_def(name: str | None = None) -> st.SearchStrategy[ColumnDef]:
|
||||
"""生成随机的 ColumnDef。"""
|
||||
return st.builds(
|
||||
ColumnDef,
|
||||
name=st.just(name) if name else st_identifier(),
|
||||
data_type=st_data_type(),
|
||||
nullable=st.booleans(),
|
||||
is_pk=st.just(False), # PK 由 TableDef 层面控制
|
||||
default=st.just(None),
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def st_table_def(draw, name: str | None = None) -> TableDef:
|
||||
"""生成随机的 TableDef(1~8 个字段,可选主键)。"""
|
||||
tbl_name = name or draw(st_identifier())
|
||||
num_cols = draw(st.integers(min_value=1, max_value=8))
|
||||
|
||||
# 生成不重复的列名
|
||||
col_names = draw(
|
||||
st.lists(st_identifier(), min_size=num_cols, max_size=num_cols, unique=True)
|
||||
)
|
||||
|
||||
columns: dict[str, ColumnDef] = {}
|
||||
for cn in col_names:
|
||||
col = draw(st_column_def(name=cn))
|
||||
columns[cn] = col
|
||||
|
||||
# 随机选取 0~2 个列作为主键
|
||||
pk_count = draw(st.integers(min_value=0, max_value=min(2, len(col_names))))
|
||||
pk_cols = draw(
|
||||
st.lists(
|
||||
st.sampled_from(col_names),
|
||||
min_size=pk_count, max_size=pk_count, unique=True,
|
||||
)
|
||||
) if pk_count > 0 and col_names else []
|
||||
|
||||
for pk_col in pk_cols:
|
||||
columns[pk_col].is_pk = True
|
||||
columns[pk_col].nullable = False
|
||||
|
||||
return TableDef(name=tbl_name, columns=columns, pk_columns=pk_cols)
|
||||
|
||||
|
||||
@st.composite
|
||||
def st_table_dict(draw, min_tables: int = 0, max_tables: int = 5) -> dict[str, TableDef]:
|
||||
"""生成 {表名: TableDef} 字典,表名不重复。"""
|
||||
num = draw(st.integers(min_value=min_tables, max_value=max_tables))
|
||||
names = draw(st.lists(st_identifier(), min_size=num, max_size=num, unique=True))
|
||||
tables: dict[str, TableDef] = {}
|
||||
for n in names:
|
||||
tables[n] = draw(st_table_def(name=n))
|
||||
return tables
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Property 2: DDL 对比脚本差异检测完整性
|
||||
# **Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestProperty2DiffDetection:
|
||||
"""Property 2: 对比脚本能检测到所有注入的差异。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
|
||||
@given(base=st_table_dict(min_tables=1, max_tables=4))
|
||||
@settings(max_examples=30)
|
||||
def test_identical_tables_produce_no_diffs(self, base: dict[str, TableDef]):
|
||||
"""完全相同的定义不应产生任何差异。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
diffs = compare_tables(base, base)
|
||||
assert diffs == [], f"相同定义不应有差异,但得到: {diffs}"
|
||||
|
||||
@given(
|
||||
common=st_table_dict(min_tables=0, max_tables=3),
|
||||
db_only=st_table_dict(min_tables=1, max_tables=3),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_missing_table_detected(
|
||||
self,
|
||||
common: dict[str, TableDef],
|
||||
db_only: dict[str, TableDef],
|
||||
):
|
||||
"""数据库有但 DDL 没有的表应报告为 MISSING_TABLE。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
# 确保 db_only 的表名与 common 不重叠
|
||||
overlap = set(common.keys()) & set(db_only.keys())
|
||||
assume(not overlap)
|
||||
|
||||
ddl_tables = dict(common)
|
||||
db_tables = {**common, **db_only}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
missing_tables = {d.table for d in diffs if d.kind == DiffKind.MISSING_TABLE}
|
||||
|
||||
for tbl in db_only:
|
||||
assert tbl in missing_tables, (
|
||||
f"表 '{tbl}' 仅存在于数据库中,应报告为 MISSING_TABLE"
|
||||
)
|
||||
|
||||
@given(
|
||||
common=st_table_dict(min_tables=0, max_tables=3),
|
||||
ddl_only=st_table_dict(min_tables=1, max_tables=3),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_extra_table_detected(
|
||||
self,
|
||||
common: dict[str, TableDef],
|
||||
ddl_only: dict[str, TableDef],
|
||||
):
|
||||
"""DDL 有但数据库没有的表应报告为 EXTRA_TABLE。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
overlap = set(common.keys()) & set(ddl_only.keys())
|
||||
assume(not overlap)
|
||||
|
||||
ddl_tables = {**common, **ddl_only}
|
||||
db_tables = dict(common)
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
extra_tables = {d.table for d in diffs if d.kind == DiffKind.EXTRA_TABLE}
|
||||
|
||||
for tbl in ddl_only:
|
||||
assert tbl in extra_tables, (
|
||||
f"表 '{tbl}' 仅存在于 DDL 中,应报告为 EXTRA_TABLE"
|
||||
)
|
||||
|
||||
@given(
|
||||
table=st_table_def(),
|
||||
extra_cols=st.lists(
|
||||
st.tuples(st_identifier(), st_data_type(), st.booleans()),
|
||||
min_size=1, max_size=4, unique_by=lambda x: x[0],
|
||||
),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_missing_column_detected(
|
||||
self,
|
||||
table: TableDef,
|
||||
extra_cols: list[tuple[str, str, bool]],
|
||||
):
|
||||
"""数据库有但 DDL 没有的字段应报告为 MISSING_COLUMN。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
# 确保注入的列名与现有列不重叠
|
||||
existing_names = set(table.columns.keys())
|
||||
injected_names = {c[0] for c in extra_cols}
|
||||
assume(not (existing_names & injected_names))
|
||||
|
||||
# DDL 侧:原始表
|
||||
ddl_tables = {table.name: table}
|
||||
|
||||
# DB 侧:原始表 + 额外字段
|
||||
db_table = TableDef(
|
||||
name=table.name,
|
||||
columns=dict(table.columns),
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
for col_name, col_type, nullable in extra_cols:
|
||||
db_table.columns[col_name] = ColumnDef(
|
||||
name=col_name, data_type=col_type, nullable=nullable,
|
||||
)
|
||||
db_tables = {table.name: db_table}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
missing_cols = {
|
||||
d.column for d in diffs
|
||||
if d.kind == DiffKind.MISSING_COLUMN and d.table == table.name
|
||||
}
|
||||
|
||||
for col_name, _, _ in extra_cols:
|
||||
assert col_name in missing_cols, (
|
||||
f"字段 '{table.name}.{col_name}' 仅存在于数据库中,"
|
||||
f"应报告为 MISSING_COLUMN"
|
||||
)
|
||||
|
||||
@given(
|
||||
table=st_table_def(),
|
||||
extra_cols=st.lists(
|
||||
st.tuples(st_identifier(), st_data_type(), st.booleans()),
|
||||
min_size=1, max_size=4, unique_by=lambda x: x[0],
|
||||
),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_extra_column_detected(
|
||||
self,
|
||||
table: TableDef,
|
||||
extra_cols: list[tuple[str, str, bool]],
|
||||
):
|
||||
"""DDL 有但数据库没有的字段应报告为 EXTRA_COLUMN。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
existing_names = set(table.columns.keys())
|
||||
injected_names = {c[0] for c in extra_cols}
|
||||
assume(not (existing_names & injected_names))
|
||||
|
||||
# DDL 侧:原始表 + 额外字段
|
||||
ddl_table = TableDef(
|
||||
name=table.name,
|
||||
columns=dict(table.columns),
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
for col_name, col_type, nullable in extra_cols:
|
||||
ddl_table.columns[col_name] = ColumnDef(
|
||||
name=col_name, data_type=col_type, nullable=nullable,
|
||||
)
|
||||
ddl_tables = {table.name: ddl_table}
|
||||
|
||||
# DB 侧:原始表
|
||||
db_tables = {table.name: table}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
extra_col_set = {
|
||||
d.column for d in diffs
|
||||
if d.kind == DiffKind.EXTRA_COLUMN and d.table == table.name
|
||||
}
|
||||
|
||||
for col_name, _, _ in extra_cols:
|
||||
assert col_name in extra_col_set, (
|
||||
f"字段 '{table.name}.{col_name}' 仅存在于 DDL 中,"
|
||||
f"应报告为 EXTRA_COLUMN"
|
||||
)
|
||||
|
||||
@given(table=st_table_def())
|
||||
@settings(max_examples=30)
|
||||
def test_type_mismatch_detected(self, table: TableDef):
|
||||
"""字段类型不一致时应报告为 TYPE_MISMATCH。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
"""
|
||||
assume(len(table.columns) >= 1)
|
||||
|
||||
# 选取第一个字段,修改其类型
|
||||
target_col_name = next(iter(table.columns))
|
||||
original_type = table.columns[target_col_name].data_type
|
||||
|
||||
# 选一个不同的类型
|
||||
alt_types = [t for t in _NORMALIZED_TYPES if t != original_type]
|
||||
assume(len(alt_types) > 0)
|
||||
new_type = alt_types[0]
|
||||
|
||||
# DDL 侧:原始定义
|
||||
ddl_tables = {table.name: table}
|
||||
|
||||
# DB 侧:修改目标字段类型
|
||||
db_table = TableDef(
|
||||
name=table.name,
|
||||
columns={
|
||||
cn: ColumnDef(
|
||||
name=c.name,
|
||||
data_type=new_type if cn == target_col_name else c.data_type,
|
||||
nullable=c.nullable,
|
||||
is_pk=c.is_pk,
|
||||
default=c.default,
|
||||
)
|
||||
for cn, c in table.columns.items()
|
||||
},
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
db_tables = {table.name: db_table}
|
||||
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
type_mismatches = [
|
||||
d for d in diffs
|
||||
if d.kind == DiffKind.TYPE_MISMATCH
|
||||
and d.table == table.name
|
||||
and d.column == target_col_name
|
||||
]
|
||||
|
||||
assert len(type_mismatches) == 1, (
|
||||
f"字段 '{table.name}.{target_col_name}' 类型从 "
|
||||
f"'{original_type}' 变为 '{new_type}',应报告 TYPE_MISMATCH"
|
||||
)
|
||||
assert type_mismatches[0].ddl_value == original_type
|
||||
assert type_mismatches[0].db_value == new_type
|
||||
|
||||
@given(
|
||||
ddl_tables=st_table_dict(min_tables=0, max_tables=4),
|
||||
db_tables=st_table_dict(min_tables=0, max_tables=4),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_all_diff_kinds_are_accounted_for(
|
||||
self,
|
||||
ddl_tables: dict[str, TableDef],
|
||||
db_tables: dict[str, TableDef],
|
||||
):
|
||||
"""对任意输入,所有差异都应属于已知的 DiffKind 枚举值。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
|
||||
|
||||
这是一个"元属性":验证对比函数不会产生未定义的差异类型,
|
||||
且每条差异都有合理的 table/column 引用。
|
||||
"""
|
||||
diffs = compare_tables(ddl_tables, db_tables)
|
||||
|
||||
valid_kinds = set(DiffKind)
|
||||
all_table_names = set(ddl_tables.keys()) | set(db_tables.keys())
|
||||
|
||||
for d in diffs:
|
||||
# 差异类型必须是已知枚举值
|
||||
assert d.kind in valid_kinds, f"未知差异类型: {d.kind}"
|
||||
# 差异引用的表必须存在于某一侧
|
||||
assert d.table in all_table_names, (
|
||||
f"差异引用了不存在的表: {d.table}"
|
||||
)
|
||||
# 表级差异不应有 column
|
||||
if d.kind in (DiffKind.MISSING_TABLE, DiffKind.EXTRA_TABLE):
|
||||
assert d.column is None, (
|
||||
f"表级差异 {d.kind} 不应包含 column 字段"
|
||||
)
|
||||
# 字段级差异必须有 column
|
||||
if d.kind in (
|
||||
DiffKind.MISSING_COLUMN, DiffKind.EXTRA_COLUMN,
|
||||
DiffKind.TYPE_MISMATCH, DiffKind.NULLABLE_MISMATCH,
|
||||
):
|
||||
assert d.column is not None, (
|
||||
f"字段级差异 {d.kind} 必须包含 column 字段"
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Property 3: DDL 修正后零差异(不动点)
|
||||
# **Validates: Requirements 2.5**
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestProperty3FixpointZeroDiff:
|
||||
"""Property 3: DDL 修正后零差异(不动点)。
|
||||
|
||||
核心思想:以数据库实际状态修正 DDL 文件后,再次运行对比脚本,
|
||||
差异列表应为空。即 compare_tables(db_state, db_state) == []。
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
|
||||
@given(db_state=st_table_dict(min_tables=0, max_tables=5))
|
||||
@settings(max_examples=30)
|
||||
def test_identical_state_yields_no_diff(self, db_state: dict[str, TableDef]):
|
||||
"""将同一份定义同时作为 DDL 和 DB 输入,差异应为零。
|
||||
|
||||
模拟场景:修正完成后 DDL 与数据库完全一致。
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
diffs = compare_tables(db_state, db_state)
|
||||
assert diffs == [], (
|
||||
f"不动点违反:相同定义应产生零差异,但得到 {len(diffs)} 项: "
|
||||
f"{[str(d) for d in diffs]}"
|
||||
)
|
||||
|
||||
@given(db_state=st_table_dict(min_tables=1, max_tables=5))
|
||||
@settings(max_examples=30)
|
||||
def test_copy_db_as_ddl_yields_no_diff(self, db_state: dict[str, TableDef]):
|
||||
"""模拟修正过程:从 DB 定义深拷贝构造 DDL 定义,对比应为零差异。
|
||||
|
||||
这验证了"以数据库为准生成 DDL"后的不动点性质。
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
# 模拟修正:逐表逐字段从 DB 定义复制出一份新的 DDL 定义
|
||||
ddl_fixed: dict[str, TableDef] = {}
|
||||
for tbl_name, tbl in db_state.items():
|
||||
new_columns: dict[str, ColumnDef] = {}
|
||||
for col_name, col in tbl.columns.items():
|
||||
new_columns[col_name] = ColumnDef(
|
||||
name=col.name,
|
||||
data_type=col.data_type,
|
||||
nullable=col.nullable,
|
||||
is_pk=col.is_pk,
|
||||
default=col.default,
|
||||
)
|
||||
ddl_fixed[tbl_name] = TableDef(
|
||||
name=tbl.name,
|
||||
columns=new_columns,
|
||||
pk_columns=list(tbl.pk_columns),
|
||||
)
|
||||
|
||||
diffs = compare_tables(ddl_fixed, db_state)
|
||||
assert diffs == [], (
|
||||
f"不动点违反:从 DB 复制的 DDL 定义应与 DB 零差异,"
|
||||
f"但得到 {len(diffs)} 项: {[str(d) for d in diffs]}"
|
||||
)
|
||||
|
||||
@given(
|
||||
db_state=st_table_dict(min_tables=1, max_tables=4),
|
||||
extra_tables=st_table_dict(min_tables=1, max_tables=3),
|
||||
)
|
||||
@settings(max_examples=30)
|
||||
def test_fixpoint_after_adding_missing_tables(
|
||||
self,
|
||||
db_state: dict[str, TableDef],
|
||||
extra_tables: dict[str, TableDef],
|
||||
):
|
||||
"""模拟修正流程:DDL 缺少部分表 → 补齐后再对比应为零差异。
|
||||
|
||||
步骤:
|
||||
1. DB 有 common + extra 表,DDL 只有 common 表
|
||||
2. 发现 MISSING_TABLE 差异
|
||||
3. 将缺失的表补入 DDL(模拟修正)
|
||||
4. 再次对比,差异应为零
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
# 确保表名不重叠
|
||||
overlap = set(db_state.keys()) & set(extra_tables.keys())
|
||||
assume(not overlap)
|
||||
|
||||
full_db = {**db_state, **extra_tables}
|
||||
ddl_before_fix = dict(db_state) # 缺少 extra_tables
|
||||
|
||||
# 第一次对比:应有差异
|
||||
diffs_before = compare_tables(ddl_before_fix, full_db)
|
||||
missing = {d.table for d in diffs_before if d.kind == DiffKind.MISSING_TABLE}
|
||||
assert missing == set(extra_tables.keys()), (
|
||||
f"修正前应检测到缺失表 {set(extra_tables.keys())},实际 {missing}"
|
||||
)
|
||||
|
||||
# 模拟修正:将缺失的表补入 DDL
|
||||
ddl_after_fix = {**ddl_before_fix, **extra_tables}
|
||||
|
||||
# 第二次对比:不动点,差异应为零
|
||||
diffs_after = compare_tables(ddl_after_fix, full_db)
|
||||
assert diffs_after == [], (
|
||||
f"不动点违反:修正后应零差异,但得到 {len(diffs_after)} 项: "
|
||||
f"{[str(d) for d in diffs_after]}"
|
||||
)
|
||||
|
||||
@given(table=st_table_def())
|
||||
@settings(max_examples=30)
|
||||
def test_fixpoint_after_correcting_type_mismatch(self, table: TableDef):
|
||||
"""模拟修正流程:字段类型不一致 → 以 DB 为准修正后零差异。
|
||||
|
||||
步骤:
|
||||
1. DDL 中某字段类型与 DB 不同
|
||||
2. 发现 TYPE_MISMATCH
|
||||
3. 将 DDL 字段类型改为 DB 的类型(模拟修正)
|
||||
4. 再次对比,差异应为零
|
||||
|
||||
**Validates: Requirements 2.5**
|
||||
"""
|
||||
assume(len(table.columns) >= 1)
|
||||
|
||||
target_col = next(iter(table.columns))
|
||||
original_type = table.columns[target_col].data_type
|
||||
alt_types = [t for t in _NORMALIZED_TYPES if t != original_type]
|
||||
assume(len(alt_types) > 0)
|
||||
wrong_type = alt_types[0]
|
||||
|
||||
# DDL 侧:目标字段使用错误类型
|
||||
ddl_table = TableDef(
|
||||
name=table.name,
|
||||
columns={
|
||||
cn: ColumnDef(
|
||||
name=c.name,
|
||||
data_type=wrong_type if cn == target_col else c.data_type,
|
||||
nullable=c.nullable,
|
||||
is_pk=c.is_pk,
|
||||
default=c.default,
|
||||
)
|
||||
for cn, c in table.columns.items()
|
||||
},
|
||||
pk_columns=list(table.pk_columns),
|
||||
)
|
||||
ddl_before = {table.name: ddl_table}
|
||||
db_tables = {table.name: table}
|
||||
|
||||
# 第一次对比:应有 TYPE_MISMATCH
|
||||
diffs_before = compare_tables(ddl_before, db_tables)
|
||||
type_diffs = [
|
||||
d for d in diffs_before
|
||||
if d.kind == DiffKind.TYPE_MISMATCH and d.column == target_col
|
||||
]
|
||||
assert len(type_diffs) >= 1, "修正前应检测到 TYPE_MISMATCH"
|
||||
|
||||
# 模拟修正:以 DB 为准,直接使用 DB 定义作为 DDL
|
||||
ddl_after = dict(db_tables)
|
||||
|
||||
# 第二次对比:不动点
|
||||
diffs_after = compare_tables(ddl_after, db_tables)
|
||||
assert diffs_after == [], (
|
||||
f"不动点违反:类型修正后应零差异,但得到 {len(diffs_after)} 项: "
|
||||
f"{[str(d) for d in diffs_after]}"
|
||||
)
|
||||
24
apps/etl/pipelines/feiqiu/tests/unit/test_config.py
Normal file
24
apps/etl/pipelines/feiqiu/tests/unit/test_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""配置管理测试"""
|
||||
import pytest
|
||||
from config.settings import AppConfig
|
||||
from config.defaults import DEFAULTS
|
||||
|
||||
def test_config_load():
|
||||
"""测试配置加载"""
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("app.timezone") == DEFAULTS["app"]["timezone"]
|
||||
|
||||
def test_config_override():
|
||||
"""测试配置覆盖"""
|
||||
overrides = {
|
||||
"app": {"store_id": 12345}
|
||||
}
|
||||
config = AppConfig.load(overrides)
|
||||
assert config.get("app.store_id") == 12345
|
||||
|
||||
def test_config_get_nested():
|
||||
"""测试嵌套配置获取"""
|
||||
config = AppConfig.load({"app": {"store_id": 1}})
|
||||
assert config.get("db.batch_size") == 1000
|
||||
assert config.get("nonexistent.key", "default") == "default"
|
||||
@@ -0,0 +1,55 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""配置映射属性测试 — 使用 hypothesis 验证配置键兼容映射的通用正确性属性。"""
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from config.settings import AppConfig, _FLOW_TO_DATA_SOURCE
|
||||
|
||||
|
||||
# ── 确保测试不读取 .env 文件 ──────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_dotenv(monkeypatch):
|
||||
monkeypatch.setenv("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
flow_st = st.sampled_from(["FULL", "FETCH_ONLY", "INGEST_ONLY"])
|
||||
|
||||
|
||||
# ── Property 11: pipeline_flow → data_source 映射一致性 ──────
|
||||
# Feature: scheduler-refactor, Property 11: pipeline_flow → data_source 映射一致性
|
||||
# **Validates: Requirements 8.1, 8.2, 8.3, 5.2, 8.4**
|
||||
#
|
||||
# 对于任意旧 pipeline_flow 值(FULL/FETCH_ONLY/INGEST_ONLY),
|
||||
# 映射到 data_source 的结果应与预定义映射表一致:
|
||||
# FULL→hybrid、FETCH_ONLY→online、INGEST_ONLY→offline。
|
||||
# 同样,配置键 pipeline.flow 应自动映射到 run.data_source。
|
||||
|
||||
|
||||
class TestProperty11FlowToDataSourceMapping:
|
||||
"""Property 11: pipeline_flow → data_source 映射一致性。"""
|
||||
|
||||
@given(flow=flow_st)
|
||||
@settings(max_examples=100)
|
||||
def test_pipeline_flow_maps_to_data_source(self, flow):
|
||||
"""通过 pipeline.flow 设置旧值后,run.data_source 应与映射表一致。"""
|
||||
expected = _FLOW_TO_DATA_SOURCE[flow]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
config = AppConfig.load({
|
||||
"app": {"store_id": 1},
|
||||
"pipeline": {"flow": flow},
|
||||
})
|
||||
|
||||
actual = config.get("run.data_source")
|
||||
assert actual == expected, (
|
||||
f"pipeline.flow={flow!r} 应映射为 run.data_source={expected!r},"
|
||||
f"实际为 {actual!r}"
|
||||
)
|
||||
@@ -0,0 +1,152 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""CLI 参数和管道类型文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 7.1, 7.2**
|
||||
|
||||
Property 6: 对于所有在 cli/main.py 的 parse_args() 中定义的 CLI 参数,
|
||||
README.md 或 base_task_mechanism.md 中应包含该参数的说明。
|
||||
|
||||
Property 7: 对于所有在 PipelineRunner.PIPELINE_LAYERS 中定义的管道类型,
|
||||
README.md 中应包含该管道类型的层组合说明。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 6 & 7
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ── 常量 ──────────────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_CLI_MAIN_PATH = _PROJECT_ROOT / "cli" / "main.py"
|
||||
_README_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "README.md"
|
||||
_BASE_MECHANISM_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "base_task_mechanism.md"
|
||||
|
||||
|
||||
# ── 辅助函数:通过 AST 解析 parse_args() 中的 CLI 参数名 ─────
|
||||
|
||||
def _extract_cli_params_via_ast() -> list[str]:
|
||||
"""从 cli/main.py 的 parse_args() 函数中,通过 AST 提取所有 add_argument 的参数名。
|
||||
|
||||
只提取以 '--' 开头的长参数名(如 '--store-id'),忽略位置参数。
|
||||
当 add_argument 有多个名称时(如 '--api-token', '--token'),取第一个 '--' 开头的名称。
|
||||
"""
|
||||
source = _CLI_MAIN_PATH.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(_CLI_MAIN_PATH))
|
||||
|
||||
params: list[str] = []
|
||||
|
||||
# 找到 parse_args 函数定义
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef) or node.name != "parse_args":
|
||||
continue
|
||||
|
||||
# 遍历函数体中的所有 add_argument 调用
|
||||
for child in ast.walk(node):
|
||||
if not isinstance(child, ast.Call):
|
||||
continue
|
||||
# 匹配 parser.add_argument(...) 或 xxx.add_argument(...)
|
||||
func = child.func
|
||||
if not (isinstance(func, ast.Attribute) and func.attr == "add_argument"):
|
||||
continue
|
||||
|
||||
# 从位置参数中提取 '--xxx' 形式的参数名
|
||||
for arg in child.args:
|
||||
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
||||
val = arg.value
|
||||
if val.startswith("--"):
|
||||
params.append(val)
|
||||
break # 取第一个 '--' 开头的名称即可
|
||||
|
||||
return sorted(set(params))
|
||||
|
||||
|
||||
# ── 辅助函数:提取 PIPELINE_LAYERS 的键 ──────────────────────
|
||||
|
||||
def _extract_pipeline_types() -> list[str]:
|
||||
"""从 PipelineRunner.PIPELINE_LAYERS 获取所有管道类型名称。
|
||||
|
||||
直接导入 PIPELINE_LAYERS 字典,避免实例化 PipelineRunner。
|
||||
"""
|
||||
# 通过 AST 解析 pipeline_runner.py 提取 PIPELINE_LAYERS 的键
|
||||
pr_path = _PROJECT_ROOT / "orchestration" / "pipeline_runner.py"
|
||||
source = pr_path.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(pr_path))
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.ClassDef) or node.name != "PipelineRunner":
|
||||
continue
|
||||
for item in node.body:
|
||||
if not isinstance(item, (ast.Assign, ast.AnnAssign)):
|
||||
continue
|
||||
# 匹配 PIPELINE_LAYERS = {...} 或 PIPELINE_LAYERS: ... = {...}
|
||||
targets = (
|
||||
[item.target] if isinstance(item, ast.AnnAssign) else item.targets
|
||||
)
|
||||
for target in targets:
|
||||
if isinstance(target, ast.Name) and target.id == "PIPELINE_LAYERS":
|
||||
value = item.value
|
||||
if isinstance(value, ast.Dict):
|
||||
keys: list[str] = []
|
||||
for k in value.keys:
|
||||
if isinstance(k, ast.Constant) and isinstance(k.value, str):
|
||||
keys.append(k.value)
|
||||
return sorted(keys)
|
||||
|
||||
raise RuntimeError("未能从 pipeline_runner.py 中解析出 PIPELINE_LAYERS")
|
||||
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_CLI_PARAMS: list[str] = _extract_cli_params_via_ast()
|
||||
_PIPELINE_TYPES: list[str] = _extract_pipeline_types()
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def readme_content() -> str:
|
||||
"""读取 README.md 全文。"""
|
||||
assert _README_PATH.exists(), f"文档文件不存在: {_README_PATH}"
|
||||
return _README_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def base_mechanism_content() -> str:
|
||||
"""读取 base_task_mechanism.md 全文。"""
|
||||
assert _BASE_MECHANISM_PATH.exists(), f"文档文件不存在: {_BASE_MECHANISM_PATH}"
|
||||
return _BASE_MECHANISM_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── Property 6: CLI 参数文档覆盖完整性 ────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("param", _CLI_PARAMS, ids=_CLI_PARAMS)
|
||||
def test_cli_param_in_docs(param: str, readme_content: str, base_mechanism_content: str):
|
||||
"""Property 6: 每个 CLI 参数在 README.md 或 base_task_mechanism.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 7.1**
|
||||
"""
|
||||
# 参数名以反引号包裹或直接出现均可
|
||||
combined = readme_content + "\n" + base_mechanism_content
|
||||
assert param in combined, (
|
||||
f"CLI 参数 '{param}' 在 parse_args() 中定义,"
|
||||
f"但未在 README.md 或 base_task_mechanism.md 中找到对应说明"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 7: 管道类型文档覆盖完整性 ────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("pipeline_type", _PIPELINE_TYPES, ids=_PIPELINE_TYPES)
|
||||
def test_pipeline_type_in_readme(pipeline_type: str, readme_content: str):
|
||||
"""Property 7: 每个管道类型在 README.md 中有对应的层组合说明。
|
||||
|
||||
**Validates: Requirements 7.2**
|
||||
"""
|
||||
assert pipeline_type in readme_content, (
|
||||
f"管道类型 '{pipeline_type}' 在 PIPELINE_LAYERS 中定义,"
|
||||
f"但未在 README.md 中找到对应的层组合说明"
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWD 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 3.1**
|
||||
|
||||
从 task_registry.py 中提取所有 layer="DWD" 的任务代码,
|
||||
验证 docs/etl_tasks/dwd_tasks.md 中包含每个任务代码的说明章节。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 2: DWD 任务文档覆盖完整性
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_DWD_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "dwd_tasks.md"
|
||||
|
||||
# 从注册表动态获取所有 DWD 层任务代码
|
||||
_DWD_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("DWD")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dwd_doc_content() -> str:
|
||||
"""读取 dwd_tasks.md 全文,供所有测试用例共享。"""
|
||||
assert _DWD_DOC_PATH.exists(), f"文档文件不存在: {_DWD_DOC_PATH}"
|
||||
return _DWD_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── 参数化验证:每个 DWD 任务代码必须出现在文档中 ─────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _DWD_TASK_CODES, ids=_DWD_TASK_CODES)
|
||||
def test_dwd_task_code_in_doc(task_code: str, dwd_doc_content: str):
|
||||
"""Property 2: 每个注册的 DWD 任务代码在 dwd_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 3.1**
|
||||
"""
|
||||
assert task_code in dwd_doc_content, (
|
||||
f"DWD 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 dwd_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWS 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 4.1, 4.4**
|
||||
|
||||
从 task_registry.py 中提取所有 layer="DWS" 的任务代码,
|
||||
验证 docs/etl_tasks/dws_tasks.md 中包含每个任务代码的说明章节,
|
||||
并标注其更新策略。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 3: DWS 任务文档覆盖完整性
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_DWS_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "dws_tasks.md"
|
||||
|
||||
# 从注册表动态获取所有 DWS 层任务代码
|
||||
_DWS_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("DWS")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dws_doc_content() -> str:
|
||||
"""读取 dws_tasks.md 全文,供所有测试用例共享。"""
|
||||
assert _DWS_DOC_PATH.exists(), f"文档文件不存在: {_DWS_DOC_PATH}"
|
||||
return _DWS_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── 参数化验证:每个 DWS 任务代码必须出现在文档中 ─────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _DWS_TASK_CODES, ids=_DWS_TASK_CODES)
|
||||
def test_dws_task_code_in_doc(task_code: str, dws_doc_content: str):
|
||||
"""Property 3: 每个注册的 DWS 任务代码在 dws_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 4.1, 4.4**
|
||||
"""
|
||||
assert task_code in dws_doc_content, (
|
||||
f"DWS 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 dws_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""INDEX 和 Utility 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 5.1, 6.1**
|
||||
|
||||
Property 4: 对于所有在 task_registry.py 中注册且 layer="INDEX" 的任务代码,
|
||||
index_tasks.md 中应包含该任务代码的说明章节。
|
||||
|
||||
Property 5: 对于所有在 task_registry.py 中注册且 task_type="utility" 的任务代码,
|
||||
utility_tasks.md 中应包含该任务代码的说明章节。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 4 & 5
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_INDEX_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "index_tasks.md"
|
||||
_UTILITY_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "utility_tasks.md"
|
||||
|
||||
# INDEX 层任务:通过 get_tasks_by_layer 获取
|
||||
_INDEX_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("INDEX")
|
||||
|
||||
# Utility 任务:注册表无 get_tasks_by_type,直接遍历内部字典筛选
|
||||
_UTILITY_TASK_CODES: list[str] = [
|
||||
code
|
||||
for code, meta in default_registry._tasks.items()
|
||||
if meta.task_type == "utility"
|
||||
]
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def index_doc_content() -> str:
|
||||
"""读取 index_tasks.md 全文。"""
|
||||
assert _INDEX_DOC_PATH.exists(), f"文档文件不存在: {_INDEX_DOC_PATH}"
|
||||
return _INDEX_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def utility_doc_content() -> str:
|
||||
"""读取 utility_tasks.md 全文。"""
|
||||
assert _UTILITY_DOC_PATH.exists(), f"文档文件不存在: {_UTILITY_DOC_PATH}"
|
||||
return _UTILITY_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── Property 4: INDEX 任务文档覆盖完整性 ──────────────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _INDEX_TASK_CODES, ids=_INDEX_TASK_CODES)
|
||||
def test_index_task_code_in_doc(task_code: str, index_doc_content: str):
|
||||
"""Property 4: 每个注册的 INDEX 任务代码在 index_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 5.1**
|
||||
"""
|
||||
assert task_code in index_doc_content, (
|
||||
f"INDEX 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 index_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 5: Utility 任务文档覆盖完整性 ────────────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _UTILITY_TASK_CODES, ids=_UTILITY_TASK_CODES)
|
||||
def test_utility_task_code_in_doc(task_code: str, utility_doc_content: str):
|
||||
"""Property 5: 每个注册的 task_type="utility" 任务代码在 utility_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 6.1**
|
||||
"""
|
||||
assert task_code in utility_doc_content, (
|
||||
f"Utility 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 utility_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ODS 任务文档覆盖完整性验证。
|
||||
|
||||
**Validates: Requirements 2.1, 2.4**
|
||||
|
||||
从 task_registry.py 中提取所有 layer="ODS" 的任务代码,
|
||||
验证 docs/etl_tasks/ods_tasks.md 中包含每个任务代码的说明章节。
|
||||
"""
|
||||
# Feature: etl-task-documentation, Property 1: ODS 任务文档覆盖完整性
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from orchestration.task_registry import default_registry
|
||||
|
||||
# ── 测试数据准备 ──────────────────────────────────────────────
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_ODS_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "ods_tasks.md"
|
||||
|
||||
# 从注册表动态获取所有 ODS 层任务代码
|
||||
_ODS_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("ODS")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ods_doc_content() -> str:
|
||||
"""读取 ods_tasks.md 全文,供所有测试用例共享。"""
|
||||
assert _ODS_DOC_PATH.exists(), f"文档文件不存在: {_ODS_DOC_PATH}"
|
||||
return _ODS_DOC_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ── 参数化验证:每个 ODS 任务代码必须出现在文档中 ─────────────
|
||||
|
||||
@pytest.mark.parametrize("task_code", _ODS_TASK_CODES, ids=_ODS_TASK_CODES)
|
||||
def test_ods_task_code_in_doc(task_code: str, ods_doc_content: str):
|
||||
"""Property 1: 每个注册的 ODS 任务代码在 ods_tasks.md 中有对应说明。
|
||||
|
||||
**Validates: Requirements 2.1, 2.4**
|
||||
"""
|
||||
assert task_code in ods_doc_content, (
|
||||
f"ODS 任务 '{task_code}' 已在 task_registry 中注册,"
|
||||
f"但未在 ods_tasks.md 中找到对应说明章节"
|
||||
)
|
||||
479
apps/etl/pipelines/feiqiu/tests/unit/test_dws_tasks.py
Normal file
479
apps/etl/pipelines/feiqiu/tests/unit/test_dws_tasks.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# AI_CHANGELOG
|
||||
# - 2026-02-14 | bugfix: 修复 3 个测试 bug
|
||||
# prompt: "继续。完成后检查所有任务是否全面"
|
||||
# 直接原因: (1) mock_config.get 返回 None 导致 timezone 异常;(2) _build_daily_record 缺少 gift_card 参数;(3) loaded_at naive/aware 不匹配
|
||||
# 变更: mock_config.get 改用 side_effect 返回 default;补充 gift_card 参数;loaded_at 改用 aware datetime
|
||||
# 验证: pytest tests/unit -x(449 passed)
|
||||
"""
|
||||
DWS任务单元测试
|
||||
|
||||
测试内容:
|
||||
- BaseDwsTask基类方法
|
||||
- 时间计算方法
|
||||
- 配置应用方法
|
||||
- 排名计算方法
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from tasks.dws.base_dws_task import (
|
||||
BaseDwsTask,
|
||||
TimeLayer,
|
||||
TimeWindow,
|
||||
CourseType,
|
||||
TimeRange,
|
||||
ConfigCache
|
||||
)
|
||||
from tasks.dws.finance_daily_task import FinanceDailyTask
|
||||
from tasks.dws.assistant_monthly_task import AssistantMonthlyTask
|
||||
|
||||
|
||||
class TestTimeLayerRange:
|
||||
"""测试时间分层范围计算"""
|
||||
|
||||
def test_last_2_days(self):
|
||||
"""测试近2天"""
|
||||
base_date = date(2026, 2, 1)
|
||||
# 创建一个模拟的BaseDwsTask实例
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_layer_range(TimeLayer.LAST_2_DAYS, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 31)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
def test_last_1_month(self):
|
||||
"""测试近1月"""
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_layer_range(TimeLayer.LAST_1_MONTH, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 2)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
def test_last_3_months(self):
|
||||
"""测试近3月"""
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_layer_range(TimeLayer.LAST_3_MONTHS, base_date)
|
||||
|
||||
assert result.start == date(2025, 11, 3)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
|
||||
class TestTimeWindowRange:
|
||||
"""测试时间窗口范围计算"""
|
||||
|
||||
def test_this_week_monday_start(self):
|
||||
"""测试本周(周一起始)"""
|
||||
# 2026-02-01 是周日
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.THIS_WEEK, base_date)
|
||||
|
||||
# 本周一是 2026-01-26
|
||||
assert result.start == date(2026, 1, 26)
|
||||
assert result.end == date(2026, 2, 1)
|
||||
|
||||
def test_last_week(self):
|
||||
"""测试上周"""
|
||||
base_date = date(2026, 2, 1)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_WEEK, base_date)
|
||||
|
||||
# 上周一是 2026-01-19,上周日是 2026-01-25
|
||||
assert result.start == date(2026, 1, 19)
|
||||
assert result.end == date(2026, 1, 25)
|
||||
|
||||
def test_this_month(self):
|
||||
"""测试本月"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.THIS_MONTH, base_date)
|
||||
|
||||
assert result.start == date(2026, 2, 1)
|
||||
assert result.end == date(2026, 2, 15)
|
||||
|
||||
def test_last_month(self):
|
||||
"""测试上月"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_MONTH, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 1)
|
||||
assert result.end == date(2026, 1, 31)
|
||||
|
||||
def test_last_3_months_excl_current(self):
|
||||
"""测试前3个月(不含本月)"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_EXCL_CURRENT, base_date)
|
||||
|
||||
assert result.start == date(2025, 11, 1)
|
||||
assert result.end == date(2026, 1, 31)
|
||||
|
||||
def test_last_3_months_incl_current(self):
|
||||
"""测试前3个月(含本月)"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_INCL_CURRENT, base_date)
|
||||
|
||||
assert result.start == date(2025, 12, 1)
|
||||
assert result.end == date(2026, 2, 15)
|
||||
|
||||
def test_this_quarter(self):
|
||||
"""测试本季度"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.THIS_QUARTER, base_date)
|
||||
|
||||
assert result.start == date(2026, 1, 1)
|
||||
assert result.end == date(2026, 2, 15)
|
||||
|
||||
def test_last_6_months(self):
|
||||
"""测试最近半年(不含本月)"""
|
||||
base_date = date(2026, 2, 15)
|
||||
task = create_mock_task()
|
||||
|
||||
result = task.get_time_window_range(TimeWindow.LAST_6_MONTHS, base_date)
|
||||
|
||||
# 不含本月,从上月末往前6个月
|
||||
assert result.end == date(2026, 1, 31)
|
||||
assert result.start == date(2025, 8, 1)
|
||||
|
||||
|
||||
class TestComparisonRange:
|
||||
"""测试环比区间计算"""
|
||||
|
||||
def test_comparison_7_days(self):
|
||||
"""测试7天环比"""
|
||||
task = create_mock_task()
|
||||
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 2, 7))
|
||||
|
||||
result = task.get_comparison_range(current)
|
||||
|
||||
# 上一个7天:1月25日-1月31日
|
||||
assert result.start == date(2026, 1, 25)
|
||||
assert result.end == date(2026, 1, 31)
|
||||
|
||||
def test_comparison_30_days(self):
|
||||
"""测试30天环比"""
|
||||
task = create_mock_task()
|
||||
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 3, 2))
|
||||
|
||||
result = task.get_comparison_range(current)
|
||||
|
||||
# 上一个30天区间
|
||||
assert (result.end - result.start).days == (current.end - current.start).days
|
||||
|
||||
|
||||
class TestFinanceDailyRecord:
|
||||
"""测试财务日度记录计算"""
|
||||
|
||||
def test_groupbuy_and_cashflow(self):
|
||||
"""测试团购优惠与现金流口径"""
|
||||
task = create_finance_daily_task()
|
||||
stat_date = date(2026, 2, 1)
|
||||
|
||||
settle = {
|
||||
'gross_amount': Decimal('1000'),
|
||||
'table_fee_amount': Decimal('1000'),
|
||||
'goods_amount': Decimal('0'),
|
||||
'assistant_pd_amount': Decimal('0'),
|
||||
'assistant_cx_amount': Decimal('0'),
|
||||
'cash_pay_amount': Decimal('300'),
|
||||
'card_pay_amount': Decimal('0'),
|
||||
'balance_pay_amount': Decimal('0'),
|
||||
'gift_card_pay_amount': Decimal('0'),
|
||||
'coupon_amount': Decimal('200'),
|
||||
'pl_coupon_sale_amount': Decimal('0'),
|
||||
'adjust_amount': Decimal('50'),
|
||||
'member_discount_amount': Decimal('10'),
|
||||
'rounding_amount': Decimal('0'),
|
||||
'order_count': 1,
|
||||
'member_order_count': 1,
|
||||
'guest_order_count': 0,
|
||||
}
|
||||
groupbuy = {'groupbuy_pay_total': Decimal('80')}
|
||||
recharge = {'recharge_cash': Decimal('20')}
|
||||
expense = {'expense_amount': Decimal('40')}
|
||||
platform = {
|
||||
'settlement_amount': Decimal('60'),
|
||||
'commission_amount': Decimal('5'),
|
||||
'service_fee': Decimal('5'),
|
||||
}
|
||||
big_customer = {'big_customer_amount': Decimal('20')}
|
||||
|
||||
gift_card = {'gift_card_consume': Decimal('0')}
|
||||
record = task._build_daily_record(
|
||||
stat_date, settle, groupbuy, recharge, gift_card, expense, platform, big_customer, 1
|
||||
)
|
||||
|
||||
assert record['discount_groupbuy'] == Decimal('120')
|
||||
assert record['discount_other'] == Decimal('30')
|
||||
assert record['platform_settlement_amount'] == Decimal('60')
|
||||
assert record['platform_fee_amount'] == Decimal('10')
|
||||
assert record['cash_inflow_total'] == Decimal('380')
|
||||
assert record['cash_outflow_total'] == Decimal('50')
|
||||
assert record['cash_balance_change'] == Decimal('330')
|
||||
|
||||
|
||||
class TestNewHireTier:
|
||||
"""测试新入职定档规则"""
|
||||
|
||||
def test_new_hire_tier_hours(self):
|
||||
"""测试日均*30折算"""
|
||||
task = create_assistant_monthly_task()
|
||||
effective_hours = Decimal('15')
|
||||
work_days = 5
|
||||
result = task._calc_new_hire_tier_hours(effective_hours, work_days)
|
||||
assert result == Decimal('90')
|
||||
|
||||
def test_max_tier_level_cap(self):
|
||||
"""测试新入职定档上限"""
|
||||
task = create_mock_task()
|
||||
now = datetime.now(tz=task.tz)
|
||||
task._config_cache = ConfigCache(
|
||||
performance_tiers=[
|
||||
{'tier_id': 1, 'tier_level': 1, 'min_hours': 0, 'max_hours': 100, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
{'tier_id': 2, 'tier_level': 2, 'min_hours': 100, 'max_hours': 200, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
{'tier_id': 3, 'tier_level': 3, 'min_hours': 200, 'max_hours': 300, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
{'tier_id': 4, 'tier_level': 4, 'min_hours': 300, 'max_hours': None, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
|
||||
],
|
||||
level_prices=[],
|
||||
bonus_rules=[],
|
||||
area_categories={},
|
||||
skill_types={},
|
||||
loaded_at=now
|
||||
)
|
||||
|
||||
tier = task.get_performance_tier(
|
||||
Decimal('350'),
|
||||
is_new_hire=True,
|
||||
effective_date=date(2026, 2, 1),
|
||||
max_tier_level=3
|
||||
)
|
||||
assert tier['tier_level'] == 3
|
||||
|
||||
|
||||
class TestNewHireCheck:
|
||||
"""测试新入职判断"""
|
||||
|
||||
def test_new_hire_in_month(self):
|
||||
"""测试月内入职为新入职"""
|
||||
task = create_mock_task()
|
||||
hire_date = date(2026, 2, 5)
|
||||
stat_month = date(2026, 2, 1)
|
||||
|
||||
assert task.is_new_hire_in_month(hire_date, stat_month) == True
|
||||
|
||||
def test_not_new_hire(self):
|
||||
"""测试月前入职不是新入职"""
|
||||
task = create_mock_task()
|
||||
hire_date = date(2026, 1, 15)
|
||||
stat_month = date(2026, 2, 1)
|
||||
|
||||
assert task.is_new_hire_in_month(hire_date, stat_month) == False
|
||||
|
||||
def test_hire_on_first_day(self):
|
||||
"""测试月1日入职为新入职"""
|
||||
task = create_mock_task()
|
||||
hire_date = date(2026, 2, 1)
|
||||
stat_month = date(2026, 2, 1)
|
||||
|
||||
assert task.is_new_hire_in_month(hire_date, stat_month) == True
|
||||
|
||||
|
||||
class TestRankWithTies:
|
||||
"""测试考虑并列的排名计算"""
|
||||
|
||||
def test_no_ties(self):
|
||||
"""测试无并列情况"""
|
||||
task = create_mock_task()
|
||||
values = [
|
||||
(1, Decimal('100')),
|
||||
(2, Decimal('90')),
|
||||
(3, Decimal('80')),
|
||||
]
|
||||
|
||||
result = task.calculate_rank_with_ties(values)
|
||||
|
||||
assert result[0] == (1, 1, 1) # 第1名
|
||||
assert result[1] == (2, 2, 2) # 第2名
|
||||
assert result[2] == (3, 3, 3) # 第3名
|
||||
|
||||
def test_with_ties(self):
|
||||
"""测试有并列情况"""
|
||||
task = create_mock_task()
|
||||
values = [
|
||||
(1, Decimal('100')),
|
||||
(2, Decimal('100')), # 并列第1
|
||||
(3, Decimal('80')),
|
||||
]
|
||||
|
||||
result = task.calculate_rank_with_ties(values)
|
||||
|
||||
# 两个第1,下一个是第3
|
||||
assert result[0][1] == 1 # 第1名
|
||||
assert result[1][1] == 1 # 并列第1名
|
||||
assert result[2][1] == 3 # 第3名(跳过2)
|
||||
|
||||
def test_all_ties(self):
|
||||
"""测试全部并列"""
|
||||
task = create_mock_task()
|
||||
values = [
|
||||
(1, Decimal('100')),
|
||||
(2, Decimal('100')),
|
||||
(3, Decimal('100')),
|
||||
]
|
||||
|
||||
result = task.calculate_rank_with_ties(values)
|
||||
|
||||
# 全部第1
|
||||
assert all(r[1] == 1 for r in result)
|
||||
|
||||
|
||||
class TestGuestCheck:
|
||||
"""测试散客判断"""
|
||||
|
||||
def test_guest_zero(self):
|
||||
"""测试member_id=0为散客"""
|
||||
task = create_mock_task()
|
||||
assert task.is_guest(0) == True
|
||||
|
||||
def test_guest_none(self):
|
||||
"""测试member_id=None为散客"""
|
||||
task = create_mock_task()
|
||||
assert task.is_guest(None) == True
|
||||
|
||||
def test_not_guest(self):
|
||||
"""测试正常会员不是散客"""
|
||||
task = create_mock_task()
|
||||
assert task.is_guest(12345) == False
|
||||
|
||||
|
||||
class TestUtilityMethods:
|
||||
"""测试工具方法"""
|
||||
|
||||
def test_safe_decimal(self):
|
||||
"""测试安全Decimal转换"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.safe_decimal(100) == Decimal('100')
|
||||
assert task.safe_decimal('123.45') == Decimal('123.45')
|
||||
assert task.safe_decimal(None) == Decimal('0')
|
||||
assert task.safe_decimal('invalid') == Decimal('0')
|
||||
|
||||
def test_safe_int(self):
|
||||
"""测试安全int转换"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.safe_int(100) == 100
|
||||
assert task.safe_int('123') == 123
|
||||
assert task.safe_int(None) == 0
|
||||
assert task.safe_int('invalid') == 0
|
||||
|
||||
def test_seconds_to_hours(self):
|
||||
"""测试秒转小时"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.seconds_to_hours(3600) == Decimal('1')
|
||||
assert task.seconds_to_hours(5400) == Decimal('1.5')
|
||||
assert task.seconds_to_hours(0) == Decimal('0')
|
||||
|
||||
def test_hours_to_seconds(self):
|
||||
"""测试小时转秒"""
|
||||
task = create_mock_task()
|
||||
|
||||
assert task.hours_to_seconds(Decimal('1')) == 3600
|
||||
assert task.hours_to_seconds(Decimal('1.5')) == 5400
|
||||
|
||||
|
||||
class TestCourseType:
|
||||
"""测试课程类型"""
|
||||
|
||||
def test_base_course(self):
|
||||
"""测试基础课"""
|
||||
assert CourseType.BASE.value == 'BASE'
|
||||
|
||||
def test_bonus_course(self):
|
||||
"""测试附加课"""
|
||||
assert CourseType.BONUS.value == 'BONUS'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
|
||||
def create_mock_task():
|
||||
"""
|
||||
创建一个模拟的BaseDwsTask实例用于测试
|
||||
"""
|
||||
# 创建一个具体的子类用于测试
|
||||
class TestDwsTask(BaseDwsTask):
|
||||
def get_task_code(self):
|
||||
return "TEST_DWS_TASK"
|
||||
|
||||
def get_target_table(self):
|
||||
return "test_table"
|
||||
|
||||
def get_primary_keys(self):
|
||||
return ["id"]
|
||||
|
||||
def extract(self, context):
|
||||
return {}
|
||||
|
||||
def load(self, transformed, context):
|
||||
return {}
|
||||
|
||||
# 创建模拟的依赖
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: default
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
|
||||
task = TestDwsTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
return task
|
||||
|
||||
|
||||
def create_finance_daily_task():
|
||||
"""创建 FinanceDailyTask 实例用于测试"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: 1 if key == "app.tenant_id" else default
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
return FinanceDailyTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
|
||||
|
||||
def create_assistant_monthly_task():
|
||||
"""创建 AssistantMonthlyTask 实例用于测试"""
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: default
|
||||
mock_db = MagicMock()
|
||||
mock_api = MagicMock()
|
||||
mock_logger = MagicMock()
|
||||
return AssistantMonthlyTask(mock_config, mock_db, mock_api, mock_logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
222
apps/etl/pipelines/feiqiu/tests/unit/test_e2e_flow.py
Normal file
222
apps/etl/pipelines/feiqiu/tests/unit/test_e2e_flow.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""端到端流程集成测试
|
||||
|
||||
验证 CLI → PipelineRunner → TaskExecutor 完整调用链。
|
||||
使用 mock 依赖,不需要真实数据库。
|
||||
|
||||
需求: 9.4
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
import pytest
|
||||
|
||||
from orchestration.task_executor import TaskExecutor, DataSource
|
||||
from orchestration.pipeline_runner import PipelineRunner
|
||||
from orchestration.task_registry import TaskRegistry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:构造最小可用的 mock config
|
||||
# ---------------------------------------------------------------------------
|
||||
def _make_config(**overrides):
|
||||
"""构造一个行为类似 AppConfig 的 MagicMock。"""
|
||||
store = {
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
"app.store_id": 1,
|
||||
"io.fetch_root": "/tmp/fetch",
|
||||
"io.ingest_source_dir": "",
|
||||
"io.write_pretty_json": False,
|
||||
"io.export_root": "/tmp/export",
|
||||
"io.log_root": "/tmp/logs",
|
||||
"pipeline.fetch_root": None,
|
||||
"pipeline.ingest_source_dir": None,
|
||||
"run.ods_tasks": [],
|
||||
"run.dws_tasks": [],
|
||||
"run.index_tasks": [],
|
||||
"run.data_source": "hybrid",
|
||||
"verification.ods_use_local_json": False,
|
||||
"verification.skip_ods_when_fetch_before_verify": True,
|
||||
}
|
||||
store.update(overrides)
|
||||
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda k, d=None: store.get(k, d))
|
||||
config.__getitem__ = MagicMock(side_effect=lambda k: {
|
||||
"io": {"export_root": "/tmp/export", "log_root": "/tmp/logs"},
|
||||
}[k])
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:构造一个可被 TaskRegistry 注册的假任务类
|
||||
# ---------------------------------------------------------------------------
|
||||
class _FakeTask:
|
||||
"""最小假任务,execute() 返回固定结果。"""
|
||||
def __init__(self, config, db_ops, api_client, logger):
|
||||
pass
|
||||
|
||||
def execute(self, cursor_data):
|
||||
return {"status": "SUCCESS", "counts": {"fetched": 5, "inserted": 3}}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 测试 1:传统模式 — TaskExecutor.run_tasks 端到端
|
||||
# ===========================================================================
|
||||
class TestTraditionalModeE2E:
|
||||
"""传统模式:TaskExecutor.run_tasks 端到端"""
|
||||
|
||||
def test_run_tasks_executes_utility_task_and_returns_results(self):
|
||||
"""工具类任务走 _run_utility_task 路径,跳过游标和运行记录。"""
|
||||
config = _make_config()
|
||||
registry = TaskRegistry()
|
||||
registry.register(
|
||||
"FAKE_UTIL", _FakeTask,
|
||||
requires_db_config=False, task_type="utility",
|
||||
)
|
||||
|
||||
cursor_mgr = MagicMock()
|
||||
run_tracker = MagicMock()
|
||||
|
||||
executor = TaskExecutor(
|
||||
config=config,
|
||||
db_ops=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
cursor_mgr=cursor_mgr,
|
||||
run_tracker=run_tracker,
|
||||
task_registry=registry,
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
results = executor.run_tasks(["FAKE_UTIL"], data_source="hybrid")
|
||||
|
||||
assert len(results) == 1
|
||||
# 工具类任务成功时 run_tasks 包装为 "成功"
|
||||
assert results[0]["status"] in ("成功", "完成", "SUCCESS")
|
||||
# 工具类任务不应触发游标或运行记录
|
||||
cursor_mgr.get_or_create.assert_not_called()
|
||||
run_tracker.create_run.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 测试 2:管道模式 — PipelineRunner → TaskExecutor 端到端
|
||||
# ===========================================================================
|
||||
class TestPipelineModeE2E:
|
||||
"""管道模式:PipelineRunner.run → TaskExecutor.run_tasks 端到端"""
|
||||
|
||||
def test_pipeline_delegates_to_executor_and_returns_structure(self):
|
||||
"""PipelineRunner 解析层→任务后委托 TaskExecutor 执行。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = [
|
||||
{"task_code": "FAKE_ODS", "status": "成功", "counts": {"fetched": 10, "inserted": 8}},
|
||||
]
|
||||
|
||||
registry = TaskRegistry()
|
||||
registry.register("FAKE_ODS", _FakeTask, layer="ODS")
|
||||
|
||||
config = _make_config()
|
||||
|
||||
runner = PipelineRunner(
|
||||
config=config,
|
||||
task_executor=executor,
|
||||
task_registry=registry,
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
result = runner.run(
|
||||
pipeline="api_ods",
|
||||
processing_mode="increment_only",
|
||||
data_source="hybrid",
|
||||
)
|
||||
|
||||
# 结构验证
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["pipeline"] == "api_ods"
|
||||
assert result["layers"] == ["ODS"]
|
||||
assert isinstance(result["results"], list)
|
||||
# TaskExecutor 被调用
|
||||
executor.run_tasks.assert_called_once()
|
||||
call_args = executor.run_tasks.call_args
|
||||
assert call_args[1]["data_source"] == "hybrid"
|
||||
|
||||
def test_pipeline_verify_only_skips_increment(self):
|
||||
"""verify_only 模式跳过增量 ETL,仅执行校验。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
|
||||
registry = TaskRegistry()
|
||||
config = _make_config()
|
||||
|
||||
runner = PipelineRunner(
|
||||
config=config,
|
||||
task_executor=executor,
|
||||
task_registry=registry,
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
# 校验框架可能未安装,mock 掉 _run_verification
|
||||
with patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}):
|
||||
result = runner.run(
|
||||
pipeline="api_ods",
|
||||
processing_mode="verify_only",
|
||||
data_source="hybrid",
|
||||
)
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
# verify_only 且 fetch_before_verify=False 时不调用 run_tasks
|
||||
executor.run_tasks.assert_not_called()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 测试 3:ETLScheduler 薄包装层委托验证
|
||||
# ===========================================================================
|
||||
class TestSchedulerThinWrapper:
|
||||
"""ETLScheduler 薄包装层正确委托 TaskExecutor / PipelineRunner。"""
|
||||
|
||||
def test_scheduler_delegates_run_tasks(self):
|
||||
"""run_tasks() 委托给内部 task_executor。"""
|
||||
from orchestration.scheduler import ETLScheduler
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.__getitem__ = MagicMock(side_effect=lambda k: {
|
||||
"db": {
|
||||
"dsn": "postgresql://fake:5432/test",
|
||||
"session": {"timezone": "Asia/Shanghai"},
|
||||
"connect_timeout_sec": 5,
|
||||
},
|
||||
"api": {
|
||||
"base_url": "https://fake.api",
|
||||
"token": "fake-token",
|
||||
"timeout_sec": 30,
|
||||
"retries": {"max_attempts": 3},
|
||||
},
|
||||
}[k])
|
||||
mock_config.get = MagicMock(side_effect=lambda k, d=None: {
|
||||
"run.data_source": "hybrid",
|
||||
"run.tasks": ["FAKE"],
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
}.get(k, d))
|
||||
|
||||
# mock 掉资源创建,避免真实连接
|
||||
with patch("orchestration.scheduler.DatabaseConnection"), \
|
||||
patch("orchestration.scheduler.DatabaseOperations"), \
|
||||
patch("orchestration.scheduler.APIClient"), \
|
||||
patch("orchestration.scheduler.CursorManager"), \
|
||||
patch("orchestration.scheduler.RunTracker"), \
|
||||
patch("orchestration.scheduler.TaskExecutor") as MockTE, \
|
||||
patch("orchestration.scheduler.PipelineRunner") as MockPR:
|
||||
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
scheduler = ETLScheduler(mock_config, MagicMock())
|
||||
|
||||
# run_tasks 委托
|
||||
scheduler.run_tasks(["TEST_TASK"])
|
||||
scheduler.task_executor.run_tasks.assert_called_once()
|
||||
|
||||
# run_pipeline_with_verification 委托
|
||||
scheduler.run_pipeline_with_verification(pipeline="api_ods")
|
||||
scheduler.pipeline_runner.run.assert_called_once()
|
||||
@@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for recent/former endpoint routing."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from api.endpoint_routing import plan_calls, recent_boundary
|
||||
|
||||
|
||||
TZ = ZoneInfo("Asia/Shanghai")
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime(2025, 12, 18, 10, 0, 0, tzinfo=TZ)
|
||||
|
||||
|
||||
def test_recent_boundary_month_start():
|
||||
b = recent_boundary(_now())
|
||||
assert b.isoformat() == "2025-09-01T00:00:00+08:00"
|
||||
|
||||
|
||||
def test_paylog_routes_to_former_when_old_window():
|
||||
params = {"siteId": 1, "StartPayTime": "2025-08-01 00:00:00", "EndPayTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/PayLog/GetPayLogListPage", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/PayLog/GetFormerPayLogListPage"]
|
||||
|
||||
|
||||
def test_coupon_usage_stays_same_path_even_when_old():
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/Promotion/GetOfflineCouponConsumePageList", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/Promotion/GetOfflineCouponConsumePageList"]
|
||||
|
||||
|
||||
def test_goods_outbound_routes_to_queryformer_when_old():
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls("/GoodsStockManage/QueryGoodsOutboundReceipt", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/GoodsStockManage/QueryFormerGoodsOutboundReceipt"]
|
||||
|
||||
|
||||
def test_settlement_records_split_when_crossing_boundary():
|
||||
params = {"siteId": 1, "rangeStartTime": "2025-08-15 00:00:00", "rangeEndTime": "2025-09-10 00:00:00"}
|
||||
calls = plan_calls("/Site/GetAllOrderSettleList", params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == ["/Site/GetFormerOrderSettleList", "/Site/GetAllOrderSettleList"]
|
||||
assert calls[0].params["rangeEndTime"] == "2025-09-01 00:00:00"
|
||||
assert calls[1].params["rangeStartTime"] == "2025-09-01 00:00:00"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/PayLog/GetFormerPayLogListPage",
|
||||
"/Site/GetFormerOrderSettleList",
|
||||
"/GoodsStockManage/QueryFormerGoodsOutboundReceipt",
|
||||
],
|
||||
)
|
||||
def test_explicit_former_endpoint_not_rerouted(endpoint):
|
||||
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
|
||||
calls = plan_calls(endpoint, params, now=_now(), tz=TZ)
|
||||
assert [c.endpoint for c in calls] == [endpoint]
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""filter_verify_tables 单元测试"""
|
||||
|
||||
import pytest
|
||||
from tasks.verification.models import filter_verify_tables
|
||||
|
||||
|
||||
class TestFilterVerifyTables:
|
||||
"""按层过滤校验表名"""
|
||||
|
||||
def test_none_input_returns_none(self):
|
||||
assert filter_verify_tables("DWD", None) is None
|
||||
|
||||
def test_empty_list_returns_none(self):
|
||||
assert filter_verify_tables("DWD", []) is None
|
||||
|
||||
def test_dwd_layer_filters_correctly(self):
|
||||
tables = ["dwd_order", "dim_member", "fact_payment", "ods_raw", "dws_daily"]
|
||||
result = filter_verify_tables("DWD", tables)
|
||||
assert result == ["dwd_order", "dim_member", "fact_payment"]
|
||||
|
||||
def test_dws_layer_filters_correctly(self):
|
||||
tables = ["dws_daily", "dwd_order", "dws_summary"]
|
||||
result = filter_verify_tables("DWS", tables)
|
||||
assert result == ["dws_daily", "dws_summary"]
|
||||
|
||||
def test_index_layer_filters_correctly(self):
|
||||
tables = ["v_score", "wbi_index", "dws_daily", "v_rank"]
|
||||
result = filter_verify_tables("INDEX", tables)
|
||||
assert result == ["v_score", "wbi_index", "v_rank"]
|
||||
|
||||
def test_ods_layer_filters_correctly(self):
|
||||
tables = ["ods_order", "dwd_order", "ods_member"]
|
||||
result = filter_verify_tables("ODS", tables)
|
||||
assert result == ["ods_order", "ods_member"]
|
||||
|
||||
def test_unknown_layer_returns_normalized(self):
|
||||
tables = [" SomeTable ", "Another"]
|
||||
result = filter_verify_tables("UNKNOWN", tables)
|
||||
assert result == ["sometable", "another"]
|
||||
|
||||
def test_layer_case_insensitive(self):
|
||||
tables = ["dwd_order", "ods_raw"]
|
||||
assert filter_verify_tables("dwd", tables) == ["dwd_order"]
|
||||
assert filter_verify_tables("Dwd", tables) == ["dwd_order"]
|
||||
|
||||
def test_whitespace_and_empty_entries_stripped(self):
|
||||
tables = [" dwd_order ", "", " ", None, "dim_member"]
|
||||
result = filter_verify_tables("DWD", tables)
|
||||
assert result == ["dwd_order", "dim_member"]
|
||||
903
apps/etl/pipelines/feiqiu/tests/unit/test_gen_audit_dashboard.py
Normal file
903
apps/etl/pipelines/feiqiu/tests/unit/test_gen_audit_dashboard.py
Normal file
@@ -0,0 +1,903 @@
|
||||
"""审计一览表生成脚本 — 解析模块单元测试
|
||||
|
||||
覆盖:AuditEntry、parse_audit_file、classify_module、scan_audit_dir
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.gen_audit_dashboard import (
|
||||
AuditEntry,
|
||||
MODULE_MAP,
|
||||
VALID_MODULES,
|
||||
classify_module,
|
||||
parse_audit_file,
|
||||
scan_audit_dir,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify_module
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClassifyModule:
|
||||
"""classify_module 应将文件路径映射到正确的功能模块"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path, expected",
|
||||
[
|
||||
("api/recording_client.py", "API 层"),
|
||||
("tasks/ods/ods_task.py", "ODS 层"),
|
||||
("tasks/dwd/dwd_load_task.py", "DWD 层"),
|
||||
("tasks/dws/base_dws_task.py", "DWS 层"),
|
||||
("tasks/index/wbi.py", "指数算法"),
|
||||
("loaders/fact_loader.py", "数据装载"),
|
||||
("database/migrations/001.sql", "数据库"),
|
||||
("orchestration/task_registry.py", "调度"),
|
||||
("config/defaults.py", "配置"),
|
||||
("cli/main.py", "CLI"),
|
||||
("models/parser.py", "模型"),
|
||||
("scd/scd2.py", "SCD2"),
|
||||
("docs/README.md", "文档"),
|
||||
("scripts/gen_audit_dashboard.py", "脚本工具"),
|
||||
("tests/unit/test_foo.py", "测试"),
|
||||
("quality/checker.py", "质量校验"),
|
||||
("gui/main.py", "GUI"),
|
||||
("utils/logging_utils.py", "工具库"),
|
||||
],
|
||||
)
|
||||
def test_known_prefixes(self, path, expected):
|
||||
assert classify_module(path) == expected
|
||||
|
||||
def test_unknown_path_returns_other(self):
|
||||
assert classify_module("README.md") == "其他"
|
||||
assert classify_module(".kiro/steering/foo.md") == "其他"
|
||||
|
||||
def test_normalizes_backslash(self):
|
||||
"""Windows 反斜杠路径也能正确分类"""
|
||||
assert classify_module("tasks\\dws\\base.py") == "DWS 层"
|
||||
|
||||
def test_strips_leading_dot_slash(self):
|
||||
assert classify_module("./api/foo.py") == "API 层"
|
||||
|
||||
def test_result_always_in_valid_modules(self):
|
||||
"""任何输入的返回值都应在 VALID_MODULES 内"""
|
||||
for path in ["", "x", "api/", "unknown/deep/path.py"]:
|
||||
assert classify_module(path) in VALID_MODULES
|
||||
|
||||
def test_longest_prefix_wins(self):
|
||||
"""tasks/ods 应优先匹配 ODS 层,而非泛化的 tasks/ 前缀"""
|
||||
# MODULE_MAP 中没有 "tasks/" 泛前缀,但 tasks/ods 应匹配 ODS 层
|
||||
assert classify_module("tasks/ods/foo.py") == "ODS 层"
|
||||
assert classify_module("tasks/dwd/bar.py") == "DWD 层"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_audit_file — 使用临时文件
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 标准审计文件内容模板
|
||||
_STANDARD_AUDIT = textwrap.dedent("""\
|
||||
# 审计记录:测试变更
|
||||
|
||||
- 日期:2026-03-01(Asia/Shanghai)
|
||||
|
||||
## 直接原因
|
||||
测试用例
|
||||
|
||||
## 修改文件清单
|
||||
|
||||
| 文件 | 变更类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `api/client.py` | 修改 | 测试 |
|
||||
| `tasks/dws/foo.py` | 新增 | 测试 |
|
||||
|
||||
## 风险与回滚
|
||||
- 风险:极低。纯测试变更
|
||||
""")
|
||||
|
||||
|
||||
class TestParseAuditFile:
|
||||
|
||||
def test_standard_file(self, tmp_path):
|
||||
f = tmp_path / "2026-03-01__test-change.md"
|
||||
f.write_text(_STANDARD_AUDIT, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.date == "2026-03-01"
|
||||
assert entry.slug == "test-change"
|
||||
assert entry.title == "审计记录:测试变更"
|
||||
assert entry.filename == "2026-03-01__test-change.md"
|
||||
assert "api/client.py" in entry.changed_files
|
||||
assert "tasks/dws/foo.py" in entry.changed_files
|
||||
assert "API 层" in entry.modules
|
||||
assert "DWS 层" in entry.modules
|
||||
assert entry.risk_level == "极低"
|
||||
|
||||
def test_missing_title_uses_slug(self, tmp_path):
|
||||
"""缺少一级标题时用 slug 兜底"""
|
||||
content = "没有标题的文件\n\n一些内容\n"
|
||||
f = tmp_path / "2026-01-01__no-title.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.title == "no-title"
|
||||
|
||||
def test_missing_file_list_section(self, tmp_path):
|
||||
"""缺少文件清单章节 → 空列表,模块为 {"其他"}"""
|
||||
content = "# 标题\n\n没有文件清单\n"
|
||||
f = tmp_path / "2026-01-01__no-files.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.changed_files == []
|
||||
assert entry.modules == {"其他"}
|
||||
|
||||
def test_invalid_filename_returns_none(self, tmp_path):
|
||||
"""文件名不符合格式 → 返回 None"""
|
||||
f = tmp_path / "invalid-name.md"
|
||||
f.write_text("# Title\n", encoding="utf-8")
|
||||
assert parse_audit_file(f) is None
|
||||
|
||||
def test_non_md_gitkeep_returns_none(self, tmp_path):
|
||||
f = tmp_path / ".gitkeep"
|
||||
f.write_text("", encoding="utf-8")
|
||||
assert parse_audit_file(f) is None
|
||||
|
||||
def test_list_format_file_section(self, tmp_path):
|
||||
"""列表格式的文件清单也能正确解析"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
|
||||
## 文件清单(Files changed)
|
||||
- docs/api-reference/summary/foo.md
|
||||
- scripts/gen_api_docs.py
|
||||
- tasks/base_task.py(补 AI_CHANGELOG)
|
||||
""")
|
||||
f = tmp_path / "2026-02-01__list-format.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert len(entry.changed_files) == 3
|
||||
|
||||
def test_risk_from_metadata_header(self, tmp_path):
|
||||
"""从头部元数据行提取风险等级"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
- 日期:2026-01-01
|
||||
- 风险等级:低(纯文档重组)
|
||||
|
||||
## 直接原因
|
||||
测试
|
||||
""")
|
||||
f = tmp_path / "2026-01-01__meta-risk.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.risk_level == "低"
|
||||
|
||||
def test_change_type_bugfix(self, tmp_path):
|
||||
content = "# bugfix 修复\n\n修复了一个 bug\n"
|
||||
f = tmp_path / "2026-01-01__fix.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "bugfix"
|
||||
|
||||
def test_change_type_doc(self, tmp_path):
|
||||
content = "# 纯文档变更\n\n无逻辑改动\n"
|
||||
f = tmp_path / "2026-01-01__doc.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "文档"
|
||||
|
||||
def test_arrow_path_extraction(self, tmp_path):
|
||||
"""含 → 的移动行应提取源和目标路径"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
|
||||
## 变更摘要
|
||||
|
||||
### 文件移动
|
||||
- `docs/index/algo.md` → `docs/database/DWS/algo.md`
|
||||
""")
|
||||
f = tmp_path / "2026-01-01__arrow.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert "docs/index/algo.md" in entry.changed_files
|
||||
assert "docs/database/DWS/algo.md" in entry.changed_files
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# scan_audit_dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestScanAuditDir:
|
||||
|
||||
def test_empty_dir(self, tmp_path):
|
||||
assert scan_audit_dir(tmp_path) == []
|
||||
|
||||
def test_nonexistent_dir(self):
|
||||
assert scan_audit_dir("nonexistent_dir_xyz") == []
|
||||
|
||||
def test_sorts_by_date_descending(self, tmp_path):
|
||||
for name in [
|
||||
"2026-01-01__first.md",
|
||||
"2026-03-01__third.md",
|
||||
"2026-02-01__second.md",
|
||||
]:
|
||||
(tmp_path / name).write_text("# Title\n", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
dates = [e.date for e in entries]
|
||||
assert dates == ["2026-03-01", "2026-02-01", "2026-01-01"]
|
||||
|
||||
def test_skips_non_md_files(self, tmp_path):
|
||||
(tmp_path / "2026-01-01__valid.md").write_text("# OK\n", encoding="utf-8")
|
||||
(tmp_path / ".gitkeep").write_text("", encoding="utf-8")
|
||||
(tmp_path / "notes.txt").write_text("text", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert len(entries) == 1
|
||||
|
||||
def test_skips_invalid_filenames(self, tmp_path):
|
||||
(tmp_path / "2026-01-01__valid.md").write_text("# OK\n", encoding="utf-8")
|
||||
(tmp_path / "bad-name.md").write_text("# Bad\n", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].slug == "valid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 真实审计文件集成测试(仅在项目目录中运行时有效)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REAL_AUDIT_DIR = Path("docs/audit/changes")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _REAL_AUDIT_DIR.is_dir(),
|
||||
reason="真实审计目录不存在(非项目根目录运行)",
|
||||
)
|
||||
class TestRealAuditFiles:
|
||||
|
||||
def test_parses_all_real_files(self):
|
||||
entries = scan_audit_dir(_REAL_AUDIT_DIR)
|
||||
assert len(entries) > 0, "应至少解析出一条审计记录"
|
||||
|
||||
def test_all_modules_valid(self):
|
||||
entries = scan_audit_dir(_REAL_AUDIT_DIR)
|
||||
for e in entries:
|
||||
for m in e.modules:
|
||||
assert m in VALID_MODULES, (
|
||||
f"模块 {m!r} 不在 VALID_MODULES 中 (文件: {e.filename})"
|
||||
)
|
||||
|
||||
def test_dates_descending(self):
|
||||
entries = scan_audit_dir(_REAL_AUDIT_DIR)
|
||||
dates = [e.date for e in entries]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 1: 审计记录解析-渲染完整性
|
||||
# Feature: docs-optimization, Property 1: 审计记录解析-渲染完整性
|
||||
# Validates: Requirements 2.1, 2.2
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from hypothesis import given, settings, assume
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# --- 生成策略 ---
|
||||
|
||||
# 合法日期策略:YYYY-MM-DD 格式
|
||||
_date_st = st.dates(
|
||||
min_value=__import__("datetime").date(2020, 1, 1),
|
||||
max_value=__import__("datetime").date(2030, 12, 31),
|
||||
).map(lambda d: d.isoformat())
|
||||
|
||||
# slug 策略:小写字母+数字+连字符,长度 2~30
|
||||
_slug_st = st.from_regex(r"[a-z][a-z0-9\-]{1,29}", fullmatch=True)
|
||||
|
||||
# 标题策略:非空中文/英文混合文本
|
||||
_title_st = st.text(
|
||||
alphabet=st.sampled_from(
|
||||
list("审计记录变更修复新增重构清理测试文档优化迁移合并")
|
||||
+ list("abcdefghijklmnopqrstuvwxyz ")
|
||||
),
|
||||
min_size=2,
|
||||
max_size=40,
|
||||
).map(lambda s: s.strip() or "默认标题")
|
||||
|
||||
# 文件路径策略:从已知前缀中选取,确保模块分类有意义
|
||||
_KNOWN_PREFIXES = [
|
||||
"api/", "tasks/ods/", "tasks/dwd/", "tasks/dws/", "tasks/index/",
|
||||
"loaders/", "database/migrations/", "orchestration/", "config/",
|
||||
"cli/", "models/", "scd/", "docs/", "scripts/", "tests/unit/",
|
||||
"quality/", "gui/", "utils/",
|
||||
]
|
||||
|
||||
_file_path_st = st.tuples(
|
||||
st.sampled_from(_KNOWN_PREFIXES),
|
||||
st.from_regex(r"[a-z_]{1,15}\.(py|sql|md)", fullmatch=True),
|
||||
).map(lambda t: t[0] + t[1])
|
||||
|
||||
# 风险等级策略
|
||||
_risk_st = st.sampled_from(["极低", "低", "中", "高"])
|
||||
|
||||
# 变更类型关键词策略(用于在内容中嵌入,让 _infer_change_type 能推断)
|
||||
_change_kw_st = st.sampled_from(["bugfix", "修复", "重构", "清理", "纯文档", "功能新增"])
|
||||
|
||||
|
||||
def _build_audit_md(title: str, date: str, files: list[str], risk: str, change_kw: str) -> str:
|
||||
"""根据参数构造一份格式合规的审计 Markdown 内容。"""
|
||||
lines = [
|
||||
f"# {title}",
|
||||
"",
|
||||
f"- 日期:{date}(Asia/Shanghai)",
|
||||
f"- 风险等级:{risk}",
|
||||
"",
|
||||
"## 直接原因",
|
||||
f"本次变更为 {change_kw} 类型操作",
|
||||
"",
|
||||
"## 修改文件清单",
|
||||
"",
|
||||
"| 文件 | 变更类型 | 说明 |",
|
||||
"|------|----------|------|",
|
||||
]
|
||||
for fp in files:
|
||||
lines.append(f"| `{fp}` | 修改 | 自动生成 |")
|
||||
lines.append("")
|
||||
lines.append("## 风险与回滚")
|
||||
lines.append(f"- 风险:{risk}。自动生成的测试内容")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TestProperty1AuditParseCompleteness:
|
||||
"""Property 1: 审计记录解析-渲染完整性
|
||||
|
||||
对于任何格式合规的审计 Markdown 文件,parse_audit_file 解析后
|
||||
应返回包含所有必要字段(date、title、change_type、modules、filename)
|
||||
且各字段非空的 AuditEntry。
|
||||
|
||||
**Validates: Requirements 2.1, 2.2**
|
||||
"""
|
||||
|
||||
@given(
|
||||
date=_date_st,
|
||||
slug=_slug_st,
|
||||
title=_title_st,
|
||||
files=st.lists(_file_path_st, min_size=1, max_size=8),
|
||||
risk=_risk_st,
|
||||
change_kw=_change_kw_st,
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_parsed_entry_has_all_required_fields(
|
||||
self, tmp_path_factory, date, slug, title, files, risk, change_kw
|
||||
):
|
||||
"""解析格式合规的审计文件后,AuditEntry 的所有必要字段均非空。"""
|
||||
# 构造临时文件
|
||||
tmp_dir = tmp_path_factory.mktemp("audit")
|
||||
filename = f"{date}__{slug}.md"
|
||||
md_content = _build_audit_md(title, date, files, risk, change_kw)
|
||||
filepath = tmp_dir / filename
|
||||
filepath.write_text(md_content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(filepath)
|
||||
|
||||
# 核心断言:解析成功且所有必要字段非空
|
||||
assert entry is not None, f"格式合规的文件应能成功解析:{filename}"
|
||||
assert entry.date, "date 字段不应为空"
|
||||
assert entry.title, "title 字段不应为空"
|
||||
assert entry.filename, "filename 字段不应为空"
|
||||
assert entry.change_type, "change_type 字段不应为空"
|
||||
assert len(entry.modules) > 0, "modules 集合不应为空"
|
||||
|
||||
# 日期应与文件名中的日期一致
|
||||
assert entry.date == date
|
||||
|
||||
# filename 应与实际文件名一致
|
||||
assert entry.filename == filename
|
||||
|
||||
# modules 中的每个值都应在 VALID_MODULES 内
|
||||
for mod in entry.modules:
|
||||
assert mod in VALID_MODULES, f"模块 {mod!r} 不在 VALID_MODULES 中"
|
||||
|
||||
@given(
|
||||
date=_date_st,
|
||||
slug=_slug_st,
|
||||
title=_title_st,
|
||||
files=st.lists(_file_path_st, min_size=1, max_size=5),
|
||||
risk=_risk_st,
|
||||
change_kw=_change_kw_st,
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_parsed_files_match_input(
|
||||
self, tmp_path_factory, date, slug, title, files, risk, change_kw
|
||||
):
|
||||
"""解析后的 changed_files 应包含输入的所有文件路径。"""
|
||||
tmp_dir = tmp_path_factory.mktemp("audit")
|
||||
filename = f"{date}__{slug}.md"
|
||||
md_content = _build_audit_md(title, date, files, risk, change_kw)
|
||||
filepath = tmp_dir / filename
|
||||
filepath.write_text(md_content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(filepath)
|
||||
assert entry is not None
|
||||
|
||||
# 每个输入文件路径都应出现在解析结果中
|
||||
for fp in files:
|
||||
assert fp in entry.changed_files, (
|
||||
f"文件 {fp!r} 应出现在 changed_files 中,"
|
||||
f"实际结果:{entry.changed_files}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 属性测试 — Property 2: 文件路径模块分类正确性
|
||||
# Feature: docs-optimization, Property 2: 文件路径模块分类正确性
|
||||
# Validates: Requirements 2.3
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# --- 文件路径生成策略 ---
|
||||
|
||||
# 已知前缀路径:确保覆盖 MODULE_MAP 中所有前缀
|
||||
_known_prefix_path_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.from_regex(r"[a-z_]{1,20}\.(py|sql|md|txt|json)", fullmatch=True),
|
||||
).map(lambda t: t[0] + t[1])
|
||||
|
||||
# 完全随机路径:任意 Unicode 字符串(含空串、特殊字符等)
|
||||
_random_path_st = st.text(min_size=0, max_size=200)
|
||||
|
||||
# 带反斜杠的 Windows 风格路径
|
||||
_backslash_path_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.from_regex(r"[a-z_]{1,15}\.(py|md)", fullmatch=True),
|
||||
).map(lambda t: t[0].replace("/", "\\") + t[1])
|
||||
|
||||
# 带前导 ./ 的相对路径
|
||||
_dotslash_path_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.from_regex(r"[a-z_]{1,15}\.(py|md)", fullmatch=True),
|
||||
).map(lambda t: "./" + t[0] + t[1])
|
||||
|
||||
# 深层嵌套路径
|
||||
_deep_nested_st = st.tuples(
|
||||
st.sampled_from(list(MODULE_MAP.keys())),
|
||||
st.lists(
|
||||
st.from_regex(r"[a-z]{1,8}", fullmatch=True),
|
||||
min_size=1, max_size=5,
|
||||
),
|
||||
st.from_regex(r"[a-z_]{1,10}\.(py|md|sql)", fullmatch=True),
|
||||
).map(lambda t: t[0] + "/".join(t[1]) + "/" + t[2])
|
||||
|
||||
# 未知前缀路径(不以任何已知前缀开头)
|
||||
_unknown_prefix_st = st.tuples(
|
||||
st.sampled_from(["README.md", ".kiro/foo.md", "setup.py", "Makefile",
|
||||
"pyproject.toml", ".env", "unknown/deep/path.py",
|
||||
"random_dir/file.txt", ".github/workflows/ci.yml"]),
|
||||
).map(lambda t: t[0])
|
||||
|
||||
# 混合策略:从以上所有策略中随机选取
|
||||
_any_filepath_st = st.one_of(
|
||||
_known_prefix_path_st,
|
||||
_random_path_st,
|
||||
_backslash_path_st,
|
||||
_dotslash_path_st,
|
||||
_deep_nested_st,
|
||||
_unknown_prefix_st,
|
||||
)
|
||||
|
||||
|
||||
class TestProperty2ModuleClassification:
|
||||
"""Property 2: 文件路径模块分类正确性
|
||||
|
||||
对于任意文件路径字符串,classify_module 的返回值
|
||||
应始终属于预定义的 VALID_MODULES 集合。
|
||||
|
||||
**Validates: Requirements 2.3**
|
||||
"""
|
||||
|
||||
@given(filepath=_any_filepath_st)
|
||||
@settings(max_examples=200)
|
||||
def test_classify_always_returns_valid_module(self, filepath: str):
|
||||
"""任意文件路径的分类结果都在 VALID_MODULES 内。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES, (
|
||||
f"classify_module({filepath!r}) 返回 {result!r},"
|
||||
f"不在 VALID_MODULES 中"
|
||||
)
|
||||
|
||||
@given(filepath=_known_prefix_path_st)
|
||||
@settings(max_examples=150)
|
||||
def test_known_prefix_never_returns_other(self, filepath: str):
|
||||
"""以已知前缀开头的路径不应返回 '其他'。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES, (
|
||||
f"classify_module({filepath!r}) 返回 {result!r},"
|
||||
f"不在 VALID_MODULES 中"
|
||||
)
|
||||
assert result != "其他", (
|
||||
f"已知前缀路径 {filepath!r} 不应分类为 '其他',"
|
||||
f"实际返回 {result!r}"
|
||||
)
|
||||
|
||||
@given(filepath=_unknown_prefix_st)
|
||||
@settings(max_examples=50)
|
||||
def test_unknown_prefix_returns_other(self, filepath: str):
|
||||
"""不匹配任何已知前缀的路径应返回 '其他'。"""
|
||||
result = classify_module(filepath)
|
||||
assert result == "其他", (
|
||||
f"未知前缀路径 {filepath!r} 应分类为 '其他',"
|
||||
f"实际返回 {result!r}"
|
||||
)
|
||||
|
||||
@given(filepath=_backslash_path_st)
|
||||
@settings(max_examples=100)
|
||||
def test_backslash_paths_classified_correctly(self, filepath: str):
|
||||
"""Windows 反斜杠路径应正确归类(非 '其他')。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES
|
||||
assert result != "其他", (
|
||||
f"反斜杠路径 {filepath!r} 应正确分类,"
|
||||
f"实际返回 '其他'"
|
||||
)
|
||||
|
||||
@given(filepath=_dotslash_path_st)
|
||||
@settings(max_examples=100)
|
||||
def test_dotslash_paths_classified_correctly(self, filepath: str):
|
||||
"""前导 ./ 的路径应正确归类(非 '其他')。"""
|
||||
result = classify_module(filepath)
|
||||
assert result in VALID_MODULES
|
||||
assert result != "其他", (
|
||||
f"前导 ./ 路径 {filepath!r} 应正确分类,"
|
||||
f"实际返回 '其他'"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 渲染函数测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from scripts.gen_audit_dashboard import (
|
||||
render_timeline_table,
|
||||
render_module_index,
|
||||
render_dashboard,
|
||||
)
|
||||
|
||||
|
||||
def _make_entry(**overrides) -> AuditEntry:
|
||||
"""构造测试用 AuditEntry,支持字段覆盖。"""
|
||||
defaults = dict(
|
||||
date="2026-01-15",
|
||||
slug="test-change",
|
||||
title="测试变更",
|
||||
filename="2026-01-15__test-change.md",
|
||||
changed_files=["api/client.py"],
|
||||
modules={"API 层"},
|
||||
risk_level="低",
|
||||
change_type="功能",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return AuditEntry(**defaults)
|
||||
|
||||
|
||||
class TestRenderTimelineTable:
|
||||
"""render_timeline_table 单元测试"""
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""空列表返回暂无审计记录提示"""
|
||||
result = render_timeline_table([])
|
||||
assert "暂无审计记录" in result
|
||||
|
||||
def test_single_entry(self):
|
||||
"""单条记录生成正确的表格行"""
|
||||
entry = _make_entry()
|
||||
result = render_timeline_table([entry])
|
||||
assert "| 日期 |" in result
|
||||
assert "2026-01-15" in result
|
||||
assert "测试变更" in result
|
||||
assert "API 层" in result
|
||||
assert "低" in result
|
||||
assert "[链接](changes/2026-01-15__test-change.md)" in result
|
||||
|
||||
def test_multiple_entries_order_preserved(self):
|
||||
"""多条记录保持输入顺序(调用方负责排序)"""
|
||||
e1 = _make_entry(date="2026-02-01", title="后")
|
||||
e2 = _make_entry(date="2026-01-01", title="前")
|
||||
result = render_timeline_table([e1, e2])
|
||||
pos_e1 = result.index("2026-02-01")
|
||||
pos_e2 = result.index("2026-01-01")
|
||||
assert pos_e1 < pos_e2
|
||||
|
||||
def test_multiple_modules_joined(self):
|
||||
"""多个模块用逗号分隔并排序"""
|
||||
entry = _make_entry(modules={"文档", "API 层"})
|
||||
result = render_timeline_table([entry])
|
||||
assert "API 层, 文档" in result
|
||||
|
||||
def test_table_header_present(self):
|
||||
"""表格包含表头和分隔行"""
|
||||
result = render_timeline_table([_make_entry()])
|
||||
assert "|------|" in result
|
||||
assert "需求摘要" in result
|
||||
assert "变更类型" in result
|
||||
|
||||
|
||||
class TestRenderModuleIndex:
|
||||
"""render_module_index 单元测试"""
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""空列表返回暂无审计记录提示"""
|
||||
result = render_module_index([])
|
||||
assert "暂无审计记录" in result
|
||||
|
||||
def test_single_module(self):
|
||||
"""单模块生成一个三级标题章节"""
|
||||
entry = _make_entry(modules={"API 层"})
|
||||
result = render_module_index([entry])
|
||||
assert "### API 层" in result
|
||||
assert "2026-01-15" in result
|
||||
# 模块索引表格不含"影响模块"列
|
||||
lines = result.strip().splitlines()
|
||||
header = [l for l in lines if l.startswith("| 日期")]
|
||||
assert header
|
||||
assert "影响模块" not in header[0]
|
||||
|
||||
def test_multiple_modules_sorted(self):
|
||||
"""多模块按字母序排列"""
|
||||
e1 = _make_entry(modules={"文档"})
|
||||
e2 = _make_entry(modules={"API 层"}, date="2026-02-01")
|
||||
result = render_module_index([e1, e2])
|
||||
pos_api = result.index("### API 层")
|
||||
pos_doc = result.index("### 文档")
|
||||
assert pos_api < pos_doc
|
||||
|
||||
def test_entry_appears_in_multiple_modules(self):
|
||||
"""一条记录影响多个模块时,在每个模块章节中都出现"""
|
||||
entry = _make_entry(modules={"API 层", "文档"})
|
||||
result = render_module_index([entry])
|
||||
assert "### API 层" in result
|
||||
assert "### 文档" in result
|
||||
# 两个章节都包含该记录的链接
|
||||
assert result.count("[链接](changes/2026-01-15__test-change.md)") == 2
|
||||
|
||||
def test_link_format(self):
|
||||
"""详情列链接格式正确"""
|
||||
entry = _make_entry(filename="2026-03-01__my-slug.md")
|
||||
result = render_module_index([entry])
|
||||
assert "[链接](changes/2026-03-01__my-slug.md)" in result
|
||||
|
||||
|
||||
class TestRenderDashboard:
|
||||
"""render_dashboard 单元测试"""
|
||||
|
||||
def test_empty_entries(self):
|
||||
"""空列表生成包含提示的完整文档"""
|
||||
result = render_dashboard([])
|
||||
assert "# 审计一览表" in result
|
||||
assert "自动生成于" in result
|
||||
assert "暂无审计记录" in result
|
||||
|
||||
def test_contains_both_views(self):
|
||||
"""完整文档包含时间线和模块索引两个章节"""
|
||||
entry = _make_entry()
|
||||
result = render_dashboard([entry])
|
||||
assert "## 时间线视图" in result
|
||||
assert "## 模块索引" in result
|
||||
|
||||
def test_contains_header_and_timestamp(self):
|
||||
"""文档包含标题和生成时间戳"""
|
||||
result = render_dashboard([_make_entry()])
|
||||
assert "# 审计一览表" in result
|
||||
assert "自动生成于" in result
|
||||
assert "请勿手动编辑" in result
|
||||
|
||||
def test_timeline_before_module_index(self):
|
||||
"""时间线视图在模块索引之前"""
|
||||
result = render_dashboard([_make_entry()])
|
||||
pos_timeline = result.index("## 时间线视图")
|
||||
pos_module = result.index("## 模块索引")
|
||||
assert pos_timeline < pos_module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 3: 审计条目时间倒序排列
|
||||
# Feature: docs-optimization, Property 3: 审计条目时间倒序排列
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProperty3AuditEntriesDescendingOrder:
|
||||
"""属性测试:审计条目经排序后日期严格非递增。
|
||||
|
||||
**Validates: Requirements 2.4**
|
||||
"""
|
||||
|
||||
@given(
|
||||
dates=st.lists(
|
||||
st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
min_size=0,
|
||||
max_size=30,
|
||||
)
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_sorted_entries_dates_non_increasing(self, dates):
|
||||
"""任意日期列表构造的 AuditEntry,经 scan_audit_dir 同款排序后,
|
||||
日期序列应为非递增(即每个条目的日期 >= 其后续条目的日期)。"""
|
||||
# 构造 AuditEntry 列表,日期格式与 scan_audit_dir 一致(YYYY-MM-DD 字符串)
|
||||
entries = [
|
||||
AuditEntry(
|
||||
date=d.isoformat(),
|
||||
slug=f"entry-{i}",
|
||||
title=f"条目 {i}",
|
||||
filename=f"{d.isoformat()}__entry-{i}.md",
|
||||
)
|
||||
for i, d in enumerate(dates)
|
||||
]
|
||||
|
||||
# 使用与 scan_audit_dir 完全相同的排序逻辑
|
||||
entries.sort(key=lambda e: e.date, reverse=True)
|
||||
|
||||
# 验证非递增序
|
||||
for i in range(len(entries) - 1):
|
||||
assert entries[i].date >= entries[i + 1].date, (
|
||||
f"位置 {i} 的日期 {entries[i].date} "
|
||||
f"不应小于位置 {i+1} 的日期 {entries[i+1].date}"
|
||||
)
|
||||
|
||||
@given(
|
||||
dates=st.lists(
|
||||
st.dates(
|
||||
min_value=datetime.date(2020, 1, 1),
|
||||
max_value=datetime.date(2030, 12, 31),
|
||||
),
|
||||
min_size=2,
|
||||
max_size=30,
|
||||
)
|
||||
)
|
||||
@settings(max_examples=150)
|
||||
def test_first_entry_has_latest_date(self, dates):
|
||||
"""排序后第一个条目的日期应等于输入中的最大日期。"""
|
||||
entries = [
|
||||
AuditEntry(
|
||||
date=d.isoformat(),
|
||||
slug=f"entry-{i}",
|
||||
title=f"条目 {i}",
|
||||
filename=f"{d.isoformat()}__entry-{i}.md",
|
||||
)
|
||||
for i, d in enumerate(dates)
|
||||
]
|
||||
|
||||
entries.sort(key=lambda e: e.date, reverse=True)
|
||||
|
||||
expected_max = max(d.isoformat() for d in dates)
|
||||
assert entries[0].date == expected_max
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 补充边界情况测试 — Task 4.6
|
||||
# 覆盖:空内容、缺少风险章节、变更类型推断分支、同日期排序
|
||||
# Requirements: 2.1, 2.3
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseAuditFileEdgeCases:
|
||||
"""parse_audit_file 的补充边界情况"""
|
||||
|
||||
def test_empty_content(self, tmp_path):
|
||||
"""完全空文件 → 标题用 slug 兜底,文件清单为空,模块为 {"其他"}"""
|
||||
f = tmp_path / "2026-01-01__empty.md"
|
||||
f.write_text("", encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.title == "empty"
|
||||
assert entry.changed_files == []
|
||||
assert entry.modules == {"其他"}
|
||||
|
||||
def test_whitespace_only_content(self, tmp_path):
|
||||
"""仅含空白字符 → 同空文件处理"""
|
||||
f = tmp_path / "2026-01-01__blank.md"
|
||||
f.write_text(" \n\n \n", encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.title == "blank"
|
||||
assert entry.changed_files == []
|
||||
|
||||
def test_missing_risk_section_returns_unknown(self, tmp_path):
|
||||
"""无任何风险相关内容 → 风险等级为"未知" """
|
||||
content = "# 简单标题\n\n一些普通内容,没有提到任何等级信息。\n"
|
||||
f = tmp_path / "2026-01-01__no-risk.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.risk_level == "未知"
|
||||
|
||||
def test_change_type_refactor(self, tmp_path):
|
||||
"""含"重构"关键词 → 变更类型为"重构" """
|
||||
content = "# 代码重构\n\n对模块进行了重构优化\n"
|
||||
f = tmp_path / "2026-01-01__refactor.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "重构"
|
||||
|
||||
def test_change_type_cleanup(self, tmp_path):
|
||||
"""含"清理"关键词 → 变更类型为"清理" """
|
||||
content = "# 遗留代码清理\n\n清理了废弃文件\n"
|
||||
f = tmp_path / "2026-01-01__cleanup.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "清理"
|
||||
|
||||
def test_change_type_default_function(self, tmp_path):
|
||||
"""无任何变更类型关键词 → 默认为"功能" """
|
||||
content = "# 新增能力\n\n增加了一个全新的处理流程\n"
|
||||
f = tmp_path / "2026-01-01__feature.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry.change_type == "功能"
|
||||
|
||||
def test_file_section_with_empty_table(self, tmp_path):
|
||||
"""文件清单章节存在但表格无数据行 → 空列表"""
|
||||
content = textwrap.dedent("""\
|
||||
# 测试
|
||||
|
||||
## 修改文件清单
|
||||
|
||||
| 文件 | 变更类型 | 说明 |
|
||||
|------|----------|------|
|
||||
|
||||
## 风险与回滚
|
||||
无
|
||||
""")
|
||||
f = tmp_path / "2026-01-01__empty-table.md"
|
||||
f.write_text(content, encoding="utf-8")
|
||||
|
||||
entry = parse_audit_file(f)
|
||||
assert entry is not None
|
||||
assert entry.changed_files == []
|
||||
assert entry.modules == {"其他"}
|
||||
|
||||
|
||||
class TestScanAuditDirEdgeCases:
|
||||
"""scan_audit_dir 的补充边界情况"""
|
||||
|
||||
def test_dir_with_only_invalid_files(self, tmp_path):
|
||||
"""目录中全是无效文件 → 返回空列表"""
|
||||
(tmp_path / "README.md").write_text("# 说明\n", encoding="utf-8")
|
||||
(tmp_path / ".gitkeep").write_text("", encoding="utf-8")
|
||||
(tmp_path / "notes.txt").write_text("备注", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert entries == []
|
||||
|
||||
def test_same_date_multiple_files(self, tmp_path):
|
||||
"""同日期多个文件 → 全部解析,日期相同"""
|
||||
for slug in ["alpha", "beta", "gamma"]:
|
||||
f = tmp_path / f"2026-03-01__{slug}.md"
|
||||
f.write_text(f"# {slug} 变更\n", encoding="utf-8")
|
||||
|
||||
entries = scan_audit_dir(tmp_path)
|
||||
assert len(entries) == 3
|
||||
assert all(e.date == "2026-03-01" for e in entries)
|
||||
161
apps/etl/pipelines/feiqiu/tests/unit/test_ods_tasks.py
Normal file
161
apps/etl/pipelines/feiqiu/tests/unit/test_ods_tasks.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the new ODS ingestion tasks."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 确保在独立运行测试时能正确解析项目根目录
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.ods.ods_tasks import ODS_TASK_CLASSES
|
||||
from .task_test_utils import create_test_config, get_db_operations, FakeAPIClient
|
||||
|
||||
|
||||
def _build_config(tmp_path):
|
||||
archive_dir = tmp_path / "archive"
|
||||
temp_dir = tmp_path / "temp"
|
||||
return create_test_config("ONLINE", archive_dir, temp_dir)
|
||||
|
||||
|
||||
def test_assistant_accounts_masters_ingest(tmp_path):
|
||||
"""Ensure ODS_ASSISTANT_ACCOUNT stores raw payload with record_index dedup keys."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [
|
||||
{
|
||||
"id": 5001,
|
||||
"assistant_no": "A01",
|
||||
"nickname": "小张",
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/PersonnelManagement/SearchAssistantInfo": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_ASSISTANT_ACCOUNT"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_assistant_accounts_masters"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["counts"]["fetched"] == 1
|
||||
assert db_ops.commits == 1
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["id"] == 5001
|
||||
assert row["record_index"] == 0
|
||||
assert row["source_file"] is None or row["source_file"] == ""
|
||||
assert '"id": 5001' in row["payload"]
|
||||
|
||||
|
||||
def test_goods_stock_movements_ingest(tmp_path):
|
||||
"""Ensure ODS_INVENTORY_CHANGE stores raw payload with record_index dedup keys."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [
|
||||
{
|
||||
"siteGoodsStockId": 123456,
|
||||
"stockType": 1,
|
||||
"goodsName": "测试商品",
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/GoodsStockManage/QueryGoodsOutboundReceipt": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_INVENTORY_CHANGE"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_goods_stock_movements"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["counts"]["fetched"] == 1
|
||||
assert db_ops.commits == 1
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["sitegoodsstockid"] == 123456
|
||||
assert row["record_index"] == 0
|
||||
assert '"siteGoodsStockId": 123456' in row["payload"]
|
||||
|
||||
|
||||
def test_member_profiless_ingest(tmp_path):
|
||||
"""Ensure ODS_MEMBER task stores tenantMemberInfos raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"tenantMemberInfos": [{"id": 101, "mobile": "13800000000"}]}]
|
||||
api = FakeAPIClient({"/MemberProfile/GetTenantMemberList": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_MEMBER"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_member"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["record_index"] == 0
|
||||
assert '"id": 101' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_payment_ingest(tmp_path):
|
||||
"""Ensure ODS_PAYMENT task stores payment_transactions raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"payId": 901, "payAmount": "100.00"}]
|
||||
api = FakeAPIClient({"/PayLog/GetPayLogListPage": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_PAYMENT"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_payment"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["record_index"] == 0
|
||||
assert '"payId": 901' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_settlement_records_ingest(tmp_path):
|
||||
"""Ensure ODS_SETTLEMENT_RECORDS stores settleList raw JSON."""
|
||||
config = _build_config(tmp_path)
|
||||
sample = [{"id": 701, "orderTradeNo": 8001}]
|
||||
api = FakeAPIClient({"/Site/GetAllOrderSettleList": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_RECORDS"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_settlement_records"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["record_index"] == 0
|
||||
assert '"orderTradeNo": 8001' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
|
||||
"""Ensure settlement tickets are fetched per payment relate_id and skip existing ones."""
|
||||
config = _build_config(tmp_path)
|
||||
ticket_payload = {"data": {"data": {"orderSettleId": 9001, "orderSettleNumber": "T001"}}}
|
||||
api = FakeAPIClient({"/Order/GetOrderSettleTicketNew": [ticket_payload]})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
# 第一次查询:已有的小票ID;第二次查询:支付关联ID
|
||||
db_ops.query_results = [
|
||||
[{"order_settle_id": 9002}],
|
||||
[
|
||||
{"order_settle_id": 9001},
|
||||
{"order_settle_id": 9002},
|
||||
{"order_settle_id": None},
|
||||
],
|
||||
]
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_settlement_ticket"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
counts = result["counts"]
|
||||
assert counts["fetched"] == 1
|
||||
assert counts["inserted"] == 1
|
||||
assert counts["updated"] == 0
|
||||
assert counts["skipped"] == 0
|
||||
assert '"orderSettleId": 9001' in db_ops.upserts[0]["rows"][0]["payload"]
|
||||
assert any(
|
||||
call["endpoint"] == "/Order/GetOrderSettleTicketNew"
|
||||
and call.get("params", {}).get("orderSettleId") == 9001
|
||||
for call in api.calls
|
||||
)
|
||||
|
||||
39
apps/etl/pipelines/feiqiu/tests/unit/test_parsers.py
Normal file
39
apps/etl/pipelines/feiqiu/tests/unit/test_parsers.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""解析器测试"""
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
from models.parsers import TypeParser
|
||||
|
||||
def test_parse_decimal():
|
||||
"""测试金额解析"""
|
||||
assert TypeParser.parse_decimal("100.555", 2) == Decimal("100.56")
|
||||
assert TypeParser.parse_decimal(None) is None
|
||||
assert TypeParser.parse_decimal("invalid") is None
|
||||
|
||||
def test_parse_int():
|
||||
"""测试整数解析"""
|
||||
assert TypeParser.parse_int("123") == 123
|
||||
assert TypeParser.parse_int(456) == 456
|
||||
assert TypeParser.parse_int(None) is None
|
||||
assert TypeParser.parse_int("abc") is None
|
||||
|
||||
def test_parse_timestamp():
|
||||
"""测试时间戳解析"""
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
dt = TypeParser.parse_timestamp("2025-01-15 10:30:00", tz)
|
||||
assert dt is not None
|
||||
assert dt.year == 2025
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
|
||||
|
||||
def test_parse_timestamp_zero_epoch():
|
||||
"""0 不应被当成空值;应解析为 Unix epoch。"""
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
dt = TypeParser.parse_timestamp(0, tz)
|
||||
assert dt is not None
|
||||
assert dt.year == 1970
|
||||
assert dt.month == 1
|
||||
assert dt.day == 1
|
||||
@@ -0,0 +1,304 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""PipelineRunner 属性测试 - hypothesis 验证管道编排器的通用正确性属性。"""
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.pipeline_runner import PipelineRunner
|
||||
|
||||
# run() 内部延迟导入 TaskLogger,需要 mock 源模块路径
|
||||
_TASK_LOGGER_PATH = "utils.task_logger.TaskLogger"
|
||||
|
||||
FILE_VERSION = "v1_shell"
|
||||
|
||||
# ── 策略定义 ──────────────────────────────────────────────────────
|
||||
|
||||
pipeline_name_st = st.sampled_from(list(PipelineRunner.PIPELINE_LAYERS.keys()))
|
||||
|
||||
processing_mode_st = st.sampled_from(["increment_only", "verify_only", "increment_verify"])
|
||||
|
||||
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
|
||||
|
||||
_TASK_PREFIXES = ["ODS_", "DWD_", "DWS_", "INDEX_"]
|
||||
task_code_st = st.builds(
|
||||
lambda prefix, suffix: prefix + suffix,
|
||||
prefix=st.sampled_from(_TASK_PREFIXES),
|
||||
suffix=st.text(
|
||||
alphabet=string.ascii_uppercase + string.digits + "_",
|
||||
min_size=1, max_size=12,
|
||||
),
|
||||
)
|
||||
|
||||
# 单任务结果生成器
|
||||
task_result_st = st.fixed_dictionaries({
|
||||
"task_code": task_code_st,
|
||||
"status": st.sampled_from(["SUCCESS", "FAIL", "SKIP"]),
|
||||
"counts": st.fixed_dictionaries({
|
||||
"fetched": st.integers(min_value=0, max_value=10000),
|
||||
"inserted": st.integers(min_value=0, max_value=10000),
|
||||
"updated": st.integers(min_value=0, max_value=10000),
|
||||
"skipped": st.integers(min_value=0, max_value=10000),
|
||||
"errors": st.integers(min_value=0, max_value=100),
|
||||
}),
|
||||
"dump_dir": st.none(),
|
||||
})
|
||||
|
||||
task_results_st = st.lists(task_result_st, min_size=0, max_size=10)
|
||||
|
||||
|
||||
# ── 辅助函数 ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_config():
|
||||
"""创建 mock 配置对象。"""
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: {
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
"verification.ods_use_local_json": False,
|
||||
"verification.skip_ods_when_fetch_before_verify": True,
|
||||
"run.ods_tasks": [],
|
||||
"run.dws_tasks": [],
|
||||
"run.index_tasks": [],
|
||||
}.get(key, default))
|
||||
return config
|
||||
|
||||
|
||||
def _make_runner(task_executor=None, task_registry=None):
|
||||
"""创建 PipelineRunner 实例,注入 mock 依赖。"""
|
||||
if task_executor is None:
|
||||
task_executor = MagicMock()
|
||||
task_executor.run_tasks.return_value = []
|
||||
if task_registry is None:
|
||||
task_registry = MagicMock()
|
||||
task_registry.get_tasks_by_layer.return_value = ["FAKE_TASK"]
|
||||
return PipelineRunner(
|
||||
config=_make_config(),
|
||||
task_executor=task_executor,
|
||||
task_registry=task_registry,
|
||||
db_conn=MagicMock(),
|
||||
api_client=MagicMock(),
|
||||
logger=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
# ── Property 5: 管道名称→层列表映射 ──────────────────────────────
|
||||
# Feature: scheduler-refactor, Property 5: 管道名称→层列表映射
|
||||
# **Validates: Requirements 2.1**
|
||||
|
||||
|
||||
class TestProperty5PipelineNameToLayers:
|
||||
"""对于任意有效的管道名称,PipelineRunner 解析出的层列表应与
|
||||
PIPELINE_LAYERS 字典中的定义完全一致。"""
|
||||
|
||||
@given(pipeline=pipeline_name_st)
|
||||
@settings(max_examples=100)
|
||||
def test_layers_match_pipeline_definition(self, pipeline):
|
||||
"""run() 返回的 layers 字段与 PIPELINE_LAYERS[pipeline] 完全一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
expected_layers = PipelineRunner.PIPELINE_LAYERS[pipeline]
|
||||
assert result["layers"] == expected_layers
|
||||
|
||||
@given(pipeline=pipeline_name_st)
|
||||
@settings(max_examples=100)
|
||||
def test_resolve_tasks_called_with_correct_layers(self, pipeline):
|
||||
"""_resolve_tasks 接收的层列表与 PIPELINE_LAYERS 定义一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with (
|
||||
patch(_TASK_LOGGER_PATH),
|
||||
patch.object(runner, "_resolve_tasks", wraps=runner._resolve_tasks) as spy,
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
expected_layers = PipelineRunner.PIPELINE_LAYERS[pipeline]
|
||||
spy.assert_called_once_with(expected_layers)
|
||||
|
||||
|
||||
# ── Property 6: processing_mode 控制执行流程 ─────────────────────
|
||||
# Feature: scheduler-refactor, Property 6: processing_mode 控制执行流程
|
||||
# **Validates: Requirements 2.3, 2.4**
|
||||
|
||||
|
||||
class TestProperty6ProcessingModeControlsFlow:
|
||||
"""对于任意 processing_mode,增量 ETL 执行当且仅当模式包含 increment,
|
||||
校验流程执行当且仅当模式包含 verify。"""
|
||||
|
||||
@given(
|
||||
pipeline=pipeline_name_st,
|
||||
mode=processing_mode_st,
|
||||
data_source=data_source_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_increment_executes_iff_mode_contains_increment(self, pipeline, mode, data_source):
|
||||
"""增量 ETL(task_executor.run_tasks)执行当且仅当 mode 包含 'increment'。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with (
|
||||
patch(_TASK_LOGGER_PATH),
|
||||
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}),
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode=mode,
|
||||
data_source=data_source,
|
||||
)
|
||||
|
||||
should_increment = "increment" in mode
|
||||
if should_increment:
|
||||
assert executor.run_tasks.called, (
|
||||
f"mode={mode} 包含 'increment',但 run_tasks 未被调用"
|
||||
)
|
||||
else:
|
||||
# verify_only 且 fetch_before_verify=False(默认),run_tasks 不应被调用
|
||||
assert not executor.run_tasks.called, (
|
||||
f"mode={mode} 不包含 'increment',但 run_tasks 被调用了"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=pipeline_name_st,
|
||||
mode=processing_mode_st,
|
||||
data_source=data_source_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_verification_executes_iff_mode_contains_verify(self, pipeline, mode, data_source):
|
||||
"""校验流程(_run_verification)执行当且仅当 mode 包含 'verify'。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with (
|
||||
patch(_TASK_LOGGER_PATH),
|
||||
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}) as mock_verify,
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode=mode,
|
||||
data_source=data_source,
|
||||
)
|
||||
|
||||
should_verify = "verify" in mode
|
||||
if should_verify:
|
||||
assert mock_verify.called, (
|
||||
f"mode={mode} 包含 'verify',但 _run_verification 未被调用"
|
||||
)
|
||||
else:
|
||||
assert not mock_verify.called, (
|
||||
f"mode={mode} 不包含 'verify',但 _run_verification 被调用了"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 7: 管道结果汇总完整性 ──────────────────────────────
|
||||
# Feature: scheduler-refactor, Property 7: 管道结果汇总完整性
|
||||
# **Validates: Requirements 2.6**
|
||||
|
||||
|
||||
class TestProperty7PipelineSummaryCompleteness:
|
||||
"""对于任意一组任务执行结果,PipelineRunner 返回的汇总字典应包含
|
||||
status/pipeline/layers/results 字段,且 results 长度等于实际执行的任务数。"""
|
||||
|
||||
@given(
|
||||
pipeline=pipeline_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_summary_has_required_fields(self, pipeline, task_results):
|
||||
"""返回字典必须包含 status、pipeline、layers、results、verification_summary。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
required_keys = {"status", "pipeline", "layers", "results", "verification_summary"}
|
||||
assert required_keys.issubset(result.keys()), (
|
||||
f"缺少必要字段: {required_keys - result.keys()}"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=pipeline_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_results_length_equals_executed_tasks(self, pipeline, task_results):
|
||||
"""results 列表长度等于 task_executor.run_tasks 返回的任务数。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert len(result["results"]) == len(task_results), (
|
||||
f"results 长度 {len(result['results'])} != 实际任务数 {len(task_results)}"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=pipeline_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_pipeline_and_layers_match_input(self, pipeline, task_results):
|
||||
"""返回的 pipeline 和 layers 字段与输入一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert result["pipeline"] == pipeline
|
||||
assert result["layers"] == PipelineRunner.PIPELINE_LAYERS[pipeline]
|
||||
|
||||
@given(
|
||||
pipeline=pipeline_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_increment_only_has_no_verification(self, pipeline, task_results):
|
||||
"""increment_only 模式下 verification_summary 应为 None。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert result["verification_summary"] is None
|
||||
133
apps/etl/pipelines/feiqiu/tests/unit/test_relation_index_base.py
Normal file
133
apps/etl/pipelines/feiqiu/tests/unit/test_relation_index_base.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""关系指数基础能力单测。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tasks.dws.index.base_index_task import BaseIndexTask
|
||||
from tasks.dws.index.ml_manual_import_task import MlManualImportTask
|
||||
|
||||
|
||||
class _DummyConfig:
|
||||
"""最小配置桩对象。"""
|
||||
|
||||
def __init__(self, values: Optional[Dict[str, Any]] = None):
|
||||
self._values = values or {}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
class _DummyDB:
|
||||
"""最小数据库桩对象。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.query_calls: List[tuple] = []
|
||||
|
||||
def query(self, sql: str, params=None):
|
||||
self.query_calls.append((sql, params))
|
||||
index_type = (params or [None])[0]
|
||||
if index_type == "RS":
|
||||
return [{"param_name": "lookback_days", "param_value": 60}]
|
||||
if index_type == "MS":
|
||||
return [{"param_name": "lookback_days", "param_value": 30}]
|
||||
return []
|
||||
|
||||
|
||||
class _DummyIndexTask(BaseIndexTask):
|
||||
"""用于测试 BaseIndexTask 的最小实现。"""
|
||||
|
||||
INDEX_TYPE = "RS"
|
||||
|
||||
def get_task_code(self) -> str: # pragma: no cover - 测试桩
|
||||
return "DUMMY_INDEX"
|
||||
|
||||
def get_target_table(self) -> str: # pragma: no cover - 测试桩
|
||||
return "dummy_table"
|
||||
|
||||
def get_primary_keys(self) -> List[str]: # pragma: no cover - 测试桩
|
||||
return ["id"]
|
||||
|
||||
def get_index_type(self) -> str:
|
||||
return self.INDEX_TYPE
|
||||
|
||||
def extract(self, context): # pragma: no cover - 测试桩
|
||||
return []
|
||||
|
||||
def load(self, transformed, context): # pragma: no cover - 测试桩
|
||||
return {}
|
||||
|
||||
|
||||
def test_load_index_parameters_cache_isolated_by_index_type():
|
||||
"""参数缓存应按 index_type 隔离,避免单任务串参。"""
|
||||
task = _DummyIndexTask(
|
||||
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
|
||||
_DummyDB(),
|
||||
None,
|
||||
logging.getLogger("test_index_cache"),
|
||||
)
|
||||
|
||||
rs_first = task.load_index_parameters(index_type="RS")
|
||||
ms_first = task.load_index_parameters(index_type="MS")
|
||||
rs_second = task.load_index_parameters(index_type="RS")
|
||||
|
||||
assert rs_first["lookback_days"] == 60.0
|
||||
assert ms_first["lookback_days"] == 30.0
|
||||
assert rs_second["lookback_days"] == 60.0
|
||||
# 只应查询两次:RS 一次 + MS 一次,第二次 RS 命中缓存
|
||||
assert len(task.db.query_calls) == 2
|
||||
|
||||
|
||||
def test_batch_normalize_passes_index_type_to_smoothing_chain():
|
||||
"""batch_normalize_to_display 应把 index_type 传入平滑链路。"""
|
||||
task = _DummyIndexTask(
|
||||
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
|
||||
_DummyDB(),
|
||||
None,
|
||||
logging.getLogger("test_index_smoothing"),
|
||||
)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def _fake_apply(site_id, current_p5, current_p95, alpha=None, index_type=None):
|
||||
captured["index_type"] = index_type
|
||||
return current_p5, current_p95
|
||||
|
||||
task._apply_ewma_smoothing = _fake_apply # type: ignore[method-assign]
|
||||
|
||||
result = task.batch_normalize_to_display(
|
||||
raw_scores=[("a", 1.0), ("b", 2.0), ("c", 3.0)],
|
||||
use_smoothing=True,
|
||||
site_id=1,
|
||||
index_type="ML",
|
||||
)
|
||||
|
||||
assert result
|
||||
assert captured["index_type"] == "ML"
|
||||
|
||||
|
||||
def test_ml_manual_import_scope_day_and_p30_boundary():
|
||||
"""30天边界内按天覆盖,超过30天进入固定纪元30天桶。"""
|
||||
today = date(2026, 2, 8)
|
||||
|
||||
day_scope = MlManualImportTask.resolve_scope(
|
||||
site_id=1,
|
||||
biz_date=date(2026, 1, 9), # 距 today 30 天
|
||||
today=today,
|
||||
)
|
||||
assert day_scope.scope_type == "DAY"
|
||||
assert day_scope.start_date == date(2026, 1, 9)
|
||||
assert day_scope.end_date == date(2026, 1, 9)
|
||||
|
||||
p30_scope = MlManualImportTask.resolve_scope(
|
||||
site_id=1,
|
||||
biz_date=date(2026, 1, 8), # 距 today 31 天
|
||||
today=today,
|
||||
)
|
||||
assert p30_scope.scope_type == "P30"
|
||||
# 固定纪元 2026-01-01,第一桶应为 2026-01-01 ~ 2026-01-30
|
||||
assert p30_scope.start_date == date(2026, 1, 1)
|
||||
assert p30_scope.end_date == date(2026, 1, 30)
|
||||
22
apps/etl/pipelines/feiqiu/tests/unit/test_reporting.py
Normal file
22
apps/etl/pipelines/feiqiu/tests/unit/test_reporting.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""汇总与报告工具的单测。"""
|
||||
from utils.reporting import summarize_counts, format_report
|
||||
|
||||
|
||||
def test_summarize_counts_and_format():
|
||||
task_results = [
|
||||
{"task_code": "ORDERS", "counts": {"fetched": 2, "inserted": 2, "updated": 0, "skipped": 0, "errors": 0}},
|
||||
{"task_code": "PAYMENTS", "counts": {"fetched": 3, "inserted": 2, "updated": 1, "skipped": 0, "errors": 0}},
|
||||
]
|
||||
|
||||
summary = summarize_counts(task_results)
|
||||
assert summary["total"]["fetched"] == 5
|
||||
assert summary["total"]["inserted"] == 4
|
||||
assert summary["total"]["updated"] == 1
|
||||
assert summary["total"]["errors"] == 0
|
||||
assert len(summary["details"]) == 2
|
||||
|
||||
report = format_report(summary)
|
||||
assert "TOTAL fetched=5" in report
|
||||
assert "ORDERS:" in report
|
||||
assert "PAYMENTS:" in report
|
||||
@@ -0,0 +1,207 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskExecutor 属性测试 - hypothesis 验证执行器的通用正确性属性。"""
|
||||
import re
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.task_executor import TaskExecutor, DataSource
|
||||
from orchestration.task_registry import TaskRegistry
|
||||
|
||||
FILE_VERSION = "v4_shell"
|
||||
|
||||
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
|
||||
|
||||
_NON_ODS_PREFIXES = ["DWD_", "DWS_", "TASK_", "ETL_", "TEST_"]
|
||||
task_code_st = st.builds(
|
||||
lambda prefix, suffix: prefix + suffix,
|
||||
prefix=st.sampled_from(_NON_ODS_PREFIXES),
|
||||
suffix=st.text(
|
||||
alphabet=string.ascii_uppercase + string.digits + "_",
|
||||
min_size=1, max_size=15,
|
||||
),
|
||||
)
|
||||
|
||||
window_start_st = st.datetimes(min_value=datetime(2020, 1, 1), max_value=datetime(2030, 12, 31))
|
||||
|
||||
|
||||
def _make_fake_class(name="FakeTask"):
|
||||
return type(name, (), {"__init__": lambda self, *a, **kw: None})
|
||||
|
||||
|
||||
def _make_config():
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: {
|
||||
"app.timezone": "Asia/Shanghai",
|
||||
"io.fetch_root": "/tmp/fetch",
|
||||
"io.ingest_source_dir": "/tmp/ingest",
|
||||
"io.write_pretty_json": False,
|
||||
"pipeline.fetch_root": None,
|
||||
"pipeline.ingest_source_dir": None,
|
||||
"integrity.auto_check": False,
|
||||
"run.overlap_seconds": 600,
|
||||
}.get(key, default))
|
||||
config.__getitem__ = MagicMock(side_effect=lambda k: {
|
||||
"io": {"export_root": "/tmp/export", "log_root": "/tmp/log"},
|
||||
}[k])
|
||||
return config
|
||||
|
||||
|
||||
def _make_executor(registry):
|
||||
return TaskExecutor(
|
||||
config=_make_config(), db_ops=MagicMock(), api_client=MagicMock(),
|
||||
cursor_mgr=MagicMock(), run_tracker=MagicMock(),
|
||||
task_registry=registry, logger=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 1: data_source 参数决定执行路径
|
||||
# **Validates: Requirements 1.2**
|
||||
|
||||
class TestProperty1DataSourceDeterminesPath:
|
||||
|
||||
@given(ds=data_source_st)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_includes_fetch(self, ds):
|
||||
result = TaskExecutor._flow_includes_fetch(ds)
|
||||
assert result == (ds in {"online", "hybrid"})
|
||||
|
||||
@given(ds=data_source_st)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_includes_ingest(self, ds):
|
||||
result = TaskExecutor._flow_includes_ingest(ds)
|
||||
assert result == (ds in {"offline", "hybrid"})
|
||||
|
||||
@given(ds=data_source_st)
|
||||
@settings(max_examples=100)
|
||||
def test_fetch_and_ingest_consistency(self, ds):
|
||||
fetch = TaskExecutor._flow_includes_fetch(ds)
|
||||
ingest = TaskExecutor._flow_includes_ingest(ds)
|
||||
if ds == "hybrid":
|
||||
assert fetch and ingest
|
||||
elif ds == "online":
|
||||
assert fetch and not ingest
|
||||
elif ds == "offline":
|
||||
assert not fetch and ingest
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 2: 成功任务推进游标
|
||||
# **Validates: Requirements 1.3**
|
||||
|
||||
class TestProperty2SuccessAdvancesCursor:
|
||||
|
||||
@given(
|
||||
task_code=task_code_st,
|
||||
window_start=window_start_st,
|
||||
window_minutes=st.integers(min_value=1, max_value=1440),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_success_with_window_advances_cursor(self, task_code, window_start, window_minutes):
|
||||
window_end = window_start + timedelta(minutes=window_minutes)
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
|
||||
task_result = {
|
||||
"status": "SUCCESS",
|
||||
"counts": {"fetched": 10, "inserted": 5},
|
||||
"window": {"start": window_start, "end": window_end, "minutes": window_minutes},
|
||||
}
|
||||
executor = _make_executor(registry)
|
||||
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
|
||||
executor.run_tracker.create_run.return_value = 100
|
||||
|
||||
with (
|
||||
patch.object(TaskExecutor, "_load_task_config", return_value={
|
||||
"task_id": 42, "task_code": task_code, "store_id": 1, "enabled": True}),
|
||||
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
|
||||
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
|
||||
patch.object(TaskExecutor, "_maybe_run_integrity_check"),
|
||||
):
|
||||
executor.run_single_task(task_code, "test-uuid", 1, "offline")
|
||||
|
||||
executor.cursor_mgr.advance.assert_called_once()
|
||||
kw = executor.cursor_mgr.advance.call_args.kwargs
|
||||
assert kw["window_start"] == window_start
|
||||
assert kw["window_end"] == window_end
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 3: 失败任务标记 FAIL 并重新抛出
|
||||
# **Validates: Requirements 1.4**
|
||||
|
||||
class TestProperty3FailureMarksFailAndReraises:
|
||||
|
||||
@given(
|
||||
task_code=task_code_st,
|
||||
error_msg=st.text(
|
||||
alphabet=string.ascii_letters + string.digits + " _-",
|
||||
min_size=1, max_size=80,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_exception_marks_fail_and_reraises(self, task_code, error_msg):
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
|
||||
executor = _make_executor(registry)
|
||||
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
|
||||
executor.run_tracker.create_run.return_value = 200
|
||||
|
||||
with (
|
||||
patch.object(TaskExecutor, "_load_task_config", return_value={
|
||||
"task_id": 99, "task_code": task_code, "store_id": 1, "enabled": True}),
|
||||
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
|
||||
patch.object(TaskExecutor, "_execute_ingest", side_effect=RuntimeError(error_msg)),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match=re.escape(error_msg)):
|
||||
executor.run_single_task(task_code, "fail-uuid", 1, "offline")
|
||||
|
||||
executor.run_tracker.update_run.assert_called()
|
||||
kw = executor.run_tracker.update_run.call_args.kwargs
|
||||
assert kw["status"] == "FAIL"
|
||||
|
||||
|
||||
# Feature: scheduler-refactor, Property 4: 工具类任务由元数据决定
|
||||
# **Validates: Requirements 1.6, 4.2**
|
||||
|
||||
class TestProperty4UtilityTaskDeterminedByMetadata:
|
||||
|
||||
@given(task_code=task_code_st)
|
||||
@settings(max_examples=100)
|
||||
def test_utility_task_skips_cursor_and_run_tracker(self, task_code):
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=False, task_type="utility")
|
||||
executor = _make_executor(registry)
|
||||
mock_task = MagicMock()
|
||||
mock_task.execute.return_value = {"status": "SUCCESS", "counts": {}}
|
||||
registry.create_task = MagicMock(return_value=mock_task)
|
||||
|
||||
result = executor.run_single_task(task_code, "util-uuid", 1, "hybrid")
|
||||
|
||||
executor.cursor_mgr.get_or_create.assert_not_called()
|
||||
executor.cursor_mgr.advance.assert_not_called()
|
||||
executor.run_tracker.create_run.assert_not_called()
|
||||
assert result.get("status") == "SUCCESS"
|
||||
|
||||
@given(task_code=task_code_st)
|
||||
@settings(max_examples=100)
|
||||
def test_non_utility_task_uses_cursor_and_run_tracker(self, task_code):
|
||||
registry = TaskRegistry()
|
||||
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWS")
|
||||
task_result = {"status": "SUCCESS", "counts": {"fetched": 1}}
|
||||
executor = _make_executor(registry)
|
||||
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
|
||||
executor.run_tracker.create_run.return_value = 300
|
||||
|
||||
with (
|
||||
patch.object(TaskExecutor, "_load_task_config", return_value={
|
||||
"task_id": 77, "task_code": task_code, "store_id": 1, "enabled": True}),
|
||||
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
|
||||
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
|
||||
):
|
||||
executor.run_single_task(task_code, "non-util-uuid", 1, "offline")
|
||||
|
||||
executor.cursor_mgr.get_or_create.assert_called_once()
|
||||
executor.run_tracker.create_run.assert_called_once()
|
||||
139
apps/etl/pipelines/feiqiu/tests/unit/test_task_registry.py
Normal file
139
apps/etl/pipelines/feiqiu/tests/unit/test_task_registry.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskRegistry 单元测试 — 验证 TaskMeta 元数据注册与查询"""
|
||||
import pytest
|
||||
from orchestration.task_registry import TaskRegistry, TaskMeta
|
||||
|
||||
|
||||
# ── 辅助:用作注册的假任务类 ──────────────────────────────────
|
||||
|
||||
class _FakeTask:
|
||||
"""占位任务类,用于测试注册"""
|
||||
def __init__(self, config, db_connection, api_client, logger):
|
||||
self.config = config
|
||||
|
||||
|
||||
class _AnotherFakeTask:
|
||||
def __init__(self, config, db_connection, api_client, logger):
|
||||
pass
|
||||
|
||||
|
||||
# ── fixtures ──────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
return TaskRegistry()
|
||||
|
||||
|
||||
# ── register + get_metadata ───────────────────────────────────
|
||||
|
||||
class TestRegisterAndMetadata:
|
||||
"""注册与元数据查询"""
|
||||
|
||||
def test_register_with_defaults(self, registry):
|
||||
"""仅传 task_code + task_class 时,元数据使用默认值(向后兼容)"""
|
||||
registry.register("MY_TASK", _FakeTask)
|
||||
meta = registry.get_metadata("MY_TASK")
|
||||
assert meta is not None
|
||||
assert meta.task_class is _FakeTask
|
||||
assert meta.requires_db_config is True
|
||||
assert meta.layer is None
|
||||
assert meta.task_type == "etl"
|
||||
|
||||
def test_register_with_full_metadata(self, registry):
|
||||
"""传入完整元数据"""
|
||||
registry.register(
|
||||
"ODS_ORDERS", _FakeTask,
|
||||
requires_db_config=True, layer="ODS", task_type="etl",
|
||||
)
|
||||
meta = registry.get_metadata("ODS_ORDERS")
|
||||
assert meta.layer == "ODS"
|
||||
assert meta.task_type == "etl"
|
||||
|
||||
def test_register_utility_task(self, registry):
|
||||
"""工具类任务:requires_db_config=False"""
|
||||
registry.register(
|
||||
"INIT_SCHEMA", _FakeTask,
|
||||
requires_db_config=False, task_type="utility",
|
||||
)
|
||||
meta = registry.get_metadata("INIT_SCHEMA")
|
||||
assert meta.requires_db_config is False
|
||||
assert meta.task_type == "utility"
|
||||
|
||||
def test_case_insensitive_lookup(self, registry):
|
||||
"""task_code 大小写不敏感"""
|
||||
registry.register("my_task", _FakeTask)
|
||||
assert registry.get_metadata("MY_TASK") is not None
|
||||
assert registry.get_metadata("my_task") is not None
|
||||
|
||||
def test_get_metadata_unknown_returns_none(self, registry):
|
||||
"""查询未注册的任务返回 None"""
|
||||
assert registry.get_metadata("NONEXISTENT") is None
|
||||
|
||||
|
||||
# ── create_task(接口不变)────────────────────────────────────
|
||||
|
||||
class TestCreateTask:
|
||||
|
||||
def test_create_task_returns_instance(self, registry):
|
||||
registry.register("MY_TASK", _FakeTask)
|
||||
task = registry.create_task("MY_TASK", {"k": "v"}, None, None, None)
|
||||
assert isinstance(task, _FakeTask)
|
||||
assert task.config == {"k": "v"}
|
||||
|
||||
def test_create_task_unknown_raises(self, registry):
|
||||
with pytest.raises(ValueError, match="未知的任务类型"):
|
||||
registry.create_task("NOPE", None, None, None, None)
|
||||
|
||||
|
||||
# ── get_tasks_by_layer ────────────────────────────────────────
|
||||
|
||||
class TestGetTasksByLayer:
|
||||
|
||||
def test_returns_matching_tasks(self, registry):
|
||||
registry.register("A", _FakeTask, layer="ODS")
|
||||
registry.register("B", _AnotherFakeTask, layer="ODS")
|
||||
registry.register("C", _FakeTask, layer="DWD")
|
||||
result = registry.get_tasks_by_layer("ODS")
|
||||
assert set(result) == {"A", "B"}
|
||||
|
||||
def test_case_insensitive_layer(self, registry):
|
||||
registry.register("X", _FakeTask, layer="dws")
|
||||
assert registry.get_tasks_by_layer("DWS") == ["X"]
|
||||
|
||||
def test_no_match_returns_empty(self, registry):
|
||||
registry.register("A", _FakeTask, layer="ODS")
|
||||
assert registry.get_tasks_by_layer("INDEX") == []
|
||||
|
||||
def test_none_layer_excluded(self, registry):
|
||||
"""layer=None 的任务不会被任何层查询返回"""
|
||||
registry.register("UTIL", _FakeTask) # layer 默认 None
|
||||
assert registry.get_tasks_by_layer("ODS") == []
|
||||
|
||||
|
||||
# ── is_utility_task ───────────────────────────────────────────
|
||||
|
||||
class TestIsUtilityTask:
|
||||
|
||||
def test_utility_task(self, registry):
|
||||
registry.register("INIT", _FakeTask, requires_db_config=False)
|
||||
assert registry.is_utility_task("INIT") is True
|
||||
|
||||
def test_normal_task(self, registry):
|
||||
registry.register("ETL", _FakeTask, requires_db_config=True)
|
||||
assert registry.is_utility_task("ETL") is False
|
||||
|
||||
def test_unknown_task(self, registry):
|
||||
assert registry.is_utility_task("NOPE") is False
|
||||
|
||||
|
||||
# ── get_all_task_codes(接口不变)──────────────────────────────
|
||||
|
||||
class TestGetAllTaskCodes:
|
||||
|
||||
def test_returns_all_codes(self, registry):
|
||||
registry.register("A", _FakeTask)
|
||||
registry.register("B", _AnotherFakeTask)
|
||||
assert set(registry.get_all_task_codes()) == {"A", "B"}
|
||||
|
||||
def test_empty_registry(self, registry):
|
||||
assert registry.get_all_task_codes() == []
|
||||
@@ -0,0 +1,165 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskRegistry 属性测试 — 使用 hypothesis 验证注册表的通用正确性属性。"""
|
||||
import string
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from orchestration.task_registry import TaskRegistry, TaskMeta
|
||||
|
||||
|
||||
# ── 辅助:动态生成假任务类 ────────────────────────────────────
|
||||
|
||||
def _make_fake_class(name: str = "FakeTask") -> type:
|
||||
"""创建一个最小化的假任务类,用于注册测试。"""
|
||||
return type(name, (), {"__init__": lambda self, *a, **kw: None})
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
# 合法任务代码:大写字母 + 数字 + 下划线,长度 1~30
|
||||
task_code_st = st.text(
|
||||
alphabet=string.ascii_uppercase + string.digits + "_",
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
)
|
||||
|
||||
requires_db_config_st = st.booleans()
|
||||
|
||||
layer_st = st.sampled_from([None, "ODS", "DWD", "DWS", "INDEX"])
|
||||
|
||||
task_type_st = st.sampled_from(["etl", "utility", "verification"])
|
||||
|
||||
|
||||
# ── Property 8: TaskRegistry 元数据 round-trip ────────────────
|
||||
# Feature: scheduler-refactor, Property 8: TaskRegistry 元数据 round-trip
|
||||
# **Validates: Requirements 4.1**
|
||||
#
|
||||
# 对于任意任务代码、任务类和元数据组合(requires_db_config、layer、task_type),
|
||||
# 注册后通过 get_metadata 查询应返回相同的元数据值。
|
||||
|
||||
|
||||
class TestProperty8MetadataRoundTrip:
|
||||
"""Property 8: 注册元数据后查询应返回完全相同的值。"""
|
||||
|
||||
@given(
|
||||
task_code=task_code_st,
|
||||
requires_db=requires_db_config_st,
|
||||
layer=layer_st,
|
||||
task_type=task_type_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_metadata_round_trip(
|
||||
self, task_code: str, requires_db: bool, layer: str | None, task_type: str
|
||||
):
|
||||
"""注册任意元数据组合后,get_metadata 应返回相同的值。"""
|
||||
# Arrange — 每次迭代使用全新的注册表,避免状态泄漏
|
||||
registry = TaskRegistry()
|
||||
fake_cls = _make_fake_class()
|
||||
|
||||
# Act — 注册并查询
|
||||
registry.register(
|
||||
task_code,
|
||||
fake_cls,
|
||||
requires_db_config=requires_db,
|
||||
layer=layer,
|
||||
task_type=task_type,
|
||||
)
|
||||
meta = registry.get_metadata(task_code)
|
||||
|
||||
# Assert — 元数据 round-trip 一致
|
||||
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
|
||||
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
|
||||
assert meta.requires_db_config is requires_db, (
|
||||
f"requires_db_config 应为 {requires_db},实际为 {meta.requires_db_config}"
|
||||
)
|
||||
assert meta.layer == layer, f"layer 应为 {layer!r},实际为 {meta.layer!r}"
|
||||
assert meta.task_type == task_type, (
|
||||
f"task_type 应为 {task_type!r},实际为 {meta.task_type!r}"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 9: TaskRegistry 向后兼容默认值 ───────────────────
|
||||
# Feature: scheduler-refactor, Property 9: TaskRegistry 向后兼容默认值
|
||||
# **Validates: Requirements 4.4**
|
||||
#
|
||||
# 对于任意使用旧接口(仅 task_code 和 task_class)注册的任务,
|
||||
# 查询元数据应返回 requires_db_config=True、layer=None、task_type="etl"。
|
||||
|
||||
|
||||
class TestProperty9BackwardCompatibleDefaults:
|
||||
"""Property 9: 仅传 task_code + task_class 时,元数据应使用默认值。"""
|
||||
|
||||
@given(task_code=task_code_st)
|
||||
@settings(max_examples=100)
|
||||
def test_legacy_register_uses_defaults(self, task_code: str):
|
||||
"""使用旧接口(仅 task_code 和 task_class)注册后,元数据应为默认值。"""
|
||||
# Arrange
|
||||
registry = TaskRegistry()
|
||||
fake_cls = _make_fake_class()
|
||||
|
||||
# Act — 仅传 task_code 和 task_class,不传任何元数据参数
|
||||
registry.register(task_code, fake_cls)
|
||||
meta = registry.get_metadata(task_code)
|
||||
|
||||
# Assert — 默认值契约
|
||||
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
|
||||
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
|
||||
assert meta.requires_db_config is True, (
|
||||
f"默认 requires_db_config 应为 True,实际为 {meta.requires_db_config}"
|
||||
)
|
||||
assert meta.layer is None, (
|
||||
f"默认 layer 应为 None,实际为 {meta.layer!r}"
|
||||
)
|
||||
assert meta.task_type == "etl", (
|
||||
f"默认 task_type 应为 'etl',实际为 {meta.task_type!r}"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 10: 按层查询任务 ────────────────────────────────
|
||||
# Feature: scheduler-refactor, Property 10: 按层查询任务
|
||||
# **Validates: Requirements 4.3**
|
||||
#
|
||||
# 对于任意注册了 layer 元数据的任务集合,get_tasks_by_layer(layer)
|
||||
# 返回的任务代码集合应等于所有 layer 匹配的已注册任务代码集合。
|
||||
|
||||
# 非 None 的层值策略,用于查询验证
|
||||
non_none_layer_st = st.sampled_from(["ODS", "DWD", "DWS", "INDEX"])
|
||||
|
||||
|
||||
class TestProperty10GetTasksByLayer:
|
||||
"""Property 10: get_tasks_by_layer 返回的集合应与手动过滤一致。"""
|
||||
|
||||
@given(
|
||||
entries=st.lists(
|
||||
st.tuples(task_code_st, layer_st),
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_get_tasks_by_layer_matches_manual_filter(
|
||||
self, entries: list[tuple[str, str | None]],
|
||||
):
|
||||
"""注册一组任务后,按层查询结果应与手动过滤完全一致。"""
|
||||
# Arrange
|
||||
registry = TaskRegistry()
|
||||
# 去重:同一 task_code 只保留最后一次注册(与 register 覆盖语义一致)
|
||||
unique_entries: dict[str, str | None] = {}
|
||||
for code, layer in entries:
|
||||
fake_cls = _make_fake_class(f"Fake_{code}")
|
||||
registry.register(code, fake_cls, layer=layer)
|
||||
unique_entries[code.upper()] = layer # register 内部会 upper()
|
||||
|
||||
# Act & Assert — 对每个非 None 的层值进行验证
|
||||
for query_layer in ["ODS", "DWD", "DWS", "INDEX"]:
|
||||
actual = set(registry.get_tasks_by_layer(query_layer))
|
||||
expected = {
|
||||
code for code, layer in unique_entries.items()
|
||||
if layer is not None and layer.upper() == query_layer.upper()
|
||||
}
|
||||
assert actual == expected, (
|
||||
f"查询 layer={query_layer!r} 时,"
|
||||
f"期望 {expected},实际 {actual}"
|
||||
)
|
||||
358
apps/etl/pipelines/feiqiu/tests/unit/test_validate_bd_manual.py
Normal file
358
apps/etl/pipelines/feiqiu/tests/unit/test_validate_bd_manual.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
scripts/validate_bd_manual.py 的单元测试。
|
||||
|
||||
不需要数据库连接,通过构造临时文件系统结构来验证各 check 函数。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# 被测模块
|
||||
from scripts.validate_bd_manual import (
|
||||
check_directory_structure,
|
||||
check_ods_doc_coverage,
|
||||
check_ods_doc_format,
|
||||
check_ods_doc_naming,
|
||||
check_mapping_doc_coverage,
|
||||
check_mapping_doc_content,
|
||||
check_mapping_doc_naming,
|
||||
check_ods_dictionary_coverage,
|
||||
BD_MANUAL_ROOT,
|
||||
ODS_MAIN_DIR,
|
||||
ODS_MAPPINGS_DIR,
|
||||
ODS_DICT_PATH,
|
||||
DATA_LAYERS,
|
||||
)
|
||||
import scripts.validate_bd_manual as mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:临时目录结构搭建
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _setup_layer_dirs(tmp_path: Path) -> None:
|
||||
"""在 tmp_path 下创建完整的四层目录结构。"""
|
||||
for layer in DATA_LAYERS:
|
||||
(tmp_path / "docs" / "bd_manual" / layer / "main").mkdir(parents=True, exist_ok=True)
|
||||
(tmp_path / "docs" / "bd_manual" / layer / "changes").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
SAMPLE_ODS_DOC = textwrap.dedent("""\
|
||||
# test_table 测试表
|
||||
|
||||
> 生成时间:2026-01-01
|
||||
|
||||
## 表信息
|
||||
|
||||
| 属性 | 值 |
|
||||
|------|-----|
|
||||
| Schema | billiards_ods |
|
||||
| 表名 | test_table |
|
||||
| 主键 | id, content_hash |
|
||||
| 数据来源 | TestEndpoint |
|
||||
| 说明 | 测试表 |
|
||||
|
||||
## 字段说明
|
||||
|
||||
| 序号 | 字段名 | 类型 | 可空 | 说明 |
|
||||
|------|--------|------|------|------|
|
||||
| 1 | id | BIGINT | NO | 主键 |
|
||||
|
||||
## 使用说明
|
||||
|
||||
```sql
|
||||
SELECT * FROM billiards_ods.test_table LIMIT 10;
|
||||
```
|
||||
|
||||
## ETL 元数据字段
|
||||
|
||||
| 字段名 | 类型 | 说明 |
|
||||
|--------|------|------|
|
||||
| content_hash | TEXT | SHA256 |
|
||||
| source_file | TEXT | 文件名 |
|
||||
| source_endpoint | TEXT | 端点 |
|
||||
| fetched_at | TIMESTAMPTZ | 时间戳 |
|
||||
| payload | JSONB | 原始 JSON |
|
||||
|
||||
## 可回溯性
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| 可回溯 | ✅ 完全可回溯 |
|
||||
| 数据来源 | TestEndpoint |
|
||||
""")
|
||||
|
||||
|
||||
SAMPLE_MAPPING_DOC = textwrap.dedent("""\
|
||||
# TestEndpoint → test_table 字段映射
|
||||
|
||||
> 生成时间:2026-01-01
|
||||
|
||||
## 端点信息
|
||||
|
||||
| 属性 | 值 |
|
||||
|------|-----|
|
||||
| 接口路径 | `Test/TestEndpoint` |
|
||||
| 请求方法 | POST |
|
||||
| ODS 对应表 | `billiards_ods.test_table` |
|
||||
| JSON 数据路径 | `data.items` |
|
||||
|
||||
## 字段映射
|
||||
|
||||
| JSON 字段 | ODS 列名 | 类型转换 | 说明 |
|
||||
|-----------|----------|----------|------|
|
||||
| id | id | int→BIGINT | 主键 |
|
||||
|
||||
## ETL 补充字段
|
||||
|
||||
| ODS 列名 | 生成逻辑 |
|
||||
|-----------|----------|
|
||||
| content_hash | SHA256 |
|
||||
| source_file | 固定值 |
|
||||
| source_endpoint | 端点路径 |
|
||||
| fetched_at | 入库时间戳 |
|
||||
| payload | 原始 JSON |
|
||||
""")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 1: 目录结构一致性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckDirectoryStructure:
|
||||
"""Property 1: 数据层目录结构一致性。"""
|
||||
|
||||
def test_pass_when_all_dirs_exist(self, tmp_path, monkeypatch):
|
||||
_setup_layer_dirs(tmp_path)
|
||||
monkeypatch.setattr(mod, "BD_MANUAL_ROOT", tmp_path / "docs" / "bd_manual")
|
||||
result = check_directory_structure()
|
||||
assert result.passed is True
|
||||
assert result.property_id == "Property 1"
|
||||
|
||||
def test_fail_when_missing_subdir(self, tmp_path, monkeypatch):
|
||||
_setup_layer_dirs(tmp_path)
|
||||
# 删除 ETL_Admin/changes/
|
||||
import shutil
|
||||
shutil.rmtree(tmp_path / "docs" / "bd_manual" / "ETL_Admin" / "changes")
|
||||
monkeypatch.setattr(mod, "BD_MANUAL_ROOT", tmp_path / "docs" / "bd_manual")
|
||||
result = check_directory_structure()
|
||||
assert result.passed is False
|
||||
assert any("ETL_Admin" in d and "changes" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 4: ODS 文档覆盖率
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDocCoverage:
|
||||
"""Property 4: ODS 表级文档覆盖率。"""
|
||||
|
||||
def test_pass_when_all_docs_exist(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
tables = ["member_profiles", "payment_transactions"]
|
||||
for t in tables:
|
||||
(ods_main / f"BD_manual_{t}.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_coverage(tables)
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_doc_missing(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_coverage(["member_profiles", "missing_table"])
|
||||
assert result.passed is False
|
||||
assert any("missing_table" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 5: ODS 文档格式完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDocFormat:
|
||||
"""Property 5: ODS 表级文档格式完整性。"""
|
||||
|
||||
def test_pass_with_complete_doc(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_test_table.md").write_text(SAMPLE_ODS_DOC, encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_format()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_missing_section(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
# 去掉"可回溯性"章节
|
||||
incomplete = SAMPLE_ODS_DOC.replace("## 可回溯性", "## 其他")
|
||||
(ods_main / "BD_manual_test_table.md").write_text(incomplete, encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_format()
|
||||
assert result.passed is False
|
||||
assert any("可回溯性" in d for d in result.details)
|
||||
|
||||
def test_fail_when_missing_etl_field(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
# 去掉 content_hash
|
||||
incomplete = SAMPLE_ODS_DOC.replace("content_hash", "xxx_hash")
|
||||
(ods_main / "BD_manual_test_table.md").write_text(incomplete, encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_format()
|
||||
assert result.passed is False
|
||||
assert any("content_hash" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 6: ODS 文档命名规范
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDocNaming:
|
||||
"""Property 6: ODS 表级文档命名规范。"""
|
||||
|
||||
def test_pass_with_valid_names(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
|
||||
(ods_main / "BD_manual_payment_transactions.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_naming()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_with_invalid_name(self, tmp_path, monkeypatch):
|
||||
"""ODS/main/ 下所有 .md 文件都应符合命名规范,BadName.md 应被检出。"""
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
|
||||
(ods_main / "BadName.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_naming()
|
||||
assert result.passed is False
|
||||
assert any("BadName.md" in d for d in result.details)
|
||||
|
||||
def test_fail_with_uppercase_table_name(self, tmp_path, monkeypatch):
|
||||
ods_main = tmp_path / "ODS" / "main"
|
||||
ods_main.mkdir(parents=True)
|
||||
# BD_manual_ 后面跟大写字母,不符合 snake_case
|
||||
(ods_main / "BD_manual_MemberProfiles.md").write_text("# doc", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
|
||||
result = check_ods_doc_naming()
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 7: 映射文档覆盖率
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckMappingDocCoverage:
|
||||
"""Property 7: 映射文档覆盖率。"""
|
||||
|
||||
def test_pass_when_all_mappings_exist(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_GetTest_test_table.md").write_text("# map", encoding="utf-8")
|
||||
(mappings / "mapping_GetOther_other_table.md").write_text("# map", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_coverage(["test_table", "other_table"])
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_mapping_missing(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_GetTest_test_table.md").write_text("# map", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_coverage(["test_table", "missing_table"])
|
||||
assert result.passed is False
|
||||
assert any("missing_table" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 8: 映射文档内容完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckMappingDocContent:
|
||||
"""Property 8: 映射文档内容完整性。"""
|
||||
|
||||
def test_pass_with_complete_mapping(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_TestEndpoint_test_table.md").write_text(
|
||||
SAMPLE_MAPPING_DOC, encoding="utf-8"
|
||||
)
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_content()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_missing_section(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
incomplete = SAMPLE_MAPPING_DOC.replace("## 字段映射", "## 其他映射")
|
||||
(mappings / "mapping_TestEndpoint_test_table.md").write_text(
|
||||
incomplete, encoding="utf-8"
|
||||
)
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_content()
|
||||
assert result.passed is False
|
||||
assert any("字段映射" in d for d in result.details)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 9: 映射文档命名规范
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckMappingDocNaming:
|
||||
"""Property 9: 映射文档命名规范。"""
|
||||
|
||||
def test_pass_with_valid_names(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
(mappings / "mapping_GetTenantMemberList_member_profiles.md").write_text("# m", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_naming()
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_with_lowercase_endpoint(self, tmp_path, monkeypatch):
|
||||
mappings = tmp_path / "ODS" / "mappings"
|
||||
mappings.mkdir(parents=True)
|
||||
# 端点名以小写开头,不符合 PascalCase
|
||||
(mappings / "mapping_getTenantMemberList_member_profiles.md").write_text("# m", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
|
||||
result = check_mapping_doc_naming()
|
||||
assert result.passed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Property 10: ODS 数据字典覆盖率
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsDictionaryCoverage:
|
||||
"""Property 10: ODS 数据字典覆盖率。"""
|
||||
|
||||
def test_pass_when_all_tables_in_dict(self, tmp_path, monkeypatch):
|
||||
dict_file = tmp_path / "ods_tables_dictionary.md"
|
||||
dict_file.write_text(
|
||||
"| `member_profiles` | 会员 |\n| `payment_transactions` | 支付 |",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(mod, "ODS_DICT_PATH", dict_file)
|
||||
result = check_ods_dictionary_coverage(["member_profiles", "payment_transactions"])
|
||||
assert result.passed is True
|
||||
|
||||
def test_fail_when_table_missing_from_dict(self, tmp_path, monkeypatch):
|
||||
dict_file = tmp_path / "ods_tables_dictionary.md"
|
||||
dict_file.write_text("| `member_profiles` | 会员 |", encoding="utf-8")
|
||||
monkeypatch.setattr(mod, "ODS_DICT_PATH", dict_file)
|
||||
result = check_ods_dictionary_coverage(["member_profiles", "missing_table"])
|
||||
assert result.passed is False
|
||||
assert any("missing_table" in d for d in result.details)
|
||||
|
||||
def test_fail_when_dict_file_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(mod, "ODS_DICT_PATH", tmp_path / "nonexistent.md")
|
||||
result = check_ods_dictionary_coverage(["some_table"])
|
||||
assert result.passed is False
|
||||
Reference in New Issue
Block a user