init: 项目初始提交 - NeoZQYY Monorepo 完整代码

This commit is contained in:
Neo
2026-02-15 14:58:14 +08:00
commit ded6dfb9d8
769 changed files with 182616 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
# tests/ — 测试套件
## 目录结构
```
tests/
├── unit/ # 单元测试FakeDB/FakeAPI无需真实数据库
│ ├── task_test_utils.py # 测试工具FakeDBOperations、FakeAPIClient、OfflineAPIClient、TaskSpec
│ ├── test_ods_tasks.py # ODS 任务在线/离线模式测试
│ ├── test_cli_args.py # CLI 参数解析测试
│ ├── test_config.py # 配置管理测试
│ ├── test_e2e_flow.py # 端到端流程测试CLI → PipelineRunner → TaskExecutor
│ ├── test_task_registry.py # 任务注册表测试
│ ├── test_*_properties.py # 属性测试hypothesis
│ └── test_audit_*.py # 仓库审计相关测试
└── integration/ # 集成测试(需要真实数据库)
├── test_database.py # 数据库连接与操作测试
└── test_index_tasks.py # 指数任务集成测试
```
## 运行测试
```bash
# 安装测试依赖
pip install pytest hypothesis
# 全部单元测试
pytest tests/unit
# 指定测试文件
pytest tests/unit/test_ods_tasks.py
# 按关键字过滤
pytest tests/unit -k "online"
# 集成测试(需要设置 TEST_DB_DSN
TEST_DB_DSN="postgresql://user:pass@host:5432/db" pytest tests/integration
# 查看详细输出
pytest tests/unit -v --tb=short
```
## 测试工具task_test_utils.py
单元测试通过 `tests/unit/task_test_utils.py` 提供的桩对象避免依赖真实数据库和 API
- `FakeDBOperations` — 拦截并记录 upsert/execute/commit/rollback不触碰真实数据库
- `FakeAPIClient` — 在线模式桩,直接返回预置的内存数据
- `OfflineAPIClient` — 离线模式桩,从归档 JSON 文件回放数据
- `TaskSpec` — 描述任务测试元数据(任务代码、端点、数据路径、样例记录)
- `create_test_config()` — 构建测试用 `AppConfig`
- `dump_offline_payload()` — 将样例数据写入归档目录供离线测试使用
## 编写新测试
- 单元测试放在 `tests/unit/`,文件名 `test_*.py`
- 使用 `FakeDBOperations``FakeAPIClient` 避免外部依赖
- 属性测试使用 `hypothesis`,文件名以 `_properties.py` 结尾
- 集成测试放在 `tests/integration/`,通过 `TEST_DB_DSN` 环境变量控制是否执行

View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
"""数据库集成测试"""
import pytest
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
# 注意:这些测试需要实际的数据库连接
# 在CI/CD环境中应使用测试数据库
@pytest.fixture
def db_connection():
"""数据库连接fixture"""
# 从环境变量获取测试数据库DSN
import os
dsn = os.environ.get("TEST_DB_DSN")
if not dsn:
pytest.skip("未配置测试数据库")
conn = DatabaseConnection(dsn)
yield conn
conn.close()
def test_database_query(db_connection):
"""测试数据库查询"""
result = db_connection.query("SELECT 1 AS test")
assert len(result) == 1
assert result[0]["test"] == 1
def test_database_operations(db_connection):
"""测试数据库操作"""
ops = DatabaseOperations(db_connection)
# 添加实际的测试用例
pass

View File

@@ -0,0 +1,238 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 移除 dws_member_assistant_intimacy 表存在性检查
"""Smoke test scripts for WBI/NCI index tasks."""
import logging
import os
import sys
from typing import Dict, List
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from tasks.dws.index import NewconvIndexTask, WinbackIndexTask
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("test_index_tasks")
def _make_db() -> tuple[AppConfig, DatabaseConnection, DatabaseOperations]:
config = AppConfig.load()
db_conn = DatabaseConnection(config.config["db"]["dsn"])
db = DatabaseOperations(db_conn)
return config, db_conn, db
def _dict_rows(rows) -> List[Dict]:
return [dict(r) for r in (rows or [])]
def _fmt(value, digits: int = 2) -> str:
if value is None:
return "-"
if isinstance(value, (int, float)):
return f"{value:.{digits}f}"
return str(value)
def _check_required_tables() -> None:
_, db_conn, db = _make_db()
try:
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'billiards_dws'
AND table_name IN (
'cfg_index_parameters',
'dws_member_winback_index',
'dws_member_newconv_index'
)
"""
rows = _dict_rows(db.query(sql))
existing = {r["table_name"] for r in rows}
required = {
"cfg_index_parameters",
"dws_member_winback_index",
"dws_member_newconv_index",
}
missing = sorted(required - existing)
if missing:
raise RuntimeError(f"Missing required tables: {', '.join(missing)}")
finally:
db_conn.close()
def test_winback_index() -> Dict:
logger.info("=" * 80)
logger.info("Run WBI task")
logger.info("=" * 80)
config, db_conn, db = _make_db()
try:
task = WinbackIndexTask(config, db, None, logger)
result = task.execute(None)
logger.info("WBI result: %s", result)
if result.get("status") == "success":
stats_sql = """
SELECT
COUNT(*) AS total_count,
ROUND(AVG(display_score)::numeric, 2) AS avg_display,
ROUND(MIN(display_score)::numeric, 2) AS min_display,
ROUND(MAX(display_score)::numeric, 2) AS max_display,
ROUND(AVG(raw_score)::numeric, 4) AS avg_raw,
ROUND(AVG(overdue_old)::numeric, 4) AS avg_overdue,
ROUND(AVG(drop_old)::numeric, 4) AS avg_drop,
ROUND(AVG(recharge_old)::numeric, 4) AS avg_recharge,
ROUND(AVG(value_old)::numeric, 4) AS avg_value,
ROUND(AVG(t_v)::numeric, 2) AS avg_t_v
FROM billiards_dws.dws_member_winback_index
"""
stats_rows = _dict_rows(db.query(stats_sql))
if stats_rows:
s = stats_rows[0]
logger.info(
"WBI stats | total=%s, display(avg/min/max)=%s/%s/%s, raw_avg=%s, overdue=%s, drop=%s, recharge=%s, value=%s, t_v=%s",
s.get("total_count"),
_fmt(s.get("avg_display")),
_fmt(s.get("min_display")),
_fmt(s.get("max_display")),
_fmt(s.get("avg_raw"), 4),
_fmt(s.get("avg_overdue"), 4),
_fmt(s.get("avg_drop"), 4),
_fmt(s.get("avg_recharge"), 4),
_fmt(s.get("avg_value"), 4),
_fmt(s.get("avg_t_v"), 2),
)
top_sql = """
SELECT member_id, display_score, raw_score, t_v, visits_14d, sv_balance
FROM billiards_dws.dws_member_winback_index
ORDER BY display_score DESC NULLS LAST
LIMIT 5
"""
for i, r in enumerate(_dict_rows(db.query(top_sql)), 1):
logger.info(
"WBI TOP%d | member=%s, display=%s, raw=%s, t_v=%s, visits_14d=%s, sv_balance=%s",
i,
r.get("member_id"),
_fmt(r.get("display_score")),
_fmt(r.get("raw_score"), 4),
_fmt(r.get("t_v"), 2),
_fmt(r.get("visits_14d"), 0),
_fmt(r.get("sv_balance"), 2),
)
return result
finally:
db_conn.close()
def test_newconv_index() -> Dict:
logger.info("=" * 80)
logger.info("Run NCI task")
logger.info("=" * 80)
config, db_conn, db = _make_db()
try:
task = NewconvIndexTask(config, db, None, logger)
result = task.execute(None)
logger.info("NCI result: %s", result)
if result.get("status") == "success":
stats_sql = """
SELECT
COUNT(*) AS total_count,
ROUND(AVG(display_score)::numeric, 2) AS avg_display,
ROUND(MIN(display_score)::numeric, 2) AS min_display,
ROUND(MAX(display_score)::numeric, 2) AS max_display,
ROUND(AVG(display_score_welcome)::numeric, 2) AS avg_display_welcome,
ROUND(AVG(display_score_convert)::numeric, 2) AS avg_display_convert,
ROUND(AVG(raw_score)::numeric, 4) AS avg_raw,
ROUND(AVG(raw_score_welcome)::numeric, 4) AS avg_raw_welcome,
ROUND(AVG(raw_score_convert)::numeric, 4) AS avg_raw_convert,
ROUND(AVG(need_new)::numeric, 4) AS avg_need,
ROUND(AVG(salvage_new)::numeric, 4) AS avg_salvage,
ROUND(AVG(recharge_new)::numeric, 4) AS avg_recharge,
ROUND(AVG(value_new)::numeric, 4) AS avg_value,
ROUND(AVG(welcome_new)::numeric, 4) AS avg_welcome,
ROUND(AVG(t_v)::numeric, 2) AS avg_t_v
FROM billiards_dws.dws_member_newconv_index
"""
stats_rows = _dict_rows(db.query(stats_sql))
if stats_rows:
s = stats_rows[0]
logger.info(
"NCI stats | total=%s, display(avg/min/max)=%s/%s/%s, display_welcome=%s, display_convert=%s, raw_avg=%s, raw_welcome=%s, raw_convert=%s",
s.get("total_count"),
_fmt(s.get("avg_display")),
_fmt(s.get("min_display")),
_fmt(s.get("max_display")),
_fmt(s.get("avg_display_welcome")),
_fmt(s.get("avg_display_convert")),
_fmt(s.get("avg_raw"), 4),
_fmt(s.get("avg_raw_welcome"), 4),
_fmt(s.get("avg_raw_convert"), 4),
)
logger.info(
"NCI components | need=%s, salvage=%s, recharge=%s, value=%s, welcome=%s, t_v=%s",
_fmt(s.get("avg_need"), 4),
_fmt(s.get("avg_salvage"), 4),
_fmt(s.get("avg_recharge"), 4),
_fmt(s.get("avg_value"), 4),
_fmt(s.get("avg_welcome"), 4),
_fmt(s.get("avg_t_v"), 2),
)
top_sql = """
SELECT member_id, display_score, display_score_welcome, display_score_convert,
raw_score, raw_score_welcome, raw_score_convert, t_v, visits_14d
FROM billiards_dws.dws_member_newconv_index
ORDER BY display_score DESC NULLS LAST
LIMIT 5
"""
for i, r in enumerate(_dict_rows(db.query(top_sql)), 1):
logger.info(
"NCI TOP%d | member=%s, nci=%s (welcome=%s, convert=%s), raw=%s (w=%s,c=%s), t_v=%s, visits_14d=%s",
i,
r.get("member_id"),
_fmt(r.get("display_score")),
_fmt(r.get("display_score_welcome")),
_fmt(r.get("display_score_convert")),
_fmt(r.get("raw_score"), 4),
_fmt(r.get("raw_score_welcome"), 4),
_fmt(r.get("raw_score_convert"), 4),
_fmt(r.get("t_v"), 2),
_fmt(r.get("visits_14d"), 0),
)
return result
finally:
db_conn.close()
def main() -> None:
_check_required_tables()
results = {
"WBI": test_winback_index(),
"NCI": test_newconv_index(),
}
logger.info("=" * 80)
logger.info("Test complete")
logger.info("WBI=%s, NCI=%s", results["WBI"].get("status"), results["NCI"].get("status"))
logger.info("=" * 80)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,392 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG
# - 2026-02-14 | 删除废弃的 14 个独立 ODS 任务 TaskSpec 定义和对应 import修复语法错误TASK_SPECS=[] 后残留孤立代码块)
# 直接原因: 之前清理只把 TASK_SPECS 赋值为空列表,但未删除后续 ~370 行废弃 TaskSpec 定义,导致 IndentationError
# 验证: `python -c "import ast; ast.parse(open('tests/unit/task_test_utils.py','utf-8').read()); print('OK')"`
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
from __future__ import annotations
import json
import os
import re
from types import SimpleNamespace
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple, Type
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations as PgDBOperations
from utils.json_store import endpoint_to_filename
DEFAULT_STORE_ID = 2790685415443269
BASE_TS = "2025-01-01 10:00:00"
END_TS = "2025-01-01 12:00:00"
@dataclass(frozen=True)
class TaskSpec:
"""描述单个任务在测试中如何被驱动的元数据包含任务代码、API 路径、数据路径与样例记录。"""
code: str
task_cls: Type
endpoint: str
data_path: Tuple[str, ...]
sample_records: List[Dict]
@property
def archive_filename(self) -> str:
return endpoint_to_filename(self.endpoint)
def wrap_records(records: List[Dict], data_path: Sequence[str]):
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
payload = records
for key in reversed(data_path):
payload = {key: payload}
return payload
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
"""构建一份适合测试的 AppConfig自动填充存储、日志、归档目录等参数并保证目录存在。"""
archive_dir = Path(archive_dir)
temp_dir = Path(temp_dir)
archive_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
overrides = {
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Shanghai"},
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
"api": {
"base_url": "https://api.example.com",
"token": "test-token",
"timeout_sec": 3,
"page_size": 50,
},
"pipeline": {
"flow": flow,
"fetch_root": str(temp_dir / "json_fetch"),
"ingest_source_dir": str(archive_dir),
},
"io": {
"export_root": str(temp_dir / "export"),
"log_root": str(temp_dir / "logs"),
},
}
return AppConfig.load(overrides)
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
archive_dir = Path(archive_dir)
payload = wrap_records(spec.sample_records, spec.data_path)
file_path = archive_dir / spec.archive_filename
with file_path.open("w", encoding="utf-8") as fp:
json.dump(payload, fp, ensure_ascii=False)
return file_path
class FakeCursor:
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
def __init__(self, recorder: List[Dict], db_ops=None):
self.recorder = recorder
self._db_ops = db_ops
self._pending_rows: List[Tuple] = []
self._fetchall_rows: List[Tuple] = []
self.rowcount = 0
self.connection = SimpleNamespace(encoding="UTF8")
# pylint: disable=unused-argument
def execute(self, sql: str, params=None):
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
self.recorder.append({"sql": sql_text.strip(), "params": params})
self._fetchall_rows = []
# 处理 information_schema 查询,用于结构感知写入。
lowered = sql_text.lower()
if "from information_schema.columns" in lowered:
table_name = None
if params and len(params) >= 2:
table_name = params[1]
self._fetchall_rows = self._fake_columns(table_name)
return
if "from information_schema.table_constraints" in lowered:
self._fetchall_rows = []
return
if self._pending_rows:
self.rowcount = len(self._pending_rows)
self._record_upserts(sql_text)
self._pending_rows = []
else:
self.rowcount = 0
def fetchone(self):
return None
def fetchall(self):
return list(self._fetchall_rows)
def mogrify(self, template, args):
self._pending_rows.append(tuple(args))
return b"(?)"
def _record_upserts(self, sql_text: str):
if not self._db_ops:
return
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
if not match:
return
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
rows = []
for idx, row in enumerate(self._pending_rows):
if len(row) != len(columns):
continue
row_dict = {}
for col, val in zip(columns, row):
if col == "record_index" and val in (None, ""):
row_dict[col] = idx
continue
if hasattr(val, "adapted"):
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
else:
row_dict[col] = val
rows.append(row_dict)
if rows:
self._db_ops.upserts.append(
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
)
@staticmethod
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
return [
("id", "bigint", "int8"),
("sitegoodsstockid", "bigint", "int8"),
("record_index", "integer", "int4"),
("content_hash", "text", "text"),
("source_file", "text", "text"),
("source_endpoint", "text", "text"),
("fetched_at", "timestamp with time zone", "timestamptz"),
("payload", "jsonb", "jsonb"),
]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeConnection:
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
def __init__(self, db_ops):
self.statements: List[Dict] = []
self._db_ops = db_ops
def cursor(self):
return FakeCursor(self.statements, self._db_ops)
class FakeDBOperations:
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
def __init__(self):
self.upserts: List[Dict] = []
self.executes: List[Dict] = []
self.commits = 0
self.rollbacks = 0
self.conn = FakeConnection(self)
# 预设查询结果FIFO用于测试中控制数据库返回的行
self.query_results: List[List[Dict]] = []
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.upserts.append(
{
"sql": sql.strip(),
"count": len(rows),
"page_size": page_size,
"rows": [dict(row) for row in rows],
}
)
return len(rows), 0
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.executes.append(
{
"sql": sql.strip(),
"count": len(rows),
"page_size": page_size,
"rows": [dict(row) for row in rows],
}
)
def execute(self, sql: str, params=None):
self.executes.append({"sql": sql.strip(), "params": params})
def query(self, sql: str, params=None):
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
if self.query_results:
return self.query_results.pop(0)
return []
def cursor(self):
return self.conn.cursor()
def commit(self):
self.commits += 1
def rollback(self):
self.rollbacks += 1
class FakeAPIClient:
"""在线模式使用的伪 API Client直接返回预置的内存数据并记录调用以确保任务参数正确传递。"""
def __init__(self, data_map: Dict[str, List[Dict]]):
self.data_map = data_map
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def iter_paginated(
self,
endpoint: str,
params=None,
page_size: int = 200,
page_field: str = "page",
size_field: str = "limit",
data_path: Tuple[str, ...] = (),
list_key: str | None = None,
):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.data_map:
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
records = list(self.data_map[endpoint])
yield 1, records, dict(params or {}), {"data": records}
def get_paginated(self, endpoint: str, params=None, **kwargs):
records = []
pages = []
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
records.extend(page_records)
pages.append({"page": page_no, "request": req, "response": resp})
return records, pages
def get_source_hint(self, endpoint: str) -> str | None:
return None
class OfflineAPIClient:
"""离线模式专用 API Client根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
def __init__(self, file_map: Dict[str, Path]):
self.file_map = {k: Path(v) for k, v in file_map.items()}
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def iter_paginated(
self,
endpoint: str,
params=None,
page_size: int = 200,
page_field: str = "page",
size_field: str = "limit",
data_path: Tuple[str, ...] = (),
list_key: str | None = None,
):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.file_map:
raise AssertionError(f"Missing archive for endpoint {endpoint}")
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
payload = json.load(fp)
data = payload
for key in data_path:
if isinstance(data, dict):
data = data.get(key, [])
if list_key and isinstance(data, dict):
data = data.get(list_key, [])
if not isinstance(data, list):
data = []
total = len(data)
start = 0
page = 1
while start < total or (start == 0 and total == 0):
chunk = data[start : start + page_size]
if not chunk and total != 0:
break
yield page, list(chunk), dict(params or {}), payload
if len(chunk) < page_size:
break
start += page_size
page += 1
def get_paginated(self, endpoint: str, params=None, **kwargs):
records = []
pages = []
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
records.extend(page_records)
pages.append({"page": page_no, "request": req, "response": resp})
return records, pages
def get_source_hint(self, endpoint: str) -> str | None:
if endpoint not in self.file_map:
return None
return str(self.file_map[endpoint])
class RealDBOperationsAdapter:
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
def __init__(self, dsn: str):
self._conn = DatabaseConnection(dsn)
self._ops = PgDBOperations(self._conn)
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
self.conn = self._conn.conn
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_execute(sql, rows, page_size=page_size)
def commit(self):
self._conn.commit()
def rollback(self):
self._conn.rollback()
def close(self):
self._conn.close()
@contextmanager
def get_db_operations():
"""
测试专用的 DB 操作上下文:
- 若设置 TEST_DB_DSN则连接真实 PostgreSQL
- 否则回退到 FakeDBOperations内存桩
"""
dsn = os.environ.get("TEST_DB_DSN")
if dsn:
adapter = RealDBOperationsAdapter(dsn)
try:
yield adapter
finally:
adapter.close()
else:
fake = FakeDBOperations()
yield fake
# 14 个独立 ODS 任务已废弃删除(写入不存在的 billiards.* schema已被通用 ODS 任务替代)
TASK_SPECS: List[TaskSpec] = []

View File

