初始提交:飞球 ETL 系统全量代码

This commit is contained in:
Neo
2026-02-13 08:05:34 +08:00
commit 3c51f5485d
441 changed files with 117631 additions and 0 deletions

0
tests/unit/__init__.py Normal file
View File

View File

@@ -0,0 +1,794 @@
# -*- coding: utf-8 -*-
"""ETL 任务测试的共用辅助模块,涵盖在线/离线模式所需的伪造数据、客户端与配置等工具函数。"""
from __future__ import annotations
import json
import os
import re
from types import SimpleNamespace
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple, Type
from config.settings import AppConfig
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations as PgDBOperations
from tasks.ods.assistant_abolish_task import AssistantAbolishTask
from tasks.ods.assistants_task import AssistantsTask
from tasks.ods.coupon_usage_task import CouponUsageTask
from tasks.ods.inventory_change_task import InventoryChangeTask
from tasks.ods.ledger_task import LedgerTask
from tasks.ods.members_task import MembersTask
from tasks.ods.orders_task import OrdersTask
from tasks.ods.packages_task import PackagesDefTask
from tasks.ods.payments_task import PaymentsTask
from tasks.ods.products_task import ProductsTask
from tasks.ods.refunds_task import RefundsTask
from tasks.ods.table_discount_task import TableDiscountTask
from tasks.ods.tables_task import TablesTask
from tasks.ods.topups_task import TopupsTask
from utils.json_store import endpoint_to_filename
DEFAULT_STORE_ID = 2790685415443269
BASE_TS = "2025-01-01 10:00:00"
END_TS = "2025-01-01 12:00:00"
@dataclass(frozen=True)
class TaskSpec:
"""描述单个任务在测试中如何被驱动的元数据包含任务代码、API 路径、数据路径与样例记录。"""
code: str
task_cls: Type
endpoint: str
data_path: Tuple[str, ...]
sample_records: List[Dict]
@property
def archive_filename(self) -> str:
return endpoint_to_filename(self.endpoint)
def wrap_records(records: List[Dict], data_path: Sequence[str]):
"""按照 data_path 逐层包裹记录列表,使其结构与真实 API 返回体一致,方便离线回放。"""
payload = records
for key in reversed(data_path):
payload = {key: payload}
return payload
def create_test_config(mode: str, archive_dir: Path, temp_dir: Path) -> AppConfig:
"""构建一份适合测试的 AppConfig自动填充存储、日志、归档目录等参数并保证目录存在。"""
archive_dir = Path(archive_dir)
temp_dir = Path(temp_dir)
archive_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
flow = "FULL" if str(mode or "").upper() == "ONLINE" else "INGEST_ONLY"
overrides = {
"app": {"store_id": DEFAULT_STORE_ID, "timezone": "Asia/Taipei"},
"db": {"dsn": "postgresql://user:pass@localhost:5432/fq_etl_test"},
"api": {
"base_url": "https://api.example.com",
"token": "test-token",
"timeout_sec": 3,
"page_size": 50,
},
"pipeline": {
"flow": flow,
"fetch_root": str(temp_dir / "json_fetch"),
"ingest_source_dir": str(archive_dir),
},
"io": {
"export_root": str(temp_dir / "export"),
"log_root": str(temp_dir / "logs"),
},
}
return AppConfig.load(overrides)
def dump_offline_payload(spec: TaskSpec, archive_dir: Path) -> Path:
"""将 TaskSpec 的样例数据写入指定归档目录,供离线测试回放使用,并返回生成文件的完整路径。"""
archive_dir = Path(archive_dir)
payload = wrap_records(spec.sample_records, spec.data_path)
file_path = archive_dir / spec.archive_filename
with file_path.open("w", encoding="utf-8") as fp:
json.dump(payload, fp, ensure_ascii=False)
return file_path
class FakeCursor:
"""极简游标桩对象,记录 SQL/参数并支持上下文管理,供 FakeDBOperations 与 SCD2Handler 使用。"""
def __init__(self, recorder: List[Dict], db_ops=None):
self.recorder = recorder
self._db_ops = db_ops
self._pending_rows: List[Tuple] = []
self._fetchall_rows: List[Tuple] = []
self.rowcount = 0
self.connection = SimpleNamespace(encoding="UTF8")
# pylint: disable=unused-argument
def execute(self, sql: str, params=None):
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
self.recorder.append({"sql": sql_text.strip(), "params": params})
self._fetchall_rows = []
# 处理 information_schema 查询,用于结构感知写入。
lowered = sql_text.lower()
if "from information_schema.columns" in lowered:
table_name = None
if params and len(params) >= 2:
table_name = params[1]
self._fetchall_rows = self._fake_columns(table_name)
return
if "from information_schema.table_constraints" in lowered:
self._fetchall_rows = []
return
if self._pending_rows:
self.rowcount = len(self._pending_rows)
self._record_upserts(sql_text)
self._pending_rows = []
else:
self.rowcount = 0
def fetchone(self):
return None
def fetchall(self):
return list(self._fetchall_rows)
def mogrify(self, template, args):
self._pending_rows.append(tuple(args))
return b"(?)"
def _record_upserts(self, sql_text: str):
if not self._db_ops:
return
match = re.search(r"insert\s+into\s+[^\(]+\(([^)]*)\)\s+values", sql_text, re.I)
if not match:
return
columns = [c.strip().strip('"') for c in match.group(1).split(",")]
rows = []
for idx, row in enumerate(self._pending_rows):
if len(row) != len(columns):
continue
row_dict = {}
for col, val in zip(columns, row):
if col == "record_index" and val in (None, ""):
row_dict[col] = idx
continue
if hasattr(val, "adapted"):
row_dict[col] = json.dumps(val.adapted, ensure_ascii=False)
else:
row_dict[col] = val
rows.append(row_dict)
if rows:
self._db_ops.upserts.append(
{"sql": sql_text.strip(), "count": len(rows), "page_size": len(rows), "rows": rows}
)
@staticmethod
def _fake_columns(_table_name: str | None) -> List[Tuple[str, str, str]]:
return [
("id", "bigint", "int8"),
("sitegoodsstockid", "bigint", "int8"),
("record_index", "integer", "int4"),
("content_hash", "text", "text"),
("source_file", "text", "text"),
("source_endpoint", "text", "text"),
("fetched_at", "timestamp with time zone", "timestamptz"),
("payload", "jsonb", "jsonb"),
]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
class FakeConnection:
"""仿 psycopg 连接对象,仅满足 SCD2Handler 对 cursor 的最小需求,并缓存执行过的语句。"""
def __init__(self, db_ops):
self.statements: List[Dict] = []
self._db_ops = db_ops
def cursor(self):
return FakeCursor(self.statements, self._db_ops)
class FakeDBOperations:
"""拦截并记录批量 upsert/事务操作,避免触碰真实数据库,同时提供 commit/rollback 计数。"""
def __init__(self):
self.upserts: List[Dict] = []
self.executes: List[Dict] = []
self.commits = 0
self.rollbacks = 0
self.conn = FakeConnection(self)
# 预设查询结果FIFO用于测试中控制数据库返回的行
self.query_results: List[List[Dict]] = []
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.upserts.append(
{
"sql": sql.strip(),
"count": len(rows),
"page_size": page_size,
"rows": [dict(row) for row in rows],
}
)
return len(rows), 0
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
self.executes.append(
{
"sql": sql.strip(),
"count": len(rows),
"page_size": page_size,
"rows": [dict(row) for row in rows],
}
)
def execute(self, sql: str, params=None):
self.executes.append({"sql": sql.strip(), "params": params})
def query(self, sql: str, params=None):
self.executes.append({"sql": sql.strip(), "params": params, "type": "query"})
if self.query_results:
return self.query_results.pop(0)
return []
def cursor(self):
return self.conn.cursor()
def commit(self):
self.commits += 1
def rollback(self):
self.rollbacks += 1
class FakeAPIClient:
"""在线模式使用的伪 API Client直接返回预置的内存数据并记录调用以确保任务参数正确传递。"""
def __init__(self, data_map: Dict[str, List[Dict]]):
self.data_map = data_map
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def iter_paginated(
self,
endpoint: str,
params=None,
page_size: int = 200,
page_field: str = "page",
size_field: str = "limit",
data_path: Tuple[str, ...] = (),
list_key: str | None = None,
):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.data_map:
raise AssertionError(f"Missing fixture for endpoint {endpoint}")
records = list(self.data_map[endpoint])
yield 1, records, dict(params or {}), {"data": records}
def get_paginated(self, endpoint: str, params=None, **kwargs):
records = []
pages = []
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
records.extend(page_records)
pages.append({"page": page_no, "request": req, "response": resp})
return records, pages
def get_source_hint(self, endpoint: str) -> str | None:
return None
class OfflineAPIClient:
"""离线模式专用 API Client根据 endpoint 读取归档 JSON、套入 data_path 并回放列表数据。"""
def __init__(self, file_map: Dict[str, Path]):
self.file_map = {k: Path(v) for k, v in file_map.items()}
self.calls: List[Dict] = []
# pylint: disable=unused-argument
def iter_paginated(
self,
endpoint: str,
params=None,
page_size: int = 200,
page_field: str = "page",
size_field: str = "limit",
data_path: Tuple[str, ...] = (),
list_key: str | None = None,
):
self.calls.append({"endpoint": endpoint, "params": params})
if endpoint not in self.file_map:
raise AssertionError(f"Missing archive for endpoint {endpoint}")
with self.file_map[endpoint].open("r", encoding="utf-8") as fp:
payload = json.load(fp)
data = payload
for key in data_path:
if isinstance(data, dict):
data = data.get(key, [])
if list_key and isinstance(data, dict):
data = data.get(list_key, [])
if not isinstance(data, list):
data = []
total = len(data)
start = 0
page = 1
while start < total or (start == 0 and total == 0):
chunk = data[start : start + page_size]
if not chunk and total != 0:
break
yield page, list(chunk), dict(params or {}), payload
if len(chunk) < page_size:
break
start += page_size
page += 1
def get_paginated(self, endpoint: str, params=None, **kwargs):
records = []
pages = []
for page_no, page_records, req, resp in self.iter_paginated(endpoint, params, **kwargs):
records.extend(page_records)
pages.append({"page": page_no, "request": req, "response": resp})
return records, pages
def get_source_hint(self, endpoint: str) -> str | None:
if endpoint not in self.file_map:
return None
return str(self.file_map[endpoint])
class RealDBOperationsAdapter:
"""连接真实 PostgreSQL 的适配器,为任务提供 batch_upsert + 事务能力。"""
def __init__(self, dsn: str):
self._conn = DatabaseConnection(dsn)
self._ops = PgDBOperations(self._conn)
# SCD2Handler 会访问 db.conn.cursor(),因此暴露底层连接
self.conn = self._conn.conn
def batch_upsert_with_returning(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_upsert_with_returning(sql, rows, page_size=page_size)
def batch_execute(self, sql: str, rows: List[Dict], page_size: int = 1000):
return self._ops.batch_execute(sql, rows, page_size=page_size)
def commit(self):
self._conn.commit()
def rollback(self):
self._conn.rollback()
def close(self):
self._conn.close()
@contextmanager
def get_db_operations():
"""
测试专用的 DB 操作上下文:
- 若设置 TEST_DB_DSN则连接真实 PostgreSQL
- 否则回退到 FakeDBOperations内存桩
"""
dsn = os.environ.get("TEST_DB_DSN")
if dsn:
adapter = RealDBOperationsAdapter(dsn)
try:
yield adapter
finally:
adapter.close()
else:
fake = FakeDBOperations()
yield fake
TASK_SPECS: List[TaskSpec] = [
TaskSpec(
code="PRODUCTS",
task_cls=ProductsTask,
endpoint="/TenantGoods/QueryTenantGoods",
data_path=("data", "tenantGoodsList"),
sample_records=[
{
"siteGoodsId": 101,
"tenantGoodsId": 101,
"goodsName": "测试球杆",
"goodsCategoryId": 201,
"categoryName": "器材",
"goodsCategorySecondId": 202,
"goodsUnit": "",
"costPrice": "100.00",
"goodsPrice": "150.00",
"goodsState": "ON",
"supplierId": 20,
"barcode": "PRD001",
"isCombo": False,
"createTime": BASE_TS,
"updateTime": END_TS,
}
],
),
TaskSpec(
code="TABLES",
task_cls=TablesTask,
endpoint="/Table/GetSiteTables",
data_path=("data", "siteTables"),
sample_records=[
{
"id": 301,
"site_id": 30,
"site_table_area_id": 40,
"areaName": "大厅",
"table_name": "1号桌",
"table_price": "50.00",
"table_status": "FREE",
"tableStatusName": "空闲",
"light_status": "OFF",
"is_rest_area": False,
"show_status": True,
"virtual_table": False,
"charge_free": False,
"only_allow_groupon": False,
"is_online_reservation": True,
"createTime": BASE_TS,
}
],
),
TaskSpec(
code="MEMBERS",
task_cls=MembersTask,
endpoint="/MemberProfile/GetTenantMemberList",
data_path=("data", "tenantMemberInfos"),
sample_records=[
{
"memberId": 401,
"memberName": "张三",
"phone": "13800000000",
"balance": "88.88",
"status": "ACTIVE",
"registerTime": BASE_TS,
}
],
),
TaskSpec(
code="ASSISTANTS",
task_cls=AssistantsTask,
endpoint="/PersonnelManagement/SearchAssistantInfo",
data_path=("data", "assistantInfos"),
sample_records=[
{
"id": 501,
"assistant_no": "AS001",
"nickname": "小李",
"real_name": "李雷",
"gender": "M",
"mobile": "13900000000",
"level": "A",
"team_id": 10,
"team_name": "先锋队",
"assistant_status": "ON",
"work_status": "BUSY",
"entry_time": BASE_TS,
"resign_time": END_TS,
"start_time": BASE_TS,
"end_time": END_TS,
"create_time": BASE_TS,
"update_time": END_TS,
"system_role_id": 1,
"online_status": "ONLINE",
"allow_cx": True,
"charge_way": "TIME",
"pd_unit_price": "30.00",
"cx_unit_price": "20.00",
"is_guaranteed": True,
"is_team_leader": False,
"serial_number": "SN001",
"show_sort": 1,
"is_delete": False,
}
],
),
TaskSpec(
code="PACKAGES_DEF",
task_cls=PackagesDefTask,
endpoint="/PackageCoupon/QueryPackageCouponList",
data_path=("data", "packageCouponList"),
sample_records=[
{
"id": 601,
"package_id": "PKG001",
"package_name": "白天特惠",
"table_area_id": 70,
"table_area_name": "大厅",
"selling_price": "199.00",
"duration": 120,
"start_time": BASE_TS,
"end_time": END_TS,
"type": "Groupon",
"is_enabled": True,
"is_delete": False,
"usable_count": 3,
"creator_name": "系统",
"date_type": "WEEKDAY",
"group_type": "DINE_IN",
"coupon_money": "30.00",
"area_tag_type": "VIP",
"system_group_type": "BASIC",
"card_type_ids": "1,2,3",
}
],
),
TaskSpec(
code="ORDERS",
task_cls=OrdersTask,
endpoint="/Site/GetAllOrderSettleList",
data_path=("data", "settleList"),
sample_records=[
{
"orderId": 701,
"orderNo": "ORD001",
"memberId": 401,
"tableId": 301,
"orderTime": BASE_TS,
"endTime": END_TS,
"totalAmount": "300.00",
"discountAmount": "20.00",
"finalAmount": "280.00",
"payStatus": "PAID",
"orderStatus": "CLOSED",
"remark": "测试订单",
}
],
),
TaskSpec(
code="PAYMENTS",
task_cls=PaymentsTask,
endpoint="/PayLog/GetPayLogListPage",
data_path=("data",),
sample_records=[
{
"payId": 801,
"orderId": 701,
"payTime": END_TS,
"payAmount": "280.00",
"payType": "CARD",
"payStatus": "SUCCESS",
"remark": "测试支付",
}
],
),
TaskSpec(
code="REFUNDS",
task_cls=RefundsTask,
endpoint="/Order/GetRefundPayLogList",
data_path=("data",),
sample_records=[
{
"id": 901,
"site_id": 1,
"tenant_id": 2,
"pay_amount": "100.00",
"pay_status": "SUCCESS",
"pay_time": END_TS,
"create_time": END_TS,
"relate_type": "ORDER",
"relate_id": 701,
"payment_method": "CARD",
"refund_amount": "20.00",
"action_type": "PARTIAL",
"pay_terminal": "POS",
"operator_id": 11,
"channel_pay_no": "CH001",
"channel_fee": "1.00",
"is_delete": False,
"member_id": 401,
"member_card_id": 501,
}
],
),
TaskSpec(
code="COUPON_USAGE",
task_cls=CouponUsageTask,
endpoint="/Promotion/GetOfflineCouponConsumePageList",
data_path=("data",),
sample_records=[
{
"id": 1001,
"coupon_code": "CP001",
"coupon_channel": "MEITUAN",
"coupon_name": "双人券",
"sale_price": "50.00",
"coupon_money": "30.00",
"coupon_free_time": 60,
"use_status": "USED",
"create_time": BASE_TS,
"consume_time": END_TS,
"operator_id": 11,
"operator_name": "操作员",
"table_id": 301,
"site_order_id": 701,
"group_package_id": 601,
"coupon_remark": "备注",
"deal_id": "DEAL001",
"certificate_id": "CERT001",
"verify_id": "VERIFY001",
"is_delete": False,
}
],
),
TaskSpec(
code="INVENTORY_CHANGE",
task_cls=InventoryChangeTask,
endpoint="/GoodsStockManage/QueryGoodsOutboundReceipt",
data_path=("data", "queryDeliveryRecordsList"),
sample_records=[
{
"siteGoodsStockId": 1101,
"siteGoodsId": 101,
"stockType": "OUT",
"goodsName": "测试球杆",
"createTime": END_TS,
"startNum": 10,
"endNum": 8,
"changeNum": -2,
"unit": "",
"price": "120.00",
"operatorName": "仓管",
"remark": "测试出库",
"goodsCategoryId": 201,
"goodsSecondCategoryId": 202,
}
],
),
TaskSpec(
code="TOPUPS",
task_cls=TopupsTask,
endpoint="/Site/GetRechargeSettleList",
data_path=("data", "settleList"),
sample_records=[
{
"id": 1201,
"memberId": 401,
"memberName": "张三",
"memberPhone": "13800000000",
"tenantMemberCardId": 1301,
"memberCardTypeName": "金卡",
"payAmount": "500.00",
"consumeMoney": "100.00",
"settleStatus": "DONE",
"settleType": "AUTO",
"settleName": "日结",
"settleRelateId": 1501,
"payTime": BASE_TS,
"createTime": END_TS,
"operatorId": 11,
"operatorName": "收银员",
"paymentMethod": "CASH",
"refundAmount": "0",
"cashAmount": "500.00",
"cardAmount": "0",
"balanceAmount": "0",
"onlineAmount": "0",
"roundingAmount": "0",
"adjustAmount": "0",
"goodsMoney": "0",
"tableChargeMoney": "0",
"serviceMoney": "0",
"couponAmount": "0",
"orderRemark": "首次充值",
}
],
),
TaskSpec(
code="TABLE_DISCOUNT",
task_cls=TableDiscountTask,
endpoint="/Site/GetTaiFeeAdjustList",
data_path=("data", "taiFeeAdjustInfos"),
sample_records=[
{
"id": 1301,
"adjust_type": "DISCOUNT",
"applicant_id": 11,
"applicant_name": "店长",
"operator_id": 22,
"operator_name": "值班",
"ledger_amount": "50.00",
"ledger_count": 2,
"ledger_name": "调价",
"ledger_status": "APPROVED",
"order_settle_id": 7010,
"order_trade_no": 8001,
"site_table_id": 301,
"create_time": END_TS,
"is_delete": False,
"tableProfile": {
"id": 301,
"site_table_area_id": 40,
"site_table_area_name": "大厅",
},
}
],
),
TaskSpec(
code="ASSISTANT_ABOLISH",
task_cls=AssistantAbolishTask,
endpoint="/AssistantPerformance/GetAbolitionAssistant",
data_path=("data", "abolitionAssistants"),
sample_records=[
{
"id": 1401,
"tableId": 301,
"tableName": "1号桌",
"tableAreaId": 40,
"tableArea": "大厅",
"assistantOn": "AS001",
"assistantName": "小李",
"pdChargeMinutes": 30,
"assistantAbolishAmount": "15.00",
"createTime": END_TS,
"trashReason": "测试",
}
],
),
TaskSpec(
code="LEDGER",
task_cls=LedgerTask,
endpoint="/AssistantPerformance/GetOrderAssistantDetails",
data_path=("data", "orderAssistantDetails"),
sample_records=[
{
"id": 1501,
"assistantNo": "AS001",
"assistantName": "小李",
"nickname": "",
"levelName": "L1",
"tableName": "1号桌",
"ledger_unit_price": "30.00",
"ledger_count": 2,
"ledger_amount": "60.00",
"projected_income": "80.00",
"service_money": "5.00",
"member_discount_amount": "2.00",
"manual_discount_amount": "1.00",
"coupon_deduct_money": "3.00",
"order_trade_no": 8001,
"order_settle_id": 7010,
"operator_id": 22,
"operator_name": "值班",
"assistant_team_id": 10,
"assistant_level": "A",
"site_table_id": 301,
"order_assistant_id": 1601,
"site_assistant_id": 501,
"user_id": 5010,
"ledger_start_time": BASE_TS,
"ledger_end_time": END_TS,
"start_use_time": BASE_TS,
"last_use_time": END_TS,
"income_seconds": 3600,
"real_use_seconds": 3300,
"is_trash": False,
"trash_reason": "",
"is_confirm": True,
"ledger_status": "CLOSED",
"create_time": END_TS,
}
],
),
]

View File

@@ -0,0 +1,694 @@
# -*- coding: utf-8 -*-
"""
单元测试 — 文档对齐分析器 (doc_alignment_analyzer.py)
覆盖:
- scan_docs 文档来源识别
- extract_code_references 代码引用提取
- check_reference_validity 引用有效性检查
- find_undocumented_modules 缺失文档检测
- check_ddl_vs_dictionary DDL 与数据字典比对
- check_api_samples_vs_parsers API 样本与解析器比对
- render_alignment_report 报告渲染
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from scripts.audit import AlignmentIssue, DocMapping
from scripts.audit.doc_alignment_analyzer import (
_parse_ddl_tables,
_parse_dictionary_tables,
build_mappings,
check_api_samples_vs_parsers,
check_ddl_vs_dictionary,
check_reference_validity,
extract_code_references,
find_undocumented_modules,
render_alignment_report,
scan_docs,
)
# ---------------------------------------------------------------------------
# scan_docs
# ---------------------------------------------------------------------------
class TestScanDocs:
"""文档来源识别测试。"""
def test_finds_docs_dir_md(self, tmp_path: Path) -> None:
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "guide.md").write_text("# Guide", encoding="utf-8")
result = scan_docs(tmp_path)
assert "docs/guide.md" in result
def test_finds_root_readme(self, tmp_path: Path) -> None:
(tmp_path / "README.md").write_text("# Readme", encoding="utf-8")
result = scan_docs(tmp_path)
assert "README.md" in result
def test_finds_dev_notes(self, tmp_path: Path) -> None:
(tmp_path / "开发笔记").mkdir()
(tmp_path / "开发笔记" / "记录.md").write_text("笔记", encoding="utf-8")
result = scan_docs(tmp_path)
assert "开发笔记/记录.md" in result
def test_finds_module_readme(self, tmp_path: Path) -> None:
(tmp_path / "gui").mkdir()
(tmp_path / "gui" / "README.md").write_text("# GUI", encoding="utf-8")
result = scan_docs(tmp_path)
assert "gui/README.md" in result
def test_finds_steering_files(self, tmp_path: Path) -> None:
steering = tmp_path / ".kiro" / "steering"
steering.mkdir(parents=True)
(steering / "tech.md").write_text("# Tech", encoding="utf-8")
result = scan_docs(tmp_path)
assert ".kiro/steering/tech.md" in result
def test_finds_json_samples(self, tmp_path: Path) -> None:
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "member.json").write_text("[]", encoding="utf-8")
result = scan_docs(tmp_path)
assert "docs/test-json-doc/member.json" in result
def test_empty_repo_returns_empty(self, tmp_path: Path) -> None:
result = scan_docs(tmp_path)
assert result == []
def test_results_sorted(self, tmp_path: Path) -> None:
(tmp_path / "docs").mkdir()
(tmp_path / "docs" / "z.md").write_text("z", encoding="utf-8")
(tmp_path / "docs" / "a.md").write_text("a", encoding="utf-8")
(tmp_path / "README.md").write_text("r", encoding="utf-8")
result = scan_docs(tmp_path)
assert result == sorted(result)
# ---------------------------------------------------------------------------
# extract_code_references
# ---------------------------------------------------------------------------
class TestExtractCodeReferences:
"""代码引用提取测试。"""
def test_extracts_backtick_paths(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("使用 `tasks/base_task.py` 作为基类", encoding="utf-8")
refs = extract_code_references(doc)
assert "tasks/base_task.py" in refs
def test_extracts_class_names(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("继承 `BaseTask` 类", encoding="utf-8")
refs = extract_code_references(doc)
assert "BaseTask" in refs
def test_skips_single_char(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("变量 `x` 和 `y`", encoding="utf-8")
refs = extract_code_references(doc)
assert refs == []
def test_skips_pure_numbers(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("版本 `2.0.0` 和 ID `12345`", encoding="utf-8")
refs = extract_code_references(doc)
assert refs == []
def test_deduplicates(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("`foo.py` 和 `foo.py` 重复", encoding="utf-8")
refs = extract_code_references(doc)
assert refs.count("foo.py") == 1
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
refs = extract_code_references(tmp_path / "nonexistent.md")
assert refs == []
def test_normalizes_backslash(self, tmp_path: Path) -> None:
doc = tmp_path / "doc.md"
doc.write_text("路径 `tasks\\base_task.py`", encoding="utf-8")
refs = extract_code_references(doc)
assert "tasks/base_task.py" in refs
# ---------------------------------------------------------------------------
# check_reference_validity
# ---------------------------------------------------------------------------
class TestCheckReferenceValidity:
"""引用有效性检查测试。"""
def test_valid_file_path(self, tmp_path: Path) -> None:
(tmp_path / "tasks").mkdir()
(tmp_path / "tasks" / "base.py").write_text("", encoding="utf-8")
assert check_reference_validity("tasks/base.py", tmp_path) is True
def test_invalid_file_path(self, tmp_path: Path) -> None:
assert check_reference_validity("nonexistent/file.py", tmp_path) is False
def test_strips_legacy_prefix(self, tmp_path: Path) -> None:
"""兼容旧包名前缀etl_billiards/和当前根目录前缀FQ-ETL/"""
(tmp_path / "tasks").mkdir()
(tmp_path / "tasks" / "x.py").write_text("", encoding="utf-8")
assert check_reference_validity("etl_billiards/tasks/x.py", tmp_path) is True
assert check_reference_validity("FQ-ETL/tasks/x.py", tmp_path) is True
def test_directory_path(self, tmp_path: Path) -> None:
(tmp_path / "loaders").mkdir()
assert check_reference_validity("loaders", tmp_path) is True
def test_dotted_module_path(self, tmp_path: Path) -> None:
(tmp_path / "config").mkdir()
(tmp_path / "config" / "settings.py").write_text("", encoding="utf-8")
assert check_reference_validity("config.settings", tmp_path) is True
# ---------------------------------------------------------------------------
# find_undocumented_modules
# ---------------------------------------------------------------------------
class TestFindUndocumentedModules:
"""缺失文档检测测试。"""
def test_finds_undocumented(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert "tasks/ods_task.py" in result
def test_excludes_init(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "__init__.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert all("__init__" not in r for r in result)
def test_documented_module_excluded(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "ods_task.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, {"tasks/ods_task.py"})
assert "tasks/ods_task.py" not in result
def test_non_core_dirs_ignored(self, tmp_path: Path) -> None:
"""gui/ 不在核心代码目录列表中,不应被检测。"""
gui_dir = tmp_path / "gui"
gui_dir.mkdir()
(gui_dir / "main.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert all("gui/" not in r for r in result)
def test_results_sorted(self, tmp_path: Path) -> None:
tasks_dir = tmp_path / "tasks"
tasks_dir.mkdir()
(tasks_dir / "z_task.py").write_text("", encoding="utf-8")
(tasks_dir / "a_task.py").write_text("", encoding="utf-8")
result = find_undocumented_modules(tmp_path, set())
assert result == sorted(result)
# ---------------------------------------------------------------------------
# _parse_ddl_tables / _parse_dictionary_tables
# ---------------------------------------------------------------------------
class TestParseDdlTables:
"""DDL 解析测试。"""
def test_extracts_table_and_columns(self) -> None:
sql = """
CREATE TABLE IF NOT EXISTS dim_member (
member_id BIGINT,
nickname TEXT,
mobile TEXT,
PRIMARY KEY (member_id)
);
"""
result = _parse_ddl_tables(sql)
assert "dim_member" in result
assert "member_id" in result["dim_member"]
assert "nickname" in result["dim_member"]
assert "mobile" in result["dim_member"]
def test_handles_schema_prefix(self) -> None:
sql = "CREATE TABLE billiards_dwd.dim_site (\n site_id BIGINT\n);"
result = _parse_ddl_tables(sql)
assert "dim_site" in result
def test_excludes_sql_keywords(self) -> None:
sql = """
CREATE TABLE test_tbl (
id INTEGER,
PRIMARY KEY (id)
);
"""
result = _parse_ddl_tables(sql)
assert "primary" not in result.get("test_tbl", set())
class TestParseDictionaryTables:
"""数据字典解析测试。"""
def test_extracts_table_and_fields(self) -> None:
md = """## dim_member
| 字段 | 类型 | 说明 |
|------|------|------|
| member_id | BIGINT | 会员ID |
| nickname | TEXT | 昵称 |
"""
result = _parse_dictionary_tables(md)
assert "dim_member" in result
assert "member_id" in result["dim_member"]
assert "nickname" in result["dim_member"]
def test_skips_header_row(self) -> None:
md = """## dim_test
| 字段 | 类型 |
|------|------|
| col_a | INT |
"""
result = _parse_dictionary_tables(md)
assert "字段" not in result.get("dim_test", set())
def test_handles_backtick_table_name(self) -> None:
md = "## `dim_goods`\n\n| 字段 |\n| goods_id |"
result = _parse_dictionary_tables(md)
assert "dim_goods" in result
# ---------------------------------------------------------------------------
# check_ddl_vs_dictionary
# ---------------------------------------------------------------------------
class TestCheckDdlVsDictionary:
"""DDL 与数据字典比对测试。"""
def test_detects_missing_table_in_dictionary(self, tmp_path: Path) -> None:
# DDL 有表,字典没有
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_test.sql").write_text(
"CREATE TABLE dim_orphan (\n id BIGINT\n);",
encoding="utf-8",
)
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
"## dim_other\n\n| 字段 |\n| id |",
encoding="utf-8",
)
issues = check_ddl_vs_dictionary(tmp_path)
missing = [i for i in issues if i.issue_type == "missing"]
assert any("dim_orphan" in i.description for i in missing)
def test_detects_column_mismatch(self, tmp_path: Path) -> None:
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_test.sql").write_text(
"CREATE TABLE dim_x (\n id BIGINT,\n extra_col TEXT\n);",
encoding="utf-8",
)
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
"## dim_x\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
encoding="utf-8",
)
issues = check_ddl_vs_dictionary(tmp_path)
conflict = [i for i in issues if i.issue_type == "conflict"]
assert any("extra_col" in i.description for i in conflict)
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_test.sql").write_text(
"CREATE TABLE dim_ok (\n id BIGINT\n);",
encoding="utf-8",
)
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
(docs_dir / "dwd_main_tables_dictionary.md").write_text(
"## dim_ok\n\n| 字段 | 类型 |\n|---|---|\n| id | BIGINT |",
encoding="utf-8",
)
issues = check_ddl_vs_dictionary(tmp_path)
assert len(issues) == 0
# ---------------------------------------------------------------------------
# check_api_samples_vs_parsers
# ---------------------------------------------------------------------------
class TestCheckApiSamplesVsParsers:
"""API 样本与解析器比对测试。"""
def test_detects_json_field_not_in_ods(self, tmp_path: Path) -> None:
# JSON 样本有 extra_fieldODS 没有
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "test_entity.json").write_text(
json.dumps([{"id": 1, "name": "a", "extra_field": "x"}]),
encoding="utf-8",
)
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_ODS_doc.sql").write_text(
"CREATE TABLE billiards_ods.test_entity (\n"
" id BIGINT,\n name TEXT,\n"
" content_hash TEXT,\n payload JSONB\n);",
encoding="utf-8",
)
issues = check_api_samples_vs_parsers(tmp_path)
assert any("extra_field" in i.description for i in issues)
def test_no_issues_when_aligned(self, tmp_path: Path) -> None:
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "aligned_entity.json").write_text(
json.dumps([{"id": 1, "name": "a"}]),
encoding="utf-8",
)
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_ODS_doc.sql").write_text(
"CREATE TABLE billiards_ods.aligned_entity (\n"
" id BIGINT,\n name TEXT,\n"
" content_hash TEXT,\n payload JSONB\n);",
encoding="utf-8",
)
issues = check_api_samples_vs_parsers(tmp_path)
assert len(issues) == 0
def test_skips_when_no_ods_table(self, tmp_path: Path) -> None:
sample_dir = tmp_path / "docs" / "test-json-doc"
sample_dir.mkdir(parents=True)
(sample_dir / "unknown.json").write_text(
json.dumps([{"a": 1}]),
encoding="utf-8",
)
db_dir = tmp_path / "database"
db_dir.mkdir()
(db_dir / "schema_ODS_doc.sql").write_text("-- empty", encoding="utf-8")
issues = check_api_samples_vs_parsers(tmp_path)
assert len(issues) == 0
# ---------------------------------------------------------------------------
# render_alignment_report
# ---------------------------------------------------------------------------
class TestRenderAlignmentReport:
"""报告渲染测试。"""
def test_contains_all_sections(self) -> None:
report = render_alignment_report([], [], "/repo")
assert "## 映射关系" in report
assert "## 过期点" in report
assert "## 冲突点" in report
assert "## 缺失点" in report
assert "## 统计摘要" in report
def test_contains_header_metadata(self) -> None:
report = render_alignment_report([], [], "/repo")
assert "生成时间" in report
assert "`/repo`" in report
def test_contains_iso_timestamp(self) -> None:
report = render_alignment_report([], [], "/repo")
# ISO 格式时间戳包含 T 和 Z
import re
assert re.search(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", report)
def test_mapping_table_rendered(self) -> None:
mappings = [
DocMapping(
doc_path="docs/guide.md",
doc_topic="项目文档",
related_code=["tasks/base.py"],
status="aligned",
)
]
report = render_alignment_report(mappings, [], "/repo")
assert "`docs/guide.md`" in report
assert "`tasks/base.py`" in report
assert "aligned" in report
def test_stale_issues_rendered(self) -> None:
issues = [
AlignmentIssue(
doc_path="docs/old.md",
issue_type="stale",
description="引用了已删除的文件",
related_code="tasks/deleted.py",
)
]
report = render_alignment_report([], issues, "/repo")
assert "引用了已删除的文件" in report
assert "## 过期点" in report
def test_conflict_issues_rendered(self) -> None:
issues = [
AlignmentIssue(
doc_path="docs/dict.md",
issue_type="conflict",
description="字段不一致",
related_code="database/schema.sql",
)
]
report = render_alignment_report([], issues, "/repo")
assert "字段不一致" in report
def test_missing_issues_rendered(self) -> None:
issues = [
AlignmentIssue(
doc_path="docs/dict.md",
issue_type="missing",
description="缺少表定义",
related_code="database/schema.sql",
)
]
report = render_alignment_report([], issues, "/repo")
assert "缺少表定义" in report
def test_summary_counts(self) -> None:
issues = [
AlignmentIssue("a", "stale", "d1", "c1"),
AlignmentIssue("b", "stale", "d2", "c2"),
AlignmentIssue("c", "conflict", "d3", "c3"),
AlignmentIssue("d", "missing", "d4", "c4"),
]
mappings = [DocMapping("x", "t", [], "aligned")]
report = render_alignment_report(mappings, issues, "/repo")
assert "过期点数量2" in report
assert "冲突点数量1" in report
assert "缺失点数量1" in report
assert "文档总数1" in report
def test_empty_report(self) -> None:
report = render_alignment_report([], [], "/repo")
assert "未发现过期点" in report
assert "未发现冲突点" in report
assert "未发现缺失点" in report
assert "过期点数量0" in report
# ---------------------------------------------------------------------------
# 属性测试 — Property 11 / 12 / 16 (hypothesis)
# hypothesis 与 pytest 的 function-scoped fixture (tmp_path) 不兼容,
# 因此在测试内部使用 tempfile.mkdtemp 自行管理临时目录。
# ---------------------------------------------------------------------------
import shutil
import tempfile
from hypothesis import given, settings
from hypothesis import strategies as st
from scripts.audit.doc_alignment_analyzer import _CORE_CODE_DIRS
class TestPropertyStaleReferenceDetection:
"""Feature: repo-audit, Property 11: 过期引用检测
*对于任意* 文档中提取的代码引用,若该引用指向的文件路径在仓库中不存在,
则 check_reference_validity 应返回 False。
Validates: Requirements 3.3
"""
_safe_name = st.from_regex(r"[a-z][a-z0-9_]{1,12}", fullmatch=True)
@given(
existing_names=st.lists(
_safe_name, min_size=1, max_size=5, unique=True,
),
missing_names=st.lists(
_safe_name, min_size=1, max_size=5, unique=True,
),
)
@settings(max_examples=100)
def test_nonexistent_path_returns_false(
self,
existing_names: list[str],
missing_names: list[str],
) -> None:
"""不存在的文件路径引用应返回 False。"""
tmp = Path(tempfile.mkdtemp())
try:
for name in existing_names:
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
existing_set = set(existing_names)
# 只检查确实不存在的名称
truly_missing = [n for n in missing_names if n not in existing_set]
for name in truly_missing:
ref = f"nonexistent_dir/{name}.py"
result = check_reference_validity(ref, tmp)
assert result is False, (
f"引用 '{ref}' 指向不存在的文件,但返回了 True"
)
finally:
shutil.rmtree(tmp, ignore_errors=True)
@given(
existing_names=st.lists(
_safe_name, min_size=1, max_size=5, unique=True,
),
)
@settings(max_examples=100)
def test_existing_path_returns_true(
self,
existing_names: list[str],
) -> None:
"""存在的文件路径引用应返回 True。"""
tmp = Path(tempfile.mkdtemp())
try:
for name in existing_names:
(tmp / f"{name}.py").write_text("# ok", encoding="utf-8")
for name in existing_names:
ref = f"{name}.py"
result = check_reference_validity(ref, tmp)
assert result is True, (
f"引用 '{ref}' 指向存在的文件,但返回了 False"
)
finally:
shutil.rmtree(tmp, ignore_errors=True)
class TestPropertyMissingDocDetection:
"""Feature: repo-audit, Property 12: 缺失文档检测
*对于任意* 核心代码模块集合和已文档化模块集合,
find_undocumented_modules 返回的缺失列表应恰好等于核心模块集合与已文档化集合的差集。
Validates: Requirements 3.5
"""
_core_dir = st.sampled_from(list(_CORE_CODE_DIRS))
_module_name = st.from_regex(r"[a-z][a-z0-9_]{1,10}", fullmatch=True)
@given(
core_dir=_core_dir,
module_names=st.lists(
_module_name, min_size=2, max_size=6, unique=True,
),
doc_fraction=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=100)
def test_undocumented_equals_difference(
self,
core_dir: str,
module_names: list[str],
doc_fraction: float,
) -> None:
"""返回的缺失列表应恰好等于核心模块与已文档化集合的差集。"""
tmp = Path(tempfile.mkdtemp())
try:
code_dir = tmp / core_dir
code_dir.mkdir(parents=True, exist_ok=True)
all_modules: set[str] = set()
for name in module_names:
(code_dir / f"{name}.py").write_text("# module", encoding="utf-8")
all_modules.add(f"{core_dir}/{name}.py")
split_idx = int(len(module_names) * doc_fraction)
documented = {
f"{core_dir}/{n}.py" for n in module_names[:split_idx]
}
result = find_undocumented_modules(tmp, documented)
expected = sorted(all_modules - documented)
assert result == expected, (
f"期望缺失列表 {expected},实际得到 {result}"
)
finally:
shutil.rmtree(tmp, ignore_errors=True)
class TestPropertyAlignmentReportSections:
"""Feature: repo-audit, Property 16: 文档对齐报告分区完整性
*对于任意* render_alignment_report 的输出Markdown 文本应包含
"映射关系""过期点""冲突点""缺失点"四个分区标题。
Validates: Requirements 3.8
"""
_issue_type = st.sampled_from(["stale", "conflict", "missing"])
_text = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P"),
blacklist_characters="\x00",
),
min_size=1,
max_size=30,
)
_mapping_st = st.builds(
DocMapping,
doc_path=_text,
doc_topic=_text,
related_code=st.lists(_text, max_size=3),
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
)
_issue_st = st.builds(
AlignmentIssue,
doc_path=_text,
issue_type=_issue_type,
description=_text,
related_code=_text,
)
@given(
mappings=st.lists(_mapping_st, max_size=5),
issues=st.lists(_issue_st, max_size=8),
)
@settings(max_examples=100)
def test_report_contains_four_sections(
self,
mappings: list[DocMapping],
issues: list[AlignmentIssue],
) -> None:
"""报告应包含四个分区标题。"""
report = render_alignment_report(mappings, issues, "/test/repo")
required_sections = ["## 映射关系", "## 过期点", "## 冲突点", "## 缺失点"]
for section in required_sections:
assert section in report, (
f"报告中缺少分区标题 '{section}'"
)

View File

@@ -0,0 +1,667 @@
# -*- coding: utf-8 -*-
"""
单元测试 — 流程树分析器 (flow_analyzer.py)
覆盖:
- parse_imports: import 语句解析、标准库/第三方排除、语法错误容错
- build_flow_tree: 递归构建、循环导入处理
- find_orphan_modules: 孤立模块检测
- render_flow_report: Markdown 渲染、Mermaid 图、统计摘要
- discover_entry_points: 入口点识别
- classify_task_type / classify_loader_type: 类型区分
"""
from __future__ import annotations
from pathlib import Path
import pytest
from scripts.audit import FileEntry, FlowNode
from scripts.audit.flow_analyzer import (
build_flow_tree,
classify_loader_type,
classify_task_type,
discover_entry_points,
find_orphan_modules,
parse_imports,
render_flow_report,
_path_to_module_name,
_parse_bat_python_target,
)
# ---------------------------------------------------------------------------
# parse_imports 单元测试
# ---------------------------------------------------------------------------
class TestParseImports:
"""import 语句解析测试。"""
def test_absolute_import(self, tmp_path: Path) -> None:
"""绝对导入项目内部模块应被识别。"""
f = tmp_path / "test.py"
f.write_text("import cli.main\nimport config.settings\n", encoding="utf-8")
result = parse_imports(f)
assert "cli.main" in result
assert "config.settings" in result
def test_from_import(self, tmp_path: Path) -> None:
"""from ... import 语句应被识别。"""
f = tmp_path / "test.py"
f.write_text("from tasks.base_task import BaseTask\n", encoding="utf-8")
result = parse_imports(f)
assert "tasks.base_task" in result
def test_stdlib_excluded(self, tmp_path: Path) -> None:
"""标准库模块应被排除。"""
f = tmp_path / "test.py"
f.write_text("import os\nimport sys\nimport json\nfrom pathlib import Path\n", encoding="utf-8")
result = parse_imports(f)
assert result == []
def test_third_party_excluded(self, tmp_path: Path) -> None:
"""第三方包应被排除。"""
f = tmp_path / "test.py"
f.write_text("import requests\nfrom psycopg2 import sql\nimport flask\n", encoding="utf-8")
result = parse_imports(f)
assert result == []
def test_mixed_imports(self, tmp_path: Path) -> None:
"""混合导入应只保留项目内部模块。"""
f = tmp_path / "test.py"
f.write_text(
"import os\nimport cli.main\nimport requests\nfrom loaders.base_loader import BaseLoader\n",
encoding="utf-8",
)
result = parse_imports(f)
assert "cli.main" in result
assert "loaders.base_loader" in result
assert "os" not in result
assert "requests" not in result
def test_syntax_error_returns_empty(self, tmp_path: Path) -> None:
"""语法错误的文件应返回空列表。"""
f = tmp_path / "bad.py"
f.write_text("def broken(\n", encoding="utf-8")
result = parse_imports(f)
assert result == []
def test_nonexistent_file_returns_empty(self, tmp_path: Path) -> None:
"""不存在的文件应返回空列表。"""
result = parse_imports(tmp_path / "nonexistent.py")
assert result == []
def test_deduplication(self, tmp_path: Path) -> None:
"""重复导入应去重。"""
f = tmp_path / "test.py"
f.write_text("import cli.main\nimport cli.main\nfrom cli.main import main\n", encoding="utf-8")
result = parse_imports(f)
assert result.count("cli.main") == 1
def test_empty_file(self, tmp_path: Path) -> None:
"""空文件应返回空列表。"""
f = tmp_path / "empty.py"
f.write_text("", encoding="utf-8")
result = parse_imports(f)
assert result == []
# ---------------------------------------------------------------------------
# build_flow_tree 单元测试
# ---------------------------------------------------------------------------
class TestBuildFlowTree:
"""流程树构建测试。"""
def test_single_file_no_imports(self, tmp_path: Path) -> None:
"""无导入的单文件应生成叶节点。"""
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
tree = build_flow_tree(tmp_path, "cli/main.py")
assert tree.name == "cli.main"
assert tree.source_file == "cli/main.py"
assert tree.children == []
def test_simple_import_chain(self, tmp_path: Path) -> None:
"""简单导入链应正确构建子节点。"""
# cli/main.py → config/settings.py
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text(
"from config.settings import AppConfig\n", encoding="utf-8"
)
config_dir = tmp_path / "config"
config_dir.mkdir()
(config_dir / "__init__.py").write_text("", encoding="utf-8")
(config_dir / "settings.py").write_text("class AppConfig: pass\n", encoding="utf-8")
tree = build_flow_tree(tmp_path, "cli/main.py")
assert tree.name == "cli.main"
assert len(tree.children) == 1
assert tree.children[0].name == "config.settings"
def test_circular_import_no_infinite_loop(self, tmp_path: Path) -> None:
"""循环导入不应导致无限递归。"""
pkg = tmp_path / "utils"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
# a → b → a循环
(pkg / "a.py").write_text("from utils.b import func_b\n", encoding="utf-8")
(pkg / "b.py").write_text("from utils.a import func_a\n", encoding="utf-8")
# 不应抛出 RecursionError
tree = build_flow_tree(tmp_path, "utils/a.py")
assert tree.name == "utils.a"
def test_entry_node_type(self, tmp_path: Path) -> None:
"""CLI 入口文件应标记为 entry 类型。"""
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
tree = build_flow_tree(tmp_path, "cli/main.py")
assert tree.node_type == "entry"
# ---------------------------------------------------------------------------
# find_orphan_modules 单元测试
# ---------------------------------------------------------------------------
class TestFindOrphanModules:
"""孤立模块检测测试。"""
def test_all_reachable(self, tmp_path: Path) -> None:
"""所有模块都可达时应返回空列表。"""
entries = [
FileEntry("cli/main.py", False, 100, ".py", False),
FileEntry("config/settings.py", False, 200, ".py", False),
]
reachable = {"cli/main.py", "config/settings.py"}
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_orphan_detected(self, tmp_path: Path) -> None:
"""不可达的模块应被标记为孤立。"""
entries = [
FileEntry("cli/main.py", False, 100, ".py", False),
FileEntry("utils/orphan.py", False, 50, ".py", False),
]
reachable = {"cli/main.py"}
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert "utils/orphan.py" in orphans
def test_init_files_excluded(self, tmp_path: Path) -> None:
"""__init__.py 不应被视为孤立模块。"""
entries = [
FileEntry("cli/__init__.py", False, 0, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert "cli/__init__.py" not in orphans
def test_test_files_excluded(self, tmp_path: Path) -> None:
"""测试文件不应被视为孤立模块。"""
entries = [
FileEntry("tests/unit/test_something.py", False, 100, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_audit_scripts_excluded(self, tmp_path: Path) -> None:
"""审计脚本自身不应被视为孤立模块。"""
entries = [
FileEntry("scripts/audit/scanner.py", False, 100, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_directories_excluded(self, tmp_path: Path) -> None:
"""目录条目不应出现在孤立列表中。"""
entries = [
FileEntry("cli", True, 0, "", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == []
def test_sorted_output(self, tmp_path: Path) -> None:
"""孤立模块列表应按路径排序。"""
entries = [
FileEntry("utils/z.py", False, 50, ".py", False),
FileEntry("utils/a.py", False, 50, ".py", False),
FileEntry("cli/orphan.py", False, 50, ".py", False),
]
reachable: set[str] = set()
orphans = find_orphan_modules(tmp_path, entries, reachable)
assert orphans == sorted(orphans)
# ---------------------------------------------------------------------------
# render_flow_report 单元测试
# ---------------------------------------------------------------------------
class TestRenderFlowReport:
"""流程树报告渲染测试。"""
def test_header_contains_timestamp_and_path(self) -> None:
"""报告头部应包含时间戳和仓库路径。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, [], "/repo")
assert "生成时间:" in report
assert "`/repo`" in report
def test_contains_mermaid_block(self) -> None:
"""报告应包含 Mermaid 代码块。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, [], "/repo")
assert "```mermaid" in report
assert "graph TD" in report
def test_contains_indented_text(self) -> None:
"""报告应包含缩进文本形式的流程树。"""
child = FlowNode("config.settings", "config/settings.py", "module", [])
root = FlowNode("cli.main", "cli/main.py", "entry", [child])
report = render_flow_report([root], [], "/repo")
assert "`cli.main`" in report
assert "`config.settings`" in report
def test_orphan_section(self) -> None:
"""报告应包含孤立模块列表。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
orphans = ["utils/orphan.py", "models/unused.py"]
report = render_flow_report(trees, orphans, "/repo")
assert "孤立模块" in report
assert "`utils/orphan.py`" in report
assert "`models/unused.py`" in report
def test_no_orphans_message(self) -> None:
"""无孤立模块时应显示提示信息。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, [], "/repo")
assert "未发现孤立模块" in report
def test_statistics_summary(self) -> None:
"""报告应包含统计摘要。"""
trees = [FlowNode("cli.main", "cli/main.py", "entry", [])]
report = render_flow_report(trees, ["a.py"], "/repo")
assert "统计摘要" in report
assert "入口点" in report
assert "任务" in report
assert "加载器" in report
assert "孤立模块" in report
def test_task_type_annotation(self) -> None:
"""任务模块应带有类型标注。"""
task_node = FlowNode("tasks.ods_member", "tasks/ods_member.py", "module", [])
root = FlowNode("cli.main", "cli/main.py", "entry", [task_node])
report = render_flow_report([root], [], "/repo")
assert "ODS" in report
def test_loader_type_annotation(self) -> None:
"""加载器模块应带有类型标注。"""
loader_node = FlowNode(
"loaders.dimensions.member", "loaders/dimensions/member.py", "module", []
)
root = FlowNode("cli.main", "cli/main.py", "entry", [loader_node])
report = render_flow_report([root], [], "/repo")
assert "维度" in report or "SCD2" in report
# ---------------------------------------------------------------------------
# discover_entry_points 单元测试
# ---------------------------------------------------------------------------
class TestDiscoverEntryPoints:
"""入口点识别测试。"""
def test_cli_entry(self, tmp_path: Path) -> None:
"""应识别 CLI 入口。"""
cli_dir = tmp_path / "cli"
cli_dir.mkdir()
(cli_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
entries = discover_entry_points(tmp_path)
cli_entries = [e for e in entries if e["type"] == "CLI"]
assert len(cli_entries) == 1
assert cli_entries[0]["file"] == "cli/main.py"
def test_gui_entry(self, tmp_path: Path) -> None:
"""应识别 GUI 入口。"""
gui_dir = tmp_path / "gui"
gui_dir.mkdir()
(gui_dir / "main.py").write_text("def main(): pass\n", encoding="utf-8")
entries = discover_entry_points(tmp_path)
gui_entries = [e for e in entries if e["type"] == "GUI"]
assert len(gui_entries) == 1
def test_bat_entry(self, tmp_path: Path) -> None:
"""应识别批处理文件入口。"""
(tmp_path / "run_etl.bat").write_text(
"@echo off\npython -m cli.main %*\n", encoding="utf-8"
)
entries = discover_entry_points(tmp_path)
bat_entries = [e for e in entries if e["type"] == "批处理"]
assert len(bat_entries) == 1
assert "cli.main" in bat_entries[0]["description"]
def test_script_entry(self, tmp_path: Path) -> None:
"""应识别运维脚本入口。"""
scripts_dir = tmp_path / "scripts"
scripts_dir.mkdir()
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
(scripts_dir / "rebuild_db.py").write_text(
'if __name__ == "__main__": pass\n', encoding="utf-8"
)
entries = discover_entry_points(tmp_path)
script_entries = [e for e in entries if e["type"] == "运维脚本"]
assert len(script_entries) == 1
assert script_entries[0]["file"] == "scripts/rebuild_db.py"
def test_init_py_excluded_from_scripts(self, tmp_path: Path) -> None:
"""scripts/__init__.py 不应被识别为入口。"""
scripts_dir = tmp_path / "scripts"
scripts_dir.mkdir()
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
entries = discover_entry_points(tmp_path)
script_entries = [e for e in entries if e["type"] == "运维脚本"]
assert all(e["file"] != "scripts/__init__.py" for e in script_entries)
# ---------------------------------------------------------------------------
# classify_task_type / classify_loader_type 单元测试
# ---------------------------------------------------------------------------
class TestClassifyTypes:
"""任务类型和加载器类型区分测试。"""
def test_ods_task(self) -> None:
assert "ODS" in classify_task_type("tasks/ods_member.py")
def test_dwd_task(self) -> None:
assert "DWD" in classify_task_type("tasks/dwd_load.py")
def test_dws_task(self) -> None:
assert "DWS" in classify_task_type("tasks/dws/assistant_daily.py")
def test_verification_task(self) -> None:
assert "校验" in classify_task_type("tasks/verification/balance_check.py")
def test_schema_init_task(self) -> None:
assert "Schema" in classify_task_type("tasks/init_ods_schema.py")
def test_dimension_loader(self) -> None:
result = classify_loader_type("loaders/dimensions/member.py")
assert "维度" in result or "SCD2" in result
def test_fact_loader(self) -> None:
assert "事实" in classify_loader_type("loaders/facts/order.py")
def test_ods_loader(self) -> None:
assert "ODS" in classify_loader_type("loaders/ods/generic.py")
# ---------------------------------------------------------------------------
# _path_to_module_name 单元测试
# ---------------------------------------------------------------------------
class TestPathToModuleName:
"""路径到模块名转换测试。"""
def test_simple_file(self) -> None:
assert _path_to_module_name("cli/main.py") == "cli.main"
def test_init_file(self) -> None:
assert _path_to_module_name("cli/__init__.py") == "cli"
def test_nested_path(self) -> None:
assert _path_to_module_name("tasks/dws/assistant.py") == "tasks.dws.assistant"
# ---------------------------------------------------------------------------
# _parse_bat_python_target 单元测试
# ---------------------------------------------------------------------------
class TestParseBatPythonTarget:
"""批处理文件 Python 命令解析测试。"""
def test_module_invocation(self, tmp_path: Path) -> None:
bat = tmp_path / "run.bat"
bat.write_text("@echo off\npython -m cli.main %*\n", encoding="utf-8")
assert _parse_bat_python_target(bat) == "cli.main"
def test_no_python_command(self, tmp_path: Path) -> None:
bat = tmp_path / "run.bat"
bat.write_text("@echo off\necho hello\n", encoding="utf-8")
assert _parse_bat_python_target(bat) is None
def test_nonexistent_file(self, tmp_path: Path) -> None:
assert _parse_bat_python_target(tmp_path / "missing.bat") is None
# ---------------------------------------------------------------------------
# 属性测试 — Property 9 & 10hypothesis
# ---------------------------------------------------------------------------
import os
import string
from hypothesis import given, settings, assume
from hypothesis import strategies as st
# ---------------------------------------------------------------------------
# 辅助:项目包名列表(与 flow_analyzer 中 _PROJECT_PACKAGES 一致)
# ---------------------------------------------------------------------------
_PROJECT_PACKAGES_LIST = [
"cli", "config", "api", "database", "tasks", "loaders",
"scd", "orchestration", "quality", "models", "utils",
"gui", "scripts",
]
# ---------------------------------------------------------------------------
# Property 9: 流程树节点 source_file 有效性
# Feature: repo-audit, Property 9: 流程树节点 source_file 有效性
# Validates: Requirements 2.7
#
# 策略:在临时目录中随机生成 1~5 个项目内部模块文件,
# 其中一个作为入口,其他文件通过 import 语句相互引用。
# 构建流程树后,遍历所有节点验证 source_file 非空且文件存在。
# ---------------------------------------------------------------------------
def _collect_all_nodes(node: FlowNode) -> list[FlowNode]:
"""递归收集流程树中所有节点。"""
result = [node]
for child in node.children:
result.extend(_collect_all_nodes(child))
return result
# 生成合法的 Python 标识符作为模块文件名
_module_name_st = st.from_regex(r"[a-z][a-z0-9_]{0,8}", fullmatch=True).filter(
lambda s: s not in {"__init__", ""}
)
@st.composite
def project_layout(draw):
"""生成一个随机的项目布局:包名、模块文件名列表、以及模块间的 import 关系。
返回 (package, module_names, imports_map)
- package: 项目包名(如 "cli"
- module_names: 模块文件名列表(不含 .py 后缀),第一个为入口
- imports_map: dict[str, list[str]],每个模块导入的其他模块列表
"""
package = draw(st.sampled_from(_PROJECT_PACKAGES_LIST))
n_modules = draw(st.integers(min_value=1, max_value=5))
module_names = draw(
st.lists(
_module_name_st,
min_size=n_modules,
max_size=n_modules,
unique=True,
)
)
# 确保至少有一个模块
assume(len(module_names) >= 1)
# 为每个模块随机选择要导入的其他模块(子集)
imports_map: dict[str, list[str]] = {}
for i, mod in enumerate(module_names):
# 只能导入列表中的其他模块
others = [m for m in module_names if m != mod]
if others:
imported = draw(
st.lists(st.sampled_from(others), max_size=len(others), unique=True)
)
else:
imported = []
imports_map[mod] = imported
return package, module_names, imports_map
@given(layout=project_layout())
@settings(max_examples=100)
def test_property9_flow_tree_source_file_validity(layout, tmp_path_factory):
"""Property 9: 流程树中每个节点的 source_file 非空且对应文件在仓库中实际存在。
**Feature: repo-audit, Property 9: 流程树节点 source_file 有效性**
**Validates: Requirements 2.7**
"""
package, module_names, imports_map = layout
tmp_path = tmp_path_factory.mktemp("prop9")
# 创建包目录和 __init__.py
pkg_dir = tmp_path / package
pkg_dir.mkdir(parents=True, exist_ok=True)
(pkg_dir / "__init__.py").write_text("", encoding="utf-8")
# 创建每个模块文件,写入 import 语句
for mod in module_names:
lines = []
for imp in imports_map[mod]:
lines.append(f"from {package}.{imp} import *")
lines.append("") # 确保文件非空
(pkg_dir / f"{mod}.py").write_text("\n".join(lines), encoding="utf-8")
# 以第一个模块为入口构建流程树
entry_rel = f"{package}/{module_names[0]}.py"
tree = build_flow_tree(tmp_path, entry_rel)
# 遍历所有节点,验证 source_file 有效性
all_nodes = _collect_all_nodes(tree)
for node in all_nodes:
# source_file 应为非空字符串
assert isinstance(node.source_file, str), (
f"source_file 应为字符串,实际为 {type(node.source_file)}"
)
assert node.source_file != "", "source_file 不应为空字符串"
# 对应文件应在仓库中实际存在
full_path = tmp_path / node.source_file
assert full_path.exists(), (
f"source_file '{node.source_file}' 对应的文件不存在: {full_path}"
)
# ---------------------------------------------------------------------------
# Property 10: 孤立模块检测正确性
# Feature: repo-audit, Property 10: 孤立模块检测正确性
# Validates: Requirements 2.8
#
# 策略:生成随机的 FileEntry 列表(模拟项目中的 .py 文件),
# 生成随机的 reachable 集合(是 FileEntry 路径的子集),
# 调用 find_orphan_modules 验证:
# 1. 返回的每个孤立模块都不在 reachable 集合中
# 2. reachable 集合中的每个模块都不在孤立列表中
#
# 注意find_orphan_modules 会排除 __init__.py、tests/、scripts/audit/ 下的文件,
# 以及不属于 _PROJECT_PACKAGES 的子目录文件。生成器需要考虑这些排除规则。
# ---------------------------------------------------------------------------
# 生成属于项目包的 .py 文件路径(排除被 find_orphan_modules 忽略的路径)
_eligible_packages = [
p for p in _PROJECT_PACKAGES_LIST
if p not in ("scripts",) # scripts 下只有 scripts/audit/ 会被排除,但为简化直接排除
]
@st.composite
def orphan_test_data(draw):
"""生成 (file_entries, reachable_set) 用于测试 find_orphan_modules。
只生成"合格"的文件条目(属于项目包、非 __init__.py、非 tests/、非 scripts/audit/
这样可以精确验证 reachable 与 orphan 的互斥关系。
"""
# 生成 1~10 个合格的 .py 文件路径
n_files = draw(st.integers(min_value=1, max_value=10))
paths: list[str] = []
for _ in range(n_files):
pkg = draw(st.sampled_from(_eligible_packages))
fname = draw(_module_name_st)
path = f"{pkg}/{fname}.py"
paths.append(path)
# 去重
paths = list(dict.fromkeys(paths))
assume(len(paths) >= 1)
# 构建 FileEntry 列表
entries = [
FileEntry(rel_path=p, is_dir=False, size_bytes=100, extension=".py", is_empty_dir=False)
for p in paths
]
# 随机选择一个子集作为 reachable
reachable = set(draw(
st.lists(st.sampled_from(paths), max_size=len(paths), unique=True)
))
return entries, reachable
@given(data=orphan_test_data())
@settings(max_examples=100)
def test_property10_orphan_module_detection(data, tmp_path_factory):
"""Property 10: 孤立模块与可达模块互斥——孤立列表中的模块不在 reachable 中,
reachable 中的模块不在孤立列表中。
**Feature: repo-audit, Property 10: 孤立模块检测正确性**
**Validates: Requirements 2.8**
"""
entries, reachable = data
tmp_path = tmp_path_factory.mktemp("prop10")
orphans = find_orphan_modules(tmp_path, entries, reachable)
orphan_set = set(orphans)
# 验证 1: 孤立模块不应出现在 reachable 集合中
overlap = orphan_set & reachable
assert overlap == set(), (
f"孤立模块与可达集合存在交集: {overlap}"
)
# 验证 2: reachable 中的模块不应出现在孤立列表中
for r in reachable:
assert r not in orphan_set, (
f"可达模块 '{r}' 不应出现在孤立列表中"
)
# 验证 3: 孤立列表应已排序
assert orphans == sorted(orphans), "孤立模块列表应按路径排序"

View File

@@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
"""
属性测试 — classify 完整性
Feature: repo-audit, Property 1: classify 完整性
Validates: Requirements 1.2, 1.3
对于任意 FileEntryclassify 函数返回的 InventoryItem 的 category 字段
应属于 Category 枚举disposition 字段应属于 Disposition 枚举,
且 description 字段为非空字符串。
"""
from __future__ import annotations
import string
from hypothesis import given, settings
from hypothesis import strategies as st
from scripts.audit import Category, Disposition, FileEntry, InventoryItem
from scripts.audit.inventory_analyzer import classify
# ---------------------------------------------------------------------------
# 生成器策略
# ---------------------------------------------------------------------------
# 常见文件扩展名(含空扩展名表示无扩展名的情况)
_EXTENSIONS = st.sampled_from([
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
".toml", ".yaml", ".yml", ".html", ".css", ".js",
])
# 路径片段:字母数字加常见特殊字符
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
_path_segment = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=20,
)
# 生成 1~4 层目录深度的相对路径
_rel_path = st.lists(
_path_segment,
min_size=1,
max_size=4,
).map(lambda parts: "/".join(parts))
def _file_entry_strategy() -> st.SearchStrategy[FileEntry]:
"""生成随机 FileEntry 的 hypothesis 策略。
覆盖各种扩展名、目录层级、大小和布尔标志组合。
"""
return st.builds(
FileEntry,
rel_path=_rel_path,
is_dir=st.booleans(),
size_bytes=st.integers(min_value=0, max_value=10_000_000),
extension=_EXTENSIONS,
is_empty_dir=st.booleans(),
)
# ---------------------------------------------------------------------------
# Property 1: classify 完整性
# ---------------------------------------------------------------------------
@given(entry=_file_entry_strategy())
@settings(max_examples=100)
def test_classify_completeness(entry: FileEntry) -> None:
"""Property 1: classify 完整性
Feature: repo-audit, Property 1: classify 完整性
Validates: Requirements 1.2, 1.3
对于任意 FileEntryclassify 返回的 InventoryItem 应满足:
- category 属于 Category 枚举
- disposition 属于 Disposition 枚举
- description 为非空字符串
"""
result = classify(entry)
# 返回类型正确
assert isinstance(result, InventoryItem), (
f"classify 应返回 InventoryItem实际返回 {type(result)}"
)
# category 属于 Category 枚举
assert isinstance(result.category, Category), (
f"category 应为 Category 枚举成员,实际为 {result.category!r}"
)
# disposition 属于 Disposition 枚举
assert isinstance(result.disposition, Disposition), (
f"disposition 应为 Disposition 枚举成员,实际为 {result.disposition!r}"
)
# description 为非空字符串
assert isinstance(result.description, str) and len(result.description) > 0, (
f"description 应为非空字符串,实际为 {result.description!r}"
)
# ---------------------------------------------------------------------------
# 辅助:高优先级目录前缀(用于在低优先级属性测试中排除)
# ---------------------------------------------------------------------------
_HIGH_PRIORITY_PREFIXES = ("tmp/", "logs/", "export/")
# 安全的顶层目录名(不会触发高优先级规则)
_SAFE_TOP_DIRS = st.sampled_from([
"src", "lib", "data", "misc", "vendor", "tools", "archive",
"assets", "resources", "contrib", "extras",
])
# 非 .lnk/.rar 的扩展名
_SAFE_EXTENSIONS = st.sampled_from([
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
".bat", ".sh", ".ps1", ".log", ".ini", ".cfg",
".toml", ".yaml", ".yml", ".html", ".css", ".js",
])
def _safe_rel_path() -> st.SearchStrategy[str]:
"""生成不以高优先级目录开头的相对路径。"""
return st.builds(
lambda top, rest: f"{top}/{rest}" if rest else top,
top=_SAFE_TOP_DIRS,
rest=st.lists(_path_segment, min_size=0, max_size=3).map(
lambda parts: "/".join(parts) if parts else ""
),
)
# ---------------------------------------------------------------------------
# Property 3: 空目录标记为候选删除
# ---------------------------------------------------------------------------
@given(data=st.data())
@settings(max_examples=100)
def test_empty_dir_candidate_delete(data: st.DataObject) -> None:
"""Property 3: 空目录标记为候选删除
Feature: repo-audit, Property 3: 空目录标记为候选删除
Validates: Requirements 1.5
对于任意 is_empty_dir=True 的 FileEntry排除 tmp/、logs/、reports/、
export/ 开头和 .lnk/.rar 扩展名classify 返回的 disposition
应为 Disposition.CANDIDATE_DELETE。
"""
rel_path = data.draw(_safe_rel_path())
ext = data.draw(_SAFE_EXTENSIONS)
entry = FileEntry(
rel_path=rel_path,
is_dir=True,
size_bytes=0,
extension=ext,
is_empty_dir=True,
)
result = classify(entry)
assert result.disposition == Disposition.CANDIDATE_DELETE, (
f"空目录 '{entry.rel_path}' 应标记为候选删除,"
f"实际为 {result.disposition.value}"
)
# ---------------------------------------------------------------------------
# Property 4: .lnk/.rar 文件标记为候选删除
# ---------------------------------------------------------------------------
@given(data=st.data())
@settings(max_examples=100)
def test_lnk_rar_candidate_delete(data: st.DataObject) -> None:
"""Property 4: .lnk/.rar 文件标记为候选删除
Feature: repo-audit, Property 4: .lnk/.rar 文件标记为候选删除
Validates: Requirements 1.6
对于任意扩展名为 .lnk 或 .rar 的 FileEntry排除 tmp/、logs/、
reports/、export/ 开头,且 is_empty_dir=Falseclassify 返回的
disposition 应为 Disposition.CANDIDATE_DELETE。
"""
rel_path = data.draw(_safe_rel_path())
ext = data.draw(st.sampled_from([".lnk", ".rar"]))
entry = FileEntry(
rel_path=rel_path,
is_dir=False,
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
extension=ext,
is_empty_dir=False,
)
result = classify(entry)
assert result.disposition == Disposition.CANDIDATE_DELETE, (
f"文件 '{entry.rel_path}' (ext={ext}) 应标记为候选删除,"
f"实际为 {result.disposition.value}"
)
# ---------------------------------------------------------------------------
# Property 5: tmp/ 下文件处置范围
# ---------------------------------------------------------------------------
_TMP_EXTENSIONS = st.sampled_from([
"", ".py", ".sql", ".md", ".txt", ".json", ".csv", ".xlsx",
".bat", ".sh", ".ps1", ".lnk", ".rar", ".log", ".ini", ".cfg",
".toml", ".yaml", ".yml", ".html", ".css", ".js", ".tmp", ".bak",
])
def _tmp_rel_path() -> st.SearchStrategy[str]:
"""生成以 tmp/ 开头的相对路径。"""
return st.builds(
lambda rest: f"tmp/{rest}",
rest=st.lists(_path_segment, min_size=1, max_size=3).map(
lambda parts: "/".join(parts)
),
)
@given(data=st.data())
@settings(max_examples=100)
def test_tmp_disposition_range(data: st.DataObject) -> None:
"""Property 5: tmp/ 下文件处置范围
Feature: repo-audit, Property 5: tmp/ 下文件处置范围
Validates: Requirements 1.7
对于任意 rel_path 以 tmp/ 开头的 FileEntryclassify 返回的
disposition 应为 CANDIDATE_DELETE 或 CANDIDATE_ARCHIVE 之一。
"""
rel_path = data.draw(_tmp_rel_path())
ext = data.draw(_TMP_EXTENSIONS)
entry = FileEntry(
rel_path=rel_path,
is_dir=data.draw(st.booleans()),
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
extension=ext,
is_empty_dir=data.draw(st.booleans()),
)
result = classify(entry)
allowed = {Disposition.CANDIDATE_DELETE, Disposition.CANDIDATE_ARCHIVE}
assert result.disposition in allowed, (
f"tmp/ 下文件 '{entry.rel_path}' 的处置应为候选删除或候选归档,"
f"实际为 {result.disposition.value}"
)
# ---------------------------------------------------------------------------
# Property 6: 运行时产出目录标记为候选归档
# ---------------------------------------------------------------------------
_RUNTIME_DIRS = st.sampled_from(["logs", "export"])
# 排除 __init__.py 的文件名
_NON_INIT_BASENAME = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=20,
).filter(lambda s: s != "__init__.py")
def _runtime_output_rel_path() -> st.SearchStrategy[str]:
"""生成以 logs/、reports/ 或 export/ 开头的相对路径basename 不是 __init__.py。"""
return st.builds(
lambda top, mid, name: (
f"{top}/{'/'.join(mid)}/{name}" if mid else f"{top}/{name}"
),
top=_RUNTIME_DIRS,
mid=st.lists(_path_segment, min_size=0, max_size=2),
name=_NON_INIT_BASENAME,
)
@given(data=st.data())
@settings(max_examples=100)
def test_runtime_output_candidate_archive(data: st.DataObject) -> None:
"""Property 6: 运行时产出目录标记为候选归档
Feature: repo-audit, Property 6: 运行时产出目录标记为候选归档
Validates: Requirements 1.8
对于任意 rel_path 以 logs/ 或 export/ 开头且非 __init__.py
的 FileEntryclassify 返回的 disposition 应为 CANDIDATE_ARCHIVE。
需求 1.8 仅覆盖 logs/ 和 export/ 目录(不含 reports/)。
"""
rel_path = data.draw(_runtime_output_rel_path())
ext = data.draw(_EXTENSIONS)
entry = FileEntry(
rel_path=rel_path,
is_dir=data.draw(st.booleans()),
size_bytes=data.draw(st.integers(min_value=0, max_value=10_000_000)),
extension=ext,
is_empty_dir=data.draw(st.booleans()),
)
result = classify(entry)
assert result.disposition == Disposition.CANDIDATE_ARCHIVE, (
f"运行时产出 '{entry.rel_path}' 应标记为候选归档,"
f"实际为 {result.disposition.value}"
)

View File

@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
"""
属性测试 — 清单渲染完整性与分类分组
Feature: repo-audit
- Property 2: 清单渲染完整性
- Property 8: 清单按分类分组
Validates: Requirements 1.4, 1.10
"""
from __future__ import annotations
import string
from hypothesis import given, settings
from hypothesis import strategies as st
from scripts.audit import Category, Disposition, InventoryItem
from scripts.audit.inventory_analyzer import render_inventory_report
# ---------------------------------------------------------------------------
# 生成器策略
# ---------------------------------------------------------------------------
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
_path_segment = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=15,
)
# 随机相对路径1~3 层)
_rel_path = st.lists(
_path_segment,
min_size=1,
max_size=3,
).map(lambda parts: "/".join(parts))
# 随机非空描述(不含管道符和换行符,避免破坏 Markdown 表格解析)
_description = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S", "Z"),
blacklist_characters="|\n\r",
),
min_size=1,
max_size=40,
)
def _inventory_item_strategy() -> st.SearchStrategy[InventoryItem]:
"""生成随机 InventoryItem 的 hypothesis 策略。"""
return st.builds(
InventoryItem,
rel_path=_rel_path,
category=st.sampled_from(list(Category)),
disposition=st.sampled_from(list(Disposition)),
description=_description,
)
# 生成 0~20 个 InventoryItem 的列表
_inventory_list = st.lists(
_inventory_item_strategy(),
min_size=0,
max_size=20,
)
# ---------------------------------------------------------------------------
# Property 2: 清单渲染完整性
# ---------------------------------------------------------------------------
@given(items=_inventory_list)
@settings(max_examples=100)
def test_render_inventory_completeness(items: list[InventoryItem]) -> None:
"""Property 2: 清单渲染完整性
Feature: repo-audit, Property 2: 清单渲染完整性
Validates: Requirements 1.4
对于任意 InventoryItem 列表render_inventory_report 生成的 Markdown 中,
每个条目的 rel_path、category.value、disposition.value 和 description
四个字段都应出现在输出文本中。
"""
report = render_inventory_report(items, "/tmp/test-repo")
for item in items:
# rel_path 出现在表格行中
assert item.rel_path in report, (
f"rel_path '{item.rel_path}' 未出现在报告中"
)
# category.value 出现在分组标题中
assert item.category.value in report, (
f"category '{item.category.value}' 未出现在报告中"
)
# disposition.value 出现在表格行中
assert item.disposition.value in report, (
f"disposition '{item.disposition.value}' 未出现在报告中"
)
# description 出现在表格行中
assert item.description in report, (
f"description '{item.description}' 未出现在报告中"
)
# ---------------------------------------------------------------------------
# Property 8: 清单按分类分组
# ---------------------------------------------------------------------------
@given(items=_inventory_list)
@settings(max_examples=100)
def test_render_inventory_grouped_by_category(items: list[InventoryItem]) -> None:
"""Property 8: 清单按分类分组
Feature: repo-audit, Property 8: 清单按分类分组
Validates: Requirements 1.10
对于任意 InventoryItem 列表render_inventory_report 生成的 Markdown 中,
同一 Category 的条目应连续出现(不应被其他 Category 的条目打断)。
"""
report = render_inventory_report(items, "/tmp/test-repo")
if not items:
return # 空列表无需验证
# 从报告中按行提取条目对应的 category 顺序
# 表格行格式: | `{rel_path}` | {disposition} | {description} |
# 分组标题格式: ## {category.value}
lines = report.split("\n")
# 收集每个分组标题下的条目,按出现顺序记录 category
categories_in_order: list[Category] = []
current_category: Category | None = None
# 建立 category.value -> Category 的映射
value_to_cat = {c.value: c for c in Category}
for line in lines:
stripped = line.strip()
# 检测分组标题 "## {category.value}"
if stripped.startswith("## ") and stripped[3:] in value_to_cat:
current_category = value_to_cat[stripped[3:]]
continue
# 检测表格数据行(跳过表头和分隔行)
if (
current_category is not None
and stripped.startswith("| `")
and not stripped.startswith("| 相对路径")
and not stripped.startswith("|---")
):
categories_in_order.append(current_category)
# 验证同一 Category 的条目连续出现
seen: set[Category] = set()
prev: Category | None = None
for cat in categories_in_order:
if cat != prev:
assert cat not in seen, (
f"Category '{cat.value}' 的条目不连续——"
f"在其他分类条目之后再次出现"
)
seen.add(cat)
prev = cat

View File

@@ -0,0 +1,485 @@
# -*- coding: utf-8 -*-
"""
属性测试 — 报告输出属性
Feature: repo-audit
- Property 13: 统计摘要一致性
- Property 14: 报告头部元信息
- Property 15: 写操作仅限 docs/audit/
Validates: Requirements 4.2, 4.5, 4.6, 4.7, 5.2
"""
from __future__ import annotations
import os
import re
import string
from pathlib import Path
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from scripts.audit import (
AlignmentIssue,
Category,
Disposition,
DocMapping,
FlowNode,
InventoryItem,
)
from scripts.audit.inventory_analyzer import render_inventory_report
from scripts.audit.flow_analyzer import render_flow_report
from scripts.audit.doc_alignment_analyzer import render_alignment_report
# ---------------------------------------------------------------------------
# 共享生成器策略
# ---------------------------------------------------------------------------
_PATH_CHARS = string.ascii_letters + string.digits + "_-."
_path_segment = st.text(
alphabet=_PATH_CHARS,
min_size=1,
max_size=12,
)
_rel_path = st.lists(
_path_segment,
min_size=1,
max_size=3,
).map(lambda parts: "/".join(parts))
_safe_text = st.text(
alphabet=st.characters(
whitelist_categories=("L", "N", "P", "S", "Z"),
blacklist_characters="|\n\r",
),
min_size=1,
max_size=30,
)
_repo_root_str = st.text(
alphabet=string.ascii_letters + string.digits + "/_-.",
min_size=3,
max_size=40,
).map(lambda s: "/" + s.lstrip("/"))
# ---------------------------------------------------------------------------
# InventoryItem 生成器
# ---------------------------------------------------------------------------
def _inventory_item_st() -> st.SearchStrategy[InventoryItem]:
return st.builds(
InventoryItem,
rel_path=_rel_path,
category=st.sampled_from(list(Category)),
disposition=st.sampled_from(list(Disposition)),
description=_safe_text,
)
_inventory_list = st.lists(_inventory_item_st(), min_size=0, max_size=20)
# ---------------------------------------------------------------------------
# FlowNode 生成器(限制深度和宽度)
# ---------------------------------------------------------------------------
def _flow_node_st(max_depth: int = 2) -> st.SearchStrategy[FlowNode]:
"""生成随机 FlowNode 树,限制深度避免爆炸。"""
if max_depth <= 0:
return st.builds(
FlowNode,
name=_path_segment,
source_file=_rel_path,
node_type=st.sampled_from(["entry", "module", "class", "function"]),
children=st.just([]),
)
return st.builds(
FlowNode,
name=_path_segment,
source_file=_rel_path,
node_type=st.sampled_from(["entry", "module", "class", "function"]),
children=st.lists(
_flow_node_st(max_depth - 1),
min_size=0,
max_size=3,
),
)
_flow_tree_list = st.lists(_flow_node_st(), min_size=0, max_size=5)
_orphan_list = st.lists(_rel_path, min_size=0, max_size=10)
# ---------------------------------------------------------------------------
# DocMapping / AlignmentIssue 生成器
# ---------------------------------------------------------------------------
_issue_type_st = st.sampled_from(["stale", "conflict", "missing"])
def _alignment_issue_st() -> st.SearchStrategy[AlignmentIssue]:
return st.builds(
AlignmentIssue,
doc_path=_rel_path,
issue_type=_issue_type_st,
description=_safe_text,
related_code=_rel_path,
)
def _doc_mapping_st() -> st.SearchStrategy[DocMapping]:
return st.builds(
DocMapping,
doc_path=_rel_path,
doc_topic=_safe_text,
related_code=st.lists(_rel_path, min_size=0, max_size=5),
status=st.sampled_from(["aligned", "stale", "conflict", "orphan"]),
)
_mapping_list = st.lists(_doc_mapping_st(), min_size=0, max_size=15)
_issue_list = st.lists(_alignment_issue_st(), min_size=0, max_size=15)
# ===========================================================================
# Property 13: 统计摘要一致性
# ===========================================================================
class TestProperty13SummaryConsistency:
"""Property 13: 统计摘要一致性
Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.5, 4.6, 4.7
对于任意报告的统计摘要,各分类/标签的计数之和应等于对应条目列表的总长度。
"""
# --- 13a: render_inventory_report 的分类计数之和 = 列表长度 ---
@given(items=_inventory_list)
@settings(max_examples=100)
def test_inventory_category_counts_sum(
self, items: list[InventoryItem]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.5
render_inventory_report 统计摘要中各用途分类的计数之和应等于条目总数。
"""
report = render_inventory_report(items, "/tmp/repo")
# 定位"按用途分类"表格,提取各行数字并求和
cat_sum = _extract_summary_total(report, "按用途分类")
assert cat_sum == len(items), (
f"分类计数之和 {cat_sum} != 条目总数 {len(items)}"
)
# --- 13b: render_inventory_report 的处置标签计数之和 = 列表长度 ---
@given(items=_inventory_list)
@settings(max_examples=100)
def test_inventory_disposition_counts_sum(
self, items: list[InventoryItem]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.5
render_inventory_report 统计摘要中各处置标签的计数之和应等于条目总数。
"""
report = render_inventory_report(items, "/tmp/repo")
disp_sum = _extract_summary_total(report, "按处置标签")
assert disp_sum == len(items), (
f"处置标签计数之和 {disp_sum} != 条目总数 {len(items)}"
)
# --- 13c: render_flow_report 的孤立模块数量 = orphans 列表长度 ---
@given(trees=_flow_tree_list, orphans=_orphan_list)
@settings(max_examples=100)
def test_flow_orphan_count_matches(
self, trees: list[FlowNode], orphans: list[str]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.6
render_flow_report 统计摘要中的孤立模块数量应等于 orphans 列表长度。
"""
report = render_flow_report(trees, orphans, "/tmp/repo")
# 从统计摘要表格中提取"孤立模块"行的数字
orphan_count = _extract_flow_stat(report, "孤立模块")
assert orphan_count == len(orphans), (
f"报告中孤立模块数 {orphan_count} != orphans 列表长度 {len(orphans)}"
)
# --- 13d: render_alignment_report 的 issue 类型计数一致 ---
@given(mappings=_mapping_list, issues=_issue_list)
@settings(max_examples=100)
def test_alignment_issue_counts_match(
self, mappings: list[DocMapping], issues: list[AlignmentIssue]
) -> None:
"""Feature: repo-audit, Property 13: 统计摘要一致性
Validates: Requirements 4.7
render_alignment_report 统计摘要中过期/冲突/缺失点计数应与
issues 列表中对应类型的实际数量一致。
"""
report = render_alignment_report(mappings, issues, "/tmp/repo")
expected_stale = sum(1 for i in issues if i.issue_type == "stale")
expected_conflict = sum(1 for i in issues if i.issue_type == "conflict")
expected_missing = sum(1 for i in issues if i.issue_type == "missing")
actual_stale = _extract_alignment_stat(report, "过期点数量")
actual_conflict = _extract_alignment_stat(report, "冲突点数量")
actual_missing = _extract_alignment_stat(report, "缺失点数量")
assert actual_stale == expected_stale, (
f"过期点: 报告 {actual_stale} != 实际 {expected_stale}"
)
assert actual_conflict == expected_conflict, (
f"冲突点: 报告 {actual_conflict} != 实际 {expected_conflict}"
)
assert actual_missing == expected_missing, (
f"缺失点: 报告 {actual_missing} != 实际 {expected_missing}"
)
# ===========================================================================
# Property 14: 报告头部元信息
# ===========================================================================
class TestProperty14ReportHeader:
"""Property 14: 报告头部元信息
Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
对于任意报告输出,头部应包含一个符合 ISO 格式的时间戳字符串和仓库根目录路径字符串。
"""
_ISO_TS_RE = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
@given(items=_inventory_list, repo_root=_repo_root_str)
@settings(max_examples=100)
def test_inventory_report_header(
self, items: list[InventoryItem], repo_root: str
) -> None:
"""Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
render_inventory_report 头部应包含 ISO 时间戳和仓库路径。
"""
report = render_inventory_report(items, repo_root)
header = report[:500]
assert self._ISO_TS_RE.search(header), (
"inventory 报告头部缺少 ISO 格式时间戳"
)
assert repo_root in header, (
f"inventory 报告头部缺少仓库路径 '{repo_root}'"
)
@given(trees=_flow_tree_list, orphans=_orphan_list, repo_root=_repo_root_str)
@settings(max_examples=100)
def test_flow_report_header(
self, trees: list[FlowNode], orphans: list[str], repo_root: str
) -> None:
"""Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
render_flow_report 头部应包含 ISO 时间戳和仓库路径。
"""
report = render_flow_report(trees, orphans, repo_root)
header = report[:500]
assert self._ISO_TS_RE.search(header), (
"flow 报告头部缺少 ISO 格式时间戳"
)
assert repo_root in header, (
f"flow 报告头部缺少仓库路径 '{repo_root}'"
)
@given(mappings=_mapping_list, issues=_issue_list, repo_root=_repo_root_str)
@settings(max_examples=100)
def test_alignment_report_header(
self, mappings: list[DocMapping], issues: list[AlignmentIssue], repo_root: str
) -> None:
"""Feature: repo-audit, Property 14: 报告头部元信息
Validates: Requirements 4.2
render_alignment_report 头部应包含 ISO 时间戳和仓库路径。
"""
report = render_alignment_report(mappings, issues, repo_root)
header = report[:500]
assert self._ISO_TS_RE.search(header), (
"alignment 报告头部缺少 ISO 格式时间戳"
)
assert repo_root in header, (
f"alignment 报告头部缺少仓库路径 '{repo_root}'"
)
# ===========================================================================
# Property 15: 写操作仅限 docs/audit/
# ===========================================================================
class TestProperty15WritesOnlyDocsAudit:
"""Property 15: 写操作仅限 docs/audit/
Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
Validates: Requirements 5.2
对于任意审计执行过程,所有文件写操作的目标路径应以 docs/audit/ 为前缀。
由于需要实际文件系统,使用较少迭代。
"""
@staticmethod
def _make_minimal_repo(base: Path, variant: int) -> Path:
"""构造最小仓库结构variant 控制变体以增加多样性。"""
repo = base / f"repo_{variant}"
repo.mkdir()
# 必需的 cli 入口
cli_dir = repo / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text(
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
encoding="utf-8",
)
# config 目录
config_dir = repo / "config"
config_dir.mkdir()
(config_dir / "__init__.py").write_text("", encoding="utf-8")
# docs 目录
docs_dir = repo / "docs"
docs_dir.mkdir()
# 根据 variant 添加不同的额外文件
if variant % 3 == 0:
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
if variant % 3 == 1:
scripts_dir = repo / "scripts"
scripts_dir.mkdir()
(scripts_dir / "__init__.py").write_text("", encoding="utf-8")
if variant % 3 == 2:
(docs_dir / "notes.md").write_text("# 笔记\n", encoding="utf-8")
return repo
@staticmethod
def _snapshot_files(repo: Path) -> dict[str, float]:
"""记录仓库中所有文件的 mtime 快照(排除 docs/audit/)。"""
snap: dict[str, float] = {}
for p in repo.rglob("*"):
if p.is_file():
rel = p.relative_to(repo).as_posix()
if not rel.startswith("docs/audit"):
snap[rel] = p.stat().st_mtime
return snap
@given(variant=st.integers(min_value=0, max_value=9))
@settings(max_examples=10)
def test_writes_only_under_docs_audit(self, variant: int) -> None:
"""Feature: repo-audit, Property 15: 写操作仅限 docs/audit/
Validates: Requirements 5.2
运行 run_audit 后docs/audit/ 外不应有新文件被创建。
docs/audit/ 下应有报告文件。
"""
import tempfile
from scripts.audit.run_audit import run_audit
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
repo = self._make_minimal_repo(tmp_path, variant)
before_snap = self._snapshot_files(repo)
run_audit(repo)
# 验证 docs/audit/ 下有新文件
audit_dir = repo / "docs" / "audit"
assert audit_dir.is_dir(), "docs/audit/ 目录未创建"
audit_files = list(audit_dir.iterdir())
assert len(audit_files) > 0, "docs/audit/ 下无报告文件"
# 验证 docs/audit/ 外无新文件
for p in repo.rglob("*"):
if p.is_file():
rel = p.relative_to(repo).as_posix()
if rel.startswith("docs/audit"):
continue
assert rel in before_snap, (
f"docs/audit/ 外出现了新文件: {rel}"
)
# ===========================================================================
# 辅助函数 — 从报告文本中提取统计数字
# ===========================================================================
def _extract_summary_total(report: str, section_name: str) -> int:
"""从 inventory 报告的统计摘要中提取指定分区的数字之和。
查找 "### {section_name}" 下的 Markdown 表格,
累加每行最后一列的数字(排除合计行)。
"""
lines = report.split("\n")
in_section = False
total = 0
for line in lines:
stripped = line.strip()
if stripped == f"### {section_name}":
in_section = True
continue
if in_section and stripped.startswith("###"):
# 进入下一个子节
break
if in_section and stripped.startswith("|") and "**合计**" not in stripped:
# 跳过表头和分隔行
if stripped.startswith("| 用途分类") or stripped.startswith("| 处置标签"):
continue
if stripped.startswith("|---"):
continue
# 提取最后一列的数字
cells = [c.strip() for c in stripped.split("|") if c.strip()]
if cells:
try:
total += int(cells[-1])
except ValueError:
pass
return total
def _extract_flow_stat(report: str, label: str) -> int:
"""从 flow 报告统计摘要表格中提取指定指标的数字。"""
# 匹配 "| 孤立模块 | 5 |" 格式
pattern = re.compile(rf"\|\s*{re.escape(label)}\s*\|\s*(\d+)\s*\|")
m = pattern.search(report)
return int(m.group(1)) if m else -1
def _extract_alignment_stat(report: str, label: str) -> int:
"""从 alignment 报告统计摘要中提取指定指标的数字。
匹配 "- 过期点数量3" 格式。
"""
# 兼容全角/半角冒号
pattern = re.compile(rf"{re.escape(label)}[:]\s*(\d+)")
m = pattern.search(report)
return int(m.group(1)) if m else -1

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
"""
run_audit 主入口的单元测试。
验证:
- docs/audit/ 目录自动创建
- 三份报告文件正确生成
- 报告头部包含时间戳和仓库路径
- 目录创建失败时抛出 RuntimeError
"""
from __future__ import annotations
import os
import re
from pathlib import Path
import pytest
class TestEnsureReportDir:
"""测试 _ensure_report_dir 目录创建逻辑。"""
def test_creates_dir_when_missing(self, tmp_path: Path):
from scripts.audit.run_audit import _ensure_report_dir
result = _ensure_report_dir(tmp_path)
expected = tmp_path / "docs" / "audit"
assert result == expected
assert expected.is_dir()
def test_returns_existing_dir(self, tmp_path: Path):
from scripts.audit.run_audit import _ensure_report_dir
audit_dir = tmp_path / "docs" / "audit"
audit_dir.mkdir(parents=True)
result = _ensure_report_dir(tmp_path)
assert result == audit_dir
def test_raises_on_creation_failure(self, tmp_path: Path):
from scripts.audit.run_audit import _ensure_report_dir
# 在 docs/audit 位置放一个文件,使 mkdir 失败
docs = tmp_path / "docs"
docs.mkdir()
(docs / "audit").write_text("block", encoding="utf-8")
with pytest.raises(RuntimeError, match="无法创建报告输出目录"):
_ensure_report_dir(tmp_path)
class TestInjectHeader:
"""测试 _inject_header 兜底注入逻辑。"""
def test_skips_when_header_present(self):
from scripts.audit.run_audit import _inject_header
report = "# 标题\n\n- 生成时间: 2025-01-01T00:00:00Z\n- 仓库路径: `/repo`\n"
result = _inject_header(report, "2025-06-01T00:00:00Z", "/other")
# 不应修改已有头部
assert result == report
def test_injects_when_header_missing(self):
from scripts.audit.run_audit import _inject_header
report = "# 无头部报告\n\n内容..."
result = _inject_header(report, "2025-06-01T00:00:00Z", "/repo")
assert "生成时间: 2025-06-01T00:00:00Z" in result
assert "仓库路径: `/repo`" in result
class TestRunAudit:
"""测试 run_audit 完整流程(使用最小仓库结构)。"""
def _make_minimal_repo(self, tmp_path: Path) -> Path:
"""构造一个最小仓库结构,足以让 run_audit 跑通。"""
repo = tmp_path / "repo"
repo.mkdir()
# 核心代码目录
cli_dir = repo / "cli"
cli_dir.mkdir()
(cli_dir / "__init__.py").write_text("", encoding="utf-8")
(cli_dir / "main.py").write_text(
"# -*- coding: utf-8 -*-\ndef main(): pass\n",
encoding="utf-8",
)
# config 目录
config_dir = repo / "config"
config_dir.mkdir()
(config_dir / "__init__.py").write_text("", encoding="utf-8")
(config_dir / "defaults.py").write_text("DEFAULTS = {}\n", encoding="utf-8")
# docs 目录
docs_dir = repo / "docs"
docs_dir.mkdir()
(docs_dir / "README.md").write_text("# 文档\n", encoding="utf-8")
# 根目录文件
(repo / "README.md").write_text("# 项目\n", encoding="utf-8")
return repo
def test_creates_three_reports(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
run_audit(repo)
audit_dir = repo / "docs" / "audit"
assert (audit_dir / "file_inventory.md").is_file()
assert (audit_dir / "flow_tree.md").is_file()
assert (audit_dir / "doc_alignment.md").is_file()
def test_reports_contain_timestamp(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
run_audit(repo)
audit_dir = repo / "docs" / "audit"
# ISO 时间戳格式
ts_pattern = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z")
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
content = (audit_dir / name).read_text(encoding="utf-8")
assert ts_pattern.search(content), f"{name} 缺少时间戳"
def test_reports_contain_repo_path(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
run_audit(repo)
audit_dir = repo / "docs" / "audit"
repo_str = str(repo.resolve())
for name in ("file_inventory.md", "flow_tree.md", "doc_alignment.md"):
content = (audit_dir / name).read_text(encoding="utf-8")
assert repo_str in content, f"{name} 缺少仓库路径"
def test_writes_only_to_docs_audit(self, tmp_path: Path):
"""验证所有写操作仅限 docs/audit/ 目录Property 15"""
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
# 记录运行前的文件快照(排除 docs/audit/
before = set()
for p in repo.rglob("*"):
rel = p.relative_to(repo).as_posix()
if not rel.startswith("docs/audit"):
before.add((rel, p.stat().st_mtime if p.is_file() else None))
run_audit(repo)
# 运行后检查docs/audit/ 外的文件不应被修改
for p in repo.rglob("*"):
rel = p.relative_to(repo).as_posix()
if rel.startswith("docs/audit"):
continue
if p.is_file():
# 文件应在之前的快照中
found = any(r == rel for r, _ in before)
assert found, f"意外创建了 docs/audit/ 外的文件: {rel}"
def test_auto_creates_docs_audit_dir(self, tmp_path: Path):
from scripts.audit.run_audit import run_audit
repo = self._make_minimal_repo(tmp_path)
# 确保 docs/audit/ 不存在
audit_dir = repo / "docs" / "audit"
assert not audit_dir.exists()
run_audit(repo)
assert audit_dir.is_dir()

View File

@@ -0,0 +1,428 @@
# -*- coding: utf-8 -*-
"""
单元测试 — 仓库扫描器 (scanner.py)
覆盖:
- 排除模式匹配逻辑
- 递归遍历与 FileEntry 构建
- 空目录检测
- 权限错误容错
"""
from __future__ import annotations
import os
from pathlib import Path
import pytest
from scripts.audit import FileEntry
from scripts.audit.scanner import EXCLUDED_PATTERNS, _is_excluded, scan_repo
# ---------------------------------------------------------------------------
# _is_excluded 单元测试
# ---------------------------------------------------------------------------
class TestIsExcluded:
"""排除模式匹配逻辑测试。"""
def test_exact_match_git(self) -> None:
assert _is_excluded(".git", EXCLUDED_PATTERNS) is True
def test_exact_match_pycache(self) -> None:
assert _is_excluded("__pycache__", EXCLUDED_PATTERNS) is True
def test_exact_match_pytest_cache(self) -> None:
assert _is_excluded(".pytest_cache", EXCLUDED_PATTERNS) is True
def test_exact_match_kiro(self) -> None:
assert _is_excluded(".kiro", EXCLUDED_PATTERNS) is True
def test_wildcard_pyc(self) -> None:
assert _is_excluded("module.pyc", EXCLUDED_PATTERNS) is True
def test_normal_py_not_excluded(self) -> None:
assert _is_excluded("main.py", EXCLUDED_PATTERNS) is False
def test_normal_dir_not_excluded(self) -> None:
assert _is_excluded("src", EXCLUDED_PATTERNS) is False
def test_empty_patterns(self) -> None:
assert _is_excluded(".git", []) is False
def test_custom_pattern(self) -> None:
assert _is_excluded("data.csv", ["*.csv"]) is True
# ---------------------------------------------------------------------------
# scan_repo 单元测试
# ---------------------------------------------------------------------------
class TestScanRepo:
"""scan_repo 递归遍历测试。"""
def test_basic_structure(self, tmp_path: Path) -> None:
"""基本文件和目录应被正确扫描。"""
(tmp_path / "a.py").write_text("# code", encoding="utf-8")
sub = tmp_path / "sub"
sub.mkdir()
(sub / "b.txt").write_text("hello", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "a.py" in paths
assert "sub" in paths
assert "sub/b.txt" in paths
def test_file_entry_fields(self, tmp_path: Path) -> None:
"""FileEntry 各字段应正确填充。"""
(tmp_path / "hello.md").write_text("# hi", encoding="utf-8")
entries = scan_repo(tmp_path)
md = next(e for e in entries if e.rel_path == "hello.md")
assert md.is_dir is False
assert md.size_bytes > 0
assert md.extension == ".md"
assert md.is_empty_dir is False
def test_directory_entry_fields(self, tmp_path: Path) -> None:
"""目录条目的字段应正确设置。"""
sub = tmp_path / "mydir"
sub.mkdir()
(sub / "file.py").write_text("pass", encoding="utf-8")
entries = scan_repo(tmp_path)
d = next(e for e in entries if e.rel_path == "mydir")
assert d.is_dir is True
assert d.size_bytes == 0
assert d.extension == ""
assert d.is_empty_dir is False
def test_excluded_git_dir(self, tmp_path: Path) -> None:
""".git 目录及其内容应被排除。"""
git_dir = tmp_path / ".git"
git_dir.mkdir()
(git_dir / "config").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert ".git" not in paths
assert ".git/config" not in paths
def test_excluded_pycache(self, tmp_path: Path) -> None:
"""__pycache__ 目录应被排除。"""
cache = tmp_path / "pkg" / "__pycache__"
cache.mkdir(parents=True)
(cache / "mod.cpython-310.pyc").write_bytes(b"\x00")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert not any("__pycache__" in p for p in paths)
def test_excluded_pyc_files(self, tmp_path: Path) -> None:
"""*.pyc 文件应被排除。"""
(tmp_path / "mod.pyc").write_bytes(b"\x00")
(tmp_path / "mod.py").write_text("pass", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "mod.pyc" not in paths
assert "mod.py" in paths
def test_empty_directory_detection(self, tmp_path: Path) -> None:
"""空目录应被标记为 is_empty_dir=True。"""
(tmp_path / "empty").mkdir()
entries = scan_repo(tmp_path)
d = next(e for e in entries if e.rel_path == "empty")
assert d.is_dir is True
assert d.is_empty_dir is True
def test_dir_with_only_excluded_children(self, tmp_path: Path) -> None:
"""仅含被排除子项的目录应视为空目录。"""
sub = tmp_path / "pkg"
sub.mkdir()
cache = sub / "__pycache__"
cache.mkdir()
(cache / "x.pyc").write_bytes(b"\x00")
entries = scan_repo(tmp_path)
d = next(e for e in entries if e.rel_path == "pkg")
assert d.is_empty_dir is True
def test_custom_exclude_patterns(self, tmp_path: Path) -> None:
"""自定义排除模式应生效。"""
(tmp_path / "keep.py").write_text("pass", encoding="utf-8")
(tmp_path / "skip.log").write_text("log", encoding="utf-8")
entries = scan_repo(tmp_path, exclude=["*.log"])
paths = {e.rel_path for e in entries}
assert "keep.py" in paths
assert "skip.log" not in paths
def test_empty_repo(self, tmp_path: Path) -> None:
"""空仓库应返回空列表。"""
entries = scan_repo(tmp_path)
assert entries == []
def test_results_sorted(self, tmp_path: Path) -> None:
"""返回结果应按 rel_path 排序。"""
(tmp_path / "z.py").write_text("", encoding="utf-8")
(tmp_path / "a.py").write_text("", encoding="utf-8")
sub = tmp_path / "m"
sub.mkdir()
(sub / "b.py").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = [e.rel_path for e in entries]
assert paths == sorted(paths)
@pytest.mark.skipif(
os.name == "nt",
reason="Windows 上 chmod 行为不同,跳过权限测试",
)
def test_permission_error_skipped(self, tmp_path: Path) -> None:
"""权限不足的目录应被跳过,不中断扫描。"""
ok_file = tmp_path / "ok.py"
ok_file.write_text("pass", encoding="utf-8")
no_access = tmp_path / "secret"
no_access.mkdir()
(no_access / "data.txt").write_text("x", encoding="utf-8")
no_access.chmod(0o000)
try:
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
# ok.py 应正常扫描到
assert "ok.py" in paths
# secret 目录本身会被记录(在 _walk 中先记录目录再尝试 iterdir
# 但其子文件不应出现
assert "secret/data.txt" not in paths
finally:
no_access.chmod(0o755)
def test_nested_directories(self, tmp_path: Path) -> None:
"""多层嵌套目录应被正确遍历。"""
deep = tmp_path / "a" / "b" / "c"
deep.mkdir(parents=True)
(deep / "leaf.py").write_text("pass", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "a" in paths
assert "a/b" in paths
assert "a/b/c" in paths
assert "a/b/c/leaf.py" in paths
def test_extension_lowercase(self, tmp_path: Path) -> None:
"""扩展名应统一为小写。"""
(tmp_path / "README.MD").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
md = next(e for e in entries if "README" in e.rel_path)
assert md.extension == ".md"
def test_no_extension(self, tmp_path: Path) -> None:
"""无扩展名的文件 extension 应为空字符串。"""
(tmp_path / "Makefile").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
f = next(e for e in entries if e.rel_path == "Makefile")
assert f.extension == ""
def test_root_not_in_entries(self, tmp_path: Path) -> None:
"""根目录自身不应出现在结果中。"""
(tmp_path / "a.py").write_text("", encoding="utf-8")
entries = scan_repo(tmp_path)
paths = {e.rel_path for e in entries}
assert "." not in paths
assert "" not in paths
# ---------------------------------------------------------------------------
# 属性测试 — Property 7: 扫描器排除规则
# Feature: repo-audit, Property 7: 扫描器排除规则
# Validates: Requirements 1.1
# ---------------------------------------------------------------------------
import fnmatch
import string
import tempfile
from hypothesis import given, settings
from hypothesis import strategies as st
# --- 生成器策略 ---
# 合法的文件/目录名字符(排除路径分隔符和特殊字符)
_SAFE_CHARS = string.ascii_lowercase + string.digits + "_-"
# 安全的文件名策略(不与排除模式冲突的普通名称)
_safe_name = st.text(_SAFE_CHARS, min_size=1, max_size=8)
# 排除模式中的目录名
_EXCLUDED_DIR_NAMES = [".git", "__pycache__", ".pytest_cache", ".kiro"]
# 排除模式中的文件扩展名
_EXCLUDED_FILE_EXT = ".pyc"
# 随机选择一个被排除的目录名
_excluded_dir_name = st.sampled_from(_EXCLUDED_DIR_NAMES)
def _build_tree(tmp: Path, normal_names: list[str], excluded_dirs: list[str],
include_pyc: bool) -> None:
"""在临时目录中构建包含正常文件和被排除条目的文件树。"""
# 创建正常文件
for name in normal_names:
safe = name or "f"
filepath = tmp / f"{safe}.txt"
if not filepath.exists():
filepath.write_text("ok", encoding="utf-8")
# 创建被排除的目录(含子文件)
for dirname in excluded_dirs:
d = tmp / dirname
d.mkdir(exist_ok=True)
(d / "inner.txt").write_text("hidden", encoding="utf-8")
# 可选:创建 .pyc 文件
if include_pyc:
(tmp / "module.pyc").write_bytes(b"\x00")
class TestProperty7ScannerExclusionRules:
"""
Property 7: 扫描器排除规则
对于任意文件树scan_repo 返回的 FileEntry 列表中不应包含
rel_path 匹配排除模式(.git、__pycache__、.pytest_cache 等)的条目。
Feature: repo-audit, Property 7: 扫描器排除规则
Validates: Requirements 1.1
"""
@given(
normal_names=st.lists(_safe_name, min_size=0, max_size=5),
excluded_dirs=st.lists(_excluded_dir_name, min_size=1, max_size=3),
include_pyc=st.booleans(),
)
@settings(max_examples=100)
def test_excluded_entries_never_in_results(
self,
normal_names: list[str],
excluded_dirs: list[str],
include_pyc: bool,
) -> None:
"""扫描结果中不应包含任何匹配排除模式的条目。"""
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
_build_tree(tmp, normal_names, excluded_dirs, include_pyc)
entries = scan_repo(tmp)
for entry in entries:
# 检查 rel_path 的每一段是否匹配排除模式
parts = entry.rel_path.split("/")
for part in parts:
for pat in EXCLUDED_PATTERNS:
assert not fnmatch.fnmatch(part, pat), (
f"排除模式 '{pat}' 不应出现在结果中,"
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
)
@given(
excluded_dir=_excluded_dir_name,
depth=st.integers(min_value=1, max_value=3),
)
@settings(max_examples=100)
def test_excluded_dirs_at_any_depth(
self,
excluded_dir: str,
depth: int,
) -> None:
"""被排除目录无论在哪一层嵌套深度,都不应出现在结果中。"""
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
# 构建嵌套路径normal/normal/.../excluded_dir/file.txt
current = tmp
for i in range(depth):
current = current / f"level{i}"
current.mkdir(exist_ok=True)
# 放一个正常文件保证父目录非空
(current / "keep.txt").write_text("ok", encoding="utf-8")
# 在最深层放置被排除目录
excluded = current / excluded_dir
excluded.mkdir(exist_ok=True)
(excluded / "secret.txt").write_text("hidden", encoding="utf-8")
entries = scan_repo(tmp)
for entry in entries:
parts = entry.rel_path.split("/")
assert excluded_dir not in parts, (
f"被排除目录 '{excluded_dir}' 不应出现在结果中,"
f"但发现 rel_path='{entry.rel_path}'"
)
@given(
custom_patterns=st.lists(
st.sampled_from(["*.log", "*.tmp", "*.bak", "node_modules", ".venv"]),
min_size=1,
max_size=3,
),
)
@settings(max_examples=100)
def test_custom_exclude_patterns_respected(
self,
custom_patterns: list[str],
) -> None:
"""自定义排除模式同样应被 scan_repo 正确排除。"""
with tempfile.TemporaryDirectory() as tmpdir:
tmp = Path(tmpdir)
# 创建一个正常文件
(tmp / "main.py").write_text("pass", encoding="utf-8")
# 为每个自定义模式创建匹配的文件或目录
for pat in custom_patterns:
if pat.startswith("*."):
# 通配符模式 → 创建匹配的文件
ext = pat[1:] # e.g. ".log"
(tmp / f"data{ext}").write_text("x", encoding="utf-8")
else:
# 精确匹配 → 创建目录
d = tmp / pat
d.mkdir(exist_ok=True)
(d / "inner.txt").write_text("x", encoding="utf-8")
entries = scan_repo(tmp, exclude=custom_patterns)
for entry in entries:
parts = entry.rel_path.split("/")
for part in parts:
for pat in custom_patterns:
assert not fnmatch.fnmatch(part, pat), (
f"自定义排除模式 '{pat}' 不应出现在结果中,"
f"但发现 rel_path='{entry.rel_path}' 包含 '{part}'"
)

137
tests/unit/test_cli_args.py Normal file
View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
"""CLI 参数解析单元测试
验证 --data-source 新参数、--pipeline-flow 弃用映射、
--pipeline + --tasks 同时使用、以及 build_cli_overrides 集成行为。
需求: 3.1, 3.3, 3.5
"""
import warnings
from argparse import Namespace
from unittest.mock import patch
import pytest
from cli.main import parse_args, resolve_data_source, build_cli_overrides
# ---------------------------------------------------------------------------
# 1. --data-source 新参数解析
# ---------------------------------------------------------------------------
class TestDataSourceArg:
"""--data-source 新参数测试"""
@pytest.mark.parametrize("value", ["online", "offline", "hybrid"])
def test_data_source_valid_values(self, value):
with patch("sys.argv", ["cli", "--data-source", value]):
args = parse_args()
assert args.data_source == value
def test_data_source_default_is_none(self):
with patch("sys.argv", ["cli"]):
args = parse_args()
assert args.data_source is None
# ---------------------------------------------------------------------------
# 2. resolve_data_source() 弃用映射
# ---------------------------------------------------------------------------
class TestResolveDataSource:
"""resolve_data_source() 弃用映射测试"""
def test_explicit_data_source_returns_directly(self):
args = Namespace(data_source="online", pipeline_flow=None)
assert resolve_data_source(args) == "online"
def test_data_source_takes_priority_over_pipeline_flow(self):
"""--data-source 优先于 --pipeline-flow"""
args = Namespace(data_source="online", pipeline_flow="FULL")
assert resolve_data_source(args) == "online"
@pytest.mark.parametrize(
"flow, expected",
[
("FULL", "hybrid"),
("FETCH_ONLY", "online"),
("INGEST_ONLY", "offline"),
],
)
def test_pipeline_flow_maps_with_deprecation_warning(self, flow, expected):
"""旧参数 --pipeline-flow 映射到正确的 data_source 并发出弃用警告"""
args = Namespace(data_source=None, pipeline_flow=flow)
with pytest.warns(DeprecationWarning, match="--pipeline-flow 已弃用"):
result = resolve_data_source(args)
assert result == expected
def test_neither_arg_defaults_to_hybrid(self):
"""两个参数都未指定时,默认返回 hybrid"""
args = Namespace(data_source=None, pipeline_flow=None)
assert resolve_data_source(args) == "hybrid"
# ---------------------------------------------------------------------------
# 3. build_cli_overrides() 集成
# ---------------------------------------------------------------------------
class TestBuildCliOverrides:
"""build_cli_overrides() 集成测试"""
def _make_args(self, **kwargs):
"""构造最小 Namespace未指定的参数设为 None/False"""
defaults = dict(
store_id=None, tasks=None, dry_run=False,
pipeline=None, processing_mode="increment_only",
fetch_before_verify=False, verify_tables=None,
window_split="none", lookback_hours=24, overlap_seconds=3600,
pg_dsn=None, pg_host=None, pg_port=None, pg_name=None,
pg_user=None, pg_password=None,
api_base=None, api_token=None, api_timeout=None,
api_page_size=None, api_retry_max=None,
window_start=None, window_end=None,
force_window_override=False,
window_split_unit=None, window_split_days=None,
window_compensation_hours=None,
export_root=None, log_root=None,
data_source=None, pipeline_flow=None,
fetch_root=None, ingest_source=None, write_pretty_json=False,
idle_start=None, idle_end=None, allow_empty_advance=False,
)
defaults.update(kwargs)
return Namespace(**defaults)
def test_data_source_online_sets_run_key(self):
args = self._make_args(data_source="online")
overrides = build_cli_overrides(args)
assert overrides["run"]["data_source"] == "online"
def test_pipeline_flow_sets_both_keys(self):
"""旧参数同时写入 pipeline.flow 和 run.data_source"""
args = self._make_args(pipeline_flow="FULL")
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
overrides = build_cli_overrides(args)
assert overrides["pipeline"]["flow"] == "FULL"
assert overrides["run"]["data_source"] == "hybrid"
def test_default_data_source_is_hybrid(self):
"""无 --data-source 也无 --pipeline-flow 时run.data_source 默认 hybrid"""
args = self._make_args()
overrides = build_cli_overrides(args)
assert overrides["run"]["data_source"] == "hybrid"
# ---------------------------------------------------------------------------
# 4. --pipeline + --tasks 同时使用
# ---------------------------------------------------------------------------
class TestPipelineAndTasks:
"""--pipeline + --tasks 同时使用时的行为"""
def test_pipeline_and_tasks_both_parsed(self):
with patch("sys.argv", [
"cli",
"--pipeline", "api_full",
"--tasks", "ODS_MEMBER,ODS_ORDER",
]):
args = parse_args()
assert args.pipeline == "api_full"
assert args.tasks == "ODS_MEMBER,ODS_ORDER"

24
tests/unit/test_config.py Normal file
View File

@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
"""配置管理测试"""
import pytest
from config.settings import AppConfig
from config.defaults import DEFAULTS
def test_config_load():
"""测试配置加载"""
config = AppConfig.load({"app": {"store_id": 1}})
assert config.get("app.timezone") == DEFAULTS["app"]["timezone"]
def test_config_override():
"""测试配置覆盖"""
overrides = {
"app": {"store_id": 12345}
}
config = AppConfig.load(overrides)
assert config.get("app.store_id") == 12345
def test_config_get_nested():
"""测试嵌套配置获取"""
config = AppConfig.load({"app": {"store_id": 1}})
assert config.get("db.batch_size") == 1000
assert config.get("nonexistent.key", "default") == "default"

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
"""配置映射属性测试 — 使用 hypothesis 验证配置键兼容映射的通用正确性属性。"""
import os
import warnings
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from config.settings import AppConfig, _FLOW_TO_DATA_SOURCE
# ── 确保测试不读取 .env 文件 ──────────────────────────────────
@pytest.fixture(autouse=True)
def skip_dotenv(monkeypatch):
monkeypatch.setenv("ETL_SKIP_DOTENV", "1")
# ── 生成策略 ──────────────────────────────────────────────────
flow_st = st.sampled_from(["FULL", "FETCH_ONLY", "INGEST_ONLY"])
# ── Property 11: pipeline_flow → data_source 映射一致性 ──────
# Feature: scheduler-refactor, Property 11: pipeline_flow → data_source 映射一致性
# **Validates: Requirements 8.1, 8.2, 8.3, 5.2, 8.4**
#
# 对于任意旧 pipeline_flow 值FULL/FETCH_ONLY/INGEST_ONLY
# 映射到 data_source 的结果应与预定义映射表一致:
# FULL→hybrid、FETCH_ONLY→online、INGEST_ONLY→offline。
# 同样,配置键 pipeline.flow 应自动映射到 run.data_source。
class TestProperty11FlowToDataSourceMapping:
"""Property 11: pipeline_flow → data_source 映射一致性。"""
@given(flow=flow_st)
@settings(max_examples=100)
def test_pipeline_flow_maps_to_data_source(self, flow):
"""通过 pipeline.flow 设置旧值后run.data_source 应与映射表一致。"""
expected = _FLOW_TO_DATA_SOURCE[flow]
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
config = AppConfig.load({
"app": {"store_id": 1},
"pipeline": {"flow": flow},
})
actual = config.get("run.data_source")
assert actual == expected, (
f"pipeline.flow={flow!r} 应映射为 run.data_source={expected!r}"
f"实际为 {actual!r}"
)

View File

@@ -0,0 +1,472 @@
# -*- coding: utf-8 -*-
"""
DWS任务单元测试
测试内容:
- BaseDwsTask基类方法
- 时间计算方法
- 配置应用方法
- 排名计算方法
"""
import pytest
from datetime import date, datetime, timedelta
from decimal import Decimal
from unittest.mock import MagicMock, patch
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from tasks.dws.base_dws_task import (
BaseDwsTask,
TimeLayer,
TimeWindow,
CourseType,
TimeRange,
ConfigCache
)
from tasks.dws.finance_daily_task import FinanceDailyTask
from tasks.dws.assistant_monthly_task import AssistantMonthlyTask
class TestTimeLayerRange:
"""测试时间分层范围计算"""
def test_last_2_days(self):
"""测试近2天"""
base_date = date(2026, 2, 1)
# 创建一个模拟的BaseDwsTask实例
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_2_DAYS, base_date)
assert result.start == date(2026, 1, 31)
assert result.end == date(2026, 2, 1)
def test_last_1_month(self):
"""测试近1月"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_1_MONTH, base_date)
assert result.start == date(2026, 1, 2)
assert result.end == date(2026, 2, 1)
def test_last_3_months(self):
"""测试近3月"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_3_MONTHS, base_date)
assert result.start == date(2025, 11, 3)
assert result.end == date(2026, 2, 1)
class TestTimeWindowRange:
"""测试时间窗口范围计算"""
def test_this_week_monday_start(self):
"""测试本周(周一起始)"""
# 2026-02-01 是周日
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_WEEK, base_date)
# 本周一是 2026-01-26
assert result.start == date(2026, 1, 26)
assert result.end == date(2026, 2, 1)
def test_last_week(self):
"""测试上周"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_WEEK, base_date)
# 上周一是 2026-01-19上周日是 2026-01-25
assert result.start == date(2026, 1, 19)
assert result.end == date(2026, 1, 25)
def test_this_month(self):
"""测试本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_MONTH, base_date)
assert result.start == date(2026, 2, 1)
assert result.end == date(2026, 2, 15)
def test_last_month(self):
"""测试上月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_MONTH, base_date)
assert result.start == date(2026, 1, 1)
assert result.end == date(2026, 1, 31)
def test_last_3_months_excl_current(self):
"""测试前3个月不含本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_EXCL_CURRENT, base_date)
assert result.start == date(2025, 11, 1)
assert result.end == date(2026, 1, 31)
def test_last_3_months_incl_current(self):
"""测试前3个月含本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_INCL_CURRENT, base_date)
assert result.start == date(2025, 12, 1)
assert result.end == date(2026, 2, 15)
def test_this_quarter(self):
"""测试本季度"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_QUARTER, base_date)
assert result.start == date(2026, 1, 1)
assert result.end == date(2026, 2, 15)
def test_last_6_months(self):
"""测试最近半年(不含本月)"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_6_MONTHS, base_date)
# 不含本月从上月末往前6个月
assert result.end == date(2026, 1, 31)
assert result.start == date(2025, 8, 1)
class TestComparisonRange:
"""测试环比区间计算"""
def test_comparison_7_days(self):
"""测试7天环比"""
task = create_mock_task()
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 2, 7))
result = task.get_comparison_range(current)
# 上一个7天1月25日-1月31日
assert result.start == date(2026, 1, 25)
assert result.end == date(2026, 1, 31)
def test_comparison_30_days(self):
"""测试30天环比"""
task = create_mock_task()
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 3, 2))
result = task.get_comparison_range(current)
# 上一个30天区间
assert (result.end - result.start).days == (current.end - current.start).days
class TestFinanceDailyRecord:
"""测试财务日度记录计算"""
def test_groupbuy_and_cashflow(self):
"""测试团购优惠与现金流口径"""
task = create_finance_daily_task()
stat_date = date(2026, 2, 1)
settle = {
'gross_amount': Decimal('1000'),
'table_fee_amount': Decimal('1000'),
'goods_amount': Decimal('0'),
'assistant_pd_amount': Decimal('0'),
'assistant_cx_amount': Decimal('0'),
'cash_pay_amount': Decimal('300'),
'card_pay_amount': Decimal('0'),
'balance_pay_amount': Decimal('0'),
'gift_card_pay_amount': Decimal('0'),
'coupon_amount': Decimal('200'),
'pl_coupon_sale_amount': Decimal('0'),
'adjust_amount': Decimal('50'),
'member_discount_amount': Decimal('10'),
'rounding_amount': Decimal('0'),
'order_count': 1,
'member_order_count': 1,
'guest_order_count': 0,
}
groupbuy = {'groupbuy_pay_total': Decimal('80')}
recharge = {'recharge_cash': Decimal('20')}
expense = {'expense_amount': Decimal('40')}
platform = {
'settlement_amount': Decimal('60'),
'commission_amount': Decimal('5'),
'service_fee': Decimal('5'),
}
big_customer = {'big_customer_amount': Decimal('20')}
record = task._build_daily_record(
stat_date, settle, groupbuy, recharge, expense, platform, big_customer, 1
)
assert record['discount_groupbuy'] == Decimal('120')
assert record['discount_other'] == Decimal('30')
assert record['platform_settlement_amount'] == Decimal('60')
assert record['platform_fee_amount'] == Decimal('10')
assert record['cash_inflow_total'] == Decimal('380')
assert record['cash_outflow_total'] == Decimal('50')
assert record['cash_balance_change'] == Decimal('330')
class TestNewHireTier:
"""测试新入职定档规则"""
def test_new_hire_tier_hours(self):
"""测试日均*30折算"""
task = create_assistant_monthly_task()
effective_hours = Decimal('15')
work_days = 5
result = task._calc_new_hire_tier_hours(effective_hours, work_days)
assert result == Decimal('90')
def test_max_tier_level_cap(self):
"""测试新入职定档上限"""
task = create_mock_task()
now = datetime.now()
task._config_cache = ConfigCache(
performance_tiers=[
{'tier_id': 1, 'tier_level': 1, 'min_hours': 0, 'max_hours': 100, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 2, 'tier_level': 2, 'min_hours': 100, 'max_hours': 200, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 3, 'tier_level': 3, 'min_hours': 200, 'max_hours': 300, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 4, 'tier_level': 4, 'min_hours': 300, 'max_hours': None, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
],
level_prices=[],
bonus_rules=[],
area_categories={},
skill_types={},
loaded_at=now
)
tier = task.get_performance_tier(
Decimal('350'),
is_new_hire=True,
effective_date=date(2026, 2, 1),
max_tier_level=3
)
assert tier['tier_level'] == 3
class TestNewHireCheck:
"""测试新入职判断"""
def test_new_hire_in_month(self):
"""测试月内入职为新入职"""
task = create_mock_task()
hire_date = date(2026, 2, 5)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == True
def test_not_new_hire(self):
"""测试月前入职不是新入职"""
task = create_mock_task()
hire_date = date(2026, 1, 15)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == False
def test_hire_on_first_day(self):
"""测试月1日入职为新入职"""
task = create_mock_task()
hire_date = date(2026, 2, 1)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == True
class TestRankWithTies:
"""测试考虑并列的排名计算"""
def test_no_ties(self):
"""测试无并列情况"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('90')),
(3, Decimal('80')),
]
result = task.calculate_rank_with_ties(values)
assert result[0] == (1, 1, 1) # 第1名
assert result[1] == (2, 2, 2) # 第2名
assert result[2] == (3, 3, 3) # 第3名
def test_with_ties(self):
"""测试有并列情况"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('100')), # 并列第1
(3, Decimal('80')),
]
result = task.calculate_rank_with_ties(values)
# 两个第1下一个是第3
assert result[0][1] == 1 # 第1名
assert result[1][1] == 1 # 并列第1名
assert result[2][1] == 3 # 第3名跳过2
def test_all_ties(self):
"""测试全部并列"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('100')),
(3, Decimal('100')),
]
result = task.calculate_rank_with_ties(values)
# 全部第1
assert all(r[1] == 1 for r in result)
class TestGuestCheck:
"""测试散客判断"""
def test_guest_zero(self):
"""测试member_id=0为散客"""
task = create_mock_task()
assert task.is_guest(0) == True
def test_guest_none(self):
"""测试member_id=None为散客"""
task = create_mock_task()
assert task.is_guest(None) == True
def test_not_guest(self):
"""测试正常会员不是散客"""
task = create_mock_task()
assert task.is_guest(12345) == False
class TestUtilityMethods:
"""测试工具方法"""
def test_safe_decimal(self):
"""测试安全Decimal转换"""
task = create_mock_task()
assert task.safe_decimal(100) == Decimal('100')
assert task.safe_decimal('123.45') == Decimal('123.45')
assert task.safe_decimal(None) == Decimal('0')
assert task.safe_decimal('invalid') == Decimal('0')
def test_safe_int(self):
"""测试安全int转换"""
task = create_mock_task()
assert task.safe_int(100) == 100
assert task.safe_int('123') == 123
assert task.safe_int(None) == 0
assert task.safe_int('invalid') == 0
def test_seconds_to_hours(self):
"""测试秒转小时"""
task = create_mock_task()
assert task.seconds_to_hours(3600) == Decimal('1')
assert task.seconds_to_hours(5400) == Decimal('1.5')
assert task.seconds_to_hours(0) == Decimal('0')
def test_hours_to_seconds(self):
"""测试小时转秒"""
task = create_mock_task()
assert task.hours_to_seconds(Decimal('1')) == 3600
assert task.hours_to_seconds(Decimal('1.5')) == 5400
class TestCourseType:
"""测试课程类型"""
def test_base_course(self):
"""测试基础课"""
assert CourseType.BASE.value == 'BASE'
def test_bonus_course(self):
"""测试附加课"""
assert CourseType.BONUS.value == 'BONUS'
# =============================================================================
# 辅助函数
# =============================================================================
def create_mock_task():
"""
创建一个模拟的BaseDwsTask实例用于测试
"""
# 创建一个具体的子类用于测试
class TestDwsTask(BaseDwsTask):
def get_task_code(self):
return "TEST_DWS_TASK"
def get_target_table(self):
return "test_table"
def get_primary_keys(self):
return ["id"]
def extract(self, context):
return {}
def load(self, transformed, context):
return {}
# 创建模拟的依赖
mock_config = MagicMock()
mock_config.get.return_value = None
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
task = TestDwsTask(mock_config, mock_db, mock_api, mock_logger)
return task
def create_finance_daily_task():
"""创建 FinanceDailyTask 实例用于测试"""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: 1 if key == "app.tenant_id" else default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
return FinanceDailyTask(mock_config, mock_db, mock_api, mock_logger)
def create_assistant_monthly_task():
"""创建 AssistantMonthlyTask 实例用于测试"""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
return AssistantMonthlyTask(mock_config, mock_db, mock_api, mock_logger)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

222
tests/unit/test_e2e_flow.py Normal file
View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
"""端到端流程集成测试
验证 CLI → PipelineRunner → TaskExecutor 完整调用链。
使用 mock 依赖,不需要真实数据库。
需求: 9.4
"""
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
from orchestration.task_executor import TaskExecutor, DataSource
from orchestration.pipeline_runner import PipelineRunner
from orchestration.task_registry import TaskRegistry
# ---------------------------------------------------------------------------
# 辅助:构造最小可用的 mock config
# ---------------------------------------------------------------------------
def _make_config(**overrides):
"""构造一个行为类似 AppConfig 的 MagicMock。"""
store = {
"app.timezone": "Asia/Shanghai",
"app.store_id": 1,
"io.fetch_root": "/tmp/fetch",
"io.ingest_source_dir": "",
"io.write_pretty_json": False,
"io.export_root": "/tmp/export",
"io.log_root": "/tmp/logs",
"pipeline.fetch_root": None,
"pipeline.ingest_source_dir": None,
"run.ods_tasks": [],
"run.dws_tasks": [],
"run.index_tasks": [],
"run.data_source": "hybrid",
"verification.ods_use_local_json": False,
"verification.skip_ods_when_fetch_before_verify": True,
}
store.update(overrides)
config = MagicMock()
config.get = MagicMock(side_effect=lambda k, d=None: store.get(k, d))
config.__getitem__ = MagicMock(side_effect=lambda k: {
"io": {"export_root": "/tmp/export", "log_root": "/tmp/logs"},
}[k])
return config
# ---------------------------------------------------------------------------
# 辅助:构造一个可被 TaskRegistry 注册的假任务类
# ---------------------------------------------------------------------------
class _FakeTask:
"""最小假任务execute() 返回固定结果。"""
def __init__(self, config, db_ops, api_client, logger):
pass
def execute(self, cursor_data):
return {"status": "SUCCESS", "counts": {"fetched": 5, "inserted": 3}}
# ===========================================================================
# 测试 1传统模式 — TaskExecutor.run_tasks 端到端
# ===========================================================================
class TestTraditionalModeE2E:
"""传统模式TaskExecutor.run_tasks 端到端"""
def test_run_tasks_executes_utility_task_and_returns_results(self):
"""工具类任务走 _run_utility_task 路径,跳过游标和运行记录。"""
config = _make_config()
registry = TaskRegistry()
registry.register(
"FAKE_UTIL", _FakeTask,
requires_db_config=False, task_type="utility",
)
cursor_mgr = MagicMock()
run_tracker = MagicMock()
executor = TaskExecutor(
config=config,
db_ops=MagicMock(),
api_client=MagicMock(),
cursor_mgr=cursor_mgr,
run_tracker=run_tracker,
task_registry=registry,
logger=MagicMock(),
)
results = executor.run_tasks(["FAKE_UTIL"], data_source="hybrid")
assert len(results) == 1
# 工具类任务成功时 run_tasks 包装为 "成功"
assert results[0]["status"] in ("成功", "完成", "SUCCESS")
# 工具类任务不应触发游标或运行记录
cursor_mgr.get_or_create.assert_not_called()
run_tracker.create_run.assert_not_called()
# ===========================================================================
# 测试 2管道模式 — PipelineRunner → TaskExecutor 端到端
# ===========================================================================
class TestPipelineModeE2E:
"""管道模式PipelineRunner.run → TaskExecutor.run_tasks 端到端"""
def test_pipeline_delegates_to_executor_and_returns_structure(self):
"""PipelineRunner 解析层→任务后委托 TaskExecutor 执行。"""
executor = MagicMock()
executor.run_tasks.return_value = [
{"task_code": "FAKE_ODS", "status": "成功", "counts": {"fetched": 10, "inserted": 8}},
]
registry = TaskRegistry()
registry.register("FAKE_ODS", _FakeTask, layer="ODS")
config = _make_config()
runner = PipelineRunner(
config=config,
task_executor=executor,
task_registry=registry,
db_conn=MagicMock(),
api_client=MagicMock(),
logger=MagicMock(),
)
result = runner.run(
pipeline="api_ods",
processing_mode="increment_only",
data_source="hybrid",
)
# 结构验证
assert result["status"] == "SUCCESS"
assert result["pipeline"] == "api_ods"
assert result["layers"] == ["ODS"]
assert isinstance(result["results"], list)
# TaskExecutor 被调用
executor.run_tasks.assert_called_once()
call_args = executor.run_tasks.call_args
assert call_args[1]["data_source"] == "hybrid"
def test_pipeline_verify_only_skips_increment(self):
"""verify_only 模式跳过增量 ETL仅执行校验。"""
executor = MagicMock()
executor.run_tasks.return_value = []
registry = TaskRegistry()
config = _make_config()
runner = PipelineRunner(
config=config,
task_executor=executor,
task_registry=registry,
db_conn=MagicMock(),
api_client=MagicMock(),
logger=MagicMock(),
)
# 校验框架可能未安装mock 掉 _run_verification
with patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}):
result = runner.run(
pipeline="api_ods",
processing_mode="verify_only",
data_source="hybrid",
)
assert result["status"] == "SUCCESS"
# verify_only 且 fetch_before_verify=False 时不调用 run_tasks
executor.run_tasks.assert_not_called()
# ===========================================================================
# 测试 3ETLScheduler 薄包装层委托验证
# ===========================================================================
class TestSchedulerThinWrapper:
"""ETLScheduler 薄包装层正确委托 TaskExecutor / PipelineRunner。"""
def test_scheduler_delegates_run_tasks(self):
"""run_tasks() 委托给内部 task_executor。"""
from orchestration.scheduler import ETLScheduler
mock_config = MagicMock()
mock_config.__getitem__ = MagicMock(side_effect=lambda k: {
"db": {
"dsn": "postgresql://fake:5432/test",
"session": {"timezone": "Asia/Shanghai"},
"connect_timeout_sec": 5,
},
"api": {
"base_url": "https://fake.api",
"token": "fake-token",
"timeout_sec": 30,
"retries": {"max_attempts": 3},
},
}[k])
mock_config.get = MagicMock(side_effect=lambda k, d=None: {
"run.data_source": "hybrid",
"run.tasks": ["FAKE"],
"app.timezone": "Asia/Shanghai",
}.get(k, d))
# mock 掉资源创建,避免真实连接
with patch("orchestration.scheduler.DatabaseConnection"), \
patch("orchestration.scheduler.DatabaseOperations"), \
patch("orchestration.scheduler.APIClient"), \
patch("orchestration.scheduler.CursorManager"), \
patch("orchestration.scheduler.RunTracker"), \
patch("orchestration.scheduler.TaskExecutor") as MockTE, \
patch("orchestration.scheduler.PipelineRunner") as MockPR:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
scheduler = ETLScheduler(mock_config, MagicMock())
# run_tasks 委托
scheduler.run_tasks(["TEST_TASK"])
scheduler.task_executor.run_tasks.assert_called_once()
# run_pipeline_with_verification 委托
scheduler.run_pipeline_with_verification(pipeline="api_ods")
scheduler.pipeline_runner.run.assert_called_once()

View File

@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""Unit tests for recent/former endpoint routing."""
import sys
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from api.endpoint_routing import plan_calls, recent_boundary
TZ = ZoneInfo("Asia/Shanghai")
def _now():
return datetime(2025, 12, 18, 10, 0, 0, tzinfo=TZ)
def test_recent_boundary_month_start():
b = recent_boundary(_now())
assert b.isoformat() == "2025-09-01T00:00:00+08:00"
def test_paylog_routes_to_former_when_old_window():
params = {"siteId": 1, "StartPayTime": "2025-08-01 00:00:00", "EndPayTime": "2025-08-02 00:00:00"}
calls = plan_calls("/PayLog/GetPayLogListPage", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/PayLog/GetFormerPayLogListPage"]
def test_coupon_usage_stays_same_path_even_when_old():
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
calls = plan_calls("/Promotion/GetOfflineCouponConsumePageList", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/Promotion/GetOfflineCouponConsumePageList"]
def test_goods_outbound_routes_to_queryformer_when_old():
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
calls = plan_calls("/GoodsStockManage/QueryGoodsOutboundReceipt", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/GoodsStockManage/QueryFormerGoodsOutboundReceipt"]
def test_settlement_records_split_when_crossing_boundary():
params = {"siteId": 1, "rangeStartTime": "2025-08-15 00:00:00", "rangeEndTime": "2025-09-10 00:00:00"}
calls = plan_calls("/Site/GetAllOrderSettleList", params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == ["/Site/GetFormerOrderSettleList", "/Site/GetAllOrderSettleList"]
assert calls[0].params["rangeEndTime"] == "2025-09-01 00:00:00"
assert calls[1].params["rangeStartTime"] == "2025-09-01 00:00:00"
@pytest.mark.parametrize(
"endpoint",
[
"/PayLog/GetFormerPayLogListPage",
"/Site/GetFormerOrderSettleList",
"/GoodsStockManage/QueryFormerGoodsOutboundReceipt",
],
)
def test_explicit_former_endpoint_not_rerouted(endpoint):
params = {"siteId": 1, "startTime": "2025-08-01 00:00:00", "endTime": "2025-08-02 00:00:00"}
calls = plan_calls(endpoint, params, now=_now(), tz=TZ)
assert [c.endpoint for c in calls] == [endpoint]

View File

@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
"""filter_verify_tables 单元测试"""
import pytest
from tasks.verification.models import filter_verify_tables
class TestFilterVerifyTables:
"""按层过滤校验表名"""
def test_none_input_returns_none(self):
assert filter_verify_tables("DWD", None) is None
def test_empty_list_returns_none(self):
assert filter_verify_tables("DWD", []) is None
def test_dwd_layer_filters_correctly(self):
tables = ["dwd_order", "dim_member", "fact_payment", "ods_raw", "dws_daily"]
result = filter_verify_tables("DWD", tables)
assert result == ["dwd_order", "dim_member", "fact_payment"]
def test_dws_layer_filters_correctly(self):
tables = ["dws_daily", "dwd_order", "dws_summary"]
result = filter_verify_tables("DWS", tables)
assert result == ["dws_daily", "dws_summary"]
def test_index_layer_filters_correctly(self):
tables = ["v_score", "wbi_index", "dws_daily", "v_rank"]
result = filter_verify_tables("INDEX", tables)
assert result == ["v_score", "wbi_index", "v_rank"]
def test_ods_layer_filters_correctly(self):
tables = ["ods_order", "dwd_order", "ods_member"]
result = filter_verify_tables("ODS", tables)
assert result == ["ods_order", "ods_member"]
def test_unknown_layer_returns_normalized(self):
tables = [" SomeTable ", "Another"]
result = filter_verify_tables("UNKNOWN", tables)
assert result == ["sometable", "another"]
def test_layer_case_insensitive(self):
tables = ["dwd_order", "ods_raw"]
assert filter_verify_tables("dwd", tables) == ["dwd_order"]
assert filter_verify_tables("Dwd", tables) == ["dwd_order"]
def test_whitespace_and_empty_entries_stripped(self):
tables = [" dwd_order ", "", " ", None, "dim_member"]
result = filter_verify_tables("DWD", tables)
assert result == ["dwd_order", "dim_member"]

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""Unit tests for the new ODS ingestion tasks."""
import logging
import os
import sys
from pathlib import Path
# 确保在独立运行测试时能正确解析项目根目录
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
from tasks.ods.ods_tasks import ODS_TASK_CLASSES
from .task_test_utils import create_test_config, get_db_operations, FakeAPIClient
def _build_config(tmp_path):
archive_dir = tmp_path / "archive"
temp_dir = tmp_path / "temp"
return create_test_config("ONLINE", archive_dir, temp_dir)
def test_assistant_accounts_masters_ingest(tmp_path):
"""Ensure ODS_ASSISTANT_ACCOUNT stores raw payload with record_index dedup keys."""
config = _build_config(tmp_path)
sample = [
{
"id": 5001,
"assistant_no": "A01",
"nickname": "小张",
}
]
api = FakeAPIClient({"/PersonnelManagement/SearchAssistantInfo": sample})
task_cls = ODS_TASK_CLASSES["ODS_ASSISTANT_ACCOUNT"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_assistant_accounts_masters"))
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == 1
assert db_ops.commits == 1
row = db_ops.upserts[0]["rows"][0]
assert row["id"] == 5001
assert row["record_index"] == 0
assert row["source_file"] is None or row["source_file"] == ""
assert '"id": 5001' in row["payload"]
def test_goods_stock_movements_ingest(tmp_path):
"""Ensure ODS_INVENTORY_CHANGE stores raw payload with record_index dedup keys."""
config = _build_config(tmp_path)
sample = [
{
"siteGoodsStockId": 123456,
"stockType": 1,
"goodsName": "测试商品",
}
]
api = FakeAPIClient({"/GoodsStockManage/QueryGoodsOutboundReceipt": sample})
task_cls = ODS_TASK_CLASSES["ODS_INVENTORY_CHANGE"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_goods_stock_movements"))
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == 1
assert db_ops.commits == 1
row = db_ops.upserts[0]["rows"][0]
assert row["sitegoodsstockid"] == 123456
assert row["record_index"] == 0
assert '"siteGoodsStockId": 123456' in row["payload"]
def test_member_profiless_ingest(tmp_path):
"""Ensure ODS_MEMBER task stores tenantMemberInfos raw JSON."""
config = _build_config(tmp_path)
sample = [{"tenantMemberInfos": [{"id": 101, "mobile": "13800000000"}]}]
api = FakeAPIClient({"/MemberProfile/GetTenantMemberList": sample})
task_cls = ODS_TASK_CLASSES["ODS_MEMBER"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_member"))
result = task.execute()
assert result["status"] == "SUCCESS"
row = db_ops.upserts[0]["rows"][0]
assert row["record_index"] == 0
assert '"id": 101' in row["payload"]
def test_ods_payment_ingest(tmp_path):
"""Ensure ODS_PAYMENT task stores payment_transactions raw JSON."""
config = _build_config(tmp_path)
sample = [{"payId": 901, "payAmount": "100.00"}]
api = FakeAPIClient({"/PayLog/GetPayLogListPage": sample})
task_cls = ODS_TASK_CLASSES["ODS_PAYMENT"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_payment"))
result = task.execute()
assert result["status"] == "SUCCESS"
row = db_ops.upserts[0]["rows"][0]
assert row["record_index"] == 0
assert '"payId": 901' in row["payload"]
def test_ods_settlement_records_ingest(tmp_path):
"""Ensure ODS_SETTLEMENT_RECORDS stores settleList raw JSON."""
config = _build_config(tmp_path)
sample = [{"id": 701, "orderTradeNo": 8001}]
api = FakeAPIClient({"/Site/GetAllOrderSettleList": sample})
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_RECORDS"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_settlement_records"))
result = task.execute()
assert result["status"] == "SUCCESS"
row = db_ops.upserts[0]["rows"][0]
assert row["record_index"] == 0
assert '"orderTradeNo": 8001' in row["payload"]
def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
"""Ensure settlement tickets are fetched per payment relate_id and skip existing ones."""
config = _build_config(tmp_path)
ticket_payload = {"data": {"data": {"orderSettleId": 9001, "orderSettleNumber": "T001"}}}
api = FakeAPIClient({"/Order/GetOrderSettleTicketNew": [ticket_payload]})
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"]
with get_db_operations() as db_ops:
# 第一次查询已有的小票ID第二次查询支付关联ID
db_ops.query_results = [
[{"order_settle_id": 9002}],
[
{"order_settle_id": 9001},
{"order_settle_id": 9002},
{"order_settle_id": None},
],
]
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_settlement_ticket"))
result = task.execute()
assert result["status"] == "SUCCESS"
counts = result["counts"]
assert counts["fetched"] == 1
assert counts["inserted"] == 1
assert counts["updated"] == 0
assert counts["skipped"] == 0
assert '"orderSettleId": 9001' in db_ops.upserts[0]["rows"][0]["payload"]
assert any(
call["endpoint"] == "/Order/GetOrderSettleTicketNew"
and call.get("params", {}).get("orderSettleId") == 9001
for call in api.calls
)

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""解析器测试"""
import pytest
from decimal import Decimal
from datetime import datetime
from zoneinfo import ZoneInfo
from models.parsers import TypeParser
def test_parse_decimal():
"""测试金额解析"""
assert TypeParser.parse_decimal("100.555", 2) == Decimal("100.56")
assert TypeParser.parse_decimal(None) is None
assert TypeParser.parse_decimal("invalid") is None
def test_parse_int():
"""测试整数解析"""
assert TypeParser.parse_int("123") == 123
assert TypeParser.parse_int(456) == 456
assert TypeParser.parse_int(None) is None
assert TypeParser.parse_int("abc") is None
def test_parse_timestamp():
"""测试时间戳解析"""
tz = ZoneInfo("Asia/Taipei")
dt = TypeParser.parse_timestamp("2025-01-15 10:30:00", tz)
assert dt is not None
assert dt.year == 2025
assert dt.month == 1
assert dt.day == 15
def test_parse_timestamp_zero_epoch():
"""0 不应被当成空值;应解析为 Unix epoch。"""
tz = ZoneInfo("Asia/Taipei")
dt = TypeParser.parse_timestamp(0, tz)
assert dt is not None
assert dt.year == 1970
assert dt.month == 1
assert dt.day == 1

View File

@@ -0,0 +1,304 @@
# -*- coding: utf-8 -*-
"""PipelineRunner 属性测试 - hypothesis 验证管道编排器的通用正确性属性。"""
import string
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from orchestration.pipeline_runner import PipelineRunner
# run() 内部延迟导入 TaskLogger需要 mock 源模块路径
_TASK_LOGGER_PATH = "utils.task_logger.TaskLogger"
FILE_VERSION = "v1_shell"
# ── 策略定义 ──────────────────────────────────────────────────────
pipeline_name_st = st.sampled_from(list(PipelineRunner.PIPELINE_LAYERS.keys()))
processing_mode_st = st.sampled_from(["increment_only", "verify_only", "increment_verify"])
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
_TASK_PREFIXES = ["ODS_", "DWD_", "DWS_", "INDEX_"]
task_code_st = st.builds(
lambda prefix, suffix: prefix + suffix,
prefix=st.sampled_from(_TASK_PREFIXES),
suffix=st.text(
alphabet=string.ascii_uppercase + string.digits + "_",
min_size=1, max_size=12,
),
)
# 单任务结果生成器
task_result_st = st.fixed_dictionaries({
"task_code": task_code_st,
"status": st.sampled_from(["SUCCESS", "FAIL", "SKIP"]),
"counts": st.fixed_dictionaries({
"fetched": st.integers(min_value=0, max_value=10000),
"inserted": st.integers(min_value=0, max_value=10000),
"updated": st.integers(min_value=0, max_value=10000),
"skipped": st.integers(min_value=0, max_value=10000),
"errors": st.integers(min_value=0, max_value=100),
}),
"dump_dir": st.none(),
})
task_results_st = st.lists(task_result_st, min_size=0, max_size=10)
# ── 辅助函数 ──────────────────────────────────────────────────────
def _make_config():
"""创建 mock 配置对象。"""
config = MagicMock()
config.get = MagicMock(side_effect=lambda key, default=None: {
"app.timezone": "Asia/Shanghai",
"verification.ods_use_local_json": False,
"verification.skip_ods_when_fetch_before_verify": True,
"run.ods_tasks": [],
"run.dws_tasks": [],
"run.index_tasks": [],
}.get(key, default))
return config
def _make_runner(task_executor=None, task_registry=None):
"""创建 PipelineRunner 实例,注入 mock 依赖。"""
if task_executor is None:
task_executor = MagicMock()
task_executor.run_tasks.return_value = []
if task_registry is None:
task_registry = MagicMock()
task_registry.get_tasks_by_layer.return_value = ["FAKE_TASK"]
return PipelineRunner(
config=_make_config(),
task_executor=task_executor,
task_registry=task_registry,
db_conn=MagicMock(),
api_client=MagicMock(),
logger=MagicMock(),
)
# ── Property 5: 管道名称→层列表映射 ──────────────────────────────
# Feature: scheduler-refactor, Property 5: 管道名称→层列表映射
# **Validates: Requirements 2.1**
class TestProperty5PipelineNameToLayers:
"""对于任意有效的管道名称PipelineRunner 解析出的层列表应与
PIPELINE_LAYERS 字典中的定义完全一致。"""
@given(pipeline=pipeline_name_st)
@settings(max_examples=100)
def test_layers_match_pipeline_definition(self, pipeline):
"""run() 返回的 layers 字段与 PIPELINE_LAYERS[pipeline] 完全一致。"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
expected_layers = PipelineRunner.PIPELINE_LAYERS[pipeline]
assert result["layers"] == expected_layers
@given(pipeline=pipeline_name_st)
@settings(max_examples=100)
def test_resolve_tasks_called_with_correct_layers(self, pipeline):
"""_resolve_tasks 接收的层列表与 PIPELINE_LAYERS 定义一致。"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with (
patch(_TASK_LOGGER_PATH),
patch.object(runner, "_resolve_tasks", wraps=runner._resolve_tasks) as spy,
):
runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
expected_layers = PipelineRunner.PIPELINE_LAYERS[pipeline]
spy.assert_called_once_with(expected_layers)
# ── Property 6: processing_mode 控制执行流程 ─────────────────────
# Feature: scheduler-refactor, Property 6: processing_mode 控制执行流程
# **Validates: Requirements 2.3, 2.4**
class TestProperty6ProcessingModeControlsFlow:
"""对于任意 processing_mode增量 ETL 执行当且仅当模式包含 increment
校验流程执行当且仅当模式包含 verify。"""
@given(
pipeline=pipeline_name_st,
mode=processing_mode_st,
data_source=data_source_st,
)
@settings(max_examples=100)
def test_increment_executes_iff_mode_contains_increment(self, pipeline, mode, data_source):
"""增量 ETLtask_executor.run_tasks执行当且仅当 mode 包含 'increment'"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with (
patch(_TASK_LOGGER_PATH),
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}),
):
runner.run(
pipeline=pipeline,
processing_mode=mode,
data_source=data_source,
)
should_increment = "increment" in mode
if should_increment:
assert executor.run_tasks.called, (
f"mode={mode} 包含 'increment',但 run_tasks 未被调用"
)
else:
# verify_only 且 fetch_before_verify=False默认run_tasks 不应被调用
assert not executor.run_tasks.called, (
f"mode={mode} 不包含 'increment',但 run_tasks 被调用了"
)
@given(
pipeline=pipeline_name_st,
mode=processing_mode_st,
data_source=data_source_st,
)
@settings(max_examples=100)
def test_verification_executes_iff_mode_contains_verify(self, pipeline, mode, data_source):
"""校验流程_run_verification执行当且仅当 mode 包含 'verify'"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with (
patch(_TASK_LOGGER_PATH),
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}) as mock_verify,
):
runner.run(
pipeline=pipeline,
processing_mode=mode,
data_source=data_source,
)
should_verify = "verify" in mode
if should_verify:
assert mock_verify.called, (
f"mode={mode} 包含 'verify',但 _run_verification 未被调用"
)
else:
assert not mock_verify.called, (
f"mode={mode} 不包含 'verify',但 _run_verification 被调用了"
)
# ── Property 7: 管道结果汇总完整性 ──────────────────────────────
# Feature: scheduler-refactor, Property 7: 管道结果汇总完整性
# **Validates: Requirements 2.6**
class TestProperty7PipelineSummaryCompleteness:
"""对于任意一组任务执行结果PipelineRunner 返回的汇总字典应包含
status/pipeline/layers/results 字段,且 results 长度等于实际执行的任务数。"""
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_summary_has_required_fields(self, pipeline, task_results):
"""返回字典必须包含 status、pipeline、layers、results、verification_summary。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
required_keys = {"status", "pipeline", "layers", "results", "verification_summary"}
assert required_keys.issubset(result.keys()), (
f"缺少必要字段: {required_keys - result.keys()}"
)
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_results_length_equals_executed_tasks(self, pipeline, task_results):
"""results 列表长度等于 task_executor.run_tasks 返回的任务数。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
assert len(result["results"]) == len(task_results), (
f"results 长度 {len(result['results'])} != 实际任务数 {len(task_results)}"
)
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_pipeline_and_layers_match_input(self, pipeline, task_results):
"""返回的 pipeline 和 layers 字段与输入一致。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
assert result["pipeline"] == pipeline
assert result["layers"] == PipelineRunner.PIPELINE_LAYERS[pipeline]
@given(
pipeline=pipeline_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_increment_only_has_no_verification(self, pipeline, task_results):
"""increment_only 模式下 verification_summary 应为 None。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
processing_mode="increment_only",
data_source="offline",
)
assert result["verification_summary"] is None

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""关系指数基础能力单测。"""
from __future__ import annotations
import logging
from datetime import date
from typing import Any, Dict, List, Optional
from tasks.dws.index.base_index_task import BaseIndexTask
from tasks.dws.index.ml_manual_import_task import MlManualImportTask
class _DummyConfig:
"""最小配置桩对象。"""
def __init__(self, values: Optional[Dict[str, Any]] = None):
self._values = values or {}
def get(self, key: str, default: Any = None) -> Any:
return self._values.get(key, default)
class _DummyDB:
"""最小数据库桩对象。"""
def __init__(self) -> None:
self.query_calls: List[tuple] = []
def query(self, sql: str, params=None):
self.query_calls.append((sql, params))
index_type = (params or [None])[0]
if index_type == "RS":
return [{"param_name": "lookback_days", "param_value": 60}]
if index_type == "MS":
return [{"param_name": "lookback_days", "param_value": 30}]
return []
class _DummyIndexTask(BaseIndexTask):
"""用于测试 BaseIndexTask 的最小实现。"""
INDEX_TYPE = "RS"
def get_task_code(self) -> str: # pragma: no cover - 测试桩
return "DUMMY_INDEX"
def get_target_table(self) -> str: # pragma: no cover - 测试桩
return "dummy_table"
def get_primary_keys(self) -> List[str]: # pragma: no cover - 测试桩
return ["id"]
def get_index_type(self) -> str:
return self.INDEX_TYPE
def extract(self, context): # pragma: no cover - 测试桩
return []
def load(self, transformed, context): # pragma: no cover - 测试桩
return {}
def test_load_index_parameters_cache_isolated_by_index_type():
"""参数缓存应按 index_type 隔离,避免单任务串参。"""
task = _DummyIndexTask(
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
_DummyDB(),
None,
logging.getLogger("test_index_cache"),
)
rs_first = task.load_index_parameters(index_type="RS")
ms_first = task.load_index_parameters(index_type="MS")
rs_second = task.load_index_parameters(index_type="RS")
assert rs_first["lookback_days"] == 60.0
assert ms_first["lookback_days"] == 30.0
assert rs_second["lookback_days"] == 60.0
# 只应查询两次RS 一次 + MS 一次,第二次 RS 命中缓存
assert len(task.db.query_calls) == 2
def test_batch_normalize_passes_index_type_to_smoothing_chain():
"""batch_normalize_to_display 应把 index_type 传入平滑链路。"""
task = _DummyIndexTask(
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
_DummyDB(),
None,
logging.getLogger("test_index_smoothing"),
)
captured: Dict[str, Any] = {}
def _fake_apply(site_id, current_p5, current_p95, alpha=None, index_type=None):
captured["index_type"] = index_type
return current_p5, current_p95
task._apply_ewma_smoothing = _fake_apply # type: ignore[method-assign]
result = task.batch_normalize_to_display(
raw_scores=[("a", 1.0), ("b", 2.0), ("c", 3.0)],
use_smoothing=True,
site_id=1,
index_type="ML",
)
assert result
assert captured["index_type"] == "ML"
def test_ml_manual_import_scope_day_and_p30_boundary():
"""30天边界内按天覆盖超过30天进入固定纪元30天桶。"""
today = date(2026, 2, 8)
day_scope = MlManualImportTask.resolve_scope(
site_id=1,
biz_date=date(2026, 1, 9), # 距 today 30 天
today=today,
)
assert day_scope.scope_type == "DAY"
assert day_scope.start_date == date(2026, 1, 9)
assert day_scope.end_date == date(2026, 1, 9)
p30_scope = MlManualImportTask.resolve_scope(
site_id=1,
biz_date=date(2026, 1, 8), # 距 today 31 天
today=today,
)
assert p30_scope.scope_type == "P30"
# 固定纪元 2026-01-01第一桶应为 2026-01-01 ~ 2026-01-30
assert p30_scope.start_date == date(2026, 1, 1)
assert p30_scope.end_date == date(2026, 1, 30)

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""汇总与报告工具的单测。"""
from utils.reporting import summarize_counts, format_report
def test_summarize_counts_and_format():
task_results = [
{"task_code": "ORDERS", "counts": {"fetched": 2, "inserted": 2, "updated": 0, "skipped": 0, "errors": 0}},
{"task_code": "PAYMENTS", "counts": {"fetched": 3, "inserted": 2, "updated": 1, "skipped": 0, "errors": 0}},
]
summary = summarize_counts(task_results)
assert summary["total"]["fetched"] == 5
assert summary["total"]["inserted"] == 4
assert summary["total"]["updated"] == 1
assert summary["total"]["errors"] == 0
assert len(summary["details"]) == 2
report = format_report(summary)
assert "TOTAL fetched=5" in report
assert "ORDERS:" in report
assert "PAYMENTS:" in report

View File

@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
"""TaskExecutor 属性测试 - hypothesis 验证执行器的通用正确性属性。"""
import re
import string
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from orchestration.task_executor import TaskExecutor, DataSource
from orchestration.task_registry import TaskRegistry
FILE_VERSION = "v4_shell"
data_source_st = st.sampled_from(["online", "offline", "hybrid"])
_NON_ODS_PREFIXES = ["DWD_", "DWS_", "TASK_", "ETL_", "TEST_"]
task_code_st = st.builds(
lambda prefix, suffix: prefix + suffix,
prefix=st.sampled_from(_NON_ODS_PREFIXES),
suffix=st.text(
alphabet=string.ascii_uppercase + string.digits + "_",
min_size=1, max_size=15,
),
)
window_start_st = st.datetimes(min_value=datetime(2020, 1, 1), max_value=datetime(2030, 12, 31))
def _make_fake_class(name="FakeTask"):
return type(name, (), {"__init__": lambda self, *a, **kw: None})
def _make_config():
config = MagicMock()
config.get = MagicMock(side_effect=lambda key, default=None: {
"app.timezone": "Asia/Shanghai",
"io.fetch_root": "/tmp/fetch",
"io.ingest_source_dir": "/tmp/ingest",
"io.write_pretty_json": False,
"pipeline.fetch_root": None,
"pipeline.ingest_source_dir": None,
"integrity.auto_check": False,
"run.overlap_seconds": 600,
}.get(key, default))
config.__getitem__ = MagicMock(side_effect=lambda k: {
"io": {"export_root": "/tmp/export", "log_root": "/tmp/log"},
}[k])
return config
def _make_executor(registry):
return TaskExecutor(
config=_make_config(), db_ops=MagicMock(), api_client=MagicMock(),
cursor_mgr=MagicMock(), run_tracker=MagicMock(),
task_registry=registry, logger=MagicMock(),
)
# Feature: scheduler-refactor, Property 1: data_source 参数决定执行路径
# **Validates: Requirements 1.2**
class TestProperty1DataSourceDeterminesPath:
@given(ds=data_source_st)
@settings(max_examples=100)
def test_flow_includes_fetch(self, ds):
result = TaskExecutor._flow_includes_fetch(ds)
assert result == (ds in {"online", "hybrid"})
@given(ds=data_source_st)
@settings(max_examples=100)
def test_flow_includes_ingest(self, ds):
result = TaskExecutor._flow_includes_ingest(ds)
assert result == (ds in {"offline", "hybrid"})
@given(ds=data_source_st)
@settings(max_examples=100)
def test_fetch_and_ingest_consistency(self, ds):
fetch = TaskExecutor._flow_includes_fetch(ds)
ingest = TaskExecutor._flow_includes_ingest(ds)
if ds == "hybrid":
assert fetch and ingest
elif ds == "online":
assert fetch and not ingest
elif ds == "offline":
assert not fetch and ingest
# Feature: scheduler-refactor, Property 2: 成功任务推进游标
# **Validates: Requirements 1.3**
class TestProperty2SuccessAdvancesCursor:
@given(
task_code=task_code_st,
window_start=window_start_st,
window_minutes=st.integers(min_value=1, max_value=1440),
)
@settings(max_examples=100)
def test_success_with_window_advances_cursor(self, task_code, window_start, window_minutes):
window_end = window_start + timedelta(minutes=window_minutes)
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
task_result = {
"status": "SUCCESS",
"counts": {"fetched": 10, "inserted": 5},
"window": {"start": window_start, "end": window_end, "minutes": window_minutes},
}
executor = _make_executor(registry)
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
executor.run_tracker.create_run.return_value = 100
with (
patch.object(TaskExecutor, "_load_task_config", return_value={
"task_id": 42, "task_code": task_code, "store_id": 1, "enabled": True}),
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
patch.object(TaskExecutor, "_maybe_run_integrity_check"),
):
executor.run_single_task(task_code, "test-uuid", 1, "offline")
executor.cursor_mgr.advance.assert_called_once()
kw = executor.cursor_mgr.advance.call_args.kwargs
assert kw["window_start"] == window_start
assert kw["window_end"] == window_end
# Feature: scheduler-refactor, Property 3: 失败任务标记 FAIL 并重新抛出
# **Validates: Requirements 1.4**
class TestProperty3FailureMarksFailAndReraises:
@given(
task_code=task_code_st,
error_msg=st.text(
alphabet=string.ascii_letters + string.digits + " _-",
min_size=1, max_size=80,
),
)
@settings(max_examples=100)
def test_exception_marks_fail_and_reraises(self, task_code, error_msg):
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWD")
executor = _make_executor(registry)
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
executor.run_tracker.create_run.return_value = 200
with (
patch.object(TaskExecutor, "_load_task_config", return_value={
"task_id": 99, "task_code": task_code, "store_id": 1, "enabled": True}),
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
patch.object(TaskExecutor, "_execute_ingest", side_effect=RuntimeError(error_msg)),
):
with pytest.raises(RuntimeError, match=re.escape(error_msg)):
executor.run_single_task(task_code, "fail-uuid", 1, "offline")
executor.run_tracker.update_run.assert_called()
kw = executor.run_tracker.update_run.call_args.kwargs
assert kw["status"] == "FAIL"
# Feature: scheduler-refactor, Property 4: 工具类任务由元数据决定
# **Validates: Requirements 1.6, 4.2**
class TestProperty4UtilityTaskDeterminedByMetadata:
@given(task_code=task_code_st)
@settings(max_examples=100)
def test_utility_task_skips_cursor_and_run_tracker(self, task_code):
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=False, task_type="utility")
executor = _make_executor(registry)
mock_task = MagicMock()
mock_task.execute.return_value = {"status": "SUCCESS", "counts": {}}
registry.create_task = MagicMock(return_value=mock_task)
result = executor.run_single_task(task_code, "util-uuid", 1, "hybrid")
executor.cursor_mgr.get_or_create.assert_not_called()
executor.cursor_mgr.advance.assert_not_called()
executor.run_tracker.create_run.assert_not_called()
assert result.get("status") == "SUCCESS"
@given(task_code=task_code_st)
@settings(max_examples=100)
def test_non_utility_task_uses_cursor_and_run_tracker(self, task_code):
registry = TaskRegistry()
registry.register(task_code, _make_fake_class(), requires_db_config=True, layer="DWS")
task_result = {"status": "SUCCESS", "counts": {"fetched": 1}}
executor = _make_executor(registry)
executor.cursor_mgr.get_or_create.return_value = {"cursor_id": 1, "last_end": None}
executor.run_tracker.create_run.return_value = 300
with (
patch.object(TaskExecutor, "_load_task_config", return_value={
"task_id": 77, "task_code": task_code, "store_id": 1, "enabled": True}),
patch.object(TaskExecutor, "_resolve_ingest_source", return_value=Path("/tmp/src")),
patch.object(TaskExecutor, "_execute_ingest", return_value=task_result),
):
executor.run_single_task(task_code, "non-util-uuid", 1, "offline")
executor.cursor_mgr.get_or_create.assert_called_once()
executor.run_tracker.create_run.assert_called_once()

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""TaskRegistry 单元测试 — 验证 TaskMeta 元数据注册与查询"""
import pytest
from orchestration.task_registry import TaskRegistry, TaskMeta
# ── 辅助:用作注册的假任务类 ──────────────────────────────────
class _FakeTask:
"""占位任务类,用于测试注册"""
def __init__(self, config, db_connection, api_client, logger):
self.config = config
class _AnotherFakeTask:
def __init__(self, config, db_connection, api_client, logger):
pass
# ── fixtures ──────────────────────────────────────────────────
@pytest.fixture
def registry():
return TaskRegistry()
# ── register + get_metadata ───────────────────────────────────
class TestRegisterAndMetadata:
"""注册与元数据查询"""
def test_register_with_defaults(self, registry):
"""仅传 task_code + task_class 时,元数据使用默认值(向后兼容)"""
registry.register("MY_TASK", _FakeTask)
meta = registry.get_metadata("MY_TASK")
assert meta is not None
assert meta.task_class is _FakeTask
assert meta.requires_db_config is True
assert meta.layer is None
assert meta.task_type == "etl"
def test_register_with_full_metadata(self, registry):
"""传入完整元数据"""
registry.register(
"ODS_ORDERS", _FakeTask,
requires_db_config=True, layer="ODS", task_type="etl",
)
meta = registry.get_metadata("ODS_ORDERS")
assert meta.layer == "ODS"
assert meta.task_type == "etl"
def test_register_utility_task(self, registry):
"""工具类任务requires_db_config=False"""
registry.register(
"INIT_SCHEMA", _FakeTask,
requires_db_config=False, task_type="utility",
)
meta = registry.get_metadata("INIT_SCHEMA")
assert meta.requires_db_config is False
assert meta.task_type == "utility"
def test_case_insensitive_lookup(self, registry):
"""task_code 大小写不敏感"""
registry.register("my_task", _FakeTask)
assert registry.get_metadata("MY_TASK") is not None
assert registry.get_metadata("my_task") is not None
def test_get_metadata_unknown_returns_none(self, registry):
"""查询未注册的任务返回 None"""
assert registry.get_metadata("NONEXISTENT") is None
# ── create_task接口不变────────────────────────────────────
class TestCreateTask:
def test_create_task_returns_instance(self, registry):
registry.register("MY_TASK", _FakeTask)
task = registry.create_task("MY_TASK", {"k": "v"}, None, None, None)
assert isinstance(task, _FakeTask)
assert task.config == {"k": "v"}
def test_create_task_unknown_raises(self, registry):
with pytest.raises(ValueError, match="未知的任务类型"):
registry.create_task("NOPE", None, None, None, None)
# ── get_tasks_by_layer ────────────────────────────────────────
class TestGetTasksByLayer:
def test_returns_matching_tasks(self, registry):
registry.register("A", _FakeTask, layer="ODS")
registry.register("B", _AnotherFakeTask, layer="ODS")
registry.register("C", _FakeTask, layer="DWD")
result = registry.get_tasks_by_layer("ODS")
assert set(result) == {"A", "B"}
def test_case_insensitive_layer(self, registry):
registry.register("X", _FakeTask, layer="dws")
assert registry.get_tasks_by_layer("DWS") == ["X"]
def test_no_match_returns_empty(self, registry):
registry.register("A", _FakeTask, layer="ODS")
assert registry.get_tasks_by_layer("INDEX") == []
def test_none_layer_excluded(self, registry):
"""layer=None 的任务不会被任何层查询返回"""
registry.register("UTIL", _FakeTask) # layer 默认 None
assert registry.get_tasks_by_layer("ODS") == []
# ── is_utility_task ───────────────────────────────────────────
class TestIsUtilityTask:
def test_utility_task(self, registry):
registry.register("INIT", _FakeTask, requires_db_config=False)
assert registry.is_utility_task("INIT") is True
def test_normal_task(self, registry):
registry.register("ETL", _FakeTask, requires_db_config=True)
assert registry.is_utility_task("ETL") is False
def test_unknown_task(self, registry):
assert registry.is_utility_task("NOPE") is False
# ── get_all_task_codes接口不变──────────────────────────────
class TestGetAllTaskCodes:
def test_returns_all_codes(self, registry):
registry.register("A", _FakeTask)
registry.register("B", _AnotherFakeTask)
assert set(registry.get_all_task_codes()) == {"A", "B"}
def test_empty_registry(self, registry):
assert registry.get_all_task_codes() == []

View File

@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
"""TaskRegistry 属性测试 — 使用 hypothesis 验证注册表的通用正确性属性。"""
import string
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from orchestration.task_registry import TaskRegistry, TaskMeta
# ── 辅助:动态生成假任务类 ────────────────────────────────────
def _make_fake_class(name: str = "FakeTask") -> type:
"""创建一个最小化的假任务类,用于注册测试。"""
return type(name, (), {"__init__": lambda self, *a, **kw: None})
# ── 生成策略 ──────────────────────────────────────────────────
# 合法任务代码:大写字母 + 数字 + 下划线,长度 1~30
task_code_st = st.text(
alphabet=string.ascii_uppercase + string.digits + "_",
min_size=1,
max_size=30,
)
requires_db_config_st = st.booleans()
layer_st = st.sampled_from([None, "ODS", "DWD", "DWS", "INDEX"])
task_type_st = st.sampled_from(["etl", "utility", "verification"])
# ── Property 8: TaskRegistry 元数据 round-trip ────────────────
# Feature: scheduler-refactor, Property 8: TaskRegistry 元数据 round-trip
# **Validates: Requirements 4.1**
#
# 对于任意任务代码、任务类和元数据组合requires_db_config、layer、task_type
# 注册后通过 get_metadata 查询应返回相同的元数据值。
class TestProperty8MetadataRoundTrip:
"""Property 8: 注册元数据后查询应返回完全相同的值。"""
@given(
task_code=task_code_st,
requires_db=requires_db_config_st,
layer=layer_st,
task_type=task_type_st,
)
@settings(max_examples=100)
def test_metadata_round_trip(
self, task_code: str, requires_db: bool, layer: str | None, task_type: str
):
"""注册任意元数据组合后get_metadata 应返回相同的值。"""
# Arrange — 每次迭代使用全新的注册表,避免状态泄漏
registry = TaskRegistry()
fake_cls = _make_fake_class()
# Act — 注册并查询
registry.register(
task_code,
fake_cls,
requires_db_config=requires_db,
layer=layer,
task_type=task_type,
)
meta = registry.get_metadata(task_code)
# Assert — 元数据 round-trip 一致
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
assert meta.requires_db_config is requires_db, (
f"requires_db_config 应为 {requires_db},实际为 {meta.requires_db_config}"
)
assert meta.layer == layer, f"layer 应为 {layer!r},实际为 {meta.layer!r}"
assert meta.task_type == task_type, (
f"task_type 应为 {task_type!r},实际为 {meta.task_type!r}"
)
# ── Property 9: TaskRegistry 向后兼容默认值 ───────────────────
# Feature: scheduler-refactor, Property 9: TaskRegistry 向后兼容默认值
# **Validates: Requirements 4.4**
#
# 对于任意使用旧接口(仅 task_code 和 task_class注册的任务
# 查询元数据应返回 requires_db_config=True、layer=None、task_type="etl"。
class TestProperty9BackwardCompatibleDefaults:
"""Property 9: 仅传 task_code + task_class 时,元数据应使用默认值。"""
@given(task_code=task_code_st)
@settings(max_examples=100)
def test_legacy_register_uses_defaults(self, task_code: str):
"""使用旧接口(仅 task_code 和 task_class注册后元数据应为默认值。"""
# Arrange
registry = TaskRegistry()
fake_cls = _make_fake_class()
# Act — 仅传 task_code 和 task_class不传任何元数据参数
registry.register(task_code, fake_cls)
meta = registry.get_metadata(task_code)
# Assert — 默认值契约
assert meta is not None, f"注册后 get_metadata('{task_code}') 不应返回 None"
assert meta.task_class is fake_cls, "task_class 应与注册时一致"
assert meta.requires_db_config is True, (
f"默认 requires_db_config 应为 True实际为 {meta.requires_db_config}"
)
assert meta.layer is None, (
f"默认 layer 应为 None实际为 {meta.layer!r}"
)
assert meta.task_type == "etl", (
f"默认 task_type 应为 'etl',实际为 {meta.task_type!r}"
)
# ── Property 10: 按层查询任务 ────────────────────────────────
# Feature: scheduler-refactor, Property 10: 按层查询任务
# **Validates: Requirements 4.3**
#
# 对于任意注册了 layer 元数据的任务集合get_tasks_by_layer(layer)
# 返回的任务代码集合应等于所有 layer 匹配的已注册任务代码集合。
# 非 None 的层值策略,用于查询验证
non_none_layer_st = st.sampled_from(["ODS", "DWD", "DWS", "INDEX"])
class TestProperty10GetTasksByLayer:
"""Property 10: get_tasks_by_layer 返回的集合应与手动过滤一致。"""
@given(
entries=st.lists(
st.tuples(task_code_st, layer_st),
min_size=1,
max_size=20,
),
)
@settings(max_examples=100)
def test_get_tasks_by_layer_matches_manual_filter(
self, entries: list[tuple[str, str | None]],
):
"""注册一组任务后,按层查询结果应与手动过滤完全一致。"""
# Arrange
registry = TaskRegistry()
# 去重:同一 task_code 只保留最后一次注册(与 register 覆盖语义一致)
unique_entries: dict[str, str | None] = {}
for code, layer in entries:
fake_cls = _make_fake_class(f"Fake_{code}")
registry.register(code, fake_cls, layer=layer)
unique_entries[code.upper()] = layer # register 内部会 upper()
# Act & Assert — 对每个非 None 的层值进行验证
for query_layer in ["ODS", "DWD", "DWS", "INDEX"]:
actual = set(registry.get_tasks_by_layer(query_layer))
expected = {
code for code, layer in unique_entries.items()
if layer is not None and layer.upper() == query_layer.upper()
}
assert actual == expected, (
f"查询 layer={query_layer!r} 时,"
f"期望 {expected},实际 {actual}"
)