275 lines
9.4 KiB
Python
275 lines
9.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""任务注册表 API 单元测试
|
||
|
||
覆盖 4 个端点:registry / dwd-tables / flows / validate
|
||
通过 JWT mock 绕过认证依赖。
|
||
"""
|
||
|
||
import os
|
||
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.main import app
|
||
from app.services.task_registry import (
|
||
ALL_TASKS,
|
||
DWD_TABLES,
|
||
FLOW_LAYER_MAP,
|
||
get_tasks_grouped_by_domain,
|
||
)
|
||
|
||
# 固定测试用户
|
||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||
|
||
|
||
def _override_auth():
|
||
return _TEST_USER
|
||
|
||
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
client = TestClient(app)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/tasks/registry
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestTaskRegistry:
|
||
def setup_method(self):
|
||
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
|
||
def test_registry_returns_grouped_tasks(self):
|
||
resp = client.get("/api/tasks/registry")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert "groups" in data
|
||
|
||
# 所有任务都应出现在某个分组中
|
||
all_codes_in_response = set()
|
||
for domain, tasks in data["groups"].items():
|
||
for t in tasks:
|
||
all_codes_in_response.add(t["code"])
|
||
assert t["domain"] == domain
|
||
|
||
expected_codes = {t.code for t in ALL_TASKS}
|
||
assert all_codes_in_response == expected_codes
|
||
|
||
def test_registry_task_fields_complete(self):
|
||
"""每个任务项包含所有必要字段"""
|
||
resp = client.get("/api/tasks/registry")
|
||
data = resp.json()
|
||
required_fields = {"code", "name", "description", "domain", "layer",
|
||
"requires_window", "is_ods", "is_dimension", "default_enabled"}
|
||
for tasks in data["groups"].values():
|
||
for t in tasks:
|
||
assert required_fields.issubset(t.keys())
|
||
|
||
def test_registry_requires_auth(self):
|
||
"""未认证时返回 401"""
|
||
app.dependency_overrides.pop(get_current_user, None)
|
||
try:
|
||
resp = client.get("/api/tasks/registry")
|
||
assert resp.status_code == 401
|
||
finally:
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/tasks/dwd-tables
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDwdTables:
|
||
def test_dwd_tables_returns_grouped(self):
|
||
resp = client.get("/api/tasks/dwd-tables")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert "groups" in data
|
||
|
||
all_tables_in_response = set()
|
||
for domain, tables in data["groups"].items():
|
||
for t in tables:
|
||
all_tables_in_response.add(t["table_name"])
|
||
assert t["domain"] == domain
|
||
|
||
expected_tables = {t.table_name for t in DWD_TABLES}
|
||
assert all_tables_in_response == expected_tables
|
||
|
||
def test_dwd_tables_fields_complete(self):
|
||
resp = client.get("/api/tasks/dwd-tables")
|
||
data = resp.json()
|
||
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
|
||
for tables in data["groups"].values():
|
||
for t in tables:
|
||
assert required_fields.issubset(t.keys())
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/tasks/flows
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestFlows:
|
||
def test_flows_returns_seven_flows(self):
|
||
resp = client.get("/api/tasks/flows")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data["flows"]) == 7
|
||
|
||
def test_flows_returns_three_processing_modes(self):
|
||
resp = client.get("/api/tasks/flows")
|
||
data = resp.json()
|
||
assert len(data["processing_modes"]) == 3
|
||
|
||
def test_flow_ids_match_registry(self):
|
||
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
|
||
resp = client.get("/api/tasks/flows")
|
||
data = resp.json()
|
||
flow_ids = {f["id"] for f in data["flows"]}
|
||
assert flow_ids == set(FLOW_LAYER_MAP.keys())
|
||
|
||
def test_flow_layers_non_empty(self):
|
||
resp = client.get("/api/tasks/flows")
|
||
data = resp.json()
|
||
for f in data["flows"]:
|
||
assert len(f["layers"]) > 0
|
||
|
||
def test_processing_mode_ids(self):
|
||
resp = client.get("/api/tasks/flows")
|
||
data = resp.json()
|
||
mode_ids = {m["id"] for m in data["processing_modes"]}
|
||
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/tasks/validate
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestValidate:
|
||
def test_validate_success(self):
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
|
||
"pipeline": "api_ods",
|
||
}
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["valid"] is True
|
||
assert data["errors"] == []
|
||
assert len(data["command_args"]) > 0
|
||
assert "--store-id" in data["command"]
|
||
# store_id 应从 JWT 注入(测试用户 site_id=100)
|
||
assert "100" in data["command"]
|
||
|
||
def test_validate_injects_store_id(self):
|
||
"""即使前端传了 store_id,后端也用 JWT 中的值覆盖"""
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": ["DWD_LOAD_FROM_ODS"],
|
||
"pipeline": "ods_dwd",
|
||
"store_id": 999,
|
||
}
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
# 命令中应包含 JWT 的 site_id=100,而非前端传的 999
|
||
assert "--store-id" in data["command"]
|
||
idx = data["command_args"].index("--store-id")
|
||
assert data["command_args"][idx + 1] == "100"
|
||
|
||
def test_validate_invalid_flow(self):
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": ["ODS_MEMBER"],
|
||
"pipeline": "nonexistent_flow",
|
||
}
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["valid"] is False
|
||
assert any("无效的执行流程" in e for e in data["errors"])
|
||
|
||
def test_validate_empty_tasks(self):
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": [],
|
||
"pipeline": "api_ods",
|
||
}
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["valid"] is False
|
||
assert any("任务列表不能为空" in e for e in data["errors"])
|
||
|
||
def test_validate_custom_window(self):
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": ["ODS_MEMBER"],
|
||
"pipeline": "api_ods",
|
||
"window_mode": "custom",
|
||
"window_start": "2024-01-01",
|
||
"window_end": "2024-01-31",
|
||
}
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["valid"] is True
|
||
assert "--window-start" in data["command"]
|
||
assert "--window-end" in data["command"]
|
||
|
||
def test_validate_window_end_before_start_rejected(self):
|
||
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": ["ODS_MEMBER"],
|
||
"pipeline": "api_ods",
|
||
"window_mode": "custom",
|
||
"window_start": "2024-12-31",
|
||
"window_end": "2024-01-01",
|
||
}
|
||
})
|
||
assert resp.status_code == 422
|
||
|
||
def test_validate_dry_run_flag(self):
|
||
resp = client.post("/api/tasks/validate", json={
|
||
"config": {
|
||
"tasks": ["ODS_MEMBER"],
|
||
"pipeline": "api_ods",
|
||
"dry_run": True,
|
||
}
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert "--dry-run" in data["command"]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# task_registry 服务层测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestTaskRegistryService:
|
||
def test_all_tasks_have_unique_codes(self):
|
||
codes = [t.code for t in ALL_TASKS]
|
||
assert len(codes) == len(set(codes))
|
||
|
||
def test_grouped_tasks_cover_all(self):
|
||
grouped = get_tasks_grouped_by_domain()
|
||
all_codes = set()
|
||
for tasks in grouped.values():
|
||
for t in tasks:
|
||
all_codes.add(t.code)
|
||
assert all_codes == {t.code for t in ALL_TASKS}
|
||
|
||
def test_ods_tasks_marked_is_ods(self):
|
||
for t in ALL_TASKS:
|
||
if t.layer == "ODS":
|
||
assert t.is_ods is True
|
||
|
||
def test_flow_layer_map_covers_all_flows(self):
|
||
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
|
||
"dwd_dws", "dwd_dws_index", "dwd_index"}
|
||
assert set(FLOW_LAYER_MAP.keys()) == expected_flows
|