@@ -0,0 +1,696 @@
# -*- coding: utf-8 -*-
"""
单元测试 — 文档对齐分析器 (doc_alignment_analyzer.py)
覆盖:
- scan_docs 文档来源识别
- extract_code_references 代码引用提取
- check_reference_validity 引用有效性检查
- find_undocumented_modules 缺失文档检测
- check_ddl_vs_dictionary DDL 与数据字典比对
- check_api_samples_vs_parsers API 样本与解析器比对
- render_alignment_report 报告渲染
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from scripts.audit import AlignmentIssue, DocMapping
from scripts.audit.doc_alignment_analyzer import (
_parse_ddl_tables,
_parse_dictionary_tables,
build_mappings,
check_api_samples_vs_parsers,
check_ddl_vs_dictionary,
check_reference_validity,
extract_code_references,
find_undocumented_modules,
render_alignment_report,
scan_docs,
)
# ---------------------------------------------------------------------------
# scan_docs
# ---------------------------------------------------------------------------
class TestScanDocs:
"""文档来源识别测试。"""
def test_finds_docs_dir_md(self, tmp_path: Path) -> None:
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "guide.md").write_text("# Guide", encoding="utf-8")
result = scan_docs(tmp_path)
assert "docs/guide.md" in result
def test_finds_root_readme(self, tmp_path: Path) -> None:
(tmp_path / "README.md").write_text("# Readme", encoding="utf-8")
result = scan_docs(tmp_path)
assert "README.md" in result
def test_finds_docs_subdir_requirements(self, tmp_path: Path) -> None:
"""docs/requirements/ 下的文件应被扫描到。"""
req_dir = tmp_path / "docs" / "requirements"
req_dir.mkdir(parents=True)
(req_dir / "需求.md").write_text("需求", encoding="utf-8")
result = scan_docs(tmp_path)
assert "docs/requirements/需求.md" in result
def test_finds_module_readme(self, tmp_path: Path) -> None:
(tmp_path / "gui").mkdir()
(tmp_path / "gui" / "README.md").write_text("# GUI", encoding="utf-8")
result = scan_docs(tmp_path)
assert "gui/README.md" in result
def test_finds_steering_files(self, tmp_path: Path) -> None:
steering = tmp_path / ".kiro" / "steering"
steering.mkdir(parents=True)
(steering / "tech.md").write_text("# Tech", encoding="utf-8")
result = scan_docs(tmp_path)
assert ".kiro/steering/tech.md" in result
def test_finds_json_samples(self, tmp_path: Path) -> None:
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "member.json").write_text("[]", encoding="utf-8")
result = scan_docs(tmp_path)
assert "docs/test-json-doc/member.json" in result
def test_empty_repo_returns_empty(self, tmp_path: Path) -> None:
result = scan_docs(tmp_path)
assert result == []
def test_results_sorted(self, tmp_path: Path) -> None:
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "z.md").write_text("z", encoding="utf-8")
(tmp_path / "docs" / "a.md").write_text("a", encoding="utf-8")
(tmp_path / "README.md").write_text("r", encoding="utf-8")
result = scan_docs(tmp_path)
assert result == sorted(result)
# ---------------------------------------------------------------------------
# extract_code_references
# ---------------------------------------------------------------------------
class TestExtractCodeReferences:
"""代码引用提取测试。"""
def test_extracts_backtick_paths(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("使用 `tasks/base_task.py` 作为基类", encoding="utf-8")
refs = extract_code_references(doc)
assert "tasks/base_task.py" in refs
def test_extracts_class_names(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("继承 `BaseTask` 类", encoding="utf-8")
refs = extract_code_references(doc)
assert "BaseTask" in refs
def test_skips_single_char(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("变量 `x` 和 `y`", encoding="utf-8")
refs = extract_code_references(doc)
assert refs == []
def test_skips_pure_numbers(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("版本 `2.0.0` 和 ID `12345`", encoding="utf-8")
refs = extract_code_references(doc)
assert refs == []
def test_deduplicates(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("`foo.py` 和 `foo.py` 重复", encoding="utf-8")
refs = extract_code_references(doc)
assert refs.count("foo.py") == 1
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
refs = extract_code_references(tmp_path / "nonexistent.md")
assert refs == []
def test_normalizes_backslash(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("路径 `tasks\\base_task.py`", encoding="utf-8")
refs = extract_code_references(doc)
assert "tasks/base_task.py" in refs
# ---------------------------------------------------------------------------
# check_reference_validity
# ---------------------------------------------------------------------------
class TestCheckReferenceValidity:
"""引用有效性检查测试。"""
def test_valid_file_path(self, tmp_path: Path) -> None:
(tmp_path / "tasks").mkdir()
(tmp_path / "tasks" / "base.py").write_text("", encoding="utf-8")
assert check_reference_validity("tasks/base.py", tmp_path) is True
def test_invalid_file_path(self, tmp_path: Path) -> None:
assert check_reference_validity("nonexistent/file.py", tmp_path) is False
def test_strips_legacy_prefix(self, tmp_path: Path) -> None:
"""兼容旧包名前缀etl_billiards/和当前根目录前缀FQ-ETL/"""
(tmp_path / "tasks").mkdir()
(tmp_path / "tasks" / "x.py").write_text("", encoding="utf-8")
assert check_reference_validity("etl_billiards/tasks/x.py", tmp_path) is True
assert check_reference_validity("FQ-ETL/tasks/x.py", tmp_path) is True
def test_directory_path(self, tmp_path: Path) -> None:
(tmp_path / "loaders").mkdir()
assert check_reference_validity("loaders", tmp_path) is True
def test_dotted_module_path(self, tmp_path: Path) -> None:
(tmp_path / "config").mkdir()
(tmp_path / "config" / "settings.py").write_text("", encoding="utf-8")
assert check_reference_validity("config.settings", tmp_path) is True
# ---------------------------------------------------------------------------
# find_undocumented_modules
# ---------------------------------------------------------------------------
class TestFindUndocumentedModules:
"""缺失文档检测测试。"""
def test_finds_undocumented(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert "tasks/ods_task.py" in result
def test_excludes_init(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert all("__init__" not in r for r in result)
def test_documented_module_excluded(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, {"tasks/ods_task.py"})
assert "tasks/ods_task.py" not in result
def test_non_core_dirs_ignored(self, tmp_path: Path) -> None:
"""gui/ 不在核心代码目录列表中,不应被检测。"""
gui_dir = tmp_path / "gui"
gui_dir.mkdir()
(gui_dir / "main.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert all("gui/" not in r for r in result)
def test_results_sorted(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "z_task.py").write_text("", encoding="utf-8")
(tasks_dir / "a_task.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert result == sorted(result)
# ---------------------------------------------------------------------------
# _parse_ddl_tables / _parse_dictionary_tables
# ---------------------------------------------------------------------------
class TestParseDdlTables:
"""DDL 解析测试。"""
def test_extracts_table_and_columns(self) -> None:
sql = """
CREATE TABLE IF NOT EXISTS dim_member (
member_id BIGINT,
nickname TEXT,
mobile TEXT,
PRIMARY KEY (member_id)
);
"""
result = _parse_ddl_tables(sql)
assert "dim_member" in result
assert "member_id" in result["dim_member"]
assert "nickname" in result["dim_member"]
assert "mobile" in result["dim_member"]
def test_handles_schema_prefix(self) -> None:
sql = "CREATE TABLE billiards_dwd.dim_site (\n site_id BIGINT\n);"
result = _parse_ddl_tables(sql)
assert "dim_site" in result
def test_excludes_sql_keywords(self) -> None:
sql = """
CREATE TABLE test_tbl (
id INTEGER,
PRIMARY KEY (id)
);
"""
result = _parse_ddl_tables(sql)
assert "primary" not in result.get("test_tbl", set())
class TestParseDictionaryTables:
"""数据字典解析测试。"""
def test_extracts_table_and_fields(self) -> None:
md = """## dim_member
| 字段 | 类型 | 说明 |
|------|------|------|
| member_id | BIGINT | 会员ID |
| nickname | TEXT | 昵称 |
"""
result = _parse_dictionary_tables(md)
assert "dim_member" in result
assert "member_id" in result["dim_member"]
assert "nickname" in result["dim_member"]
def test_skips_header_row(self) -> None:
md = """## dim_test
| 字段 | 类型 |
|------|------|
| col_a | INT |
"""
result = _parse_dictionary_tables(md)
assert "字段" not in result.get("dim_test", set())
def test_handles_backtick_table_name(self) -> None:
md = "## `dim_goods`\n\n| 字段 |\n| goods_id |"
result = _parse_dictionary_tables(md)
assert "dim_goods" in result
# ---------------------------------------------------------------------------
# check_ddl_vs_dictionary
# ---------------------------------------------------------------------------
class TestCheckDdlVsDictionary:
"""DDL 与数据字典比对测试。"""
def test_detects_missing_table_in_dictionary(self, tmp_path: Path) -> None:
# DDL 有表,字典没有
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_test.sql").write_text(
"CREATE TABLE dim_orphan (\n id BIGINT\n);",
encoding="utf-8",
)
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
"## dim_other\n\n| 字段 |\n| id |",
encoding="utf-8",
)
issues = check_ddl_vs_dictionary(tmp_path)
missing = [i for i in issues if i.issue_type == "missing"]
assert any("dim_orphan" in i.description for i in missing)
def test_detects_column_mismatch(self, tmp_path: Path) -> None:
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_test.sql").write_text(
"CREATE TABLE dim_x (\n id BIGINT,\n extra_col TEXT\n);",
encoding="utf-8",
)
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
"## dim_x\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
encoding="utf-8",
)
issues = check_ddl_vs_dictionary(tmp_path)
conflict = [i for i in issues if i.issue_type == "conflict"]
assert any("extra_col" in i.description for i in conflict)
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_test.sql").write_text(
"CREATE TABLE dim_ok (\n id BIGINT\n);",
encoding="utf-8",
)
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
"## dim_ok\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
encoding="utf-8",
)
issues = check_ddl_vs_dictionary(tmp_path)
assert len(issues) == 0
# ---------------------------------------------------------------------------
# check_api_samples_vs_parsers
# ---------------------------------------------------------------------------
class TestCheckApiSamplesVsParsers:
"""API 样本与解析器比对测试。"""
def test_detects_json_field_not_in_ods(self, tmp_path: Path) -> None:
# JSON 样本有 extra_fieldODS 没有
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "test_entity.json").write_text(
json.dumps([{"id": 1, "name": "a", "extra_field": "x"}]),
encoding="utf-8",
)
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_ODS_doc.sql").write_text(
"CREATE TABLE billiards_ods.test_entity (\n"
" id BIGINT,\n name TEXT,\n"
" content_hash TEXT,\n payload JSONB\n);",
encoding="utf-8",
)
issues = check_api_samples_vs_parsers(tmp_path)
assert any("extra_field" in i.description for i in issues)
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "aligned_entity.json").write_text(
json.dumps([{"id": 1, "name": "a"}]),
encoding="utf-8",
)
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_ODS_doc.sql").write_text(
"CREATE TABLE billiards_ods.aligned_entity (\n"
" id BIGINT,\n name TEXT,\n"
" content_hash TEXT,\n payload JSONB\n);",
encoding="utf-8",
)
issues = check_api_samples_vs_parsers(tmp_path)
assert len(issues) == 0
def test_skips_when_no_ods_table(self, tmp_path: Path) -> None:
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "unknown.json").write_text(
json.dumps([{"a": 1}]),
encoding="utf-8",
)
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_ODS_doc.sql").write_text("-- empty", encoding="utf-8")
issues = check_api_samples_vs_parsers(tmp_path)
assert len(issues) == 0
# ---------------------------------------------------------------------------
# render_alignment_report
# ---------------------------------------------------------------------------
class TestRenderAlignmentReport:
"""报告渲染测试。"""
def test_contains_all_sections(self) -> None:
report = render_alignment_report([], [], "/repo")
assert "## 映射关系" in report
assert "## 过期点" in report
assert "## 冲突点" in report
assert "## 缺失点" in report
assert "## 统计摘要" in report
def test_contains_header_metadata(self) -> None:
report = render_alignment_report([], [], "/repo")
assert "生成时间" in report
assert "`/repo`" in report
def test_contains_iso_timestamp(self) -> None:
report = render_alignment_report([], [], "/repo")
# ISO 格式时间戳包含 T 和 Z
import re
assert re.search(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", report)
def test_mapping_table_rendered(self) -> None:
mappings = [
DocMapping(
doc_path="docs/guide.md",
doc_topic="项目文档",
related_code=["tasks/base.py"],
status="aligned",
)
]
report = render_alignment_report(mappings, [], "/repo")
assert "`docs/guide.md`" in report
assert "`tasks/base.py`" in report
assert "aligned" in report
def test_stale_issues_rendered(self) -> None:
issues = [
AlignmentIssue(
doc_path="docs/old.md",
issue_type="stale",
description="引用了已删除的文件",
related_code="tasks/deleted.py",
)
]
report = render_alignment_report([], issues, "/repo")
assert "引用了已删除的文件" in report
assert "## 过期点" in report
def test_conflict_issues_rendered(self) -> None:
issues = [
AlignmentIssue(
doc_path="docs/dict.md",
issue_type="conflict",
description="字段不一致",
related_code="database/schema.sql",
)
]
report = render_alignment_report([], issues, "/repo")
assert "字段不一致" in report
def test_missing_issues_rendered(self) -> None:
issues = [
AlignmentIssue(
doc_path="docs/dict.md",
issue_type="missing",
description="缺少表定义",
related_code="database/schema.sql",
)
]
report = render_alignment_report([], issues, "/repo")
assert "缺少表定义" in report
def test_summary_counts(self) -> None:
issues = [
AlignmentIssue("a", "stale", "d1", "c1"),
AlignmentIssue("b", "stale", "d2", "c2"),
AlignmentIssue("c", "conflict", "d3", "c3"),
AlignmentIssue("d", "missing", "d4", "c4"),
]
mappings = [DocMapping("x", "t", [], "aligned")]
report = render_alignment_report(mappings, issues, "/repo")
assert "过期点数量2" in report
assert "冲突点数量1" in report
assert "缺失点数量1" in report
assert "文档总数1" in report
def test_empty_report(self) -> None:
report = render_alignment_report([], [], "/repo")
assert "未发现过期点" in report
assert "未发现冲突点" in report
assert "未发现缺失点" in report
assert "过期点数量0" in report
# ---------------------------------------------------------------------------
# 属性测试 — Property 11 / 12 / 16 (hypothesis)
# hypothesis 与 pytest 的 function-scoped fixture (tmp_path) 不兼容,
# 因此在测试内部使用 tempfile.mkdtemp 自行管理临时目录。
# ---------------------------------------------------------------------------
import shutil
import tempfile
from hypothesis import given, settings
from hypothesis import strategies as st
from scripts.audit.doc_alignment_analyzer import _CORE_CODE_DIRS
class TestPropertyStaleReferenceDetection:
"""Feature: repo-audit, Property 11: 过期引用检测
*对于任意* 文档中提取的代码引用,若该引用指向的文件路径在仓库中不存在,
则 check_reference_validity 应返回 False。
Validates: Requirements 3.3
"""
_safe_name = st.from_regex(r"[a-z][a-z0-9_]{1,12}", fullmatch=True)
@given(
existing_names=st.lists(
_safe_name, min_size=1, max_size=5, unique=True,
),
missing_names=st.lists(
_safe_name, min_size=1, max_size=5, unique=True,
),
)
@settings(max_examples=100)
def test_nonexistent_path_returns_false(
self,
existing_names: list[str],
missing_names: list[str],
) -> None:
"""不存在的文件路径引用应返回 False。"""
tmp = Path(tempfile.mkdtemp())
try:
for name in existing_names:
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
existing_set = set(existing_names)
# 只检查确实不存在的名称
truly_missing = [n for n in missing_names if n not in existing_set]
for name in truly_missing:
ref = f"nonexistent_dir/{name}.py"
result = check_reference_validity(ref, tmp)
assert result is False, (
f"引用 '{ref}' 指向不存在的文件,但返回了 True"
)
finally:
shutil.rmtree(tmp, ignore_errors=True)
@given(
existing_names=st.lists(
_safe_name, min_size=1, max_size=5, unique=True,
),
)
@settings(max_examples=100)
def test_existing_path_returns_true(
self,
existing_names: list[str],
) -> None:
"""存在的文件路径引用应返回 True。"""
tmp = Path(tempfile.mkdtemp())
try:
for name in existing_names:
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
for name in existing_names:
ref = f"{name}.py"
result = check_reference_validity(ref, tmp)
assert result is True, (
f"引用 '{ref}' 指向存在的文件,但返回了 False"
)
finally:
shutil.rmtree(tmp, ignore_errors=True)
class TestPropertyMissingDocDetection:
"""Feature: repo-audit, Property 12: 缺失文档检测
*对于任意* 核心代码模块集合和已文档化模块集合,
find_undocumented_modules 返回的缺失列表应恰好等于核心模块集合与已文档化集合的差集。
Validates: Requirements 3.5
"""
_core_dir = st.sampled_from(list(_CORE_CODE_DIRS))
_module_name = st.from_regex(r"[a-z][a-z0-9_]{1,10}", fullmatch=True)
@given(
core_dir=_core_dir,
module_names=st.lists(
_module_name, min_size=2, max_size=6, unique=True,
),
doc_fraction=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=100)
def test_undocumented_equals_difference(
self,
core_dir: str,
module_names: list[str],
doc_fraction: float,
) -> None:
"""返回的缺失列表应恰好等于核心模块与已文档化集合的差集。"""
tmp = Path(tempfile.mkdtemp())
try:
code_dir = tmp / core_dir
code_dir.mkdir(parents=True, exist_ok=True)
all_modules: set[str] = set()
for name in module_names:
(code_dir / f"{name}.py").write_text("# module", encoding="utf-8")
all_modules.add(f"{core_dir}/{name}.py")
split_idx = int(len(module_names) * doc_fraction)
documented = {
f"{core_dir}/{n}.py" for n in module_names[:split_idx]
}
result = find_undocumented_modules(tmp, documented)
expected = sorted(all_modules - documented)
assert result == expected, (
f"期望缺失列表 {expected},实际得到 {result}"
)
finally:
shutil.rmtree(tmp, ignore_errors=True)
class TestPropertyAlignmentReportSections:
"""Feature: repo-audit, Property 16: 文档对齐报告分区完整性
*对于任意* render_alignment_report 的输出Markdown 文本应包含
"映射关系""过期点""冲突点""缺失点"四个分区标题。
Validates: Requirements 3.8
"""
_issue_type = st.sampled_from(["stale", "conflict", "missing"])
_text = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P"),
blacklist_characters="\x00",
),
min_size=1,
max_size=30,
)
_mapping_st = st.builds(
DocMapping,
doc_path=_text,
doc_topic=_text,
related_code=st.lists(_text, max_size=3),
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
)
_issue_st = st.builds(
AlignmentIssue,
doc_path=_text,
issue_type=_issue_type,
description=_text,
related_code=_text,
)
@given(
mappings=st.lists(_mapping_st, max_size=5),
issues=st.lists(_issue_st, max_size=8),
)
@settings(max_examples=100)
def test_report_contains_four_sections(
self,
mappings: list[DocMapping],
issues: list[AlignmentIssue],
) -> None:
"""报告应包含四个分区标题。"""
report = render_alignment_report(mappings, issues, "/test/repo")
required_sections = ["## 映射关系", "## 过期点", "## 冲突点", "## 缺失点"]
for section in required_sections:
assert section in report, (
f"报告中缺少分区标题 '{section}'"
)

View File

@@ -0,0 +1,667 @@
# -*- coding: utf-8 -*-
"""
单元测试 — 流程树分析器 (flow_analyzer.py)
覆盖:
- parse_imports: import 语句解析、标准库/第三方排除、语法错误容错
- build_flow_tree: 递归构建、循环导入处理
- find_orphan_modules: 孤立模块检测
- render_flow_report: Markdown 渲染、Mermaid 图、统计摘要
- discover_entry_points: 入口点识别
- classify_task_type / classify_loader_type: 类型区分
"""
from __future__ import annotations
from pathlib import Path
import pytest
from scripts.audit import FileEntry, FlowNode
from scripts.audit.flow_analyzer import (
build_flow_tree,
classify_loader_type,
classify_task_type,
discover_entry_points,
find_orphan_modules,
parse_imports,
render_flow_report,
_path_to_module_name,
_parse_bat_python_target,
)
# ---------------------------------------------------------------------------
# parse_imports 单元测试
# ---------------------------------------------------------------------------
class TestParseImports:
"""import 语句解析测试。"""
def test_absolute_import(self, tmp_path: Path) -> None:
"""绝对导入项目内部模块应被识别。"""
f = tmp_path / "test.py"
f.write_text("import cli.main\nimport config.settings\n", encoding="utf-8")
result = parse_imports(f)
assert "cli.main" in result
assert "config.settings" in result
def test_from_import(self, tmp_path: Path) -> None:
"""from ... import 语句应被识别。"""
f = tmp_path / "test.py"
f.write_text("from tasks.base_task import BaseTask\n", encoding="utf-8")
result = parse_imports(f)
assert "tasks.base_task" in result
def test_stdlib_excluded(self, tmp_path: Path) -> None:
"""标准库模块应被排除。"""
f = tmp_path / "test.py"
f.write_text("import os\nimport sys\nimport json\nfrom pathlib import Path\n", encoding="utf-8")
result = parse_imports(f)
assert result == []
def test_third_party_excluded(self, tmp_path: Path) -> None:
"""第三方包应被排除。"""
f = tmp_path / "test.py"
f.write_text("import requests\nfrom psycopg2 import sql\nimport flask\n", encoding="utf-8")
result = parse_imports(f)
assert result == []
def test_mixed_imports(self, tmp_path: Path) -> None:
"""混合导入应只保留项目内部模块。"""
f = tmp_path / "test.py"
f.write_text(
"import os\nimport cli.main\nimport requests\nfrom loaders.base_loader import BaseLoader\n",
encoding="utf-8",
)
result = parse_imports(f)
assert "cli.main" in result
assert "loaders.base_loader" in result
assert "os" not in result
assert "requests" not in result
def test_syntax_error_returns_empty(self, tmp_path: Path) -> None:
"""语法错误的文件应返回空列表。"""
f = tmp_path / "bad.py"
f.write_text("def broken(\n", encoding="utf-8")
result = parse_imports(f)
assert result == []
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
"""不存在的文件应返回空列表。"""
result = parse_imports(tmp_path / "nonexistent.py")
assert result == []
def test_deduplication(self, tmp_path: Path) -> None:
"""重复导入应去重。"""
f = tmp_path / "test.py"
f.write_text("import cli.main\nimport cli.main\nfrom cli.main import main\n", encoding="utf-8")
result = parse_imports(f)
assert result.count("cli.main") == 1
def test_empty_file(self, tmp_path: Path) -> None:
"""空文件应返回空列表。"""
f = tmp_path / "empty.py"
f.write_text("", encoding="utf-8")
result = parse_imports(f)
assert result == []
# ---------------------------------------------------------------------------
# build_flow_tree 单元测试
# ---------------------------------------------------------------------------
class TestBuildFlowTree:
"""流程树构建测试。"""
def test_single_file_no_imports(self, tmp_path: Path) -> None:
"""无导入的单文件应生成叶节点。"""
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
tree = build_flow_tree(tmp_path, "cli/main.py")
assert tree.name == "cli.main"
assert tree.source_file == "cli/main.py"
assert tree.children == []
def test_simple_import_chain(self, tmp_path: Path) -> None:
"""简单导入链应正确构建子节点。"""
# cli/main.py → config/settings.py
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text(
"from config.settings import AppConfig\n", encoding="utf-8"
)
config_dir = tmp_path / "config"
config_dir.mkdir()
(config_dir / "__init__.py").write_text("", encoding="utf-8")
(config_dir / "settings.py").write_text("class AppConfig: pass\n", encoding="utf-8")
tree = build_flow_tree(tmp_path, "cli/main.py")
assert tree.name == "cli.main"
assert len(tree.children) == 1
assert tree.children[0].name == "config.settings"
def test_circular_import_no_infinite_loop(self, tmp_path: Path) -> None:
"""循环导入不应导致无限递归。"""
pkg = tmp_path / "utils"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
# a → b → a循环
(pkg / "a.py").write_text("from utils.b import func_b\n", encoding="utf-8")
(pkg / "b.py").write_text("from utils.a import func_a\n", encoding="utf-8")
# 不应抛出 RecursionError
tree = build_flow_tree(tmp_path, "utils/a.py")
assert tree.name == "utils.a"
def test_entry_node_type(self, tmp_path: Path) -> None:
"""CLI 入口文件应标记为 entry 类型。"""
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
tree = build_flow_tree(tmp_path, "cli/main.py")
assert tree.node_type == "entry"
# ---------------------------------------------------------------------------
# find_orphan_modules 单元测试
# ---------------------------------------------------------------------------
class TestFindOrphanModules:
"""孤立模块检测测试。"""
def test_all_reachable(self, tmp_path: Path) -> None:
"""所有模块都可达时应返回空列表。"""
entries = [
FileEntry("cli/main.py", False, 100, ".py", False),
FileEntry("config/settings.py", False, 200, ".py", False),
]
reachable = {"cli/main.py", "config/settings.py"}
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_orphan_detected(self, tmp_path: Path) -> None:
"""不可达的模块应被标记为孤立。"""
entries = [
FileEntry("cli/main.py", False, 100, ".py", False),
FileEntry("utils/orphan.py", False, 50, ".py", False),
]
reachable = {"cli/main.py"}
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert "utils/orphan.py" in orphans
def test_init_files_excluded(self, tmp_path: Path) -> None:
"""__init__.py 不应被视为孤立模块。"""
entries = [
FileEntry("cli/__init__.py", False, 0, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert "cli/__init__.py" not in orphans
def test_test_files_excluded(self, tmp_path: Path) -> None:
"""测试文件不应被视为孤立模块。"""
entries = [
FileEntry("tests/unit/test_something.py", False, 100, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_audit_scripts_excluded(self, tmp_path: Path) -> None:
"""审计脚本自身不应被视为孤立模块。"""
entries = [
FileEntry("scripts/audit/scanner.py", False, 100, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_directories_excluded(self, tmp_path: Path) -> None:
"""目录条目不应出现在孤立列表中。"""
entries = [
FileEntry("cli", True, 0, "", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_sorted_output(self, tmp_path: Path) -> None:
"""孤立模块列表应按路径排序。"""
entries = [
FileEntry("utils/z.py", False, 50, ".py", False),
FileEntry("utils/a.py", False, 50, ".py", False),
FileEntry("cli/orphan.py", False, 50, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == sorted(orphans)
# ---------------------------------------------------------------------------
# render_flow_report 单元测试
# ---------------------------------------------------------------------------
class TestRenderFlowReport:
"""流程树报告渲染测试。"""
def test_header_contains_timestamp_and_path(self) -> None:
"""报告头部应包含时间戳和仓库路径。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, [], "/repo")
assert "生成时间:" in report
assert "`/repo`" in report
def test_contains_mermaid_block(self) -> None:
"""报告应包含 Mermaid 代码块。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, [], "/repo")
assert "```mermaid" in report
assert "graph TD" in report
def test_contains_indented_text(self) -> None:
"""报告应包含缩进文本形式的流程树。"""
child = FlowNode("config.settings", "config/settings.py", "module", [])
root = FlowNode("cli.main", "cli/main.py", "entry", [child])
report = render_flow_report([root], [], "/repo")
assert "`cli.main`" in report
assert "`config.settings`" in report
def test_orphan_section(self) -> None:
"""报告应包含孤立模块列表。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
orphans = ["utils/orphan.py", "models/unused.py"]
report = render_flow_report(trees, orphans, "/repo")
assert "孤立模块" in report
assert "`utils/orphan.py`" in report
assert "`models/unused.py`" in report
def test_no_orphans_message(self) -> None:
"""无孤立模块时应显示提示信息。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, [], "/repo")
assert "未发现孤立模块" in report
def test_statistics_summary(self) -> None:
"""报告应包含统计摘要。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, ["a.py"], "/repo")
assert "统计摘要" in report
assert "入口点" in report
assert "任务" in report
assert "加载器" in report
assert "孤立模块" in report
def test_task_type_annotation(self) -> None:
"""任务模块应带有类型标注。"""
task_node = FlowNode("tasks.ods_member", "tasks/ods_member.py", "module", [])
root = FlowNode("cli.main", "cli/main.py", "entry", [task_node])
report = render_flow_report([root], [], "/repo")
assert "ODS" in report
def test_loader_type_annotation(self) -> None:
"""加载器模块应带有类型标注。"""
loader_node = FlowNode(
"loaders.dimensions.member", "loaders/dimensions/member.py", "module", []
)
root = FlowNode("cli.main", "cli/main.py", "entry", [loader_node])
report = render_flow_report([root], [], "/repo")
assert "维度" in report or "SCD2" in report
# ---------------------------------------------------------------------------
# discover_entry_points 单元测试
# ---------------------------------------------------------------------------
class TestDiscoverEntryPoints:
"""入口点识别测试。"""
def test_cli_entry(self, tmp_path: Path) -> None:
"""应识别 CLI 入口。"""
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
entries = discover_entry_points(tmp_path)
cli_entries = [e for e in entries if e["type"] == "CLI"]
assert len(cli_entries) == 1
assert cli_entries[0]["file"] == "cli/main.py"
def test_gui_entry(self, tmp_path: Path) -> None:
"""应识别 GUI 入口。"""
gui_dir = tmp_path / "gui"
gui_dir.mkdir()
(gui_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
entries = discover_entry_points(tmp_path)
gui_entries = [e for e in entries if e["type"] == "GUI"]
assert len(gui_entries) == 1
def test_bat_entry(self, tmp_path: Path) -> None:
"""应识别批处理文件入口。"""
(tmp_path / "run_etl.bat").write_text(
"@echo off\npython -m cli.main %*\n", encoding="utf-8"
)
entries = discover_entry_points(tmp_path)
bat_entries = [e for e in entries if e["type"] == "批处理"]
assert len(bat_entries) == 1
assert "cli.main" in bat_entries[0]["description"]
def test_script_entry(self, tmp_path: Path) -> None:
"""应识别运维脚本入口。"""
scripts_dir = tmp_path / "scripts"
scripts_dir.mkdir()
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
(scripts_dir / "rebuild_db.py").write_text(
'if __name__ == "__main__": pass\n', encoding="utf-8"
)
entries = discover_entry_points(tmp_path)
script_entries = [e for e in entries if e["type"] == "运维脚本"]
assert len(script_entries) == 1
assert script_entries[0]["file"] == "scripts/rebuild_db.py"
def test_init_py_excluded_from_scripts(self, tmp_path: Path) -> None:
"""scripts/__init__.py 不应被识别为入口。"""
scripts_dir = tmp_path / "scripts"
scripts_dir.mkdir()
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
entries = discover_entry_points(tmp_path)
script_entries = [e for e in entries if e["type"] == "运维脚本"]
assert all(e["file"] != "scripts/__init__.py" for e in script_entries)
# ---------------------------------------------------------------------------
# classify_task_type / classify_loader_type 单元测试
# ---------------------------------------------------------------------------
class TestClassifyTypes:
"""任务类型和加载器类型区分测试。"""
def test_ods_task(self) -> None:
assert "ODS" in classify_task_type("tasks/ods_member.py")
def test_dwd_task(self) -> None:
assert "DWD" in classify_task_type("tasks/dwd_load.py")
def test_dws_task(self) -> None:
assert "DWS" in classify_task_type("tasks/dws/assistant_daily.py")
def test_verification_task(self) -> None:
assert "校验" in classify_task_type("tasks/verification/balance_check.py")
def test_schema_init_task(self) -> None:
assert "Schema" in classify_task_type("tasks/init_ods_schema.py")
def test_dimension_loader(self) -> None:
result = classify_loader_type("loaders/dimensions/member.py")
assert "维度" in result or "SCD2" in result
def test_fact_loader(self) -> None:
assert "事实" in classify_loader_type("loaders/facts/order.py")
def test_ods_loader(self) -> None:
assert "ODS" in classify_loader_type("loaders/ods/generic.py")
# ---------------------------------------------------------------------------
# _path_to_module_name 单元测试
# ---------------------------------------------------------------------------
class TestPathToModuleName:
"""路径到模块名转换测试。"""
def test_simple_file(self) -> None:
assert _path_to_module_name("cli/main.py") == "cli.main"
def test_init_file(self) -> None:
assert _path_to_module_name("cli/__init__.py") == "cli"
def test_nested_path(self) -> None:
assert _path_to_module_name("tasks/dws/assistant.py") == "tasks.dws.assistant"
# ---------------------------------------------------------------------------
# _parse_bat_python_target 单元测试
# ---------------------------------------------------------------------------
class TestParseBatPythonTarget:
"""批处理文件 Python 命令解析测试。"""
def test_module_invocation(self, tmp_path: Path) -> None:
bat = tmp_path / "run.bat"
bat.write_text("@echo off\npython -m cli.main %*\n", encoding="utf-8")
assert _parse_bat_python_target(bat) == "cli.main"
def test_no_python_command(self, tmp_path: Path) -> None:
bat = tmp_path / "run.bat"
bat.write_text("@echo off\necho hello\n", encoding="utf-8")
assert _parse_bat_python_target(bat) is None
def test_nonexistent_file(self, tmp_path: Path) -> None:
assert _parse_bat_python_target(tmp_path / "missing.bat") is None
# ---------------------------------------------------------------------------
# 属性测试 — Property 9 & 10hypothesis
# ---------------------------------------------------------------------------
import os
import string
from hypothesis import given, settings, assume
from hypothesis import strategies as st
# ---------------------------------------------------------------------------
# 辅助:项目包名列表(与 flow_analyzer 中 _PROJECT_PACKAGES 一致)
# ---------------------------------------------------------------------------
_PROJECT_PACKAGES_LIST = [
"cli", "config", "api", "database", "tasks", "loaders",
"scd", "orchestration", "quality", "models", "utils",
"gui", "scripts",
]
# ---------------------------------------------------------------------------
# Property 9: 流程树节点 source_file 有效性
# Feature: repo-audit, Property 9: 流程树节点 source_file 有效性
# Validates: Requirements 2.7
#
# 策略:在临时目录中随机生成 1~5 个项目内部模块文件,
# 其中一个作为入口,其他文件通过 import 语句相互引用。
# 构建流程树后,遍历所有节点验证 source_file 非空且文件存在。
# ---------------------------------------------------------------------------
def _collect_all_nodes(node: FlowNode) -> list[FlowNode]:
"""递归收集流程树中所有节点。"""
result = [node]
for child in node.children:
result.extend(_collect_all_nodes(child))
return result
# 生成合法的 Python 标识符作为模块文件名
_module_name_st = st.from_regex(r"[a-z][a-z0-9_]{0,8}", fullmatch=True).filter(
lambda s: s not in {"__init__", ""}
)
@st.composite
def project_layout(draw):
"""生成一个随机的项目布局:包名、模块文件名列表、以及模块间的 import 关系。
返回 (package, module_names, imports_map)
- package: 项目包名(如 "cli"
- module_names: 模块文件名列表(不含 .py 后缀),第一个为入口
- imports_map: dict[str, list[str]],每个模块导入的其他模块列表
"""
package = draw(st.sampled_from(_PROJECT_PACKAGES_LIST))
n_modules = draw(st.integers(min_value=1, max_value=5))
module_names = draw(
st.lists(
_module_name_st,
min_size=n_modules,
max_size=n_modules,
unique=True,
)
)
# 确保至少有一个模块
assume(len(module_names) >= 1)
# 为每个模块随机选择要导入的其他模块(子集)
imports_map: dict[str, list[str]] = {}
for i, mod in enumerate(module_names):
# 只能导入列表中的其他模块
others = [m for m in module_names if m != mod]
if others:
imported = draw(
st.lists(st.sampled_from(others), max_size=len(others), unique=True)
)
else:
imported = []
imports_map[mod] = imported
return package, module_names, imports_map
@given(layout=project_layout())
@settings(max_examples=100)
def test_property9_flow_tree_source_file_validity(layout, tmp_path_factory):
"""Property 9: 流程树中每个节点的 source_file 非空且对应文件在仓库中实际存在。
**Feature: repo-audit, Property 9: 流程树节点 source_file 有效性**
**Validates: Requirements 2.7**
"""
package, module_names, imports_map = layout
tmp_path = tmp_path_factory.mktemp("prop9")
# 创建包目录和 __init__.py
pkg_dir = tmp_path / package
pkg_dir.mkdir(parents=True, exist_ok=True)
(pkg_dir / "__init__.py").write_text("", encoding="utf-8")
# 创建每个模块文件,写入 import 语句
for mod in module_names:
lines = []
for imp in imports_map[mod]:
lines.append(f"from {package}.{imp} import *")
lines.append("") # 确保文件非空
(pkg_dir / f"{mod}.py").write_text("\n".join(lines), encoding="utf-8")
# 以第一个模块为入口构建流程树
entry_rel = f"{package}/{module_names[0]}.py"
tree = build_flow_tree(tmp_path, entry_rel)
# 遍历所有节点,验证 source_file 有效性
all_nodes = _collect_all_nodes(tree)
for node in all_nodes:
# source_file 应为非空字符串
assert isinstance(node.source_file, str), (
f"source_file 应为字符串,实际为 {type(node.source_file)}"
)
assert node.source_file != "", "source_file 不应为空字符串"
# 对应文件应在仓库中实际存在
full_path = tmp_path / node.source_file
assert full_path.exists(), (
f"source_file '{node.source_file}' 对应的文件不存在: {full_path}"
)
# ---------------------------------------------------------------------------
# Property 10: 孤立模块检测正确性
# Feature: repo-audit, Property 10: 孤立模块检测正确性
# Validates: Requirements 2.8
#
# 策略:生成随机的 FileEntry 列表(模拟项目中的 .py 文件),
# 生成随机的 reachable 集合(是 FileEntry 路径的子集),
# 调用 find_orphan_modules 验证:
# 1. 返回的每个孤立模块都不在 reachable 集合中
# 2. reachable 集合中的每个模块都不在孤立列表中
#
# 注意find_orphan_modules 会排除 __init__.py、tests/、scripts/audit/ 下的文件,
# 以及不属于 _PROJECT_PACKAGES 的子目录文件。生成器需要考虑这些排除规则。
# ---------------------------------------------------------------------------
# 生成属于项目包的 .py 文件路径(排除被 find_orphan_modules 忽略的路径)
_eligible_packages = [
p for p in _PROJECT_PACKAGES_LIST
if p not in ("scripts",) # scripts 下只有 scripts/audit/ 会被排除,但为简化直接排除
]
@st.composite
def orphan_test_data(draw):
"""生成 (file_entries, reachable_set) 用于测试 find_orphan_modules。
只生成"合格"的文件条目(属于项目包、非 __init__.py、非 tests/、非 scripts/audit/
这样可以精确验证 reachable 与 orphan 的互斥关系。
"""
# 生成 1~10 个合格的 .py 文件路径
n_files = draw(st.integers(min_value=1, max_value=10))
paths: list[str] = []
for _ in range(n_files):
pkg = draw(st.sampled_from(_eligible_packages))
fname = draw(_module_name_st)
path = f"{pkg}/{fname}.py"
paths.append(path)
# 去重
paths = list(dict.fromkeys(paths))
assume(len(paths) >= 1)
# 构建 FileEntry 列表
entries = [
FileEntry(rel_path=p, is_dir=False, size_bytes=100, extension=".py", is_empty_dir=False)
for p in paths
]
# 随机选择一个子集作为 reachable
reachable = set(draw(
st.lists(st.sampled_from(paths), max_size=len(paths), unique=True)
))
return entries, reachable
@given(data=orphan_test_data())
@settings(max_examples=100)
def test_property10_orphan_module_detection(data, tmp_path_factory):
"""Property 10: 孤立模块与可达模块互斥——孤立列表中的模块不在 reachable 中,
reachable 中的模块不在孤立列表中。
**Feature: repo-audit, Property 10: 孤立模块检测正确性**
**Validates: Requirements 2.8**
"""
entries, reachable = data
tmp_path = tmp_path_factory.mktemp("prop10")
orphans = find_orphan_modules(tmp_path, entries, reachable)
orphan_set = set(orphans)
# 验证 1: 孤立模块不应出现在 reachable 集合中
overlap = orphan_set & reachable
assert overlap == set(), (
f"孤立模块与可达集合存在交集: {overlap}"
)
# 验证 2: reachable 中的模块不应出现在孤立列表中
for r in reachable:
assert r not in orphan_set, (
f"可达模块 '{r}' 不应出现在孤立列表中"
)
# 验证 3: 孤立列表应已排序
assert orphans == sorted(orphans), "孤立模块列表应按路径排序"

View File

@@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
"""
属性测试 — classify 完整性
Feature: repo-audit, Property 1: classify 完整性
Validates: Requirements 1.2, 1.3
对于任意 FileEntryclassify 函数返回的 InventoryItem 的 category 字段
应属于 Category 枚举disposition 字段应属于 Disposition 枚举,
且 description 字段为非空字符串。
"""
from __future__ import annotations
import string
from hypothesis import given, settings
from hypothesis import strategies as st
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
from scripts.audit.inventory_analyzer import classify
# ---------------------------------------------------------------------------
# 生成器策略
# ---------------------------------------------------------------------------
# 常见文件扩展名(含空扩展名表示无扩展名的情况)
_EXTENSIONS = st.sampled_from([
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
".toml", ".yaml", ".yml", ".html", ".css", ".js",
])
# 路径片段:字母数字加常见特殊字符
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
_path_segment = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=20,
)
# 生成 1~4 层目录深度的相对路径
_rel_path = st.lists(
_path_segment,
min_size=1,
max_size=4,
).map(lambda parts: "/".join(parts))
def _file_entry_strategy() -> st.SearchStrategy[FileEntry]:
"""生成随机 FileEntry 的 hypothesis 策略。
覆盖各种扩展名、目录层级、大小和布尔标志组合。
"""
return st.builds(
FileEntry,
rel_path=_rel_path,
is_dir=st.booleans(),
size_bytes=st.integers(min_value=0, max_value=10_000_000),
extension=_EXTENSIONS,
is_empty_dir=st.booleans(),
)
# ---------------------------------------------------------------------------
# Property 1: classify 完整性
# ---------------------------------------------------------------------------
@given(entry=_file_entry_strategy())
@settings(max_examples=100)
def test_classify_completeness(entry: FileEntry) -> None:
"""Property 1: classify 完整性
Feature: repo-audit, Property 1: classify 完整性
Validates: Requirements 1.2, 1.3
对于任意 FileEntryclassify 返回的 InventoryItem 应满足:
- category 属于 Category 枚举
- disposition 属于 Disposition 枚举
- description 为非空字符串
"""
result = classify(entry)
# 返回类型正确
assert isinstance(result, InventoryItem), (
f"classify 应返回 InventoryItem实际返回 {type(result)}"
)
# category 属于 Category 枚举
assert isinstance(result.category, Category), (
f"category 应为 Category 枚举成员,实际为 {result.category!r}"
)
# disposition 属于 Disposition 枚举
assert isinstance(result.disposition, Disposition), (
f"disposition 应为 Disposition 枚举成员,实际为 {result.disposition!r}"
)
# description 为非空字符串
assert isinstance(result.description, str) and len(result.description) > 0, (
f"description 应为非空字符串,实际为 {result.description!r}"
)
# ---------------------------------------------------------------------------
# 辅助:高优先级目录前缀(用于在低优先级属性测试中排除)
# ---------------------------------------------------------------------------
_HIGH_PRIORITY_PREFIXES = ("tmp/", "logs/", "export/")
# 安全的顶层目录名(不会触发高优先级规则)
_SAFE_TOP_DIRS = st.sampled_from([
"src", "lib", "data", "misc", "vendor", "tools", "archive",
"assets", "resources", "contrib", "extras",
])
# 非 .lnk/.rar 的扩展名
_SAFE_EXTENSIONS = st.sampled_from([
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
".bat", ".sh", ".ps1", ".log", ".ini", ".cfg",
".toml", ".yaml", ".yml", ".html", ".css", ".js",
])
def _safe_rel_path() -> st.SearchStrategy[str]:
"""生成不以高优先级目录开头的相对路径。"""
return st.builds(
lambda top, rest: f"{top}/{rest}" if rest else top,
top=_SAFE_TOP_DIRS,
rest=st.lists(_path_segment, min_size=0, max_size=3).map(
lambda parts: "/".join(parts) if parts else ""
),
)
# ---------------------------------------------------------------------------
# Property 3: 空目录标记为候选删除
# ---------------------------------------------------------------------------
@given(data=st.data())
@settings(max_examples=100)
def test_empty_dir_candidate_delete(data: st.DataObject) -> None:
"""Property 3: 空目录标记为候选删除
Feature: repo-audit, Property 3: 空目录标记为候选删除
Validates: Requirements 1.5
对于任意 is_empty_dir=True 的 FileEntry排除 tmp/、logs/、reports/、
export/ 开头和 .lnk/.rar 扩展名classify 返回的 disposition
应为 Disposition.CANDIDATE_DELETE。
"""
rel_path = data.draw(_safe_rel_path())
ext = data.draw(_SAFE_EXTENSIONS)
entry = FileEntry(
rel_path=rel_path,
is_dir=True,
size_bytes=0,
extension=ext,
is_empty_dir=True,
)
result = classify(entry)
assert result.disposition == Disposition.CANDIDATE_DELETE, (
f"空目录 '{entry.rel_path}' 应标记为候选删除,"
f"实际为 {result.disposition.value}"
)
# ---------------------------------------------------------------------------
# Property 4: .lnk/.rar 文件标记为候选删除
# ---------------------------------------------------------------------------
@given(data=st.data())
@settings(max_examples=100)
def test_lnk_rar_candidate_delete(data: st.DataObject) -> None:
"""Property 4: .lnk/.rar 文件标记为候选删除
Feature: repo-audit, Property 4: .lnk/.rar 文件标记为候选删除
Validates: Requirements 1.6
对于任意扩展名为 .lnk 或 .rar 的 FileEntry排除 tmp/、logs/、
reports/、export/ 开头,且 is_empty_dir=Falseclassify 返回的
disposition 应为 Disposition.CANDIDATE_DELETE。
"""
rel_path = data.draw(_safe_rel_path())
ext = data.draw(st.sampled_from([".lnk", ".rar"]))
entry = FileEntry(
rel_path=rel_path,
is_dir=False,
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
extension=ext,
is_empty_dir=False,
)
result = classify(entry)
assert result.disposition == Disposition.CANDIDATE_DELETE, (
f"文件 '{entry.rel_path}' (ext={ext}) 应标记为候选删除,"
f"实际为 {result.disposition.value}"
)
# ---------------------------------------------------------------------------
# Property 5: tmp/ 下文件处置范围
# ---------------------------------------------------------------------------
_TMP_EXTENSIONS = st.sampled_from([
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
".toml", ".yaml", ".yml", ".html", ".css", ".js", ".tmp", ".bak",
])
def _tmp_rel_path() -> st.SearchStrategy[str]:
"""生成以 tmp/ 开头的相对路径。"""
return st.builds(
lambda rest: f"tmp/{rest}",
rest=st.lists(_path_segment, min_size=1, max_size=3).map(
lambda parts: "/".join(parts)
),
)
@given(data=st.data())
@settings(max_examples=100)
def test_tmp_disposition_range(data: st.DataObject) -> None:
"""Property 5: tmp/ 下文件处置范围
Feature: repo-audit, Property 5: tmp/ 下文件处置范围
Validates: Requirements 1.7
对于任意 rel_path 以 tmp/ 开头的 FileEntryclassify 返回的
disposition 应为 CANDIDATE_DELETE 或 CANDIDATE_ARCHIVE 之一。
"""
rel_path = data.draw(_tmp_rel_path())
ext = data.draw(_TMP_EXTENSIONS)
entry = FileEntry(
rel_path=rel_path,
is_dir=data.draw(st.booleans()),
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
extension=ext,
is_empty_dir=data.draw(st.booleans()),
)
result = classify(entry)
allowed = {Disposition.CANDIDATE_DELETE, Disposition.CANDIDATE_ARCHIVE}
assert result.disposition in allowed, (
f"tmp/ 下文件 '{entry.rel_path}' 的处置应为候选删除或候选归档,"
f"实际为 {result.disposition.value}"
)
# ---------------------------------------------------------------------------
# Property 6: 运行时产出目录标记为候选归档
# ---------------------------------------------------------------------------
_RUNTIME_DIRS = st.sampled_from(["logs", "export"])
# 排除 __init__.py 的文件名
_NON_INIT_BASENAME = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=20,
).filter(lambda s: s != "__init__.py")
def _runtime_output_rel_path() -> st.SearchStrategy[str]:
"""生成以 logs/、reports/ 或 export/ 开头的相对路径basename 不是 __init__.py。"""
return st.builds(
lambda top, mid, name: (
f"{top}/{'/'.join(mid)}/{name}" if mid else f"{top}/{name}"
),
top=_RUNTIME_DIRS,
mid=st.lists(_path_segment, min_size=0, max_size=2),
name=_NON_INIT_BASENAME,
)
@given(data=st.data())
@settings(max_examples=100)
def test_runtime_output_candidate_archive(data: st.DataObject) -> None:
"""Property 6: 运行时产出目录标记为候选归档
Feature: repo-audit, Property 6: 运行时产出目录标记为候选归档
Validates: Requirements 1.8
对于任意 rel_path 以 logs/ 或 export/ 开头且非 __init__.py
的 FileEntryclassify 返回的 disposition 应为 CANDIDATE_ARCHIVE。
需求 1.8 仅覆盖 logs/ 和 export/ 目录(不含 reports/)。
"""
rel_path = data.draw(_runtime_output_rel_path())
ext = data.draw(_EXTENSIONS)
entry = FileEntry(
rel_path=rel_path,
is_dir=data.draw(st.booleans()),
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
extension=ext,
is_empty_dir=data.draw(st.booleans()),
)
result = classify(entry)
assert result.disposition == Disposition.CANDIDATE_ARCHIVE, (
f"运行时产出 '{entry.rel_path}' 应标记为候选归档,"
f"实际为 {result.disposition.value}"
)

View File

@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
"""
属性测试 — 清单渲染完整性与分类分组
Feature: repo-audit
- Property 2: 清单渲染完整性
- Property 8: 清单按分类分组
Validates: Requirements 1.4, 1.10
"""
from __future__ import annotations
import string
from hypothesis import given, settings
from hypothesis import strategies as st
from scripts.audit import Category, Disposition, InventoryItem
from scripts.audit.inventory_analyzer import render_inventory_report
# ---------------------------------------------------------------------------
# 生成器策略
# ---------------------------------------------------------------------------
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
_path_segment = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=15,
)
# 随机相对路径1~3 层)
_rel_path = st.lists(
_path_segment,
min_size=1,
max_size=3,
).map(lambda parts: "/".join(parts))
# 随机非空描述(不含管道符和换行符,避免破坏 Markdown 表格解析)
_description = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S", "Z"),
blacklist_characters="|\n\r",
),
min_size=1,
max_size=40,
)
def _inventory_item_strategy() -> st.SearchStrategy[InventoryItem]:
"""生成随机 InventoryItem 的 hypothesis 策略。"""
return st.builds(
InventoryItem,
rel_path=_rel_path,
category=st.sampled_from(list(Category)),
disposition=st.sampled_from(list(Disposition)),
description=_description,
)
# 生成 0~20 个 InventoryItem 的列表
_inventory_list = st.lists(
_inventory_item_strategy(),
min_size=0,
max_size=20,
)
# ---------------------------------------------------------------------------
# Property 2: 清单渲染完整性
# ---------------------------------------------------------------------------
@given(items=_inventory_list)
@settings(max_examples=100)
def test_render_inventory_completeness(items: list[InventoryItem]) -> None:
"""Property 2: 清单渲染完整性
Feature: repo-audit, Property 2: 清单渲染完整性
Validates: Requirements 1.4
对于任意 InventoryItem 列表render_inventory_report 生成的 Markdown 中,
每个条目的 rel_path、category.value、disposition.value 和 description
四个字段都应出现在输出文本中。
"""
report = render_inventory_report(items, "/tmp/test-repo")
for item in items:
# rel_path 出现在表格行中
assert item.rel_path in report, (
f"rel_path '{item.rel_path}' 未出现在报告中"
)
# category.value 出现在分组标题中
assert item.category.value in report, (
f"category '{item.category.value}' 未出现在报告中"
)
# disposition.value 出现在表格行中
assert item.disposition.value in report, (
f"disposition '{item.disposition.value}' 未出现在报告中"
)
# description 出现在表格行中
assert item.description in report, (
f"description '{item.description}' 未出现在报告中"
)
# ---------------------------------------------------------------------------
# Property 8: 清单按分类分组
# ---------------------------------------------------------------------------
@given(items=_inventory_list)
@settings(max_examples=100)
def test_render_inventory_grouped_by_category(items: list[InventoryItem]) -> None:
"""Property 8: 清单按分类分组
Feature: repo-audit, Property 8: 清单按分类分组
Validates: Requirements 1.10
对于任意 InventoryItem 列表render_inventory_report 生成的 Markdown 中,
同一 Category 的条目应连续出现(不应被其他 Category 的条目打断)。
"""
report = render_inventory_report(items, "/tmp/test-repo")
if not items:
return # 空列表无需验证
# 从报告中按行提取条目对应的 category 顺序
# 表格行格式: | `{rel_path}` | {disposition} | {description} |
# 分组标题格式: ## {category.value}
lines = report.split("\n")
# 收集每个分组标题下的条目,按出现顺序记录 category
categories_in_order: list[Category] = []
current_category: Category | None = None
# 建立 category.value -> Category 的映射
value_to_cat = {c.value: c for c in Category}
for line in lines:
stripped = line.strip()
# 检测分组标题 "## {category.value}"
if stripped.startswith("## ") and stripped[3:] in value_to_cat:
current_category = value_to_cat[stripped[3:]]
continue
# 检测表格数据行(跳过表头和分隔行)
if (
current_category is not None
and stripped.startswith("| `")
and not stripped.startswith("| 相对路径")
and not stripped.startswith("|---")
):
categories_in_order.append(current_category)
# 验证同一 Category 的条目连续出现
seen: set[Category] = set()
prev: Category | None = None
for cat in categories_in_order:
if cat != prev:
assert cat not in seen, (
f"Category '{cat.value}' 的条目不连续——"
f"在其他分类条目之后再次出现"
)
seen.add(cat)
prev = cat

View File

@@ -0,0 +1,485 @@
# -*- coding: utf-8 -*-
"""
属性测试 — 报告输出属性
Feature: repo-audit
- Property 13: 统计摘要一致性
- Property 14: 报告头部元信息
- Property 15: 写操作仅限 docs/audit/
Validates: Requirements 4.2, 4.5, 4.6, 4.7, 5.2
"""
from __future__ import annotations
import os
import re
import string
from pathlib import Path
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from scripts.audit import (
AlignmentIssue,
Category,
Disposition,
DocMapping,
FlowNode,
InventoryItem,
)
from scripts.audit.inventory_analyzer import render_inventory_report
from scripts.audit.flow_analyzer import render_flow_report
from scripts.audit.doc_alignment_analyzer import render_alignment_report
# ---------------------------------------------------------------------------
# 共享生成器策略
# ---------------------------------------------------------------------------
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
_path_segment = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=12,
)
_rel_path = st.lists(
_path_segment,
min_size=1,
max_size=3,
).map(lambda parts: "/".join(parts))
_safe_text = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S", "Z"),
blacklist_characters="|\n\r",
),
min_size=1,
max_size=30,
)
_repo_root_str = st.text(
alphabet=string.ascii_letters + string.digits + "/_-.",
min_size=3,
max_size=40,
).map(lambda s: "/" + s.lstrip("/"))
# ---------------------------------------------------------------------------
# InventoryItem 生成器
# ---------------------------------------------------------------------------
def _inventory_item_st() -> st.SearchStrategy[InventoryItem]:
return st.builds(
InventoryItem,
rel_path=_rel_path,
category=st.sampled_from(list(Category)),
disposition=st.sampled_from(list(Disposition)),
description=_safe_text,
)
_inventory_list = st.lists(_inventory_item_st(), min_size=0, max_size=20)
# ---------------------------------------------------------------------------
# FlowNode 生成器(限制深度和宽度)
# ---------------------------------------------------------------------------
def _flow_node_st(max_depth: int = 2) -> st.SearchStrategy[FlowNode]:
"""生成随机 FlowNode 树,限制深度避免爆炸。"""
if max_depth <= 0:
return st.builds(
FlowNode,
name=_path_segment,
source_file=_rel_path,
node_type=st.sampled_from(["entry", "module", "class", "function"]),
children=st.just([]),
)
return st.builds(
FlowNode,
name=_path_segment,
source_file=_rel_path,
node_type=st.sampled_from(["entry", "module", "class", "function"]),
children=st.lists(
_flow_node_st(max_depth - 1),
min_size=0,
max_size=3,
),
)
_flow_tree_list = st.lists(_flow_node_st(), min_size=0, max_size=5)
_orphan_list = st.lists(_rel_path, min_size=0, max_size=10)
# ---------------------------------------------------------------------------
# DocMapping / AlignmentIssue 生成器
# ---------------------------------------------------------------------------
_issue_type_st = st.sampled_from(["stale", "conflict", "missing"])
def _alignment_issue_st() -> st.SearchStrategy[AlignmentIssue]:
return st.builds(
AlignmentIssue,
doc_path=_rel_path,
issue_type=_issue_type_st,
description=_safe_text,
related_code=_rel_path,
)
def _doc_mapping_st() -> st.SearchStrategy[DocMapping]:
return st.builds(
DocMapping,
doc_path=_rel_path,
doc_topic=_safe_text,
related_code=st.lists(_rel_path, min_size=0, max_size=5),
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
)
_mapping_list = st.lists(_doc_mapping_st(), min_size=0, max_size=15)
_issue_list = st.lists(_alignment_issue_st(), min_size=0, max_size=15)
# ===========================================================================
# Property 13: 统计摘要一致性
# ===========================================================================
class TestProperty13SummaryConsistency:
"""Property 13: 统计摘要一致性
Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.5, 4.6, 4.7
对于任意报告的统计摘要,各分类/标签的计数之和应等于对应条目列表的总长度。
"""
# --- 13a: render_inventory_report 的分类计数之和 = 列表长度 ---
@given(items=_inventory_list)
@settings(max_examples=100)
def test_inventory_category_counts_sum(
self, items: list[InventoryItem]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.5
render_inventory_report 统计摘要中各用途分类的计数之和应等于条目总数。
"""
report = render_inventory_report(items, "/tmp/repo")
# 定位"按用途分类"表格,提取各行数字并求和
cat_sum = _extract_summary_total(report, "按用途分类")
assert cat_sum == len(items), (
f"分类计数之和 {cat_sum} != 条目总数 {len(items)}"
)
# --- 13b: render_inventory_report 的处置标签计数之和 = 列表长度 ---
@given(items=_inventory_list)
@settings(max_examples=100)
def test_inventory_disposition_counts_sum(
self, items: list[InventoryItem]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.5
render_inventory_report 统计摘要中各处置标签的计数之和应等于条目总数。
"""
report = render_inventory_report(items, "/tmp/repo")
disp_sum = _extract_summary_total(report, "按处置标签")
assert disp_sum == len(items), (
f"处置标签计数之和 {disp_sum} != 条目总数 {len(items)}"
)
# --- 13c: render_flow_report 的孤立模块数量 = orphans 列表长度 ---
@given(trees=_flow_tree_list, orphans=_orphan_list)
@settings(max_examples=100)
def test_flow_orphan_count_matches(
self, trees: list[FlowNode], orphans: list[str]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.6
render_flow_report 统计摘要中的孤立模块数量应等于 orphans 列表长度。
"""
report = render_flow_report(trees, orphans, "/tmp/repo")
# 从统计摘要表格中提取"孤立模块"行的数字
orphan_count = _extract_flow_stat(report, "孤立模块")
assert orphan_count == len(orphans), (
f"报告中孤立模块数 {orphan_count} != orphans 列表长度 {len(orphans)}"
)
# --- 13d: render_alignment_report 的 issue 类型计数一致 ---
@given(mappings=_mapping_list, issues=_issue_list)
@settings(max_examples=100)
def test_alignment_issue_counts_match(
self, mappings: list[DocMapping], issues: list[AlignmentIssue]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.7
render_alignment_report 统计摘要中过期/冲突/缺失点计数应与
issues 列表中对应类型的实际数量一致。
"""
report = render_alignment_report(mappings, issues, "/tmp/repo")
expected_stale = sum(1 for i in issues if i.issue_type == "stale")
expected_conflict = sum(1 for i in issues if i.issue_type == "conflict")
expected_missing = sum(1 for i in issues if i.issue_type == "missing")
actual_stale = _extract_alignment_stat(report, "过期点数量")
actual_conflict = _extract_alignment_stat(report, "冲突点数量")
actual_missing = _extract_alignment_stat(report, "缺失点数量")
assert actual_stale == expected_stale, (
f"过期点: 报告 {actual_stale} != 实际 {expected_stale}"
)
assert actual_conflict == expected_conflict, (
f"冲突点: 报告 {actual_conflict} != 实际 {expected_conflict}"
)
assert actual_missing == expected_missing, (
f"缺失点: 报告 {actual_missing} != 实际 {expected_missing}"
)
# ===========================================================================
# Property 14: 报告头部元信息
# ===========================================================================
class TestProperty14ReportHeader:
"""Property 14: 报告头部元信息
Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
对于任意报告输出,头部应包含一个符合 ISO 格式的时间戳字符串和仓库根目录路径字符串。
"""
_ISO_TS_RE = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
@given(items=_inventory_list, repo_root=_repo_root_str)
@settings(max_examples=100)
def test_inventory_report_header(
self, items: list[InventoryItem], repo_root: str
) -> None:
"""Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
render_inventory_report 头部应包含 ISO 时间戳和仓库路径。
"""
report = render_inventory_report(items, repo_root)
header = report[:500]
assert self._ISO_TS_RE.search(header), (
"inventory 报告头部缺少 ISO 格式时间戳"
)
assert repo_root in header, (
f"inventory 报告头部缺少仓库路径 '{repo_root}'"
)
@given(trees=_flow_tree_list, orphans=_orphan_list, repo_root=_repo_root_str)
@settings(max_examples=100)
def test_flow_report_header(
self, trees: list[FlowNode], orphans: list[str], repo_root: str
) -> None:
"""Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
render_flow_report 头部应包含 ISO 时间戳和仓库路径。
"""
report = render_flow_report(trees, orphans, repo_root)
header = report[:500]
assert self._ISO_TS_RE.search(header), (
"flow 报告头部缺少 ISO 格式时间戳"
)
assert repo_root in header, (
f"flow 报告头部缺少仓库路径 '{repo_root}'"
)
@given(mappings=_mapping_list, issues=_issue_list, repo_root=_repo_root_str)
@settings(max_examples=100)
def test_alignment_report_header(
self, mappings: list[DocMapping], issues: list[AlignmentIssue], repo_root: str
) -> None:
"""Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
render_alignment_report 头部应包含 ISO 时间戳和仓库路径。
"""
report = render_alignment_report(mappings, issues, repo_root)
header = report[:500]
assert self._ISO_TS_RE.search(header), (
"alignment 报告头部缺少 ISO 格式时间戳"
)
assert repo_root in header, (
f"alignment 报告头部缺少仓库路径 '{repo_root}'"
)
# ===========================================================================
# Property 15: 写操作仅限 docs/audit/
# ===========================================================================
class TestProperty15WritesOnlyDocsAudit:
"""Property 15: 写操作仅限 docs/audit/
Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
Validates: Requirements 5.2
对于任意审计执行过程,所有文件写操作的目标路径应以 docs/audit/ 为前缀。
由于需要实际文件系统,使用较少迭代。
"""
@staticmethod
def _make_minimal_repo(base: Path, variant: int) -> Path:
"""构造最小仓库结构variant 控制变体以增加多样性。"""
repo = base / f"repo_{variant}"
repo.mkdir()
# 必需的 cli 入口
cli_dir = repo / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text(
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
encoding="utf-8",
)
# config 目录
config_dir = repo / "config"
config_dir.mkdir()
(config_dir / "__init__.py").write_text("", encoding="utf-8")
# docs 目录
docs_dir = repo / "docs"
docs_dir.mkdir()
# 根据 variant 添加不同的额外文件
if variant % 3 == 0:
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
if variant % 3 == 1:
scripts_dir = repo / "scripts"
scripts_dir.mkdir()
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
if variant % 3 == 2:
(docs_dir / "notes.md").write_text("# 笔记\n", encoding="utf-8")
return repo
@staticmethod
def _snapshot_files(repo: Path) -> dict[str, float]:
"""记录仓库中所有文件的 mtime 快照(排除 docs/audit/)。"""
snap: dict[str, float] = {}
for p in repo.rglob("*"):
if p.is_file():
rel = p.relative_to(repo).as_posix()
if not rel.startswith("docs/audit"):
snap[rel] = p.stat().st_mtime
return snap
@given(variant=st.integers(min_value=0, max_value=9))
@settings(max_examples=10)
def test_writes_only_under_docs_audit(self, variant: int) -> None:
"""Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
Validates: Requirements 5.2
运行 run_audit 后docs/audit/ 外不应有新文件被创建。
docs/audit/ 下应有报告文件。
"""
import tempfile
from scripts.audit.run_audit import run_audit
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
repo = self._make_minimal_repo(tmp_path, variant)
before_snap = self._snapshot_files(repo)
run_audit(repo)
# 验证 docs/audit/ 下有新文件
audit_dir = repo / "docs" / "audit"
assert audit_dir.is_dir(), "docs/audit/ 目录未创建"
audit_files = list(audit_dir.iterdir())
assert len(audit_files) > 0, "docs/audit/ 下无报告文件"
# 验证 docs/audit/ 外无新文件
for p in repo.rglob("*"):
if p.is_file():
rel = p.relative_to(repo).as_posix()
if rel.startswith("docs/audit"):
continue
assert rel in before_snap, (
f"docs/audit/ 外出现了新文件: {rel}"
)
# ===========================================================================
# 辅助函数 — 从报告文本中提取统计数字
# ===========================================================================
def _extract_summary_total(report: str, section_name: str) -> int:
"""从 inventory 报告的统计摘要中提取指定分区的数字之和。
查找 "### {section_name}" 下的 Markdown 表格,
累加每行最后一列的数字(排除合计行)。
"""
lines = report.split("\n")
in_section = False
total = 0
for line in lines:
stripped = line.strip()
if stripped == f"### {section_name}":
in_section = True
continue
if in_section and stripped.startswith("###"):
# 进入下一个子节
break
if in_section and stripped.startswith("|") and "**合计**" not in stripped:
# 跳过表头和分隔行
if stripped.startswith("| 用途分类") or stripped.startswith("| 处置标签"):
continue
if stripped.startswith("|---"):
continue
# 提取最后一列的数字
cells = [c.strip() for c in stripped.split("|") if c.strip()]
if cells:
try:
total += int(cells[-1])
except ValueError:
pass
return total
def _extract_flow_stat(report: str, label: str) -> int:
"""从 flow 报告统计摘要表格中提取指定指标的数字。"""
# 匹配 "| 孤立模块 | 5 |" 格式
pattern = re.compile(rf"\|\s*{re.escape(label)}\s*\|\s*(\d+)\s*\|")
m = pattern.search(report)
return int(m.group(1)) if m else -1
def _extract_alignment_stat(report: str, label: str) -> int:
"""从 alignment 报告统计摘要中提取指定指标的数字。
匹配 "- 过期点数量3" 格式。
"""
# 兼容全角/半角冒号
pattern = re.compile(rf"{re.escape(label)}[:]\s*(\d+)")
m = pattern.search(report)
return int(m.group(1)) if m else -1

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
"""
run_audit 主入口的单元测试。
验证:
- docs/audit/ 目录自动创建
- 三份报告文件正确生成
- 报告头部包含时间戳和仓库路径
- 目录创建失败时抛出 RuntimeError
"""
from __future__ import annotations
import os
import re
from pathlib import Path
import pytest
class TestEnsureReportDir:
"""测试 _ensure_report_dir 目录创建逻辑。"""
def test_creates_dir_when_missing(self, tmp_path: Path):
from scripts.audit.run_audit import _ensure_report_dir
result = _ensure_report_dir(tmp_path)
expected = tmp_path / "docs" / "audit" / "repo"
assert result == expected
assert expected.is_dir()
def test_returns_existing_dir(self, tmp_path: Path):
from scripts.audit.run_audit import _ensure_report_dir
audit_dir = tmp_path / "docs" / "audit" / "repo"
audit_dir.mkdir(parents=True)
result = _ensure_report_dir(tmp_path)
assert result == audit_dir
def test_raises_on_creation_failure(self, tmp_path: Path):
from scripts.audit.run_audit import _ensure_report_dir
# 在 docs/audit 位置放一个文件,使 mkdir 失败
docs = tmp_path / "docs"
docs.mkdir()
(docs / "audit").write_text("block", encoding="utf-8")
with pytest.raises(RuntimeError, match="无法创建报告输出目录"):
_ensure_report_dir(tmp_path)
class TestInjectHeader:
"""测试 _inject_header 兜底注入逻辑。"""
def test_skips_when_header_present(self):
from scripts.audit.run_audit import _inject_header
report = "# 标题\n\n- 生成时间: 2025-01-01T00:00:00Z\n- 仓库路径: `/repo`\n"
result = _inject_header(report, "2025-06-01T00:00:00Z", "/other")
# 不应修改已有头部
assert result == report
def test_injects_when_header_missing(self):
from scripts.audit.run_audit import _inject_header
report = "# 无头部报告\n\n内容..."
result = _inject_header(report, "2025-06-01T00:00:00Z", "/repo")
assert "生成时间: 2025-06-01T00:00:00Z" in result
assert "仓库路径: `/repo`" in result
class TestRunAudit:
"""测试 run_audit 完整流程(使用最小仓库结构)。"""
def _make_minimal_repo(self, tmp_path: Path) -> Path:
"""构造一个最小仓库结构,足以让 run_audit 跑通。"""
repo = tmp_path / "repo"
repo.mkdir()
# 核心代码目录
cli_dir = repo / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text(
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
encoding="utf-8",
)
# config 目录
config_dir = repo / "config"
config_dir.mkdir()
(config_dir / "__init__.py").write_text("", encoding="utf-8")
(config_dir / "defaults.py").write_text("DEFAULTS = {}\n", encoding="utf-8")
# docs 目录
docs_dir = repo / "docs"
docs_dir.mkdir()
(docs_dir / "README.md").write_text("# 文档\n", encoding="utf-8")
# 根目录文件
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
return repo
def test_creates_three_reports(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
run_audit(repo)
audit_dir = repo / "docs" / "audit" / "repo"
assert (audit_dir / "file_inventory.md").is_file()
assert (audit_dir / "flow_tree.md").is_file()
assert (audit_dir / "doc_alignment.md").is_file()
def test_reports_contain_timestamp(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
run_audit(repo)
audit_dir = repo / "docs" / "audit" / "repo"
# ISO 时间戳格式
ts_pattern = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
content = (audit_dir / name).read_text(encoding="utf-8")
assert ts_pattern.search(content), f"{name} 缺少时间戳"
def test_reports_contain_repo_path(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
run_audit(repo)
audit_dir = repo / "docs" / "audit" / "repo"
repo_str = str(repo.resolve())
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
content = (audit_dir / name).read_text(encoding="utf-8")
assert repo_str in content, f"{name} 缺少仓库路径"
def test_writes_only_to_docs_audit(self, tmp_path: Path):
"""验证所有写操作仅限 docs/audit/ 目录Property 15"""
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
# 记录运行前的文件快照(排除 docs/audit/
before = set()
for p in repo.rglob("*"):
rel = p.relative_to(repo).as_posix()
if not rel.startswith("docs/audit"):
before.add((rel, p.stat().st_mtime if p.is_file() else None))
run_audit(repo)
# 运行后检查docs/audit/ 外的文件不应被修改
for p in repo.rglob("*"):
rel = p.relative_to(repo).as_posix()
if rel.startswith("docs/audit"):
continue
if p.is_file():
# 文件应在之前的快照中
found = any(r == rel for r, _ in before)
assert found, f"意外创建了 docs/audit/ 外的文件: {rel}"
def test_auto_creates_docs_audit_dir(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
# 确保 docs/audit/ 不存在
audit_dir = repo / "docs" / "audit" / "repo"
assert not audit_dir.exists()
run_audit(repo)
assert audit_dir.is_dir()

View File

@@ -0,0 +1,428 @@
# -*- coding: utf-8 -*-
"""
单元测试 — 仓库扫描器 (scanner.py)
覆盖:
- 排除模式匹配逻辑
- 递归遍历与 FileEntry 构建
- 空目录检测
- 权限错误容错
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from scripts.audit import FileEntry
from scripts.audit.scanner import EXCLUDED_PATTERNS, _is_excluded, scan_repo
# ---------------------------------------------------------------------------
# _is_excluded 单元测试
# ---------------------------------------------------------------------------
class TestIsExcluded:
"""排除模式匹配逻辑测试。"""
def test_exact_match_git(self) -> None:
assert _is_excluded(".git", EXCLUDED_PATTERNS) is True
def test_exact_match_pycache(self) -> None:
assert _is_excluded("__pycache__", EXCLUDED_PATTERNS) is True
def test_exact_match_pytest_cache(self) -> None:
assert _is_excluded(".pytest_cache", EXCLUDED_PATTERNS) is True
def test_exact_match_kiro(self) -> None:
assert _is_excluded(".kiro", EXCLUDED_PATTERNS) is True
def test_wildcard_pyc(self) -> None:
assert _is_excluded("module.pyc", EXCLUDED_PATTERNS) is True
def test_normal_py_not_excluded(self) -> None:
assert _is_excluded("main.py", EXCLUDED_PATTERNS) is False
def test_normal_dir_not_excluded(self) -> None:
assert _is_excluded("src", EXCLUDED_PATTERNS) is False
def test_empty_patterns(self) -> None:
assert _is_excluded(".git", []) is False
def test_custom_pattern(self) -> None:
assert _is_excluded("data.csv", ["*.csv"]) is True
# ---------------------------------------------------------------------------
# scan_repo 单元测试
# ---------------------------------------------------------------------------
class TestScanRepo:
"""scan_repo 递归遍历测试。"""
def test_basic_structure(self, tmp_path: Path) -> None:
"""基本文件和目录应被正确扫描。"""
(tmp_path / "a.py").write_text("# code", encoding="utf-8")
sub = tmp_path / "sub"
sub.mkdir()
(sub / "b.txt").write_text("hello", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "a.py" in paths
assert "sub" in paths
assert "sub/b.txt" in paths
def test_file_entry_fields(self, tmp_path: Path) -> None:
"""FileEntry 各字段应正确填充。"""
(tmp_path / "hello.md").write_text("# hi", encoding="utf-8")
entries = scan_repo(tmp_path)
md = next(e for e in entries if e.rel_path == "hello.md")
assert md.is_dir is False
assert md.size_bytes > 0
assert md.extension == ".md"
assert md.is_empty_dir is False
def test_directory_entry_fields(self, tmp_path: Path) -> None:
"""目录条目的字段应正确设置。"""
sub = tmp_path / "mydir"
sub.mkdir()
(sub / "file.py").write_text("pass", encoding="utf-8")
entries = scan_repo(tmp_path)
d = next(e for e in entries if e.rel_path == "mydir")
assert d.is_dir is True
assert d.size_bytes == 0
assert d.extension == ""
assert d.is_empty_dir is False
def test_excluded_git_dir(self, tmp_path: Path) -> None:
""".git 目录及其内容应被排除。"""
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert ".git" not in paths
assert ".git/config" not in paths
def test_excluded_pycache(self, tmp_path: Path) -> None:
"""__pycache__ 目录应被排除。"""
cache = tmp_path / "pkg" / "__pycache__"
cache.mkdir(parents=True)
(cache / "mod.cpython-310.pyc").write_bytes(b"\x00")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert not any("__pycache__" in p for p in paths)
def test_excluded_pyc_files(self, tmp_path: Path) -> None:
"""*.pyc 文件应被排除。"""
(tmp_path / "mod.pyc").write_bytes(b"\x00")
(tmp_path / "mod.py").write_text("pass", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "mod.pyc" not in paths
assert "mod.py" in paths
def test_empty_directory_detection(self, tmp_path: Path) -> None:
"""空目录应被标记为 is_empty_dir=True。"""
(tmp_path / "empty").mkdir()
entries = scan_repo(tmp_path)
d = next(e for e in entries if e.rel_path == "empty")
assert d.is_dir is True
assert d.is_empty_dir is True
def test_dir_with_only_excluded_children(self, tmp_path: Path) -> None:
"""仅含被排除子项的目录应视为空目录。"""
sub = tmp_path / "pkg"
sub.mkdir()
cache = sub / "__pycache__"
cache.mkdir()
(cache / "x.pyc").write_bytes(b"\x00")
entries = scan_repo(tmp_path)
d = next(e for e in entries if e.rel_path == "pkg")
assert d.is_empty_dir is True
def test_custom_exclude_patterns(self, tmp_path: Path) -> None:
"""自定义排除模式应生效。"""
(tmp_path / "keep.py").write_text("pass", encoding="utf-8")
(tmp_path / "skip.log").write_text("log", encoding="utf-8")
entries = scan_repo(tmp_path, exclude=["*.log"])
paths = {e.rel_path for e in entries}
assert "keep.py" in paths
assert "skip.log" not in paths
def test_empty_repo(self, tmp_path: Path) -> None:
"""空仓库应返回空列表。"""
entries = scan_repo(tmp_path)
assert entries == []
def test_results_sorted(self, tmp_path: Path) -> None:
"""返回结果应按 rel_path 排序。"""
(tmp_path / "z.py").write_text("", encoding="utf-8")
(tmp_path / "a.py").write_text("", encoding="utf-8")
sub = tmp_path / "m"
sub.mkdir()
(sub / "b.py").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = [e.rel_path for e in entries]
assert paths == sorted(paths)
@pytest.mark.skipif(
os.name == "nt",
reason="Windows 上 chmod 行为不同,跳过权限测试",
)
def test_permission_error_skipped(self, tmp_path: Path) -> None:
"""权限不足的目录应被跳过,不中断扫描。"""
ok_file = tmp_path / "ok.py"
ok_file.write_text("pass", encoding="utf-8")
no_access = tmp_path / "secret"
no_access.mkdir()
(no_access / "data.txt").write_text("x", encoding="utf-8")
no_access.chmod(0o000)
try:
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
# ok.py 应正常扫描到
assert "ok.py" in paths
# secret 目录本身会被记录(在 _walk 中先记录目录再尝试 iterdir
# 但其子文件不应出现
assert "secret/data.txt" not in paths
finally:
no_access.chmod(0o755)
def test_nested_directories(self, tmp_path: Path) -> None:
"""多层嵌套目录应被正确遍历。"""
deep = tmp_path / "a" / "b" / "c"
deep.mkdir(parents=True)
(deep / "leaf.py").write_text("pass", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "a" in paths
assert "a/b" in paths
assert "a/b/c" in paths
assert "a/b/c/leaf.py" in paths
def test_extension_lowercase(self, tmp_path: Path) -> None:
"""扩展名应统一为小写。"""
(tmp_path / "README.MD").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
md = next(e for e in entries if "README" in e.rel_path)
assert md.extension == ".md"
def test_no_extension(self, tmp_path: Path) -> None:
"""无扩展名的文件 extension 应为空字符串。"""
(tmp_path / "Makefile").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
f = next(e for e in entries if e.rel_path == "Makefile")
assert f.extension == ""
def test_root_not_in_entries(self, tmp_path: Path) -> None:
"""根目录自身不应出现在结果中。"""
(tmp_path / "a.py").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "." not in paths
assert "" not in paths
# ---------------------------------------------------------------------------
# 属性测试 — Property 7: 扫描器排除规则
# Feature: repo-audit, Property 7: 扫描器排除规则
# Validates: Requirements 1.1
# ---------------------------------------------------------------------------
import fnmatch
import string
import tempfile
from hypothesis import given, settings
from hypothesis import strategies as st
# --- 生成器策略 ---
# 合法的文件/目录名字符(排除路径分隔符和特殊字符)
_SAFE_CHARS = string.ascii_lowercase + string.digits + "_-"
# 安全的文件名策略(不与排除模式冲突的普通名称)
_safe_name = st.text(_SAFE_CHARS, min_size=1, max_size=8)
# 排除模式中的目录名
_EXCLUDED_DIR_NAMES = [".git", "__pycache__", ".pytest_cache", ".kiro"]
# 排除模式中的文件扩展名
_EXCLUDED_FILE_EXT = ".pyc"
# 随机选择一个被排除的目录名
_excluded_dir_name = st.sampled_from(_EXCLUDED_DIR_NAMES)
def _build_tree(tmp: Path, normal_names: list[str], excluded_dirs: list[str],
include_pyc: bool) -> None:
"""在临时目录中构建包含正常文件和被排除条目的文件树。"""
# 创建正常文件
for name in normal_names:
safe = name or "f"
filepath = tmp / f"{safe}.txt"
if not filepath.exists():
filepath.write_text("ok", encoding="utf-8")
# 创建被排除的目录(含子文件)
for dirname in excluded_dirs:
d = tmp / dirname
d.mkdir(exist_ok=True)
(d / "inner.txt").write_text("hidden", encoding="utf-8")
# 可选:创建 .pyc 文件
if include_pyc:
(tmp / "module.pyc").write_bytes(b"\x00")
class TestProperty7ScannerExclusionRules:
"""
Property 7: 扫描器排除规则
对于任意文件树scan_repo 返回的 FileEntry 列表中不应包含
rel_path 匹配排除模式(.git、__pycache__、.pytest_cache 等)的条目。
Feature: repo-audit, Property 7: 扫描器排除规则
Validates: Requirements 1.1
"""
@given(
normal_names=st.lists(_safe_name, min_size=0, max_size=5),
excluded_dirs=st.lists(_excluded_dir_name, min_size=1, max_size=3),
include_pyc=st.booleans(),
)
@settings(max_examples=100)
def test_excluded_entries_never_in_results(
self,
normal_names: list[str],
excluded_dirs: list[str],
include_pyc: bool,
) -> None:
"""扫描结果中不应包含任何匹配排除模式的条目。"""
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
_build_tree(tmp, normal_names, excluded_dirs, include_pyc)
entries = scan_repo(tmp)
for entry in entries:
# 检查 rel_path 的每一段是否匹配排除模式
parts = entry.rel_path.split("/")
for part in parts:
for pat in EXCLUDED_PATTERNS:
assert not fnmatch.fnmatch(part, pat), (
f"排除模式 '{pat}' 不应出现在结果中,"
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
)
@given(
excluded_dir=_excluded_dir_name,
depth=st.integers(min_value=1, max_value=3),
)
@settings(max_examples=100)
def test_excluded_dirs_at_any_depth(
self,
excluded_dir: str,
depth: int,
) -> None:
"""被排除目录无论在哪一层嵌套深度,都不应出现在结果中。"""
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
# 构建嵌套路径normal/normal/.../excluded_dir/file.txt
current = tmp
for i in range(depth):
current = current / f"level{i}"
current.mkdir(exist_ok=True)
# 放一个正常文件保证父目录非空
(current / "keep.txt").write_text("ok", encoding="utf-8")
# 在最深层放置被排除目录
excluded = current / excluded_dir
excluded.mkdir(exist_ok=True)
(excluded / "secret.txt").write_text("hidden", encoding="utf-8")
entries = scan_repo(tmp)
for entry in entries:
parts = entry.rel_path.split("/")
assert excluded_dir not in parts, (
f"被排除目录 '{excluded_dir}' 不应出现在结果中,"
f"但发现 rel_path='{entry.rel_path}'"
)
@given(
custom_patterns=st.lists(
st.sampled_from(["*.log", "*.tmp", "*.bak", "node_modules", ".venv"]),
min_size=1,
max_size=3,
),
)
@settings(max_examples=100)
def test_custom_exclude_patterns_respected(
self,
custom_patterns: list[str],
) -> None:
"""自定义排除模式同样应被 scan_repo 正确排除。"""
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
# 创建一个正常文件
(tmp / "main.py").write_text("pass", encoding="utf-8")
# 为每个自定义模式创建匹配的文件或目录
for pat in custom_patterns:
if pat.startswith("*."):
# 通配符模式 → 创建匹配的文件
ext = pat[1:] # e.g. ".log"
(tmp / f"data{ext}").write_text("x", encoding="utf-8")
else:
# 精确匹配 → 创建目录
d = tmp / pat
d.mkdir(exist_ok=True)
(d / "inner.txt").write_text("x", encoding="utf-8")
entries = scan_repo(tmp, exclude=custom_patterns)
for entry in entries:
parts = entry.rel_path.split("/")
for part in parts:
for pat in custom_patterns:
assert not fnmatch.fnmatch(part, pat), (
f"自定义排除模式 '{pat}' 不应出现在结果中,"
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
)

View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
"""CLI 参数解析单元测试
验证 --data-source 新参数、--pipeline-flow 弃用映射、
--pipeline + --tasks 同时使用、以及 build_cli_overrides 集成行为。
需求: 3.1, 3.3, 3.5
"""
import warnings
from argparse import Namespace
from unittest.mock import patch
import pytest
from cli.main import parse_args, resolve_data_source, build_cli_overrides
# ---------------------------------------------------------------------------
# 1. --data-source 新参数解析
# ---------------------------------------------------------------------------
class TestDataSourceArg:
"""--data-source 新参数测试"""
@pytest.mark.parametrize("value", ["online", "offline", "hybrid"])
def test_data_source_valid_values(self, value):
with patch("sys.argv", ["cli", "--data-source", value]):
args = parse_args()
assert args.data_source == value
def test_data_source_default_is_none(self):
with patch("sys.argv", ["cli"]):
args = parse_args()
assert args.data_source is None
# ---------------------------------------------------------------------------
# 2. resolve_data_source() 弃用映射
# ---------------------------------------------------------------------------
class TestResolveDataSource:
"""resolve_data_source() 弃用映射测试"""
def test_explicit_data_source_returns_directly(self):
args = Namespace(data_source="online", pipeline_flow=None)
assert resolve_data_source(args) == "online"
def test_data_source_takes_priority_over_pipeline_flow(self):
"""--data-source 优先于 --pipeline-flow"""
args = Namespace(data_source="online", pipeline_flow="FULL")
assert resolve_data_source(args) == "online"
@pytest.mark.parametrize(
"flow, expected",
[
("FULL", "hybrid"),
("FETCH_ONLY", "online"),
("INGEST_ONLY", "offline"),
],
)
def test_pipeline_flow_maps_with_deprecation_warning(self, flow, expected):
"""旧参数 --pipeline-flow 映射到正确的 data_source 并发出弃用警告"""
args = Namespace(data_source=None, pipeline_flow=flow)
with pytest.warns(DeprecationWarning, match="--pipeline-flow 已弃用"):
result = resolve_data_source(args)
assert result == expected
def test_neither_arg_defaults_to_hybrid(self):
"""两个参数都未指定时,默认返回 hybrid"""
args = Namespace(data_source=None, pipeline_flow=None)
assert resolve_data_source(args) == "hybrid"
# ---------------------------------------------------------------------------
# 3. build_cli_overrides() 集成
# ---------------------------------------------------------------------------
class TestBuildCliOverrides:
"""build_cli_overrides() 集成测试"""
def _make_args(self, **kwargs):
"""构造最小 Namespace未指定的参数设为 None/False"""
defaults = dict(
store_id=None, tasks=None, dry_run=False,
pipeline=None, processing_mode="increment_only",
fetch_before_verify=False, verify_tables=None,
window_split="none", lookback_hours=24, overlap_seconds=3600,
pg_dsn=None, pg_host=None, pg_port=None, pg_name=None,
pg_user=None, pg_password=None,
api_base=None, api_token=None, api_timeout=None,
api_page_size=None, api_retry_max=None,
window_start=None, window_end=None,
force_window_override=False,
window_split_unit=None, window_split_days=None,
window_compensation_hours=None,
export_root=None, log_root=None,
data_source=None, pipeline_flow=None,
fetch_root=None, ingest_source=None, write_pretty_json=False,
idle_start=None, idle_end=None, allow_empty_advance=False,
)
defaults.update(kwargs)
return Namespace(**defaults)
def test_data_source_online_sets_run_key(self):
args = self._make_args(data_source="online")
overrides = build_cli_overrides(args)
assert overrides["run"]["data_source"] == "online"
def test_pipeline_flow_sets_both_keys(self):
"""旧参数同时写入 pipeline.flow 和 run.data_source"""
args = self._make_args(pipeline_flow="FULL")
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
overrides = build_cli_overrides(args)
assert overrides["pipeline"]["flow"] == "FULL"
assert overrides["run"]["data_source"] == "hybrid"
def test_default_data_source_is_hybrid(self):
"""无 --data-source 也无 --pipeline-flow 时run.data_source 默认 hybrid"""
args = self._make_args()
overrides = build_cli_overrides(args)
assert overrides["run"]["data_source"] == "hybrid"
# ---------------------------------------------------------------------------
# 4. --pipeline + --tasks 同时使用
# ---------------------------------------------------------------------------
class TestPipelineAndTasks:
"""--pipeline + --tasks 同时使用时的行为"""
def test_pipeline_and_tasks_both_parsed(self):
with patch("sys.argv", [
"cli",
"--pipeline", "api_full",
"--tasks", "ODS_MEMBER,ODS_ORDER",
]):
args = parse_args()
assert args.pipeline == "api_full"
assert args.tasks == "ODS_MEMBER,ODS_ORDER"

View File

@@ -0,0 +1,426 @@
"""DDL 解析器和对比逻辑的单元测试。
测试范围:
- DDL 解析器正确提取表名、字段、类型、约束
- 类型标准化逻辑
- 差异检测逻辑识别各类差异
- 边界情况空文件、COMMENT 含特殊字符
"""
import pytest
from scripts.compare_ddl_db import (
ColumnDef,
DiffKind,
SchemaDiff,
TableDef,
compare_tables,
normalize_type,
parse_ddl,
)
# =========================================================================
# normalize_type 测试
# =========================================================================
class TestNormalizeType:
"""类型标准化测试。"""
@pytest.mark.parametrize("raw,expected", [
("BIGINT", "bigint"),
("INT8", "bigint"),
("INTEGER", "integer"),
("INT", "integer"),
("INT4", "integer"),
("SMALLINT", "smallint"),
("INT2", "smallint"),
("BOOLEAN", "boolean"),
("BOOL", "boolean"),
("TEXT", "text"),
("JSONB", "jsonb"),
("JSON", "json"),
("DATE", "date"),
("BYTEA", "bytea"),
("UUID", "uuid"),
])
def test_simple_types(self, raw, expected):
assert normalize_type(raw) == expected
@pytest.mark.parametrize("raw,expected", [
("NUMERIC(18,2)", "numeric(18,2)"),
("NUMERIC(10,6)", "numeric(10,6)"),
("DECIMAL(5,2)", "numeric(5,2)"),
("NUMERIC(10)", "numeric(10)"),
("NUMERIC", "numeric"),
])
def test_numeric_types(self, raw, expected):
assert normalize_type(raw) == expected
@pytest.mark.parametrize("raw,expected", [
("VARCHAR(50)", "varchar(50)"),
("CHARACTER VARYING(100)", "varchar(100)"),
("VARCHAR", "varchar"),
("CHAR(1)", "char(1)"),
("CHARACTER(10)", "char(10)"),
])
def test_string_types(self, raw, expected):
assert normalize_type(raw) == expected
@pytest.mark.parametrize("raw,expected", [
("TIMESTAMP", "timestamp"),
("TIMESTAMP WITHOUT TIME ZONE", "timestamp"),
("TIMESTAMPTZ", "timestamptz"),
("TIMESTAMP WITH TIME ZONE", "timestamptz"),
])
def test_timestamp_types(self, raw, expected):
assert normalize_type(raw) == expected
@pytest.mark.parametrize("raw,expected", [
("BIGSERIAL", "bigint"),
("SERIAL", "integer"),
("SMALLSERIAL", "smallint"),
])
def test_serial_types(self, raw, expected):
"""serial 家族应映射到底层整数类型。"""
assert normalize_type(raw) == expected
def test_case_insensitive(self):
assert normalize_type("bigint") == normalize_type("BIGINT")
assert normalize_type("Numeric(18,2)") == normalize_type("NUMERIC(18,2)")
# =========================================================================
# parse_ddl 测试
# =========================================================================
class TestParseDdl:
"""DDL 解析器测试。"""
def test_basic_create_table(self):
"""基本 CREATE TABLE 解析。"""
sql = """
CREATE TABLE IF NOT EXISTS myschema.users (
id BIGINT NOT NULL,
name TEXT,
age INTEGER,
PRIMARY KEY (id)
);
"""
tables = parse_ddl(sql, target_schema="myschema")
assert "users" in tables
t = tables["users"]
assert len(t.columns) == 3
assert t.pk_columns == ["id"]
assert t.columns["id"].data_type == "bigint"
assert t.columns["id"].nullable is False
assert t.columns["name"].data_type == "text"
assert t.columns["name"].nullable is True
assert t.columns["age"].data_type == "integer"
def test_inline_primary_key(self):
"""内联 PRIMARY KEY 约束。"""
sql = """
CREATE TABLE test_schema.items (
item_id BIGSERIAL PRIMARY KEY,
label TEXT NOT NULL
);
"""
tables = parse_ddl(sql, target_schema="test_schema")
t = tables["items"]
assert t.columns["item_id"].is_pk is True
# BIGSERIAL → bigint
assert t.columns["item_id"].data_type == "bigint"
assert t.columns["item_id"].nullable is False
assert t.columns["label"].nullable is False
def test_composite_primary_key(self):
"""复合主键。"""
sql = """
CREATE TABLE IF NOT EXISTS billiards_ods.member_profiles (
id BIGINT,
content_hash TEXT NOT NULL,
name TEXT,
PRIMARY KEY (id, content_hash)
);
"""
tables = parse_ddl(sql, target_schema="billiards_ods")
t = tables["member_profiles"]
assert t.pk_columns == ["id", "content_hash"]
assert t.columns["id"].is_pk is True
assert t.columns["content_hash"].is_pk is True
# PK 隐含 NOT NULL
assert t.columns["id"].nullable is False
def test_various_data_types(self):
"""各种 PostgreSQL 数据类型。"""
sql = """
CREATE TABLE s.t (
a BIGINT,
b VARCHAR(50),
c NUMERIC(18,2),
d TIMESTAMP,
e TIMESTAMPTZ DEFAULT now(),
f BOOLEAN DEFAULT TRUE,
g JSONB NOT NULL,
h TEXT,
i INTEGER
);
"""
tables = parse_ddl(sql, target_schema="s")
t = tables["t"]
assert t.columns["a"].data_type == "bigint"
assert t.columns["b"].data_type == "varchar(50)"
assert t.columns["c"].data_type == "numeric(18,2)"
assert t.columns["d"].data_type == "timestamp"
assert t.columns["e"].data_type == "timestamptz"
assert t.columns["f"].data_type == "boolean"
assert t.columns["g"].data_type == "jsonb"
assert t.columns["g"].nullable is False
assert t.columns["h"].data_type == "text"
assert t.columns["i"].data_type == "integer"
def test_without_schema_prefix(self):
"""无 schema 前缀的 CREATE TABLE如 DWD DDL 中 SET search_path 后)。"""
sql = """
SET search_path TO billiards_dwd;
CREATE TABLE IF NOT EXISTS dim_site (
site_id BIGINT,
shop_name TEXT,
PRIMARY KEY (site_id)
);
"""
# target_schema 指定时,无前缀的表也应被接受
tables = parse_ddl(sql, target_schema="billiards_dwd")
assert "dim_site" in tables
def test_schema_filter(self):
"""schema 过滤:只保留目标 schema 的表。"""
sql = """
CREATE TABLE schema_a.t1 (id BIGINT);
CREATE TABLE schema_b.t2 (id BIGINT);
"""
tables = parse_ddl(sql, target_schema="schema_a")
assert "t1" in tables
assert "t2" not in tables
def test_empty_ddl(self):
"""空 DDL 文件应返回空字典。"""
tables = parse_ddl("", target_schema="any")
assert tables == {}
def test_comments_ignored(self):
"""SQL 注释不影响解析。"""
sql = """
-- 这是注释
/* 块注释 */
CREATE TABLE s.t (
id BIGINT, -- 行内注释
name TEXT
);
"""
tables = parse_ddl(sql, target_schema="s")
assert "t" in tables
assert len(tables["t"].columns) == 2
def test_comment_on_statements_ignored(self):
"""COMMENT ON 语句不影响表解析。"""
sql = """
CREATE TABLE billiards_ods.test_table (
id BIGINT NOT NULL,
name TEXT,
PRIMARY KEY (id)
);
COMMENT ON TABLE billiards_ods.test_table IS '测试表:含特殊字符 ''引号'' 和 (括号)';
COMMENT ON COLUMN billiards_ods.test_table.id IS '【说明】主键 ID。【示例】12345。';
COMMENT ON COLUMN billiards_ods.test_table.name IS '【说明】名称,含 ''单引号''"双引号"';
"""
tables = parse_ddl(sql, target_schema="billiards_ods")
assert "test_table" in tables
assert len(tables["test_table"].columns) == 2
def test_drop_then_create(self):
"""DROP TABLE 后 CREATE TABLE 应正常解析。"""
sql = """
DROP TABLE IF EXISTS billiards_dws.cfg_test CASCADE;
CREATE TABLE billiards_dws.cfg_test (
id SERIAL PRIMARY KEY,
value TEXT
);
"""
tables = parse_ddl(sql, target_schema="billiards_dws")
assert "cfg_test" in tables
assert tables["cfg_test"].columns["id"].data_type == "integer"
def test_default_values_parsed(self):
"""DEFAULT 值不影响类型和约束解析。"""
sql = """
CREATE TABLE s.t (
enabled BOOLEAN DEFAULT TRUE,
created_at TIMESTAMPTZ DEFAULT now(),
count INTEGER DEFAULT 0 NOT NULL,
label VARCHAR(20) NOT NULL
);
"""
tables = parse_ddl(sql, target_schema="s")
t = tables["t"]
assert t.columns["enabled"].data_type == "boolean"
assert t.columns["enabled"].nullable is True
assert t.columns["created_at"].data_type == "timestamptz"
assert t.columns["count"].data_type == "integer"
assert t.columns["count"].nullable is False
assert t.columns["label"].data_type == "varchar(20)"
assert t.columns["label"].nullable is False
def test_constraint_lines_skipped(self):
"""表级约束行CONSTRAINT、UNIQUE、FOREIGN KEY应被跳过。"""
sql = """
CREATE TABLE etl_admin.etl_task (
task_id BIGSERIAL PRIMARY KEY,
task_code TEXT NOT NULL,
store_id BIGINT NOT NULL,
UNIQUE (task_code, store_id)
);
"""
tables = parse_ddl(sql, target_schema="etl_admin")
t = tables["etl_task"]
assert len(t.columns) == 3
assert "task_id" in t.columns
assert "task_code" in t.columns
assert "store_id" in t.columns
def test_real_ods_ddl_parseable(self):
"""验证实际 ODS DDL 文件可被解析。"""
from pathlib import Path
ddl_path = Path("database/schema_ODS_doc.sql")
if not ddl_path.exists():
pytest.skip("DDL 文件不存在")
sql = ddl_path.read_text(encoding="utf-8")
tables = parse_ddl(sql, target_schema="billiards_ods")
# 至少应有 20+ 张表
assert len(tables) >= 20
# 每张表都应有字段
for tbl in tables.values():
assert len(tbl.columns) > 0
# =========================================================================
# compare_tables 测试
# =========================================================================
class TestCompareTables:
"""差异检测逻辑测试。"""
def _make_table(self, name: str, columns: dict[str, tuple[str, bool]],
pk: list[str] | None = None) -> TableDef:
"""辅助方法:快速构建 TableDef。
columns: {col_name: (data_type, nullable)}
"""
cols = {}
for col_name, (dtype, nullable) in columns.items():
cols[col_name] = ColumnDef(
name=col_name,
data_type=dtype,
nullable=nullable,
is_pk=col_name in (pk or []),
)
return TableDef(name=name, columns=cols, pk_columns=pk or [])
def test_no_diff(self):
"""完全一致时应返回空列表。"""
t = self._make_table("t", {"id": ("bigint", False), "name": ("text", True)}, pk=["id"])
diffs = compare_tables({"t": t}, {"t": t})
assert diffs == []
def test_missing_table(self):
"""数据库有但 DDL 没有 → MISSING_TABLE。"""
ddl = {}
db = {"extra": self._make_table("extra", {"id": ("bigint", False)})}
diffs = compare_tables(ddl, db)
assert len(diffs) == 1
assert diffs[0].kind == DiffKind.MISSING_TABLE
assert diffs[0].table == "extra"
def test_extra_table(self):
"""DDL 有但数据库没有 → EXTRA_TABLE。"""
ddl = {"orphan": self._make_table("orphan", {"id": ("bigint", False)})}
db = {}
diffs = compare_tables(ddl, db)
assert len(diffs) == 1
assert diffs[0].kind == DiffKind.EXTRA_TABLE
assert diffs[0].table == "orphan"
def test_missing_column(self):
"""数据库有但 DDL 没有的字段 → MISSING_COLUMN。"""
ddl = {"t": self._make_table("t", {"id": ("bigint", False)})}
db = {"t": self._make_table("t", {
"id": ("bigint", False),
"new_col": ("text", True),
})}
diffs = compare_tables(ddl, db)
assert len(diffs) == 1
assert diffs[0].kind == DiffKind.MISSING_COLUMN
assert diffs[0].column == "new_col"
def test_extra_column(self):
"""DDL 有但数据库没有的字段 → EXTRA_COLUMN。"""
ddl = {"t": self._make_table("t", {
"id": ("bigint", False),
"old_col": ("text", True),
})}
db = {"t": self._make_table("t", {"id": ("bigint", False)})}
diffs = compare_tables(ddl, db)
assert len(diffs) == 1
assert diffs[0].kind == DiffKind.EXTRA_COLUMN
assert diffs[0].column == "old_col"
def test_type_mismatch(self):
"""字段类型不一致 → TYPE_MISMATCH。"""
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
db = {"t": self._make_table("t", {"val": ("varchar(100)", True)})}
diffs = compare_tables(ddl, db)
assert len(diffs) == 1
assert diffs[0].kind == DiffKind.TYPE_MISMATCH
assert diffs[0].ddl_value == "text"
assert diffs[0].db_value == "varchar(100)"
def test_nullable_mismatch(self):
"""可空约束不一致 → NULLABLE_MISMATCH。"""
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
db = {"t": self._make_table("t", {"val": ("text", False)})}
diffs = compare_tables(ddl, db)
assert len(diffs) == 1
assert diffs[0].kind == DiffKind.NULLABLE_MISMATCH
assert diffs[0].ddl_value == "NULL"
assert diffs[0].db_value == "NOT NULL"
def test_multiple_diffs(self):
"""多种差异同时存在。"""
ddl = {
"t1": self._make_table("t1", {
"id": ("bigint", False),
"extra": ("text", True),
}),
"ddl_only": self._make_table("ddl_only", {"x": ("integer", True)}),
}
db = {
"t1": self._make_table("t1", {
"id": ("integer", False), # TYPE_MISMATCH
"missing": ("text", True), # MISSING_COLUMN
}),
"db_only": self._make_table("db_only", {"y": ("text", True)}),
}
diffs = compare_tables(ddl, db)
kinds = {d.kind for d in diffs}
assert DiffKind.MISSING_TABLE in kinds # db_only
assert DiffKind.EXTRA_TABLE in kinds # ddl_only
assert DiffKind.MISSING_COLUMN in kinds # t1.missing
assert DiffKind.EXTRA_COLUMN in kinds # t1.extra
assert DiffKind.TYPE_MISMATCH in kinds # t1.id
def test_empty_both(self):
"""两边都为空时应返回空列表。"""
assert compare_tables({}, {}) == []

View File

@@ -0,0 +1,545 @@
"""DDL 对比脚本差异检测完整性 — 属性测试。
# AI_CHANGELOG [2026-02-13] max_examples 从 100 降至 30 以加速测试运行
**Property 2: DDL 对比脚本差异检测完整性**
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
使用 hypothesis 生成随机的 DDL 表定义和数据库表定义,
注入已知差异,验证 compare_tables 能检测到所有差异。
"""
from __future__ import annotations
import string
import hypothesis.strategies as st
from hypothesis import given, settings, assume
from scripts.compare_ddl_db import (
ColumnDef,
DiffKind,
TableDef,
compare_tables,
)
# =========================================================================
# 自定义 Strategy生成随机的 ColumnDef / TableDef
# =========================================================================
# 可用的标准化类型池(与 normalize_type 输出一致)
_NORMALIZED_TYPES = [
"bigint", "integer", "smallint", "boolean", "text",
"jsonb", "json", "date", "bytea", "uuid",
"timestamp", "timestamptz", "double precision",
"varchar(50)", "varchar(100)", "varchar(255)",
"char(1)", "char(10)",
"numeric(18,2)", "numeric(10,6)", "numeric(5,0)",
]
# 合法的标识符字符集(小写字母 + 下划线 + 数字,首字符为字母)
_IDENT_ALPHABET = string.ascii_lowercase + "_"
_IDENT_FULL = _IDENT_ALPHABET + string.digits
def st_identifier() -> st.SearchStrategy[str]:
"""生成合法的 PostgreSQL 标识符小写3~20 字符)。"""
return st.from_regex(r"[a-z][a-z0-9_]{2,19}", fullmatch=True)
def st_data_type() -> st.SearchStrategy[str]:
"""从标准化类型池中随机选取一个类型。"""
return st.sampled_from(_NORMALIZED_TYPES)
def st_column_def(name: str | None = None) -> st.SearchStrategy[ColumnDef]:
"""生成随机的 ColumnDef。"""
return st.builds(
ColumnDef,
name=st.just(name) if name else st_identifier(),
data_type=st_data_type(),
nullable=st.booleans(),
is_pk=st.just(False), # PK 由 TableDef 层面控制
default=st.just(None),
)
@st.composite
def st_table_def(draw, name: str | None = None) -> TableDef:
"""生成随机的 TableDef1~8 个字段,可选主键)。"""
tbl_name = name or draw(st_identifier())
num_cols = draw(st.integers(min_value=1, max_value=8))
# 生成不重复的列名
col_names = draw(
st.lists(st_identifier(), min_size=num_cols, max_size=num_cols, unique=True)
)
columns: dict[str, ColumnDef] = {}
for cn in col_names:
col = draw(st_column_def(name=cn))
columns[cn] = col
# 随机选取 0~2 个列作为主键
pk_count = draw(st.integers(min_value=0, max_value=min(2, len(col_names))))
pk_cols = draw(
st.lists(
st.sampled_from(col_names),
min_size=pk_count, max_size=pk_count, unique=True,
)
) if pk_count > 0 and col_names else []
for pk_col in pk_cols:
columns[pk_col].is_pk = True
columns[pk_col].nullable = False
return TableDef(name=tbl_name, columns=columns, pk_columns=pk_cols)
@st.composite
def st_table_dict(draw, min_tables: int = 0, max_tables: int = 5) -> dict[str, TableDef]:
"""生成 {表名: TableDef} 字典,表名不重复。"""
num = draw(st.integers(min_value=min_tables, max_value=max_tables))
names = draw(st.lists(st_identifier(), min_size=num, max_size=num, unique=True))
tables: dict[str, TableDef] = {}
for n in names:
tables[n] = draw(st_table_def(name=n))
return tables
# =========================================================================
# Property 2: DDL 对比脚本差异检测完整性
# **Validates: Requirements 2.1, 2.2, 2.3, 2.4**
# =========================================================================
class TestProperty2DiffDetection:
"""Property 2: 对比脚本能检测到所有注入的差异。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
@given(base=st_table_dict(min_tables=1, max_tables=4))
@settings(max_examples=30)
def test_identical_tables_produce_no_diffs(self, base: dict[str, TableDef]):
"""完全相同的定义不应产生任何差异。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
diffs = compare_tables(base, base)
assert diffs == [], f"相同定义不应有差异,但得到: {diffs}"
@given(
common=st_table_dict(min_tables=0, max_tables=3),
db_only=st_table_dict(min_tables=1, max_tables=3),
)
@settings(max_examples=30)
def test_missing_table_detected(
self,
common: dict[str, TableDef],
db_only: dict[str, TableDef],
):
"""数据库有但 DDL 没有的表应报告为 MISSING_TABLE。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
# 确保 db_only 的表名与 common 不重叠
overlap = set(common.keys()) & set(db_only.keys())
assume(not overlap)
ddl_tables = dict(common)
db_tables = {**common, **db_only}
diffs = compare_tables(ddl_tables, db_tables)
missing_tables = {d.table for d in diffs if d.kind == DiffKind.MISSING_TABLE}
for tbl in db_only:
assert tbl in missing_tables, (
f"'{tbl}' 仅存在于数据库中,应报告为 MISSING_TABLE"
)
@given(
common=st_table_dict(min_tables=0, max_tables=3),
ddl_only=st_table_dict(min_tables=1, max_tables=3),
)
@settings(max_examples=30)
def test_extra_table_detected(
self,
common: dict[str, TableDef],
ddl_only: dict[str, TableDef],
):
"""DDL 有但数据库没有的表应报告为 EXTRA_TABLE。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
overlap = set(common.keys()) & set(ddl_only.keys())
assume(not overlap)
ddl_tables = {**common, **ddl_only}
db_tables = dict(common)
diffs = compare_tables(ddl_tables, db_tables)
extra_tables = {d.table for d in diffs if d.kind == DiffKind.EXTRA_TABLE}
for tbl in ddl_only:
assert tbl in extra_tables, (
f"'{tbl}' 仅存在于 DDL 中,应报告为 EXTRA_TABLE"
)
@given(
table=st_table_def(),
extra_cols=st.lists(
st.tuples(st_identifier(), st_data_type(), st.booleans()),
min_size=1, max_size=4, unique_by=lambda x: x[0],
),
)
@settings(max_examples=30)
def test_missing_column_detected(
self,
table: TableDef,
extra_cols: list[tuple[str, str, bool]],
):
"""数据库有但 DDL 没有的字段应报告为 MISSING_COLUMN。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
# 确保注入的列名与现有列不重叠
existing_names = set(table.columns.keys())
injected_names = {c[0] for c in extra_cols}
assume(not (existing_names & injected_names))
# DDL 侧:原始表
ddl_tables = {table.name: table}
# DB 侧:原始表 + 额外字段
db_table = TableDef(
name=table.name,
columns=dict(table.columns),
pk_columns=list(table.pk_columns),
)
for col_name, col_type, nullable in extra_cols:
db_table.columns[col_name] = ColumnDef(
name=col_name, data_type=col_type, nullable=nullable,
)
db_tables = {table.name: db_table}
diffs = compare_tables(ddl_tables, db_tables)
missing_cols = {
d.column for d in diffs
if d.kind == DiffKind.MISSING_COLUMN and d.table == table.name
}
for col_name, _, _ in extra_cols:
assert col_name in missing_cols, (
f"字段 '{table.name}.{col_name}' 仅存在于数据库中,"
f"应报告为 MISSING_COLUMN"
)
@given(
table=st_table_def(),
extra_cols=st.lists(
st.tuples(st_identifier(), st_data_type(), st.booleans()),
min_size=1, max_size=4, unique_by=lambda x: x[0],
),
)
@settings(max_examples=30)
def test_extra_column_detected(
self,
table: TableDef,
extra_cols: list[tuple[str, str, bool]],
):
"""DDL 有但数据库没有的字段应报告为 EXTRA_COLUMN。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
existing_names = set(table.columns.keys())
injected_names = {c[0] for c in extra_cols}
assume(not (existing_names & injected_names))
# DDL 侧:原始表 + 额外字段
ddl_table = TableDef(
name=table.name,
columns=dict(table.columns),
pk_columns=list(table.pk_columns),
)
for col_name, col_type, nullable in extra_cols:
ddl_table.columns[col_name] = ColumnDef(
name=col_name, data_type=col_type, nullable=nullable,
)
ddl_tables = {table.name: ddl_table}
# DB 侧:原始表
db_tables = {table.name: table}
diffs = compare_tables(ddl_tables, db_tables)
extra_col_set = {
d.column for d in diffs
if d.kind == DiffKind.EXTRA_COLUMN and d.table == table.name
}
for col_name, _, _ in extra_cols:
assert col_name in extra_col_set, (
f"字段 '{table.name}.{col_name}' 仅存在于 DDL 中,"
f"应报告为 EXTRA_COLUMN"
)
@given(table=st_table_def())
@settings(max_examples=30)
def test_type_mismatch_detected(self, table: TableDef):
"""字段类型不一致时应报告为 TYPE_MISMATCH。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
"""
assume(len(table.columns) >= 1)
# 选取第一个字段,修改其类型
target_col_name = next(iter(table.columns))
original_type = table.columns[target_col_name].data_type
# 选一个不同的类型
alt_types = [t for t in _NORMALIZED_TYPES if t != original_type]
assume(len(alt_types) > 0)
new_type = alt_types[0]
# DDL 侧:原始定义
ddl_tables = {table.name: table}
# DB 侧:修改目标字段类型
db_table = TableDef(
name=table.name,
columns={
cn: ColumnDef(
name=c.name,
data_type=new_type if cn == target_col_name else c.data_type,
nullable=c.nullable,
is_pk=c.is_pk,
default=c.default,
)
for cn, c in table.columns.items()
},
pk_columns=list(table.pk_columns),
)
db_tables = {table.name: db_table}
diffs = compare_tables(ddl_tables, db_tables)
type_mismatches = [
d for d in diffs
if d.kind == DiffKind.TYPE_MISMATCH
and d.table == table.name
and d.column == target_col_name
]
assert len(type_mismatches) == 1, (
f"字段 '{table.name}.{target_col_name}' 类型从 "
f"'{original_type}' 变为 '{new_type}',应报告 TYPE_MISMATCH"
)
assert type_mismatches[0].ddl_value == original_type
assert type_mismatches[0].db_value == new_type
@given(
ddl_tables=st_table_dict(min_tables=0, max_tables=4),
db_tables=st_table_dict(min_tables=0, max_tables=4),
)
@settings(max_examples=30)
def test_all_diff_kinds_are_accounted_for(
self,
ddl_tables: dict[str, TableDef],
db_tables: dict[str, TableDef],
):
"""对任意输入,所有差异都应属于已知的 DiffKind 枚举值。
**Validates: Requirements 2.1, 2.2, 2.3, 2.4**
这是一个"元属性":验证对比函数不会产生未定义的差异类型,
且每条差异都有合理的 table/column 引用。
"""
diffs = compare_tables(ddl_tables, db_tables)
valid_kinds = set(DiffKind)
all_table_names = set(ddl_tables.keys()) | set(db_tables.keys())
for d in diffs:
# 差异类型必须是已知枚举值
assert d.kind in valid_kinds, f"未知差异类型: {d.kind}"
# 差异引用的表必须存在于某一侧
assert d.table in all_table_names, (
f"差异引用了不存在的表: {d.table}"
)
# 表级差异不应有 column
if d.kind in (DiffKind.MISSING_TABLE, DiffKind.EXTRA_TABLE):
assert d.column is None, (
f"表级差异 {d.kind} 不应包含 column 字段"
)
# 字段级差异必须有 column
if d.kind in (
DiffKind.MISSING_COLUMN, DiffKind.EXTRA_COLUMN,
DiffKind.TYPE_MISMATCH, DiffKind.NULLABLE_MISMATCH,
):
assert d.column is not None, (
f"字段级差异 {d.kind} 必须包含 column 字段"
)
# =========================================================================
# Property 3: DDL 修正后零差异(不动点)
# **Validates: Requirements 2.5**
# =========================================================================
class TestProperty3FixpointZeroDiff:
"""Property 3: DDL 修正后零差异(不动点)。
核心思想:以数据库实际状态修正 DDL 文件后,再次运行对比脚本,
差异列表应为空。即 compare_tables(db_state, db_state) == []。
**Validates: Requirements 2.5**
"""
@given(db_state=st_table_dict(min_tables=0, max_tables=5))
@settings(max_examples=30)
def test_identical_state_yields_no_diff(self, db_state: dict[str, TableDef]):
"""将同一份定义同时作为 DDL 和 DB 输入,差异应为零。
模拟场景:修正完成后 DDL 与数据库完全一致。
**Validates: Requirements 2.5**
"""
diffs = compare_tables(db_state, db_state)
assert diffs == [], (
f"不动点违反:相同定义应产生零差异,但得到 {len(diffs)} 项: "
f"{[str(d) for d in diffs]}"
)
@given(db_state=st_table_dict(min_tables=1, max_tables=5))
@settings(max_examples=30)
def test_copy_db_as_ddl_yields_no_diff(self, db_state: dict[str, TableDef]):
"""模拟修正过程:从 DB 定义深拷贝构造 DDL 定义,对比应为零差异。
这验证了"以数据库为准生成 DDL"后的不动点性质。
**Validates: Requirements 2.5**
"""
# 模拟修正:逐表逐字段从 DB 定义复制出一份新的 DDL 定义
ddl_fixed: dict[str, TableDef] = {}
for tbl_name, tbl in db_state.items():
new_columns: dict[str, ColumnDef] = {}
for col_name, col in tbl.columns.items():
new_columns[col_name] = ColumnDef(
name=col.name,
data_type=col.data_type,
nullable=col.nullable,
is_pk=col.is_pk,
default=col.default,
)
ddl_fixed[tbl_name] = TableDef(
name=tbl.name,
columns=new_columns,
pk_columns=list(tbl.pk_columns),
)
diffs = compare_tables(ddl_fixed, db_state)
assert diffs == [], (
f"不动点违反:从 DB 复制的 DDL 定义应与 DB 零差异,"
f"但得到 {len(diffs)} 项: {[str(d) for d in diffs]}"
)
@given(
db_state=st_table_dict(min_tables=1, max_tables=4),
extra_tables=st_table_dict(min_tables=1, max_tables=3),
)
@settings(max_examples=30)
def test_fixpoint_after_adding_missing_tables(
self,
db_state: dict[str, TableDef],
extra_tables: dict[str, TableDef],
):
"""模拟修正流程DDL 缺少部分表 → 补齐后再对比应为零差异。
步骤:
1. DB 有 common + extra 表DDL 只有 common 表
2. 发现 MISSING_TABLE 差异
3. 将缺失的表补入 DDL模拟修正
4. 再次对比,差异应为零
**Validates: Requirements 2.5**
"""
# 确保表名不重叠
overlap = set(db_state.keys()) & set(extra_tables.keys())
assume(not overlap)
full_db = {**db_state, **extra_tables}
ddl_before_fix = dict(db_state) # 缺少 extra_tables
# 第一次对比:应有差异
diffs_before = compare_tables(ddl_before_fix, full_db)
missing = {d.table for d in diffs_before if d.kind == DiffKind.MISSING_TABLE}
assert missing == set(extra_tables.keys()), (
f"修正前应检测到缺失表 {set(extra_tables.keys())},实际 {missing}"
)
# 模拟修正:将缺失的表补入 DDL
ddl_after_fix = {**ddl_before_fix, **extra_tables}
# 第二次对比:不动点,差异应为零
diffs_after = compare_tables(ddl_after_fix, full_db)
assert diffs_after == [], (
f"不动点违反:修正后应零差异,但得到 {len(diffs_after)} 项: "
f"{[str(d) for d in diffs_after]}"
)
@given(table=st_table_def())
@settings(max_examples=30)
def test_fixpoint_after_correcting_type_mismatch(self, table: TableDef):
"""模拟修正流程:字段类型不一致 → 以 DB 为准修正后零差异。
步骤:
1. DDL 中某字段类型与 DB 不同
2. 发现 TYPE_MISMATCH
3. 将 DDL 字段类型改为 DB 的类型(模拟修正)
4. 再次对比,差异应为零
**Validates: Requirements 2.5**
"""
assume(len(table.columns) >= 1)
target_col = next(iter(table.columns))
original_type = table.columns[target_col].data_type
alt_types = [t for t in _NORMALIZED_TYPES if t != original_type]
assume(len(alt_types) > 0)
wrong_type = alt_types[0]
# DDL 侧:目标字段使用错误类型
ddl_table = TableDef(
name=table.name,
columns={
cn: ColumnDef(
name=c.name,
data_type=wrong_type if cn == target_col else c.data_type,
nullable=c.nullable,
is_pk=c.is_pk,
default=c.default,
)
for cn, c in table.columns.items()
},
pk_columns=list(table.pk_columns),
)
ddl_before = {table.name: ddl_table}
db_tables = {table.name: table}
# 第一次对比:应有 TYPE_MISMATCH
diffs_before = compare_tables(ddl_before, db_tables)
type_diffs = [
d for d in diffs_before
if d.kind == DiffKind.TYPE_MISMATCH and d.column == target_col
]
assert len(type_diffs) >= 1, "修正前应检测到 TYPE_MISMATCH"
# 模拟修正:以 DB 为准,直接使用 DB 定义作为 DDL
ddl_after = dict(db_tables)
# 第二次对比:不动点
diffs_after = compare_tables(ddl_after, db_tables)
assert diffs_after == [], (
f"不动点违反:类型修正后应零差异,但得到 {len(diffs_after)} 项: "
f"{[str(d) for d in diffs_after]}"
)

View File

@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
"""配置管理测试"""
import pytest
from config.settings import AppConfig
from config.defaults import DEFAULTS
def test_config_load():
"""测试配置加载"""
config = AppConfig.load({"app": {"store_id": 1}})
assert config.get("app.timezone") == DEFAULTS["app"]["timezone"]
def test_config_override():
"""测试配置覆盖"""
overrides = {
"app": {"store_id": 12345}
}
config = AppConfig.load(overrides)
assert config.get("app.store_id") == 12345
def test_config_get_nested():
"""测试嵌套配置获取"""
config = AppConfig.load({"app": {"store_id": 1}})
assert config.get("db.batch_size") == 1000
assert config.get("nonexistent.key", "default") == "default"

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
"""配置映射属性测试 — 使用 hypothesis 验证配置键兼容映射的通用正确性属性。"""
import os
import warnings
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from config.settings import AppConfig, _FLOW_TO_DATA_SOURCE
# ── 确保测试不读取 .env 文件 ──────────────────────────────────
@pytest.fixture(autouse=True)
def skip_dotenv(monkeypatch):
monkeypatch.setenv("ETL_SKIP_DOTENV", "1")
# ── 生成策略 ──────────────────────────────────────────────────
flow_st = st.sampled_from(["FULL", "FETCH_ONLY", "INGEST_ONLY"])
# ── Property 11: pipeline_flow → data_source 映射一致性 ──────
# Feature: scheduler-refactor, Property 11: pipeline_flow → data_source 映射一致性
# **Validates: Requirements 8.1, 8.2, 8.3, 5.2, 8.4**
#
# 对于任意旧 pipeline_flow 值FULL/FETCH_ONLY/INGEST_ONLY
# 映射到 data_source 的结果应与预定义映射表一致:
# FULL→hybrid、FETCH_ONLY→online、INGEST_ONLY→offline。
# 同样,配置键 pipeline.flow 应自动映射到 run.data_source。
class TestProperty11FlowToDataSourceMapping:
"""Property 11: pipeline_flow → data_source 映射一致性。"""
@given(flow=flow_st)
@settings(max_examples=100)
def test_pipeline_flow_maps_to_data_source(self, flow):
"""通过 pipeline.flow 设置旧值后run.data_source 应与映射表一致。"""
expected = _FLOW_TO_DATA_SOURCE[flow]
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
config = AppConfig.load({
"app": {"store_id": 1},
"pipeline": {"flow": flow},
})
actual = config.get("run.data_source")
assert actual == expected, (
f"pipeline.flow={flow!r} 应映射为 run.data_source={expected!r}"
f"实际为 {actual!r}"
)

View File

@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""CLI 参数和管道类型文档覆盖完整性验证。
**Validates: Requirements 7.1, 7.2**
Property 6: 对于所有在 cli/main.py 的 parse_args() 中定义的 CLI 参数,
README.md 或 base_task_mechanism.md 中应包含该参数的说明。
Property 7: 对于所有在 PipelineRunner.PIPELINE_LAYERS 中定义的管道类型,
README.md 中应包含该管道类型的层组合说明。
"""
# Feature: etl-task-documentation, Property 6 & 7
from __future__ import annotations
import ast
import re
from pathlib import Path
import pytest
# ── 常量 ──────────────────────────────────────────────────────
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_CLI_MAIN_PATH = _PROJECT_ROOT / "cli" / "main.py"
_README_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "README.md"
_BASE_MECHANISM_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "base_task_mechanism.md"
# ── 辅助函数:通过 AST 解析 parse_args() 中的 CLI 参数名 ─────
def _extract_cli_params_via_ast() -> list[str]:
"""从 cli/main.py 的 parse_args() 函数中,通过 AST 提取所有 add_argument 的参数名。
只提取以 '--' 开头的长参数名(如 '--store-id'),忽略位置参数。
当 add_argument 有多个名称时(如 '--api-token', '--token'),取第一个 '--' 开头的名称。
"""
source = _CLI_MAIN_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(_CLI_MAIN_PATH))
params: list[str] = []
# 找到 parse_args 函数定义
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "parse_args":
continue
# 遍历函数体中的所有 add_argument 调用
for child in ast.walk(node):
if not isinstance(child, ast.Call):
continue
# 匹配 parser.add_argument(...) 或 xxx.add_argument(...)
func = child.func
if not (isinstance(func, ast.Attribute) and func.attr == "add_argument"):
continue
# 从位置参数中提取 '--xxx' 形式的参数名
for arg in child.args:
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
val = arg.value
if val.startswith("--"):
params.append(val)
break # 取第一个 '--' 开头的名称即可
return sorted(set(params))
# ── 辅助函数:提取 PIPELINE_LAYERS 的键 ──────────────────────
def _extract_pipeline_types() -> list[str]:
"""从 PipelineRunner.PIPELINE_LAYERS 获取所有管道类型名称。
直接导入 PIPELINE_LAYERS 字典,避免实例化 PipelineRunner。
"""
# 通过 AST 解析 pipeline_runner.py 提取 PIPELINE_LAYERS 的键
pr_path = _PROJECT_ROOT / "orchestration" / "pipeline_runner.py"
source = pr_path.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(pr_path))
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef) or node.name != "PipelineRunner":
continue
for item in node.body:
if not isinstance(item, (ast.Assign, ast.AnnAssign)):
continue
# 匹配 PIPELINE_LAYERS = {...} 或 PIPELINE_LAYERS: ... = {...}
targets = (
[item.target] if isinstance(item, ast.AnnAssign) else item.targets
)
for target in targets:
if isinstance(target, ast.Name) and target.id == "PIPELINE_LAYERS":
value = item.value
if isinstance(value, ast.Dict):
keys: list[str] = []
for k in value.keys:
if isinstance(k, ast.Constant) and isinstance(k.value, str):
keys.append(k.value)
return sorted(keys)
raise RuntimeError("未能从 pipeline_runner.py 中解析出 PIPELINE_LAYERS")
# ── 测试数据准备 ──────────────────────────────────────────────
_CLI_PARAMS: list[str] = _extract_cli_params_via_ast()
_PIPELINE_TYPES: list[str] = _extract_pipeline_types()
# ── Fixtures ──────────────────────────────────────────────────
@pytest.fixture(scope="module")
def readme_content() -> str:
"""读取 README.md 全文。"""
assert _README_PATH.exists(), f"文档文件不存在: {_README_PATH}"
return _README_PATH.read_text(encoding="utf-8")
@pytest.fixture(scope="module")
def base_mechanism_content() -> str:
"""读取 base_task_mechanism.md 全文。"""
assert _BASE_MECHANISM_PATH.exists(), f"文档文件不存在: {_BASE_MECHANISM_PATH}"
return _BASE_MECHANISM_PATH.read_text(encoding="utf-8")
# ── Property 6: CLI 参数文档覆盖完整性 ────────────────────────
@pytest.mark.parametrize("param", _CLI_PARAMS, ids=_CLI_PARAMS)
def test_cli_param_in_docs(param: str, readme_content: str, base_mechanism_content: str):
"""Property 6: 每个 CLI 参数在 README.md 或 base_task_mechanism.md 中有对应说明。
**Validates: Requirements 7.1**
"""
# 参数名以反引号包裹或直接出现均可
combined = readme_content + "\n" + base_mechanism_content
assert param in combined, (
f"CLI 参数 '{param}' 在 parse_args() 中定义,"
f"但未在 README.md 或 base_task_mechanism.md 中找到对应说明"
)
# ── Property 7: 管道类型文档覆盖完整性 ────────────────────────
@pytest.mark.parametrize("pipeline_type", _PIPELINE_TYPES, ids=_PIPELINE_TYPES)
def test_pipeline_type_in_readme(pipeline_type: str, readme_content: str):
"""Property 7: 每个管道类型在 README.md 中有对应的层组合说明。
**Validates: Requirements 7.2**
"""
assert pipeline_type in readme_content, (
f"管道类型 '{pipeline_type}' 在 PIPELINE_LAYERS 中定义,"
f"但未在 README.md 中找到对应的层组合说明"
)

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""DWD 任务文档覆盖完整性验证。
**Validates: Requirements 3.1**
从 task_registry.py 中提取所有 layer="DWD" 的任务代码,
验证 docs/etl_tasks/dwd_tasks.md 中包含每个任务代码的说明章节。
"""
# Feature: etl-task-documentation, Property 2: DWD 任务文档覆盖完整性
from __future__ import annotations
from pathlib import Path
import pytest
from orchestration.task_registry import default_registry
# ── 测试数据准备 ──────────────────────────────────────────────
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_DWD_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "dwd_tasks.md"
# 从注册表动态获取所有 DWD 层任务代码
_DWD_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("DWD")
@pytest.fixture(scope="module")
def dwd_doc_content() -> str:
"""读取 dwd_tasks.md 全文,供所有测试用例共享。"""
assert _DWD_DOC_PATH.exists(), f"文档文件不存在: {_DWD_DOC_PATH}"
return _DWD_DOC_PATH.read_text(encoding="utf-8")
# ── 参数化验证:每个 DWD 任务代码必须出现在文档中 ─────────────
@pytest.mark.parametrize("task_code", _DWD_TASK_CODES, ids=_DWD_TASK_CODES)
def test_dwd_task_code_in_doc(task_code: str, dwd_doc_content: str):
"""Property 2: 每个注册的 DWD 任务代码在 dwd_tasks.md 中有对应说明。
**Validates: Requirements 3.1**
"""
assert task_code in dwd_doc_content, (
f"DWD 任务 '{task_code}' 已在 task_registry 中注册,"
f"但未在 dwd_tasks.md 中找到对应说明章节"
)

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""DWS 任务文档覆盖完整性验证。
**Validates: Requirements 4.1, 4.4**
从 task_registry.py 中提取所有 layer="DWS" 的任务代码,
验证 docs/etl_tasks/dws_tasks.md 中包含每个任务代码的说明章节,
并标注其更新策略。
"""
# Feature: etl-task-documentation, Property 3: DWS 任务文档覆盖完整性
from __future__ import annotations
from pathlib import Path
import pytest
from orchestration.task_registry import default_registry
# ── 测试数据准备 ──────────────────────────────────────────────
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_DWS_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "dws_tasks.md"
# 从注册表动态获取所有 DWS 层任务代码
_DWS_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("DWS")
@pytest.fixture(scope="module")
def dws_doc_content() -> str:
"""读取 dws_tasks.md 全文,供所有测试用例共享。"""
assert _DWS_DOC_PATH.exists(), f"文档文件不存在: {_DWS_DOC_PATH}"
return _DWS_DOC_PATH.read_text(encoding="utf-8")
# ── 参数化验证:每个 DWS 任务代码必须出现在文档中 ─────────────
@pytest.mark.parametrize("task_code", _DWS_TASK_CODES, ids=_DWS_TASK_CODES)
def test_dws_task_code_in_doc(task_code: str, dws_doc_content: str):
"""Property 3: 每个注册的 DWS 任务代码在 dws_tasks.md 中有对应说明。
**Validates: Requirements 4.1, 4.4**
"""
assert task_code in dws_doc_content, (
f"DWS 任务 '{task_code}' 已在 task_registry 中注册,"
f"但未在 dws_tasks.md 中找到对应说明章节"
)

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
"""INDEX 和 Utility 任务文档覆盖完整性验证。
**Validates: Requirements 5.1, 6.1**
Property 4: 对于所有在 task_registry.py 中注册且 layer="INDEX" 的任务代码,
index_tasks.md 中应包含该任务代码的说明章节。
Property 5: 对于所有在 task_registry.py 中注册且 task_type="utility" 的任务代码,
utility_tasks.md 中应包含该任务代码的说明章节。
"""
# Feature: etl-task-documentation, Property 4 & 5
from __future__ import annotations
from pathlib import Path
import pytest
from orchestration.task_registry import default_registry
# ── 测试数据准备 ──────────────────────────────────────────────
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_INDEX_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "index_tasks.md"
_UTILITY_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "utility_tasks.md"
# INDEX 层任务:通过 get_tasks_by_layer 获取
_INDEX_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("INDEX")
# Utility 任务:注册表无 get_tasks_by_type直接遍历内部字典筛选
_UTILITY_TASK_CODES: list[str] = [
code
for code, meta in default_registry._tasks.items()
if meta.task_type == "utility"
]
# ── Fixtures ──────────────────────────────────────────────────
@pytest.fixture(scope="module")
def index_doc_content() -> str:
"""读取 index_tasks.md 全文。"""
assert _INDEX_DOC_PATH.exists(), f"文档文件不存在: {_INDEX_DOC_PATH}"
return _INDEX_DOC_PATH.read_text(encoding="utf-8")
@pytest.fixture(scope="module")
def utility_doc_content() -> str:
"""读取 utility_tasks.md 全文。"""
assert _UTILITY_DOC_PATH.exists(), f"文档文件不存在: {_UTILITY_DOC_PATH}"
return _UTILITY_DOC_PATH.read_text(encoding="utf-8")
# ── Property 4: INDEX 任务文档覆盖完整性 ──────────────────────
@pytest.mark.parametrize("task_code", _INDEX_TASK_CODES, ids=_INDEX_TASK_CODES)
def test_index_task_code_in_doc(task_code: str, index_doc_content: str):
"""Property 4: 每个注册的 INDEX 任务代码在 index_tasks.md 中有对应说明。
**Validates: Requirements 5.1**
"""
assert task_code in index_doc_content, (
f"INDEX 任务 '{task_code}' 已在 task_registry 中注册,"
f"但未在 index_tasks.md 中找到对应说明章节"
)
# ── Property 5: Utility 任务文档覆盖完整性 ────────────────────
@pytest.mark.parametrize("task_code", _UTILITY_TASK_CODES, ids=_UTILITY_TASK_CODES)
def test_utility_task_code_in_doc(task_code: str, utility_doc_content: str):
"""Property 5: 每个注册的 task_type="utility" 任务代码在 utility_tasks.md 中有对应说明。
**Validates: Requirements 6.1**
"""
assert task_code in utility_doc_content, (
f"Utility 任务 '{task_code}' 已在 task_registry 中注册,"
f"但未在 utility_tasks.md 中找到对应说明章节"
)

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""ODS 任务文档覆盖完整性验证。
**Validates: Requirements 2.1, 2.4**
从 task_registry.py 中提取所有 layer="ODS" 的任务代码,
验证 docs/etl_tasks/ods_tasks.md 中包含每个任务代码的说明章节。
"""
# Feature: etl-task-documentation, Property 1: ODS 任务文档覆盖完整性
from __future__ import annotations
from pathlib import Path
import pytest
from orchestration.task_registry import default_registry
# ── 测试数据准备 ──────────────────────────────────────────────
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_ODS_DOC_PATH = _PROJECT_ROOT / "docs" / "etl_tasks" / "ods_tasks.md"
# 从注册表动态获取所有 ODS 层任务代码
_ODS_TASK_CODES: list[str] = default_registry.get_tasks_by_layer("ODS")
@pytest.fixture(scope="module")
def ods_doc_content() -> str:
"""读取 ods_tasks.md 全文,供所有测试用例共享。"""
assert _ODS_DOC_PATH.exists(), f"文档文件不存在: {_ODS_DOC_PATH}"
return _ODS_DOC_PATH.read_text(encoding="utf-8")
# ── 参数化验证:每个 ODS 任务代码必须出现在文档中 ─────────────
@pytest.mark.parametrize("task_code", _ODS_TASK_CODES, ids=_ODS_TASK_CODES)
def test_ods_task_code_in_doc(task_code: str, ods_doc_content: str):
"""Property 1: 每个注册的 ODS 任务代码在 ods_tasks.md 中有对应说明。
**Validates: Requirements 2.1, 2.4**
"""
assert task_code in ods_doc_content, (
f"ODS 任务 '{task_code}' 已在 task_registry 中注册,"
f"但未在 ods_tasks.md 中找到对应说明章节"
)

View File

@@ -0,0 +1,479 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG
# - 2026-02-14 | bugfix: 修复 3 个测试 bug
# prompt: "继续。完成后检查所有任务是否全面"
# 直接原因: (1) mock_config.get 返回 None 导致 timezone 异常;(2) _build_daily_record 缺少 gift_card 参数;(3) loaded_at naive/aware 不匹配
# 变更: mock_config.get 改用 side_effect 返回 default补充 gift_card 参数loaded_at 改用 aware datetime
# 验证: pytest tests/unit -x449 passed
"""
DWS任务单元测试
测试内容:
- BaseDwsTask基类方法
- 时间计算方法
- 配置应用方法
- 排名计算方法
"""
import pytest
from datetime import date, datetime, timedelta
from decimal import Decimal
from unittest.mock import MagicMock, patch
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from tasks.dws.base_dws_task import (
BaseDwsTask,
TimeLayer,
TimeWindow,
CourseType,
TimeRange,
ConfigCache
)
from tasks.dws.finance_daily_task import FinanceDailyTask
from tasks.dws.assistant_monthly_task import AssistantMonthlyTask
class TestTimeLayerRange:
"""测试时间分层范围计算"""
def test_last_2_days(self):
"""测试近2天"""
base_date = date(2026, 2, 1)
# 创建一个模拟的BaseDwsTask实例
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_2_DAYS, base_date)
assert result.start == date(2026, 1, 31)
assert result.end == date(2026, 2, 1)
def test_last_1_month(self):
"""测试近1月"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_1_MONTH, base_date)
assert result.start == date(2026, 1, 2)
assert result.end == date(2026, 2, 1)
def test_last_3_months(self):
"""测试近3月"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_3_MONTHS, base_date)
assert result.start == date(2025, 11, 3)
assert result.end == date(2026, 2, 1)
class TestTimeWindowRange:
"""测试时间窗口范围计算"""
def test_this_week_monday_start(self):
"""测试本周(周一起始)"""
# 2026-02-01 是周日
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_WEEK, base_date)
# 本周一是 2026-01-26
assert result.start == date(2026, 1, 26)
assert result.end == date(2026, 2, 1)
def test_last_week(self):
"""测试上周"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_WEEK, base_date)
# 上周一是 2026-01-19上周日是 2026-01-25
assert result.start == date(2026, 1, 19)
assert result.end == date(2026, 1, 25)
def test_this_month(self):
"""测试本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_MONTH, base_date)
assert result.start == date(2026, 2, 1)
assert result.end == date(2026, 2, 15)
def test_last_month(self):
"""测试上月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_MONTH, base_date)
assert result.start == date(2026, 1, 1)
assert result.end == date(2026, 1, 31)
def test_last_3_months_excl_current(self):
"""测试前3个月不含本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_EXCL_CURRENT, base_date)
assert result.start == date(2025, 11, 1)
assert result.end == date(2026, 1, 31)
def test_last_3_months_incl_current(self):
"""测试前3个月含本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_INCL_CURRENT, base_date)
assert result.start == date(2025, 12, 1)
assert result.end == date(2026, 2, 15)
def test_this_quarter(self):
"""测试本季度"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_QUARTER, base_date)
assert result.start == date(2026, 1, 1)
assert result.end == date(2026, 2, 15)
def test_last_6_months(self):
"""测试最近半年(不含本月)"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_6_MONTHS, base_date)
# 不含本月从上月末往前6个月
assert result.end == date(2026, 1, 31)
assert result.start == date(2025, 8, 1)
class TestComparisonRange:
"""测试环比区间计算"""
def test_comparison_7_days(self):
"""测试7天环比"""
task = create_mock_task()
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 2, 7))
result = task.get_comparison_range(current)
# 上一个7天1月25日-1月31日
assert result.start == date(2026, 1, 25)
assert result.end == date(2026, 1, 31)
def test_comparison_30_days(self):
"""测试30天环比"""
task = create_mock_task()
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 3, 2))
result = task.get_comparison_range(current)
# 上一个30天区间
assert (result.end - result.start).days == (current.end - current.start).days
class TestFinanceDailyRecord:
"""测试财务日度记录计算"""
def test_groupbuy_and_cashflow(self):
"""测试团购优惠与现金流口径"""
task = create_finance_daily_task()
stat_date = date(2026, 2, 1)
settle = {
'gross_amount': Decimal('1000'),
'table_fee_amount': Decimal('1000'),
'goods_amount': Decimal('0'),
'assistant_pd_amount': Decimal('0'),
'assistant_cx_amount': Decimal('0'),
'cash_pay_amount': Decimal('300'),
'card_pay_amount': Decimal('0'),
'balance_pay_amount': Decimal('0'),
'gift_card_pay_amount': Decimal('0'),
'coupon_amount': Decimal('200'),
'pl_coupon_sale_amount': Decimal('0'),
'adjust_amount': Decimal('50'),
'member_discount_amount': Decimal('10'),
'rounding_amount': Decimal('0'),
'order_count': 1,
'member_order_count': 1,
'guest_order_count': 0,
}
groupbuy = {'groupbuy_pay_total': Decimal('80')}
recharge = {'recharge_cash': Decimal('20')}
expense = {'expense_amount': Decimal('40')}
platform = {
'settlement_amount': Decimal('60'),
'commission_amount': Decimal('5'),
'service_fee': Decimal('5'),
}
big_customer = {'big_customer_amount': Decimal('20')}
gift_card = {'gift_card_consume': Decimal('0')}
record = task._build_daily_record(
stat_date, settle, groupbuy, recharge, gift_card, expense, platform, big_customer, 1
)
assert record['discount_groupbuy'] == Decimal('120')
assert record['discount_other'] == Decimal('30')
assert record['platform_settlement_amount'] == Decimal('60')
assert record['platform_fee_amount'] == Decimal('10')
assert record['cash_inflow_total'] == Decimal('380')
assert record['cash_outflow_total'] == Decimal('50')
assert record['cash_balance_change'] == Decimal('330')
class TestNewHireTier:
"""测试新入职定档规则"""
def test_new_hire_tier_hours(self):
"""测试日均*30折算"""
task = create_assistant_monthly_task()
effective_hours = Decimal('15')
work_days = 5
result = task._calc_new_hire_tier_hours(effective_hours, work_days)
assert result == Decimal('90')
def test_max_tier_level_cap(self):
"""测试新入职定档上限"""
task = create_mock_task()
now = datetime.now(tz=task.tz)
task._config_cache = ConfigCache(
performance_tiers=[
{'tier_id': 1, 'tier_level': 1, 'min_hours': 0, 'max_hours': 100, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 2, 'tier_level': 2, 'min_hours': 100, 'max_hours': 200, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 3, 'tier_level': 3, 'min_hours': 200, 'max_hours': 300, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 4, 'tier_level': 4, 'min_hours': 300, 'max_hours': None, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
],
level_prices=[],
bonus_rules=[],
area_categories={},
skill_types={},
loaded_at=now
)
tier = task.get_performance_tier(
Decimal('350'),
is_new_hire=True,
effective_date=date(2026, 2, 1),
max_tier_level=3
)
assert tier['tier_level'] == 3
class TestNewHireCheck:
"""测试新入职判断"""
def test_new_hire_in_month(self):
"""测试月内入职为新入职"""
task = create_mock_task()
hire_date = date(2026, 2, 5)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == True
def test_not_new_hire(self):
"""测试月前入职不是新入职"""
task = create_mock_task()
hire_date = date(2026, 1, 15)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == False
def test_hire_on_first_day(self):
"""测试月1日入职为新入职"""
task = create_mock_task()
hire_date = date(2026, 2, 1)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == True
class TestRankWithTies:
"""测试考虑并列的排名计算"""
def test_no_ties(self):
"""测试无并列情况"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('90')),
(3, Decimal('80')),
]
result = task.calculate_rank_with_ties(values)
assert result[0] == (1, 1, 1) # 第1名
assert result[1] == (2, 2, 2) # 第2名
assert result[2] == (3, 3, 3) # 第3名
def test_with_ties(self):
"""测试有并列情况"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('100')), # 并列第1
(3, Decimal('80')),
]
result = task.calculate_rank_with_ties(values)
# 两个第1下一个是第3
assert result[0][1] == 1 # 第1名
assert result[1][1] == 1 # 并列第1名
assert result[2][1] == 3 # 第3名跳过2
def test_all_ties(self):
"""测试全部并列"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('100')),
(3, Decimal('100')),
]
result = task.calculate_rank_with_ties(values)
# 全部第1
assert all(r[1] == 1 for r in result)
class TestGuestCheck:
"""测试散客判断"""
def test_guest_zero(self):
"""测试member_id=0为散客"""
task = create_mock_task()
assert task.is_guest(0) == True
def test_guest_none(self):
"""测试member_id=None为散客"""
task = create_mock_task()
assert task.is_guest(None) == True
def test_not_guest(self):
"""测试正常会员不是散客"""
task = create_mock_task()
assert task.is_guest(12345) == False
class TestUtilityMethods:
"""测试工具方法"""
def test_safe_decimal(self):
"""测试安全Decimal转换"""
task = create_mock_task()
assert task.safe_decimal(100) == Decimal('100')
assert task.safe_decimal('123.45') == Decimal('123.45')
assert task.safe_decimal(None) == Decimal('0')
assert task.safe_decimal('invalid') == Decimal('0')
def test_safe_int(self):
"""测试安全int转换"""
task = create_mock_task()
assert task.safe_int(100) == 100
assert task.safe_int('123') == 123
assert task.safe_int(None) == 0
assert task.safe_int('invalid') == 0
def test_seconds_to_hours(self):
"""测试秒转小时"""
task = create_mock_task()
assert task.seconds_to_hours(3600) == Decimal('1')
assert task.seconds_to_hours(5400) == Decimal('1.5')
assert task.seconds_to_hours(0) == Decimal('0')
def test_hours_to_seconds(self):
"""测试小时转秒"""
task = create_mock_task()
assert task.hours_to_seconds(Decimal('1')) == 3600
assert task.hours_to_seconds(Decimal('1.5')) == 5400
class TestCourseType:
"""测试课程类型"""
def test_base_course(self):
"""测试基础课"""
assert CourseType.BASE.value == 'BASE'
def test_bonus_course(self):
"""测试附加课"""
assert CourseType.BONUS.value == 'BONUS'
# =============================================================================
# 辅助函数
# =============================================================================
def create_mock_task():
"""
创建一个模拟的BaseDwsTask实例用于测试
"""
# 创建一个具体的子类用于测试
class TestDwsTask(BaseDwsTask):
def get_task_code(self):
return "TEST_DWS_TASK"
def get_target_table(self):
return "test_table"
def get_primary_keys(self):
return ["id"]
def extract(self, context):
return {}
def load(self, transformed, context):
return {}
# 创建模拟的依赖
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
task = TestDwsTask(mock_config, mock_db, mock_api, mock_logger)
return task
def create_finance_daily_task():
"""创建 FinanceDailyTask 实例用于测试"""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: 1 if key == "app.tenant_id" else default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
return FinanceDailyTask(mock_config, mock_db, mock_api, mock_logger)
def create_assistant_monthly_task():
"""创建 AssistantMonthlyTask 实例用于测试"""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
return AssistantMonthlyTask(mock_config, mock_db, mock_api, mock_logger)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
"""端到端流程集成测试
验证 CLI → PipelineRunner → TaskExecutor 完整调用链。
使用 mock 依赖,不需要真实数据库。
需求: 9.4
"""
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
from orchestration.task_executor import TaskExecutor, DataSource
from orchestration.pipeline_runner import PipelineRunner
from orchestration.task_registry import TaskRegistry
# ---------------------------------------------------------------------------
# 辅助:构造最小可用的 mock config
# ---------------------------------------------------------------------------
def _make_config(**overrides):
"""构造一个行为类似 AppConfig 的 MagicMock。"""
store = {
"app.timezone": "Asia/Shanghai",
"app.store_id": 1,
"io.fetch_root": "/tmp/fetch",
"io.ingest_source_dir": "",
"io.write_pretty_json": False,
"io.export_root": "/tmp/export",
"io.log_root": "/tmp/logs",
"pipeline.fetch_root": None,
"pipeline.ingest_source_dir": None,
"run.ods_tasks": [],
"run.dws_tasks": [],
"run.index_tasks": [],
"run.data_source": "hybrid",
"verification.ods_use_local_json": False,
"verification.skip_ods_when_fetch_before_verify": True,
}
store.update(overrides)
config = MagicMock()
config.get = MagicMock(side_effect=lambda k, d=None: store.get(k, d))
config.__getitem__ = MagicMock(side_effect=lambda k: {
"io": {"export_root": "/tmp/export", "log_root": "/tmp/logs"},
}[k])
return config
# ---------------------------------------------------------------------------
# 辅助:构造一个可被 TaskRegistry 注册的假任务类
# ---------------------------------------------------------------------------
class _FakeTask:
"""最小假任务execute() 返回固定结果。"""
def __init__(self, config, db_ops, api_client, logger):
pass
def execute(self, cursor_data):
return {"status": "SUCCESS", "counts": {"fetched": 5, "inserted": 3}}
# ===========================================================================
# 测试 1传统模式 — TaskExecutor.run_tasks 端到端
# ===========================================================================
class TestTraditionalModeE2E:
"""传统模式TaskExecutor.run_tasks 端到端"""
def test_run_tasks_executes_utility_task_and_returns_results(self):
"""工具类任务走 _run_utility_task 路径,跳过游标和运行记录。"""
config = _make_config()
registry = TaskRegistry()
registry.register(
"FAKE_UTIL", _FakeTask,
requires_db_config=False, task_type="utility",
)
cursor_mgr = MagicMock()
run_tracker = MagicMock()
executor = TaskExecutor(
config=config,
db_ops=MagicMock(),
api_client=MagicMock(),
cursor_mgr=cursor_mgr,
run_tracker=run_tracker,
task_registry=registry,
logger=MagicMock(),
)
results = executor.run_tasks(["FAKE_UTIL"], data_source="hybrid")
assert len(results) == 1
# 工具类任务成功时 run_tasks 包装为 "成功"
assert results[0]["status"] in ("成功", "完成", "SUCCESS")
# 工具类任务不应触发游标或运行记录
cursor_mgr.get_or_create.assert_not_called()
run_tracker.create_run.assert_not_called()
# ===========================================================================
# 测试 2管道模式 — PipelineRunner → TaskExecutor 端到端
# ===========================================================================
class TestPipelineModeE2E:
"""管道模式PipelineRunner.run → TaskExecutor.run_tasks 端到端"""
def test_pipeline_delegates_to_executor_and_returns_structure(self):
"""PipelineRunner 解析层→任务后委托 TaskExecutor 执行。"""
executor = MagicMock()
executor.run_tasks.return_value = [
{"task_code": "FAKE_ODS", "status": "成功", "counts": {"fetched": 10, "inserted": 8}},
]
registry = TaskRegistry()
registry.register("FAKE_ODS", _FakeTask, layer="ODS")
config = _make_config()
runner = PipelineRunner(
config=config,
task_executor=executor,
task_registry=registry,
db_conn=MagicMock(),
api_client=MagicMock(),
logger=MagicMock(),
)
result = runner.run(
pipeline="api_ods",
processing_mode="increment_only",
data_source="hybrid",
)
# 结构验证
assert result["status"] == "SUCCESS"
assert result["pipeline"] == "api_ods"
assert result["layers"] == ["ODS"]
assert isinstance(result["results"], list)
# TaskExecutor 被调用
executor.run_tasks.assert_called_once()
call_args = executor.run_tasks.call_args
assert call_args[1]["data_source"] == "hybrid"
def test_pipeline_verify_only_skips_increment(self):
"""verify_only 模式跳过增量 ETL仅执行校验。"""
executor = MagicMock()
executor.run_tasks.return_value = []
registry = TaskRegistry()
config = _make_config()
runner = PipelineRunner(
config=config,
task_executor=executor,
task_registry=registry,
db_conn=MagicMock(),
api_client=MagicMock(),
logger=MagicMock(),
)
# 校验框架可能未安装mock 掉 _run_verification
with patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}):
result = runner.run(
pipeline="api_ods",
processing_mode="verify_only",
data_source="hybrid",
)
assert result["status"] == "SUCCESS"
# verify_only 且 fetch_before_verify=False 时不调用 run_tasks
executor.run_tasks.assert_not_called()
# ===========================================================================
# 测试 3ETLScheduler 薄包装层委托验证
# ===========================================================================
class TestSchedulerThinWrapper:
"""ETLScheduler 薄包装层正确委托 TaskExecutor / PipelineRunner。"""
def test_scheduler_delegates_run_tasks(self):
"""run_tasks() 委托给内部 task_executor。"""
from orchestration.scheduler import ETLScheduler
mock_config = MagicMock()
mock_config.__getitem__ = MagicMock(side_effect=lambda k: {
"db": {
"dsn": "postgresql://fake:5432/test",
"session": {"timezone": "Asia/Shanghai"},
"connect_timeout_sec": 5,
},
"api": {
"base_url": "https://fake.api",
"token": "fake-token",
"timeout_sec": 30,
"retries": {"max_attempts": 3},
},
}[k])
mock_config.get = MagicMock(side_effect=lambda k, d=None: {
"run.data_source": "hybrid",
"run.tasks": ["FAKE"],
"app.timezone": "Asia/Shanghai",
}.get(k, d))
# mock 掉资源创建,避免真实连接
with patch("orchestration.scheduler.DatabaseConnection"), \
patch("orchestration.scheduler.DatabaseOperations"), \
patch("orchestration.scheduler.APIClient"), \
patch("orchestration.scheduler.CursorManager"), \
patch("orchestration.scheduler.RunTracker"), \
patch("orchestration.scheduler.TaskExecutor") as MockTE, \
patch("orchestration.scheduler.PipelineRunner") as MockPR:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
scheduler = ETLScheduler(mock_config, MagicMock())
# run_tasks 委托
scheduler.run_tasks(["TEST_TASK"])
scheduler.task_executor.run_tasks.assert_called_once()
# run_pipeline_with_verification 委托
scheduler.run_pipeline_with_verification(pipeline="api_ods")
scheduler.pipeline_runner.run.assert_called_once()

View File

@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""Unit tests for recent/former endpoint routing."""
import sys
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.endpoint_routing import plan_calls, recent_boundary
TZ = ZoneInfo("Asia/Shanghai")
def _now():
return datetime(2025, 12, 18, 10, 0, 0, tzinfo=TZ)
def test_recent_boundary_month_start():
b = recent_boundary(_now())
assert b.isoformat() == "2025-09-01T00:00:00+08:00"
def test_paylog_routes_to_former_when_old_window():
params = {"siteId": 1, "StartPayTime": "2025-08-01 00:00:00", "EndPayTime": "2025-08-02 00:00:00"}
calls = plan_calls("/PayLog/GetPayLogListPage", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/PayLog/GetFormerPayLogListPage"]
def test_coupon_usage_stays_same_path_even_when_old():
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
calls = plan_calls("/Promotion/GetOfflineCouponConsumePageList", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/Promotion/GetOfflineCouponConsumePageList"]
def test_goods_outbound_routes_to_queryformer_when_old():
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
calls = plan_calls("/GoodsStockManage/QueryGoodsOutboundReceipt", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/GoodsStockManage/QueryFormerGoodsOutboundReceipt"]
def test_settlement_records_split_when_crossing_boundary():
params = {"siteId": 1, "rangeStartTime": "2025-08-15 00:00:00", "rangeEndTime": "2025-09-10 00:00:00"}
calls = plan_calls("/Site/GetAllOrderSettleList", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/Site/GetFormerOrderSettleList", "/Site/GetAllOrderSettleList"]
assert calls[0].params["rangeEndTime"] == "2025-09-01 00:00:00"
assert calls[1].params["rangeStartTime"] == "2025-09-01 00:00:00"
@pytest.mark.parametrize(
"endpoint",
[
"/PayLog/GetFormerPayLogListPage",
"/Site/GetFormerOrderSettleList",
"/GoodsStockManage/QueryFormerGoodsOutboundReceipt",
],
)
def test_explicit_former_endpoint_not_rerouted(endpoint):
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
calls = plan_calls(endpoint, params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == [endpoint]

View File

@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
"""filter_verify_tables 单元测试"""
import pytest
from tasks.verification.models import filter_verify_tables
class TestFilterVerifyTables:
"""按层过滤校验表名"""
def test_none_input_returns_none(self):
assert filter_verify_tables("DWD", None) is None
def test_empty_list_returns_none(self):
assert filter_verify_tables("DWD", []) is None
def test_dwd_layer_filters_correctly(self):
tables = ["dwd_order", "dim_member", "fact_payment", "ods_raw", "dws_daily"]
result = filter_verify_tables("DWD", tables)
assert result == ["dwd_order", "dim_member", "fact_payment"]
def test_dws_layer_filters_correctly(self):
tables = ["dws_daily", "dwd_order", "dws_summary"]
result = filter_verify_tables("DWS", tables)
assert result == ["dws_daily", "dws_summary"]
def test_index_layer_filters_correctly(self):
tables = ["v_score", "wbi_index", "dws_daily", "v_rank"]
result = filter_verify_tables("INDEX", tables)
assert result == ["v_score", "wbi_index", "v_rank"]
def test_ods_layer_filters_correctly(self):
tables = ["ods_order", "dwd_order", "ods_member"]
result = filter_verify_tables("ODS", tables)
assert result == ["ods_order", "ods_member"]
def test_unknown_layer_returns_normalized(self):
tables = [" SomeTable ", "Another"]
result = filter_verify_tables("UNKNOWN", tables)
assert result == ["sometable", "another"]
def test_layer_case_insensitive(self):
tables = ["dwd_order", "ods_raw"]
assert filter_verify_tables("dwd", tables) == ["dwd_order"]
assert filter_verify_tables("Dwd", tables) == ["dwd_order"]
def test_whitespace_and_empty_entries_stripped(self):
tables = [" dwd_order ", "", " ", None, "dim_member"]
result = filter_verify_tables("DWD", tables)
assert result == ["dwd_order", "dim_member"]

View File

@@ -0,0 +1,903 @@
"""审计一览表生成脚本 — 解析模块单元测试
覆盖AuditEntry、parse_audit_file、classify_module、scan_audit_dir
"""
import datetime
import os
import textwrap
from pathlib import Path
import pytest
from scripts.gen_audit_dashboard import (
AuditEntry,
MODULE_MAP,
VALID_MODULES,
classify_module,
parse_audit_file,
scan_audit_dir,
)
# ---------------------------------------------------------------------------
# classify_module
# ---------------------------------------------------------------------------
class TestClassifyModule:
"""classify_module 应将文件路径映射到正确的功能模块"""
@pytest.mark.parametrize(
"path, expected",
[
("api/recording_client.py", "API 层"),
("tasks/ods/ods_task.py", "ODS 层"),
("tasks/dwd/dwd_load_task.py", "DWD 层"),
("tasks/dws/base_dws_task.py", "DWS 层"),
("tasks/index/wbi.py", "指数算法"),
("loaders/fact_loader.py", "数据装载"),
("database/migrations/001.sql", "数据库"),
("orchestration/task_registry.py", "调度"),
("config/defaults.py", "配置"),
("cli/main.py", "CLI"),
("models/parser.py", "模型"),
("scd/scd2.py", "SCD2"),
("docs/README.md", "文档"),
("scripts/gen_audit_dashboard.py", "脚本工具"),
("tests/unit/test_foo.py", "测试"),
("quality/checker.py", "质量校验"),
("gui/main.py", "GUI"),
("utils/logging_utils.py", "工具库"),
],
)
def test_known_prefixes(self, path, expected):
assert classify_module(path) == expected
def test_unknown_path_returns_other(self):
assert classify_module("README.md") == "其他"
assert classify_module(".kiro/steering/foo.md") == "其他"
def test_normalizes_backslash(self):
"""Windows 反斜杠路径也能正确分类"""
assert classify_module("tasks\\dws\\base.py") == "DWS 层"
def test_strips_leading_dot_slash(self):
assert classify_module("./api/foo.py") == "API 层"
def test_result_always_in_valid_modules(self):
"""任何输入的返回值都应在 VALID_MODULES 内"""
for path in ["", "x", "api/", "unknown/deep/path.py"]:
assert classify_module(path) in VALID_MODULES
def test_longest_prefix_wins(self):
"""tasks/ods 应优先匹配 ODS 层,而非泛化的 tasks/ 前缀"""
# MODULE_MAP 中没有 "tasks/" 泛前缀,但 tasks/ods 应匹配 ODS 层
assert classify_module("tasks/ods/foo.py") == "ODS 层"
assert classify_module("tasks/dwd/bar.py") == "DWD 层"
# ---------------------------------------------------------------------------
# parse_audit_file — 使用临时文件
# ---------------------------------------------------------------------------
# 标准审计文件内容模板
_STANDARD_AUDIT = textwrap.dedent("""\
# 审计记录:测试变更
- 日期2026-03-01Asia/Shanghai
## 直接原因
测试用例
## 修改文件清单
| 文件 | 变更类型 | 说明 |
|------|----------|------|
| `api/client.py` | 修改 | 测试 |
| `tasks/dws/foo.py` | 新增 | 测试 |
## 风险与回滚
- 风险:极低。纯测试变更
""")
class TestParseAuditFile:
def test_standard_file(self, tmp_path):
f = tmp_path / "2026-03-01__test-change.md"
f.write_text(_STANDARD_AUDIT, encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.date == "2026-03-01"
assert entry.slug == "test-change"
assert entry.title == "审计记录:测试变更"
assert entry.filename == "2026-03-01__test-change.md"
assert "api/client.py" in entry.changed_files
assert "tasks/dws/foo.py" in entry.changed_files
assert "API 层" in entry.modules
assert "DWS 层" in entry.modules
assert entry.risk_level == "极低"
def test_missing_title_uses_slug(self, tmp_path):
"""缺少一级标题时用 slug 兜底"""
content = "没有标题的文件\n\n一些内容\n"
f = tmp_path / "2026-01-01__no-title.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.title == "no-title"
def test_missing_file_list_section(self, tmp_path):
"""缺少文件清单章节 → 空列表,模块为 {"其他"}"""
content = "# 标题\n\n没有文件清单\n"
f = tmp_path / "2026-01-01__no-files.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.changed_files == []
assert entry.modules == {"其他"}
def test_invalid_filename_returns_none(self, tmp_path):
"""文件名不符合格式 → 返回 None"""
f = tmp_path / "invalid-name.md"
f.write_text("# Title\n", encoding="utf-8")
assert parse_audit_file(f) is None
def test_non_md_gitkeep_returns_none(self, tmp_path):
f = tmp_path / ".gitkeep"
f.write_text("", encoding="utf-8")
assert parse_audit_file(f) is None
def test_list_format_file_section(self, tmp_path):
"""列表格式的文件清单也能正确解析"""
content = textwrap.dedent("""\
# 测试
## 文件清单Files changed
- docs/api-reference/summary/foo.md
- scripts/gen_api_docs.py
- tasks/base_task.py补 AI_CHANGELOG
""")
f = tmp_path / "2026-02-01__list-format.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert len(entry.changed_files) == 3
def test_risk_from_metadata_header(self, tmp_path):
"""从头部元数据行提取风险等级"""
content = textwrap.dedent("""\
# 测试
- 日期2026-01-01
- 风险等级:低(纯文档重组)
## 直接原因
测试
""")
f = tmp_path / "2026-01-01__meta-risk.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry.risk_level == ""
def test_change_type_bugfix(self, tmp_path):
content = "# bugfix 修复\n\n修复了一个 bug\n"
f = tmp_path / "2026-01-01__fix.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry.change_type == "bugfix"
def test_change_type_doc(self, tmp_path):
content = "# 纯文档变更\n\n无逻辑改动\n"
f = tmp_path / "2026-01-01__doc.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry.change_type == "文档"
def test_arrow_path_extraction(self, tmp_path):
"""含 → 的移动行应提取源和目标路径"""
content = textwrap.dedent("""\
# 测试
## 变更摘要
### 文件移动
- `docs/index/algo.md` → `docs/database/DWS/algo.md`
""")
f = tmp_path / "2026-01-01__arrow.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert "docs/index/algo.md" in entry.changed_files
assert "docs/database/DWS/algo.md" in entry.changed_files
# ---------------------------------------------------------------------------
# scan_audit_dir
# ---------------------------------------------------------------------------
class TestScanAuditDir:
def test_empty_dir(self, tmp_path):
assert scan_audit_dir(tmp_path) == []
def test_nonexistent_dir(self):
assert scan_audit_dir("nonexistent_dir_xyz") == []
def test_sorts_by_date_descending(self, tmp_path):
for name in [
"2026-01-01__first.md",
"2026-03-01__third.md",
"2026-02-01__second.md",
]:
(tmp_path / name).write_text("# Title\n", encoding="utf-8")
entries = scan_audit_dir(tmp_path)
dates = [e.date for e in entries]
assert dates == ["2026-03-01", "2026-02-01", "2026-01-01"]
def test_skips_non_md_files(self, tmp_path):
(tmp_path / "2026-01-01__valid.md").write_text("# OK\n", encoding="utf-8")
(tmp_path / ".gitkeep").write_text("", encoding="utf-8")
(tmp_path / "notes.txt").write_text("text", encoding="utf-8")
entries = scan_audit_dir(tmp_path)
assert len(entries) == 1
def test_skips_invalid_filenames(self, tmp_path):
(tmp_path / "2026-01-01__valid.md").write_text("# OK\n", encoding="utf-8")
(tmp_path / "bad-name.md").write_text("# Bad\n", encoding="utf-8")
entries = scan_audit_dir(tmp_path)
assert len(entries) == 1
assert entries[0].slug == "valid"
# ---------------------------------------------------------------------------
# 真实审计文件集成测试(仅在项目目录中运行时有效)
# ---------------------------------------------------------------------------
_REAL_AUDIT_DIR = Path("docs/audit/changes")
@pytest.mark.skipif(
not _REAL_AUDIT_DIR.is_dir(),
reason="真实审计目录不存在(非项目根目录运行)",
)
class TestRealAuditFiles:
def test_parses_all_real_files(self):
entries = scan_audit_dir(_REAL_AUDIT_DIR)
assert len(entries) > 0, "应至少解析出一条审计记录"
def test_all_modules_valid(self):
entries = scan_audit_dir(_REAL_AUDIT_DIR)
for e in entries:
for m in e.modules:
assert m in VALID_MODULES, (
f"模块 {m!r} 不在 VALID_MODULES 中 (文件: {e.filename})"
)
def test_dates_descending(self):
entries = scan_audit_dir(_REAL_AUDIT_DIR)
dates = [e.date for e in entries]
assert dates == sorted(dates, reverse=True)
# ---------------------------------------------------------------------------
# 属性测试 — Property 1: 审计记录解析-渲染完整性
# Feature: docs-optimization, Property 1: 审计记录解析-渲染完整性
# Validates: Requirements 2.1, 2.2
# ---------------------------------------------------------------------------
from hypothesis import given, settings, assume
from hypothesis import strategies as st
# --- 生成策略 ---
# 合法日期策略YYYY-MM-DD 格式
_date_st = st.dates(
min_value=__import__("datetime").date(2020, 1, 1),
max_value=__import__("datetime").date(2030, 12, 31),
).map(lambda d: d.isoformat())
# slug 策略:小写字母+数字+连字符,长度 2~30
_slug_st = st.from_regex(r"[a-z][a-z0-9\-]{1,29}", fullmatch=True)
# 标题策略:非空中文/英文混合文本
_title_st = st.text(
alphabet=st.sampled_from(
list("审计记录变更修复新增重构清理测试文档优化迁移合并")
+ list("abcdefghijklmnopqrstuvwxyz ")
),
min_size=2,
max_size=40,
).map(lambda s: s.strip() or "默认标题")
# 文件路径策略:从已知前缀中选取,确保模块分类有意义
_KNOWN_PREFIXES = [
"api/", "tasks/ods/", "tasks/dwd/", "tasks/dws/", "tasks/index/",
"loaders/", "database/migrations/", "orchestration/", "config/",
"cli/", "models/", "scd/", "docs/", "scripts/", "tests/unit/",
"quality/", "gui/", "utils/",
]
_file_path_st = st.tuples(
st.sampled_from(_KNOWN_PREFIXES),
st.from_regex(r"[a-z_]{1,15}\.(py|sql|md)", fullmatch=True),
).map(lambda t: t[0] + t[1])
# 风险等级策略
_risk_st = st.sampled_from(["极低", "", "", ""])
# 变更类型关键词策略(用于在内容中嵌入,让 _infer_change_type 能推断)
_change_kw_st = st.sampled_from(["bugfix", "修复", "重构", "清理", "纯文档", "功能新增"])
def _build_audit_md(title: str, date: str, files: list[str], risk: str, change_kw: str) -> str:
"""根据参数构造一份格式合规的审计 Markdown 内容。"""
lines = [
f"# {title}",
"",
f"- 日期:{date}Asia/Shanghai",
f"- 风险等级:{risk}",
"",
"## 直接原因",
f"本次变更为 {change_kw} 类型操作",
"",
"## 修改文件清单",
"",
"| 文件 | 变更类型 | 说明 |",
"|------|----------|------|",
]
for fp in files:
lines.append(f"| `{fp}` | 修改 | 自动生成 |")
lines.append("")
lines.append("## 风险与回滚")
lines.append(f"- 风险:{risk}。自动生成的测试内容")
return "\n".join(lines)
class TestProperty1AuditParseCompleteness:
"""Property 1: 审计记录解析-渲染完整性
对于任何格式合规的审计 Markdown 文件parse_audit_file 解析后
应返回包含所有必要字段date、title、change_type、modules、filename
且各字段非空的 AuditEntry。
**Validates: Requirements 2.1, 2.2**
"""
@given(
date=_date_st,
slug=_slug_st,
title=_title_st,
files=st.lists(_file_path_st, min_size=1, max_size=8),
risk=_risk_st,
change_kw=_change_kw_st,
)
@settings(max_examples=150)
def test_parsed_entry_has_all_required_fields(
self, tmp_path_factory, date, slug, title, files, risk, change_kw
):
"""解析格式合规的审计文件后AuditEntry 的所有必要字段均非空。"""
# 构造临时文件
tmp_dir = tmp_path_factory.mktemp("audit")
filename = f"{date}__{slug}.md"
md_content = _build_audit_md(title, date, files, risk, change_kw)
filepath = tmp_dir / filename
filepath.write_text(md_content, encoding="utf-8")
entry = parse_audit_file(filepath)
# 核心断言:解析成功且所有必要字段非空
assert entry is not None, f"格式合规的文件应能成功解析:{filename}"
assert entry.date, "date 字段不应为空"
assert entry.title, "title 字段不应为空"
assert entry.filename, "filename 字段不应为空"
assert entry.change_type, "change_type 字段不应为空"
assert len(entry.modules) > 0, "modules 集合不应为空"
# 日期应与文件名中的日期一致
assert entry.date == date
# filename 应与实际文件名一致
assert entry.filename == filename
# modules 中的每个值都应在 VALID_MODULES 内
for mod in entry.modules:
assert mod in VALID_MODULES, f"模块 {mod!r} 不在 VALID_MODULES 中"
@given(
date=_date_st,
slug=_slug_st,
title=_title_st,
files=st.lists(_file_path_st, min_size=1, max_size=5),
risk=_risk_st,
change_kw=_change_kw_st,
)
@settings(max_examples=150)
def test_parsed_files_match_input(
self, tmp_path_factory, date, slug, title, files, risk, change_kw
):
"""解析后的 changed_files 应包含输入的所有文件路径。"""
tmp_dir = tmp_path_factory.mktemp("audit")
filename = f"{date}__{slug}.md"
md_content = _build_audit_md(title, date, files, risk, change_kw)
filepath = tmp_dir / filename
filepath.write_text(md_content, encoding="utf-8")
entry = parse_audit_file(filepath)
assert entry is not None
# 每个输入文件路径都应出现在解析结果中
for fp in files:
assert fp in entry.changed_files, (
f"文件 {fp!r} 应出现在 changed_files 中,"
f"实际结果:{entry.changed_files}"
)
# ---------------------------------------------------------------------------
# 属性测试 — Property 2: 文件路径模块分类正确性
# Feature: docs-optimization, Property 2: 文件路径模块分类正确性
# Validates: Requirements 2.3
# ---------------------------------------------------------------------------
# --- 文件路径生成策略 ---
# 已知前缀路径:确保覆盖 MODULE_MAP 中所有前缀
_known_prefix_path_st = st.tuples(
st.sampled_from(list(MODULE_MAP.keys())),
st.from_regex(r"[a-z_]{1,20}\.(py|sql|md|txt|json)", fullmatch=True),
).map(lambda t: t[0] + t[1])
# 完全随机路径:任意 Unicode 字符串(含空串、特殊字符等)
_random_path_st = st.text(min_size=0, max_size=200)
# 带反斜杠的 Windows 风格路径
_backslash_path_st = st.tuples(
st.sampled_from(list(MODULE_MAP.keys())),
st.from_regex(r"[a-z_]{1,15}\.(py|md)", fullmatch=True),
).map(lambda t: t[0].replace("/", "\\") + t[1])
# 带前导 ./ 的相对路径
_dotslash_path_st = st.tuples(
st.sampled_from(list(MODULE_MAP.keys())),
st.from_regex(r"[a-z_]{1,15}\.(py|md)", fullmatch=True),
).map(lambda t: "./" + t[0] + t[1])
# 深层嵌套路径
_deep_nested_st = st.tuples(
st.sampled_from(list(MODULE_MAP.keys())),
st.lists(
st.from_regex(r"[a-z]{1,8}", fullmatch=True),
min_size=1, max_size=5,
),
st.from_regex(r"[a-z_]{1,10}\.(py|md|sql)", fullmatch=True),
).map(lambda t: t[0] + "/".join(t[1]) + "/" + t[2])
# 未知前缀路径(不以任何已知前缀开头)
_unknown_prefix_st = st.tuples(
st.sampled_from(["README.md", ".kiro/foo.md", "setup.py", "Makefile",
"pyproject.toml", ".env", "unknown/deep/path.py",
"random_dir/file.txt", ".github/workflows/ci.yml"]),
).map(lambda t: t[0])
# 混合策略:从以上所有策略中随机选取
_any_filepath_st = st.one_of(
_known_prefix_path_st,
_random_path_st,
_backslash_path_st,
_dotslash_path_st,
_deep_nested_st,
_unknown_prefix_st,
)
class TestProperty2ModuleClassification:
"""Property 2: 文件路径模块分类正确性
对于任意文件路径字符串classify_module 的返回值
应始终属于预定义的 VALID_MODULES 集合。
**Validates: Requirements 2.3**
"""
@given(filepath=_any_filepath_st)
@settings(max_examples=200)
def test_classify_always_returns_valid_module(self, filepath: str):
"""任意文件路径的分类结果都在 VALID_MODULES 内。"""
result = classify_module(filepath)
assert result in VALID_MODULES, (
f"classify_module({filepath!r}) 返回 {result!r}"
f"不在 VALID_MODULES 中"
)
@given(filepath=_known_prefix_path_st)
@settings(max_examples=150)
def test_known_prefix_never_returns_other(self, filepath: str):
"""以已知前缀开头的路径不应返回 '其他'"""
result = classify_module(filepath)
assert result in VALID_MODULES, (
f"classify_module({filepath!r}) 返回 {result!r}"
f"不在 VALID_MODULES 中"
)
assert result != "其他", (
f"已知前缀路径 {filepath!r} 不应分类为 '其他'"
f"实际返回 {result!r}"
)
@given(filepath=_unknown_prefix_st)
@settings(max_examples=50)
def test_unknown_prefix_returns_other(self, filepath: str):
"""不匹配任何已知前缀的路径应返回 '其他'"""
result = classify_module(filepath)
assert result == "其他", (
f"未知前缀路径 {filepath!r} 应分类为 '其他'"
f"实际返回 {result!r}"
)
@given(filepath=_backslash_path_st)
@settings(max_examples=100)
def test_backslash_paths_classified_correctly(self, filepath: str):
"""Windows 反斜杠路径应正确归类(非 '其他')。"""
result = classify_module(filepath)
assert result in VALID_MODULES
assert result != "其他", (
f"反斜杠路径 {filepath!r} 应正确分类,"
f"实际返回 '其他'"
)
@given(filepath=_dotslash_path_st)
@settings(max_examples=100)
def test_dotslash_paths_classified_correctly(self, filepath: str):
"""前导 ./ 的路径应正确归类(非 '其他')。"""
result = classify_module(filepath)
assert result in VALID_MODULES
assert result != "其他", (
f"前导 ./ 路径 {filepath!r} 应正确分类,"
f"实际返回 '其他'"
)
# ---------------------------------------------------------------------------
# 渲染函数测试
# ---------------------------------------------------------------------------
from scripts.gen_audit_dashboard import (
render_timeline_table,
render_module_index,
render_dashboard,
)
def _make_entry(**overrides) -> AuditEntry:
"""构造测试用 AuditEntry支持字段覆盖。"""
defaults = dict(
date="2026-01-15",
slug="test-change",
title="测试变更",
filename="2026-01-15__test-change.md",
changed_files=["api/client.py"],
modules={"API 层"},
risk_level="",
change_type="功能",
)
defaults.update(overrides)
return AuditEntry(**defaults)
class TestRenderTimelineTable:
"""render_timeline_table 单元测试"""
def test_empty_entries(self):
"""空列表返回暂无审计记录提示"""
result = render_timeline_table([])
assert "暂无审计记录" in result
def test_single_entry(self):
"""单条记录生成正确的表格行"""
entry = _make_entry()
result = render_timeline_table([entry])
assert "| 日期 |" in result
assert "2026-01-15" in result
assert "测试变更" in result
assert "API 层" in result
assert "" in result
assert "[链接](changes/2026-01-15__test-change.md)" in result
def test_multiple_entries_order_preserved(self):
"""多条记录保持输入顺序(调用方负责排序)"""
e1 = _make_entry(date="2026-02-01", title="")
e2 = _make_entry(date="2026-01-01", title="")
result = render_timeline_table([e1, e2])
pos_e1 = result.index("2026-02-01")
pos_e2 = result.index("2026-01-01")
assert pos_e1 < pos_e2
def test_multiple_modules_joined(self):
"""多个模块用逗号分隔并排序"""
entry = _make_entry(modules={"文档", "API 层"})
result = render_timeline_table([entry])
assert "API 层, 文档" in result
def test_table_header_present(self):
"""表格包含表头和分隔行"""
result = render_timeline_table([_make_entry()])
assert "|------|" in result
assert "需求摘要" in result
assert "变更类型" in result
class TestRenderModuleIndex:
"""render_module_index 单元测试"""
def test_empty_entries(self):
"""空列表返回暂无审计记录提示"""
result = render_module_index([])
assert "暂无审计记录" in result
def test_single_module(self):
"""单模块生成一个三级标题章节"""
entry = _make_entry(modules={"API 层"})
result = render_module_index([entry])
assert "### API 层" in result
assert "2026-01-15" in result
# 模块索引表格不含"影响模块"列
lines = result.strip().splitlines()
header = [l for l in lines if l.startswith("| 日期")]
assert header
assert "影响模块" not in header[0]
def test_multiple_modules_sorted(self):
"""多模块按字母序排列"""
e1 = _make_entry(modules={"文档"})
e2 = _make_entry(modules={"API 层"}, date="2026-02-01")
result = render_module_index([e1, e2])
pos_api = result.index("### API 层")
pos_doc = result.index("### 文档")
assert pos_api < pos_doc
def test_entry_appears_in_multiple_modules(self):
"""一条记录影响多个模块时,在每个模块章节中都出现"""
entry = _make_entry(modules={"API 层", "文档"})
result = render_module_index([entry])
assert "### API 层" in result
assert "### 文档" in result
# 两个章节都包含该记录的链接
assert result.count("[链接](changes/2026-01-15__test-change.md)") == 2
def test_link_format(self):
"""详情列链接格式正确"""
entry = _make_entry(filename="2026-03-01__my-slug.md")
result = render_module_index([entry])
assert "[链接](changes/2026-03-01__my-slug.md)" in result
class TestRenderDashboard:
"""render_dashboard 单元测试"""
def test_empty_entries(self):
"""空列表生成包含提示的完整文档"""
result = render_dashboard([])
assert "# 审计一览表" in result
assert "自动生成于" in result
assert "暂无审计记录" in result
def test_contains_both_views(self):
"""完整文档包含时间线和模块索引两个章节"""
entry = _make_entry()
result = render_dashboard([entry])
assert "## 时间线视图" in result
assert "## 模块索引" in result
def test_contains_header_and_timestamp(self):
"""文档包含标题和生成时间戳"""
result = render_dashboard([_make_entry()])
assert "# 审计一览表" in result
assert "自动生成于" in result
assert "请勿手动编辑" in result
def test_timeline_before_module_index(self):
"""时间线视图在模块索引之前"""
result = render_dashboard([_make_entry()])
pos_timeline = result.index("## 时间线视图")
pos_module = result.index("## 模块索引")
assert pos_timeline < pos_module
# ---------------------------------------------------------------------------
# Property 3: 审计条目时间倒序排列
# Feature: docs-optimization, Property 3: 审计条目时间倒序排列
# ---------------------------------------------------------------------------
class TestProperty3AuditEntriesDescendingOrder:
"""属性测试:审计条目经排序后日期严格非递增。
**Validates: Requirements 2.4**
"""
@given(
dates=st.lists(
st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
),
min_size=0,
max_size=30,
)
)
@settings(max_examples=150)
def test_sorted_entries_dates_non_increasing(self, dates):
"""任意日期列表构造的 AuditEntry经 scan_audit_dir 同款排序后,
日期序列应为非递增(即每个条目的日期 >= 其后续条目的日期)。"""
# 构造 AuditEntry 列表,日期格式与 scan_audit_dir 一致YYYY-MM-DD 字符串)
entries = [
AuditEntry(
date=d.isoformat(),
slug=f"entry-{i}",
title=f"条目 {i}",
filename=f"{d.isoformat()}__entry-{i}.md",
)
for i, d in enumerate(dates)
]
# 使用与 scan_audit_dir 完全相同的排序逻辑
entries.sort(key=lambda e: e.date, reverse=True)
# 验证非递增序
for i in range(len(entries) - 1):
assert entries[i].date >= entries[i + 1].date, (
f"位置 {i} 的日期 {entries[i].date} "
f"不应小于位置 {i+1} 的日期 {entries[i+1].date}"
)
@given(
dates=st.lists(
st.dates(
min_value=datetime.date(2020, 1, 1),
max_value=datetime.date(2030, 12, 31),
),
min_size=2,
max_size=30,
)
)
@settings(max_examples=150)
def test_first_entry_has_latest_date(self, dates):
"""排序后第一个条目的日期应等于输入中的最大日期。"""
entries = [
AuditEntry(
date=d.isoformat(),
slug=f"entry-{i}",
title=f"条目 {i}",
filename=f"{d.isoformat()}__entry-{i}.md",
)
for i, d in enumerate(dates)
]
entries.sort(key=lambda e: e.date, reverse=True)
expected_max = max(d.isoformat() for d in dates)
assert entries[0].date == expected_max
# ---------------------------------------------------------------------------
# 补充边界情况测试 — Task 4.6
# 覆盖:空内容、缺少风险章节、变更类型推断分支、同日期排序
# Requirements: 2.1, 2.3
# ---------------------------------------------------------------------------
class TestParseAuditFileEdgeCases:
"""parse_audit_file 的补充边界情况"""
def test_empty_content(self, tmp_path):
"""完全空文件 → 标题用 slug 兜底,文件清单为空,模块为 {"其他"}"""
f = tmp_path / "2026-01-01__empty.md"
f.write_text("", encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.title == "empty"
assert entry.changed_files == []
assert entry.modules == {"其他"}
def test_whitespace_only_content(self, tmp_path):
"""仅含空白字符 → 同空文件处理"""
f = tmp_path / "2026-01-01__blank.md"
f.write_text(" \n\n \n", encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.title == "blank"
assert entry.changed_files == []
def test_missing_risk_section_returns_unknown(self, tmp_path):
"""无任何风险相关内容 → 风险等级为"未知" """
content = "# 简单标题\n\n一些普通内容,没有提到任何等级信息。\n"
f = tmp_path / "2026-01-01__no-risk.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.risk_level == "未知"
def test_change_type_refactor(self, tmp_path):
""""重构"关键词 → 变更类型为"重构" """
content = "# 代码重构\n\n对模块进行了重构优化\n"
f = tmp_path / "2026-01-01__refactor.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry.change_type == "重构"
def test_change_type_cleanup(self, tmp_path):
""""清理"关键词 → 变更类型为"清理" """
content = "# 遗留代码清理\n\n清理了废弃文件\n"
f = tmp_path / "2026-01-01__cleanup.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry.change_type == "清理"
def test_change_type_default_function(self, tmp_path):
"""无任何变更类型关键词 → 默认为"功能" """
content = "# 新增能力\n\n增加了一个全新的处理流程\n"
f = tmp_path / "2026-01-01__feature.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry.change_type == "功能"
def test_file_section_with_empty_table(self, tmp_path):
"""文件清单章节存在但表格无数据行 → 空列表"""
content = textwrap.dedent("""\
# 测试
## 修改文件清单
| 文件 | 变更类型 | 说明 |
|------|----------|------|
## 风险与回滚
""")
f = tmp_path / "2026-01-01__empty-table.md"
f.write_text(content, encoding="utf-8")
entry = parse_audit_file(f)
assert entry is not None
assert entry.changed_files == []
assert entry.modules == {"其他"}
class TestScanAuditDirEdgeCases:
"""scan_audit_dir 的补充边界情况"""
def test_dir_with_only_invalid_files(self, tmp_path):
"""目录中全是无效文件 → 返回空列表"""
(tmp_path / "README.md").write_text("# 说明\n", encoding="utf-8")
(tmp_path / ".gitkeep").write_text("", encoding="utf-8")
(tmp_path / "notes.txt").write_text("备注", encoding="utf-8")
entries = scan_audit_dir(tmp_path)
assert entries == []
def test_same_date_multiple_files(self, tmp_path):
"""同日期多个文件 → 全部解析,日期相同"""
for slug in ["alpha", "beta", "gamma"]:
f = tmp_path / f"2026-03-01__{slug}.md"
f.write_text(f"# {slug} 变更\n", encoding="utf-8")
entries = scan_audit_dir(tmp_path)
assert len(entries) == 3
assert all(e.date == "2026-03-01" for e in entries)

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""Unit tests for the new ODS ingestion tasks."""
import logging
import os
import sys
from pathlib import Path
# 确保在独立运行测试时能正确解析项目根目录
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
from tasks.ods.ods_tasks import ODS_TASK_CLASSES
from .task_test_utils import create_test_config, get_db_operations, FakeAPIClient
def _build_config(tmp_path):
archive_dir = tmp_path / "archive"
temp_dir = tmp_path / "temp"
return create_test_config("ONLINE", archive_dir, temp_dir)
def test_assistant_accounts_masters_ingest(tmp_path):
"""Ensure ODS_ASSISTANT_ACCOUNT stores raw payload with record_index dedup keys."""
config = _build_config(tmp_path)
sample = [
{
"id": 5001,
"assistant_no": "A01",
"nickname": "小张",
}
]
api = FakeAPIClient({"/PersonnelManagement/SearchAssistantInfo": sample})
task_cls = ODS_TASK_CLASSES["ODS_ASSISTANT_ACCOUNT"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_assistant_accounts_masters"))
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == 1
assert db_ops.commits == 1
row = db_ops.upserts[0]["rows"][0]
assert row["id"] == 5001
assert row["record_index"] == 0
assert row["source_file"] is None or row["source_file"] == ""
assert '"id": 5001' in row["payload"]
def test_goods_stock_movements_ingest(tmp_path):
"""Ensure ODS_INVENTORY_CHANGE stores raw payload with record_index dedup keys."""
config = _build_config(tmp_path)
sample = [
{
"siteGoodsStockId": 123456,
"stockType": 1,
"goodsName": "测试商品",
}
]
api = FakeAPIClient({"/GoodsStockManage/QueryGoodsOutboundReceipt": sample})
task_cls = ODS_TASK_CLASSES["ODS_INVENTORY_CHANGE"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_goods_stock_movements"))
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == 1
assert db_ops.commits == 1
row = db_ops.upserts[0]["rows"][0]
assert row["sitegoodsstockid"] == 123456
assert row["record_index"] == 0
assert '"siteGoodsStockId": 123456' in row["payload"]
def test_member_profiless_ingest(tmp_path):
"""Ensure ODS_MEMBER task stores tenantMemberInfos raw JSON."""
config = _build_config(tmp_path)
sample = [{"tenantMemberInfos": [{"id": 101, "mobile": "13800000000"}]}]
api = FakeAPIClient({"/MemberProfile/GetTenantMemberList": sample})
task_cls = ODS_TASK_CLASSES["ODS_MEMBER"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_member"))
result = task.execute()
assert result["status"] == "SUCCESS"
row = db_ops.upserts[0]["rows"][0]
assert row["record_index"] == 0
assert '"id": 101' in row["payload"]
def test_ods_payment_ingest(tmp_path):
"""Ensure ODS_PAYMENT task stores payment_transactions raw JSON."""
config = _build_config(tmp_path)
sample = [{"payId": 901, "payAmount": "100.00"}]
api = FakeAPIClient({"/PayLog/GetPayLogListPage": sample})
task_cls = ODS_TASK_CLASSES["ODS_PAYMENT"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_payment"))
result = task.execute()
assert result["status"] == "SUCCESS"
row = db_ops.upserts[0]["rows"][0]
assert row["record_index"] == 0
assert '"payId": 901' in row["payload"]
def test_ods_settlement_records_ingest(tmp_path):
"""Ensure ODS_SETTLEMENT_RECORDS stores settleList raw JSON."""
config = _build_config(tmp_path)
sample = [{"id": 701, "orderTradeNo": 8001}]
api = FakeAPIClient({"/Site/GetAllOrderSettleList": sample})
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_RECORDS"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_settlement_records"))
result = task.execute()
assert result["status"] == "SUCCESS"
row = db_ops.upserts[0]["rows"][0]
assert row["record_index"] == 0
assert '"orderTradeNo": 8001' in row["payload"]
def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
"""Ensure settlement tickets are fetched per payment relate_id and skip existing ones."""
config = _build_config(tmp_path)
ticket_payload = {"data": {"data": {"orderSettleId": 9001, "orderSettleNumber": "T001"}}}
api = FakeAPIClient({"/Order/GetOrderSettleTicketNew": [ticket_payload]})
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"]
with get_db_operations() as db_ops:
# 第一次查询已有的小票ID第二次查询支付关联ID
db_ops.query_results = [
[{"order_settle_id": 9002}],
[
{"order_settle_id": 9001},
{"order_settle_id": 9002},
{"order_settle_id": None},
],
]
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_settlement_ticket"))
result = task.execute()
assert result["status"] == "SUCCESS"
counts = result["counts"]
assert counts["fetched"] == 1
assert counts["inserted"] == 1
assert counts["updated"] == 0
assert counts["skipped"] == 0
assert '"orderSettleId": 9001' in db_ops.upserts[0]["rows"][0]["payload"]
assert any(
call["endpoint"] == "/Order/GetOrderSettleTicketNew"
and call.get("params", {}).get("orderSettleId") == 9001
for call in api.calls
)

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""解析器测试"""
import pytest
from decimal import Decimal
from datetime import datetime
from zoneinfo import ZoneInfo
from models.parsers import TypeParser
def test_parse_decimal():
"""测试金额解析"""
assert TypeParser.parse_decimal("100.555", 2) == Decimal("100.56")
assert TypeParser.parse_decimal(None) is None
assert TypeParser.parse_decimal("invalid") is None
def test_parse_int():
"""测试整数解析"""
assert TypeParser.parse_int("123") == 123
assert TypeParser.parse_int(456) == 456
assert TypeParser.parse_int(None) is None
assert TypeParser.parse_int("abc") is None
def test_parse_timestamp():
"""测试时间戳解析"""
tz = ZoneInfo("Asia/Shanghai")
dt = TypeParser.parse_timestamp("2025-01-15 10:30:00", tz)
assert dt is not None
assert dt.year == 2025
assert dt.month == 1
assert dt.day == 15
def test_parse_timestamp_zero_epoch():
"""0 不应被当成空值;应解析为 Unix epoch。"""
tz = ZoneInfo("Asia/Shanghai")
dt = TypeParser.parse_timestamp(0, tz)
assert dt is not None
assert dt.year == 1970
assert dt.month == 1
assert dt.day == 1

View File

@@ -0,0 +1,304 @@
# -*- coding: utf-8 -*-
"""PipelineRunner 属性测试 - hypothesis 验证管道编排器的通用正确性属性。"""
import string
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from orchestration.pipeline_runner import PipelineRunner
# run() 内部延迟导入 TaskLogger需要 mock 源模块路径
_TASK_LOGGER_PATH = "utils.task_logger.TaskLogger"
FILE_VERSION = "v1_shell"
# ── 策略定义 ──────────────────────────────────────────────────────
pipeline_name_st = st.sampled_from(list(PipelineRunner.PIPELINE_LAYERS.keys()))
processing_mode_st = st.sampled_from(["increment_only", "verify_only", "increment_verify"])
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
_TASK_PREFIXES = ["ODS_", "DWD_", "DWS_", "INDEX_"]
task_code_st = st.builds(
lambda prefix, suffix: prefix + suffix,
prefix=st.sampled_from(_TASK_PREFIXES),
suffix=st.text(
alphabet=string.ascii_uppercase + string.digits + "_",
min_size=1, max_size=12,
),
)
# 单任务结果生成器
task_result_st = st.fixed_dictionaries({
"task_code": task_code_st,
"status": st.sampled_from(["SUCCESS", "FAIL", "SKIP"]),
"counts": st.fixed_dictionaries({
"fetched": st.integers(min_value=0, max_value=10000),
"inserted": st.integers(min_value=0, max_value=10000),
"updated": st.integers(min_value=0, max_value=10000),
"skipped": st.integers(min_value=0, max_value=10000),
"errors": st.integers(min_value=0, max_value=100),
}),
"dump_dir": st.none(),
})
task_results_st = st.lists(task_result_st, min_size=0, max_size=10)
# ── 辅助函数 ──────────────────────────────────────────────────────
def _make_config():
"""创建 mock 配置对象。"""
config = MagicMock()
config.get = MagicMock(side_effect=lambda key, default=None: {
"app.timezone": "Asia/Shanghai",
"verification.ods_use_local_json": False,
"verification.skip_ods_when_fetch_before_verify": True,
"run.ods_tasks": [],
"run.dws_tasks": [],
"run.index_tasks": [],
}.get(key, default))
return config
def _make_runner(task_executor=None, task_registry=None):
"""创建 PipelineRunner 实例,注入 mock 依赖。"""
if task_executor is None:
task_executor = MagicMock()
task_executor.run_tasks.return_value = []
if task_registry is None:
task_registry = MagicMock()
task_registry.get_tasks_by_layer.return_value = ["FAKE_TASK"]
return PipelineRunner(
config=_make_config(),
task_executor=task_executor,
task_registry=task_registry,
db_conn=MagicMock(),
api_client=MagicMock(),
logger=MagicMock(),
)
# ── Property 5: 管道名称→层列表映射 ──────────────────────────────
# Feature: scheduler-refactor, Property 5: 管道名称→层列表映射
# **Validates: Requirements 2.1**
class TestProperty5PipelineNameToLayers:
"""对于任意有效的管道名称PipelineRunner 解析出的层列表应与
PIPELINE_LAYERS 字典中的定义完全一致。"""
@given(pipeline=pipeline_name_st)
@settings(max_examples=100)
def test_layers_match_pipeline_definition(self, pipeline):
"""run() 返回的 layers 字段与 PIPELINE_LAYERS[pipeline] 完全一致。"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
expected_layers = PipelineRunner.PIPELINE_LAYERS[pipeline]
assert result["layers"] == expected_layers
@given(pipeline=pipeline_name_st)
@settings(max_examples=100)
def test_resolve_tasks_called_with_correct_layers(self, pipeline):
"""_resolve_tasks 接收的层列表与 PIPELINE_LAYERS 定义一致。"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with (
patch(_TASK_LOGGER_PATH),
patch.object(runner, "_resolve_tasks", wraps=runner._resolve_tasks) as spy,
):
runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
expected_layers = PipelineRunner.PIPELINE_LAYERS[pipeline]
spy.assert_called_once_with(expected_layers)
# ── Property 6: processing_mode 控制执行流程 ─────────────────────
# Feature: scheduler-refactor, Property 6: processing_mode 控制执行流程
# **Validates: Requirements 2.3, 2.4**
class TestProperty6ProcessingModeControlsFlow:
"""对于任意 processing_mode增量 ETL 执行当且仅当模式包含 increment
校验流程执行当且仅当模式包含 verify。"""
@given(
pipeline=pipeline_name_st,
mode=processing_mode_st,
data_source=data_source_st,
)
@settings(max_examples=100)
def test_increment_executes_iff_mode_contains_increment(self, pipeline, mode, data_source):
"""增量 ETLtask_executor.run_tasks执行当且仅当 mode 包含 'increment'"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with (
patch(_TASK_LOGGER_PATH),
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}),
):
runner.run(
pipeline=pipeline,
processing_mode=mode,
data_source=data_source,
)
should_increment = "increment" in mode
if should_increment:
assert executor.run_tasks.called, (
f"mode={mode} 包含 'increment',但 run_tasks 未被调用"
)
else:
# verify_only 且 fetch_before_verify=False默认run_tasks 不应被调用
assert not executor.run_tasks.called, (
f"mode={mode} 不包含 'increment',但 run_tasks 被调用了"
)
@given(
pipeline=pipeline_name_st,
mode=processing_mode_st,
data_source=data_source_st,
)
@settings(max_examples=100)
def test_verification_executes_iff_mode_contains_verify(self, pipeline, mode, data_source):
"""校验流程_run_verification执行当且仅当 mode 包含 'verify'"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with (
patch(_TASK_LOGGER_PATH),
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}) as mock_verify,
):
runner.run(
pipeline=pipeline,
processing_mode=mode,
data_source=data_source,
)
should_verify = "verify" in mode
if should_verify:
assert mock_verify.called, (
f"mode={mode} 包含 'verify',但 _run_verification 未被调用"
)
else:
assert not mock_verify.called, (
f"mode={mode} 不包含 'verify',但 _run_verification 被调用了"
)
# ── Property 7: 管道结果汇总完整性 ──────────────────────────────
# Feature: scheduler-refactor, Property 7: 管道结果汇总完整性
# **Validates: Requirements 2.6**
class TestProperty7PipelineSummaryCompleteness:
"""对于任意一组任务执行结果PipelineRunner 返回的汇总字典应包含
status/pipeline/layers/results 字段,且 results 长度等于实际执行的任务数。"""
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_summary_has_required_fields(self, pipeline, task_results):
"""返回字典必须包含 status、pipeline、layers、results、verification_summary。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
required_keys = {"status", "pipeline", "layers", "results", "verification_summary"}
assert required_keys.issubset(result.keys()), (
f"缺少必要字段: {required_keys - result.keys()}"
)
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_results_length_equals_executed_tasks(self, pipeline, task_results):
"""results 列表长度等于 task_executor.run_tasks 返回的任务数。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
assert len(result["results"]) == len(task_results), (
f"results 长度 {len(result['results'])} != 实际任务数 {len(task_results)}"
)
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_pipeline_and_layers_match_input(self, pipeline, task_results):
"""返回的 pipeline 和 layers 字段与输入一致。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
assert result["pipeline"] == pipeline
assert result["layers"] == PipelineRunner.PIPELINE_LAYERS[pipeline]
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_increment_only_has_no_verification(self, pipeline, task_results):
"""increment_only 模式下 verification_summary 应为 None。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
assert result["verification_summary"] is None

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""关系指数基础能力单测。"""
from __future__ import annotations
import logging
from datetime import date
from typing import Any, Dict, List, Optional
from tasks.dws.index.base_index_task import BaseIndexTask
from tasks.dws.index.ml_manual_import_task import MlManualImportTask
class _DummyConfig:
"""最小配置桩对象。"""
def __init__(self, values: Optional[Dict[str, Any]] = None):
self._values = values or {}
def get(self, key: str, default: Any = None) -> Any:
return self._values.get(key, default)
class _DummyDB:
"""最小数据库桩对象。"""
def __init__(self) -> None:
self.query_calls: List[tuple] = []
def query(self, sql: str, params=None):
self.query_calls.append((sql, params))
index_type = (params or [None])[0]
if index_type == "RS":
return [{"param_name": "lookback_days", "param_value": 60}]
if index_type == "MS":
return [{"param_name": "lookback_days", "param_value": 30}]
return []
class _DummyIndexTask(BaseIndexTask):
"""用于测试 BaseIndexTask 的最小实现。"""
INDEX_TYPE = "RS"
def get_task_code(self) -> str: # pragma: no cover - 测试桩
return "DUMMY_INDEX"
def get_target_table(self) -> str: # pragma: no cover - 测试桩
return "dummy_table"
def get_primary_keys(self) -> List[str]: # pragma: no cover - 测试桩
return ["id"]
def get_index_type(self) -> str:
return self.INDEX_TYPE
def extract(self, context): # pragma: no cover - 测试桩
return []
def load(self, transformed, context): # pragma: no cover - 测试桩
return {}
def test_load_index_parameters_cache_isolated_by_index_type():
"""参数缓存应按 index_type 隔离,避免单任务串参。"""
task = _DummyIndexTask(
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
_DummyDB(),
None,
logging.getLogger("test_index_cache"),
)
rs_first = task.load_index_parameters(index_type="RS")
ms_first = task.load_index_parameters(index_type="MS")
rs_second = task.load_index_parameters(index_type="RS")
assert rs_first["lookback_days"] == 60.0
assert ms_first["lookback_days"] == 30.0
assert rs_second["lookback_days"] == 60.0
# 只应查询两次RS 一次 + MS 一次,第二次 RS 命中缓存
assert len(task.db.query_calls) == 2
def test_batch_normalize_passes_index_type_to_smoothing_chain():
"""batch_normalize_to_display 应把 index_type 传入平滑链路。"""
task = _DummyIndexTask(
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
_DummyDB(),
None,
logging.getLogger("test_index_smoothing"),
)
captured: Dict[str, Any] = {}
def _fake_apply(site_id, current_p5, current_p95, alpha=None, index_type=None):
captured["index_type"] = index_type
return current_p5, current_p95
task._apply_ewma_smoothing = _fake_apply # type: ignore[method-assign]
result = task.batch_normalize_to_display(
raw_scores=[("a", 1.0), ("b", 2.0), ("c", 3.0)],
use_smoothing=True,
site_id=1,
index_type="ML",
)
assert result
assert captured["index_type"] == "ML"
def test_ml_manual_import_scope_day_and_p30_boundary():
"""30天边界内按天覆盖超过30天进入固定纪元30天桶。"""
today = date(2026, 2, 8)
day_scope = MlManualImportTask.resolve_scope(
site_id=1,
biz_date=date(2026, 1, 9), # 距 today 30 天
today=today,
)
assert day_scope.scope_type == "DAY"
assert day_scope.start_date == date(2026, 1, 9)
assert day_scope.end_date == date(2026, 1, 9)
p30_scope = MlManualImportTask.resolve_scope(
site_id=1,
biz_date=date(2026, 1, 8), # 距 today 31 天
today=today,
)
assert p30_scope.scope_type == "P30"
# 固定纪元 2026-01-01第一桶应为 2026-01-01 ~ 2026-01-30
assert p30_scope.start_date == date(2026, 1, 1)
assert p30_scope.end_date == date(2026, 1, 30)

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""汇总与报告工具的单测。"""
from utils.reporting import summarize_counts, format_report
def test_summarize_counts_and_format():
task_results = [
{"task_code": "ORDERS", "counts": {"fetched": 2, "inserted": 2, "updated": 0, "skipped": 0, "errors": 0}},
{"task_code": "PAYMENTS", "counts": {"fetched": 3, "inserted": 2, "updated": 1, "skipped": 0, "errors": 0}},
]
summary = summarize_counts(task_results)
assert summary["total"]["fetched"] == 5
assert summary["total"]["inserted"] == 4
assert summary["total"]["updated"] == 1
assert summary["total"]["errors"] == 0
assert len(summary["details"]) == 2
report = format_report(summary)
assert "TOTAL fetched=5" in report
assert "ORDERS:" in report
assert "PAYMENTS:" in report

View File

@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
"""TaskExecutor 属性测试 - hypothesis 验证执行器的通用正确性属性。"""
import re
import string
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from orchestration.task_executor import TaskExecutor, DataSource
from orchestration.task_registry import TaskRegistry
FILE_VERSION = "v4_shell"
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
_NON_ODS_PREFIXES = ["DWD_", "DWS_", "TASK_", "ETL_", "TEST_"]
task_code_st = st.builds(
lambda prefix, suffix: prefix + suffix,
prefix=st.sampled_from(_NON_ODS_PREFIXES),
suffix=st.text(
alphabet=string.ascii_uppercase + string.digits + "_",
min_size=1, max_size=15,
),
)
window_start_st = st.datetimes(min_value=datetime(2020, 1, 1), max_value=datetime(2030, 12, 31))
def _make_fake_class(name="FakeTask"):
return type(name, (), {"__init__": lambda self, *a, **kw: None})
def _make_config():
config = MagicMock()
config.get = MagicMock(side_effect=lambda key, default=None: {
"app.timezone": "Asia/Shanghai",
"io.fetch_root": "/tmp/fetch",
"io.ingest_source_dir": "/tmp/ingest",
"io.write_pretty_json": False,
"pipeline.fetch_root": None,
"pipeline.ingest_source_dir": None,
"integrity.auto_check": False,
"run.overlap_seconds": 600,
}.get(key, default))
config.__getitem__ = MagicMock(side_effect=lambda k: {
"io": {"export_root": "/tmp/export", "log_root": "/tmp/log"},
}[k])
return config
def _make_executor(registry):
return TaskExecutor(
config=_make_config(), db_ops=MagicMock(), api_client=MagicMock(),
cursor_mgr=MagicMock(), run_tracker=MagicMock(),
task_registry=registry, logger=MagicMock(),
)
# Feature: scheduler-refactor, Property 1: data_source 参数决定执行路径
# **Validates: Requirements 1.2**
class TestProperty1DataSourceDeterminesPath:
@given(ds=data_source_st)
@settings(max_examples=100)
def test_flow_includes_fetch(self, ds):
result = TaskExecutor._flow_includes_fetch(ds)
assert result == (ds in {"online", "hybrid"})
@given(ds=data_source_st)
@settings(max_examples=100)
def test_flow_includes_ingest(self, ds):
result = TaskExecutor._flow_includes_ingest(ds)
assert result == (ds in {"offline", "hybrid"})
@given(ds=data_source_st)
@settings(max_examples=100)
def test_fetch_and_ingest_consistency(self, ds):
fetch = TaskExecutor._flow_includes_fetch(ds)
ingest = TaskExecutor._flow_includes_ingest(ds)
if ds == "hybrid":
assert fetch and ingest
elif ds == "online":
assert fetch and not ingest
elif ds == "offline":
assert not fetch and ingest
# Feature: scheduler-refactor, Property 2: 成功任务推进游标
# **Validates: Requirements 1.3**
class TestProperty2SuccessAdvancesCursor:
@given(
task_code=task_code_st,
window_start=window_start_st,
window_minutes=st.integers(min_value=1, max_value=1440),
)
@settings(max_examples=100)
def test_success_with_window_advances_cursor(self, task_code, window_start, window_minutes):
window_end = window_start + timedelta(minutes=window_minutes)
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
task_result = {
"status": "SUCCESS",
"counts": {"fetched": 10, "inserted": 5},
"window": {"start": window_start, "end": window_end, "minutes": window_minutes},
}
executor = _make_executor(registry)
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
executor.run_tracker.create_run.return_value = 100
with (
patch.object(TaskExecutor, "_load_task_config", return_value={
"task_id": 42, "task_code": task_code, "store_id": 1, "enabled": True}),
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
patch.object(TaskExecutor, "_maybe_run_integrity_check"),
):
executor.run_single_task(task_code, "test-uuid", 1, "offline")
executor.cursor_mgr.advance.assert_called_once()
kw = executor.cursor_mgr.advance.call_args.kwargs
assert kw["window_start"] == window_start
assert kw["window_end"] == window_end
# Feature: scheduler-refactor, Property 3: 失败任务标记 FAIL 并重新抛出
# **Validates: Requirements 1.4**
class TestProperty3FailureMarksFailAndReraises:
@given(
task_code=task_code_st,
error_msg=st.text(
alphabet=string.ascii_letters + string.digits + " _-",
min_size=1, max_size=80,
),
)
@settings(max_examples=100)
def test_exception_marks_fail_and_reraises(self, task_code, error_msg):
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
executor = _make_executor(registry)
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
executor.run_tracker.create_run.return_value = 200
with (
patch.object(TaskExecutor, "_load_task_config", return_value={
"task_id": 99, "task_code": task_code, "store_id": 1, "enabled": True}),
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
patch.object(TaskExecutor, "_execute_ingest", side_effect=RuntimeError(error_msg)),
):
with pytest.raises(RuntimeError, match=re.escape(error_msg)):
executor.run_single_task(task_code, "fail-uuid", 1, "offline")
executor.run_tracker.update_run.assert_called()
kw = executor.run_tracker.update_run.call_args.kwargs
assert kw["status"] == "FAIL"
# Feature: scheduler-refactor, Property 4: 工具类任务由元数据决定
# **Validates: Requirements 1.6, 4.2**
class TestProperty4UtilityTaskDeterminedByMetadata:
@given(task_code=task_code_st)
@settings(max_examples=100)
def test_utility_task_skips_cursor_and_run_tracker(self, task_code):
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=False, task_type="utility")
executor = _make_executor(registry)
mock_task = MagicMock()
mock_task.execute.return_value = {"status": "SUCCESS", "counts": {}}
registry.create_task = MagicMock(return_value=mock_task)
result = executor.run_single_task(task_code, "util-uuid", 1, "hybrid")
executor.cursor_mgr.get_or_create.assert_not_called()
executor.cursor_mgr.advance.assert_not_called()
executor.run_tracker.create_run.assert_not_called()
assert result.get("status") == "SUCCESS"
@given(task_code=task_code_st)
@settings(max_examples=100)
def test_non_utility_task_uses_cursor_and_run_tracker(self, task_code):
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWS")
task_result = {"status": "SUCCESS", "counts": {"fetched": 1}}
executor = _make_executor(registry)
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
executor.run_tracker.create_run.return_value = 300
with (
patch.object(TaskExecutor, "_load_task_config", return_value={
"task_id": 77, "task_code": task_code, "store_id": 1, "enabled": True}),
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
):
executor.run_single_task(task_code, "non-util-uuid", 1, "offline")
executor.cursor_mgr.get_or_create.assert_called_once()
executor.run_tracker.create_run.assert_called_once()

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""TaskRegistry 单元测试 — 验证 TaskMeta 元数据注册与查询"""
import pytest
from orchestration.task_registry import TaskRegistry, TaskMeta
# ── 辅助:用作注册的假任务类 ──────────────────────────────────
class _FakeTask:
"""占位任务类,用于测试注册"""
def __init__(self, config, db_connection, api_client, logger):
self.config = config
class _AnotherFakeTask:
def __init__(self, config, db_connection, api_client, logger):
pass
# ── fixtures ──────────────────────────────────────────────────
@pytest.fixture
def registry():
return TaskRegistry()
# ── register + get_metadata ───────────────────────────────────
class TestRegisterAndMetadata:
"""注册与元数据查询"""
def test_register_with_defaults(self, registry):
"""仅传 task_code + task_class 时,元数据使用默认值(向后兼容)"""
registry.register("MY_TASK", _FakeTask)
meta = registry.get_metadata("MY_TASK")
assert meta is not None
assert meta.task_class is _FakeTask
assert meta.requires_db_config is True
assert meta.layer is None
assert meta.task_type == "etl"
def test_register_with_full_metadata(self, registry):
"""传入完整元数据"""
registry.register(
"ODS_ORDERS", _FakeTask,
requires_db_config=True, layer="ODS", task_type="etl",
)
meta = registry.get_metadata("ODS_ORDERS")
assert meta.layer == "ODS"
assert meta.task_type == "etl"
def test_register_utility_task(self, registry):
"""工具类任务requires_db_config=False"""
registry.register(
"INIT_SCHEMA", _FakeTask,
requires_db_config=False, task_type="utility",
)
meta = registry.get_metadata("INIT_SCHEMA")
assert meta.requires_db_config is False
assert meta.task_type == "utility"
def test_case_insensitive_lookup(self, registry):
"""task_code 大小写不敏感"""
registry.register("my_task", _FakeTask)
assert registry.get_metadata("MY_TASK") is not None
assert registry.get_metadata("my_task") is not None
def test_get_metadata_unknown_returns_none(self, registry):
"""查询未注册的任务返回 None"""
assert registry.get_metadata("NONEXISTENT") is None
# ── create_task接口不变────────────────────────────────────
class TestCreateTask:
def test_create_task_returns_instance(self, registry):
registry.register("MY_TASK", _FakeTask)
task = registry.create_task("MY_TASK", {"k": "v"}, None, None, None)
assert isinstance(task, _FakeTask)
assert task.config == {"k": "v"}
def test_create_task_unknown_raises(self, registry):
with pytest.raises(ValueError, match="未知的任务类型"):
registry.create_task("NOPE", None, None, None, None)
# ── get_tasks_by_layer ────────────────────────────────────────
class TestGetTasksByLayer:
def test_returns_matching_tasks(self, registry):
registry.register("A", _FakeTask, layer="ODS")
registry.register("B", _AnotherFakeTask, layer="ODS")
registry.register("C", _FakeTask, layer="DWD")
result = registry.get_tasks_by_layer("ODS")
assert set(result) == {"A", "B"}
def test_case_insensitive_layer(self, registry):
registry.register("X", _FakeTask, layer="dws")
assert registry.get_tasks_by_layer("DWS") == ["X"]
def test_no_match_returns_empty(self, registry):
registry.register("A", _FakeTask, layer="ODS")
assert registry.get_tasks_by_layer("INDEX") == []
def test_none_layer_excluded(self, registry):
"""layer=None 的任务不会被任何层查询返回"""
registry.register("UTIL", _FakeTask) # layer 默认 None
assert registry.get_tasks_by_layer("ODS") == []
# ── is_utility_task ───────────────────────────────────────────
class TestIsUtilityTask:
def test_utility_task(self, registry):
registry.register("INIT", _FakeTask, requires_db_config=False)
assert registry.is_utility_task("INIT") is True
def test_normal_task(self, registry):
registry.register("ETL", _FakeTask, requires_db_config=True)
assert registry.is_utility_task("ETL") is False
def test_unknown_task(self, registry):
assert registry.is_utility_task("NOPE") is False
# ── get_all_task_codes接口不变──────────────────────────────
class TestGetAllTaskCodes:
def test_returns_all_codes(self, registry):
registry.register("A", _FakeTask)
registry.register("B", _AnotherFakeTask)
assert set(registry.get_all_task_codes()) == {"A", "B"}
def test_empty_registry(self, registry):
assert registry.get_all_task_codes() == []

View File

@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
"""TaskRegistry 属性测试 — 使用 hypothesis 验证注册表的通用正确性属性。"""
import string
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from orchestration.task_registry import TaskRegistry, TaskMeta
# ── 辅助:动态生成假任务类 ────────────────────────────────────
def _make_fake_class(name: str = "FakeTask") -> type:
"""创建一个最小化的假任务类,用于注册测试。"""
return type(name, (), {"__init__": lambda self, *a, **kw: None})
# ── 生成策略 ──────────────────────────────────────────────────
# 合法任务代码:大写字母 + 数字 + 下划线,长度 1~30
task_code_st = st.text(
alphabet=string.ascii_uppercase + string.digits + "_",
min_size=1,
max_size=30,
)
requires_db_config_st = st.booleans()
layer_st = st.sampled_from([None, "ODS", "DWD", "DWS", "INDEX"])
task_type_st = st.sampled_from(["etl", "utility", "verification"])
# ── Property 8: TaskRegistry 元数据 round-trip ────────────────
# Feature: scheduler-refactor, Property 8: TaskRegistry 元数据 round-trip
# **Validates: Requirements 4.1**
#
# 对于任意任务代码、任务类和元数据组合requires_db_config、layer、task_type
# 注册后通过 get_metadata 查询应返回相同的元数据值。
class TestProperty8MetadataRoundTrip:
"""Property 8: 注册元数据后查询应返回完全相同的值。"""
@given(
task_code=task_code_st,
requires_db=requires_db_config_st,
layer=layer_st,
task_type=task_type_st,
)
@settings(max_examples=100)
def test_metadata_round_trip(
self, task_code: str, requires_db: bool, layer: str | None, task_type: str
):
"""注册任意元数据组合后get_metadata 应返回相同的值。"""
# Arrange — 每次迭代使用全新的注册表,避免状态泄漏
registry = TaskRegistry()
fake_cls = _make_fake_class()
# Act — 注册并查询
registry.register(
task_code,
fake_cls,
requires_db_config=requires_db,
layer=layer,
task_type=task_type,
)
meta = registry.get_metadata(task_code)
# Assert — 元数据 round-trip 一致
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
assert meta.requires_db_config is requires_db, (
f"requires_db_config 应为 {requires_db},实际为 {meta.requires_db_config}"
)
assert meta.layer == layer, f"layer 应为 {layer!r},实际为 {meta.layer!r}"
assert meta.task_type == task_type, (
f"task_type 应为 {task_type!r},实际为 {meta.task_type!r}"
)
# ── Property 9: TaskRegistry 向后兼容默认值 ───────────────────
# Feature: scheduler-refactor, Property 9: TaskRegistry 向后兼容默认值
# **Validates: Requirements 4.4**
#
# 对于任意使用旧接口(仅 task_code 和 task_class注册的任务
# 查询元数据应返回 requires_db_config=True、layer=None、task_type="etl"。
class TestProperty9BackwardCompatibleDefaults:
"""Property 9: 仅传 task_code + task_class 时,元数据应使用默认值。"""
@given(task_code=task_code_st)
@settings(max_examples=100)
def test_legacy_register_uses_defaults(self, task_code: str):
"""使用旧接口(仅 task_code 和 task_class注册后元数据应为默认值。"""
# Arrange
registry = TaskRegistry()
fake_cls = _make_fake_class()
# Act — 仅传 task_code 和 task_class不传任何元数据参数
registry.register(task_code, fake_cls)
meta = registry.get_metadata(task_code)
# Assert — 默认值契约
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
assert meta.requires_db_config is True, (
f"默认 requires_db_config 应为 True实际为 {meta.requires_db_config}"
)
assert meta.layer is None, (
f"默认 layer 应为 None实际为 {meta.layer!r}"
)
assert meta.task_type == "etl", (
f"默认 task_type 应为 'etl',实际为 {meta.task_type!r}"
)
# ── Property 10: 按层查询任务 ────────────────────────────────
# Feature: scheduler-refactor, Property 10: 按层查询任务
# **Validates: Requirements 4.3**
#
# 对于任意注册了 layer 元数据的任务集合get_tasks_by_layer(layer)
# 返回的任务代码集合应等于所有 layer 匹配的已注册任务代码集合。
# 非 None 的层值策略,用于查询验证
non_none_layer_st = st.sampled_from(["ODS", "DWD", "DWS", "INDEX"])
class TestProperty10GetTasksByLayer:
"""Property 10: get_tasks_by_layer 返回的集合应与手动过滤一致。"""
@given(
entries=st.lists(
st.tuples(task_code_st, layer_st),
min_size=1,
max_size=20,
),
)
@settings(max_examples=100)
def test_get_tasks_by_layer_matches_manual_filter(
self, entries: list[tuple[str, str | None]],
):
"""注册一组任务后,按层查询结果应与手动过滤完全一致。"""
# Arrange
registry = TaskRegistry()
# 去重:同一 task_code 只保留最后一次注册(与 register 覆盖语义一致)
unique_entries: dict[str, str | None] = {}
for code, layer in entries:
fake_cls = _make_fake_class(f"Fake_{code}")
registry.register(code, fake_cls, layer=layer)
unique_entries[code.upper()] = layer # register 内部会 upper()
# Act & Assert — 对每个非 None 的层值进行验证
for query_layer in ["ODS", "DWD", "DWS", "INDEX"]:
actual = set(registry.get_tasks_by_layer(query_layer))
expected = {
code for code, layer in unique_entries.items()
if layer is not None and layer.upper() == query_layer.upper()
}
assert actual == expected, (
f"查询 layer={query_layer!r} 时,"
f"期望 {expected},实际 {actual}"
)

View File

@@ -0,0 +1,358 @@
"""
scripts/validate_bd_manual.py 的单元测试。
不需要数据库连接,通过构造临时文件系统结构来验证各 check 函数。
"""
from __future__ import annotations
import textwrap
from pathlib import Path
import pytest
# 被测模块
from scripts.validate_bd_manual import (
check_directory_structure,
check_ods_doc_coverage,
check_ods_doc_format,
check_ods_doc_naming,
check_mapping_doc_coverage,
check_mapping_doc_content,
check_mapping_doc_naming,
check_ods_dictionary_coverage,
BD_MANUAL_ROOT,
ODS_MAIN_DIR,
ODS_MAPPINGS_DIR,
ODS_DICT_PATH,
DATA_LAYERS,
)
import scripts.validate_bd_manual as mod
# ---------------------------------------------------------------------------
# 辅助:临时目录结构搭建
# ---------------------------------------------------------------------------
def _setup_layer_dirs(tmp_path: Path) -> None:
"""在 tmp_path 下创建完整的四层目录结构。"""
for layer in DATA_LAYERS:
(tmp_path / "docs" / "bd_manual" / layer / "main").mkdir(parents=True, exist_ok=True)
(tmp_path / "docs" / "bd_manual" / layer / "changes").mkdir(parents=True, exist_ok=True)
SAMPLE_ODS_DOC = textwrap.dedent("""\
# test_table 测试表
> 生成时间2026-01-01
## 表信息
| 属性 | 值 |
|------|-----|
| Schema | billiards_ods |
| 表名 | test_table |
| 主键 | id, content_hash |
| 数据来源 | TestEndpoint |
| 说明 | 测试表 |
## 字段说明
| 序号 | 字段名 | 类型 | 可空 | 说明 |
|------|--------|------|------|------|
| 1 | id | BIGINT | NO | 主键 |
## 使用说明
```sql
SELECT * FROM billiards_ods.test_table LIMIT 10;
```
## ETL 元数据字段
| 字段名 | 类型 | 说明 |
|--------|------|------|
| content_hash | TEXT | SHA256 |
| source_file | TEXT | 文件名 |
| source_endpoint | TEXT | 端点 |
| fetched_at | TIMESTAMPTZ | 时间戳 |
| payload | JSONB | 原始 JSON |
## 可回溯性
| 项目 | 说明 |
|------|------|
| 可回溯 | ✅ 完全可回溯 |
| 数据来源 | TestEndpoint |
""")
SAMPLE_MAPPING_DOC = textwrap.dedent("""\
# TestEndpoint → test_table 字段映射
> 生成时间2026-01-01
## 端点信息
| 属性 | 值 |
|------|-----|
| 接口路径 | `Test/TestEndpoint` |
| 请求方法 | POST |
| ODS 对应表 | `billiards_ods.test_table` |
| JSON 数据路径 | `data.items` |
## 字段映射
| JSON 字段 | ODS 列名 | 类型转换 | 说明 |
|-----------|----------|----------|------|
| id | id | int→BIGINT | 主键 |
## ETL 补充字段
| ODS 列名 | 生成逻辑 |
|-----------|----------|
| content_hash | SHA256 |
| source_file | 固定值 |
| source_endpoint | 端点路径 |
| fetched_at | 入库时间戳 |
| payload | 原始 JSON |
""")
# ---------------------------------------------------------------------------
# Property 1: 目录结构一致性
# ---------------------------------------------------------------------------
class TestCheckDirectoryStructure:
"""Property 1: 数据层目录结构一致性。"""
def test_pass_when_all_dirs_exist(self, tmp_path, monkeypatch):
_setup_layer_dirs(tmp_path)
monkeypatch.setattr(mod, "BD_MANUAL_ROOT", tmp_path / "docs" / "bd_manual")
result = check_directory_structure()
assert result.passed is True
assert result.property_id == "Property 1"
def test_fail_when_missing_subdir(self, tmp_path, monkeypatch):
_setup_layer_dirs(tmp_path)
# 删除 ETL_Admin/changes/
import shutil
shutil.rmtree(tmp_path / "docs" / "bd_manual" / "ETL_Admin" / "changes")
monkeypatch.setattr(mod, "BD_MANUAL_ROOT", tmp_path / "docs" / "bd_manual")
result = check_directory_structure()
assert result.passed is False
assert any("ETL_Admin" in d and "changes" in d for d in result.details)
# ---------------------------------------------------------------------------
# Property 4: ODS 文档覆盖率
# ---------------------------------------------------------------------------
class TestCheckOdsDocCoverage:
"""Property 4: ODS 表级文档覆盖率。"""
def test_pass_when_all_docs_exist(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
tables = ["member_profiles", "payment_transactions"]
for t in tables:
(ods_main / f"BD_manual_{t}.md").write_text("# doc", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_coverage(tables)
assert result.passed is True
def test_fail_when_doc_missing(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_coverage(["member_profiles", "missing_table"])
assert result.passed is False
assert any("missing_table" in d for d in result.details)
# ---------------------------------------------------------------------------
# Property 5: ODS 文档格式完整性
# ---------------------------------------------------------------------------
class TestCheckOdsDocFormat:
"""Property 5: ODS 表级文档格式完整性。"""
def test_pass_with_complete_doc(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
(ods_main / "BD_manual_test_table.md").write_text(SAMPLE_ODS_DOC, encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_format()
assert result.passed is True
def test_fail_when_missing_section(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
# 去掉"可回溯性"章节
incomplete = SAMPLE_ODS_DOC.replace("## 可回溯性", "## 其他")
(ods_main / "BD_manual_test_table.md").write_text(incomplete, encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_format()
assert result.passed is False
assert any("可回溯性" in d for d in result.details)
def test_fail_when_missing_etl_field(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
# 去掉 content_hash
incomplete = SAMPLE_ODS_DOC.replace("content_hash", "xxx_hash")
(ods_main / "BD_manual_test_table.md").write_text(incomplete, encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_format()
assert result.passed is False
assert any("content_hash" in d for d in result.details)
# ---------------------------------------------------------------------------
# Property 6: ODS 文档命名规范
# ---------------------------------------------------------------------------
class TestCheckOdsDocNaming:
"""Property 6: ODS 表级文档命名规范。"""
def test_pass_with_valid_names(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
(ods_main / "BD_manual_payment_transactions.md").write_text("# doc", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_naming()
assert result.passed is True
def test_fail_with_invalid_name(self, tmp_path, monkeypatch):
"""ODS/main/ 下所有 .md 文件都应符合命名规范BadName.md 应被检出。"""
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
(ods_main / "BD_manual_member_profiles.md").write_text("# doc", encoding="utf-8")
(ods_main / "BadName.md").write_text("# doc", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_naming()
assert result.passed is False
assert any("BadName.md" in d for d in result.details)
def test_fail_with_uppercase_table_name(self, tmp_path, monkeypatch):
ods_main = tmp_path / "ODS" / "main"
ods_main.mkdir(parents=True)
# BD_manual_ 后面跟大写字母,不符合 snake_case
(ods_main / "BD_manual_MemberProfiles.md").write_text("# doc", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAIN_DIR", ods_main)
result = check_ods_doc_naming()
assert result.passed is False
# ---------------------------------------------------------------------------
# Property 7: 映射文档覆盖率
# ---------------------------------------------------------------------------
class TestCheckMappingDocCoverage:
"""Property 7: 映射文档覆盖率。"""
def test_pass_when_all_mappings_exist(self, tmp_path, monkeypatch):
mappings = tmp_path / "ODS" / "mappings"
mappings.mkdir(parents=True)
(mappings / "mapping_GetTest_test_table.md").write_text("# map", encoding="utf-8")
(mappings / "mapping_GetOther_other_table.md").write_text("# map", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
result = check_mapping_doc_coverage(["test_table", "other_table"])
assert result.passed is True
def test_fail_when_mapping_missing(self, tmp_path, monkeypatch):
mappings = tmp_path / "ODS" / "mappings"
mappings.mkdir(parents=True)
(mappings / "mapping_GetTest_test_table.md").write_text("# map", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
result = check_mapping_doc_coverage(["test_table", "missing_table"])
assert result.passed is False
assert any("missing_table" in d for d in result.details)
# ---------------------------------------------------------------------------
# Property 8: 映射文档内容完整性
# ---------------------------------------------------------------------------
class TestCheckMappingDocContent:
"""Property 8: 映射文档内容完整性。"""
def test_pass_with_complete_mapping(self, tmp_path, monkeypatch):
mappings = tmp_path / "ODS" / "mappings"
mappings.mkdir(parents=True)
(mappings / "mapping_TestEndpoint_test_table.md").write_text(
SAMPLE_MAPPING_DOC, encoding="utf-8"
)
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
result = check_mapping_doc_content()
assert result.passed is True
def test_fail_when_missing_section(self, tmp_path, monkeypatch):
mappings = tmp_path / "ODS" / "mappings"
mappings.mkdir(parents=True)
incomplete = SAMPLE_MAPPING_DOC.replace("## 字段映射", "## 其他映射")
(mappings / "mapping_TestEndpoint_test_table.md").write_text(
incomplete, encoding="utf-8"
)
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
result = check_mapping_doc_content()
assert result.passed is False
assert any("字段映射" in d for d in result.details)
# ---------------------------------------------------------------------------
# Property 9: 映射文档命名规范
# ---------------------------------------------------------------------------
class TestCheckMappingDocNaming:
"""Property 9: 映射文档命名规范。"""
def test_pass_with_valid_names(self, tmp_path, monkeypatch):
mappings = tmp_path / "ODS" / "mappings"
mappings.mkdir(parents=True)
(mappings / "mapping_GetTenantMemberList_member_profiles.md").write_text("# m", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
result = check_mapping_doc_naming()
assert result.passed is True
def test_fail_with_lowercase_endpoint(self, tmp_path, monkeypatch):
mappings = tmp_path / "ODS" / "mappings"
mappings.mkdir(parents=True)
# 端点名以小写开头,不符合 PascalCase
(mappings / "mapping_getTenantMemberList_member_profiles.md").write_text("# m", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_MAPPINGS_DIR", mappings)
result = check_mapping_doc_naming()
assert result.passed is False
# ---------------------------------------------------------------------------
# Property 10: ODS 数据字典覆盖率
# ---------------------------------------------------------------------------
class TestCheckOdsDictionaryCoverage:
"""Property 10: ODS 数据字典覆盖率。"""
def test_pass_when_all_tables_in_dict(self, tmp_path, monkeypatch):
dict_file = tmp_path / "ods_tables_dictionary.md"
dict_file.write_text(
"| `member_profiles` | 会员 |\n| `payment_transactions` | 支付 |",
encoding="utf-8",
)
monkeypatch.setattr(mod, "ODS_DICT_PATH", dict_file)
result = check_ods_dictionary_coverage(["member_profiles", "payment_transactions"])
assert result.passed is True
def test_fail_when_table_missing_from_dict(self, tmp_path, monkeypatch):
dict_file = tmp_path / "ods_tables_dictionary.md"
dict_file.write_text("| `member_profiles` | 会员 |", encoding="utf-8")
monkeypatch.setattr(mod, "ODS_DICT_PATH", dict_file)
result = check_ods_dictionary_coverage(["member_profiles", "missing_table"])
assert result.passed is False
assert any("missing_table" in d for d in result.details)
def test_fail_when_dict_file_missing(self, tmp_path, monkeypatch):
monkeypatch.setattr(mod, "ODS_DICT_PATH", tmp_path / "nonexistent.md")
result = check_ods_dictionary_coverage(["some_table"])
assert result.passed is False