Files
Neo-ZQYY/apps/backend/tests/test_tasks_router.py

275 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""任务注册表 API 单元测试
覆盖 4 个端点registry / dwd-tables / flows / validate
通过 JWT mock 绕过认证依赖。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_registry import (
ALL_TASKS,
DWD_TABLES,
FLOW_LAYER_MAP,
get_tasks_grouped_by_domain,
)
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
# ---------------------------------------------------------------------------
# GET /api/tasks/registry
# ---------------------------------------------------------------------------
class TestTaskRegistry:
def setup_method(self):
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
app.dependency_overrides[get_current_user] = _override_auth
def test_registry_returns_grouped_tasks(self):
resp = client.get("/api/tasks/registry")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
# 所有任务都应出现在某个分组中
all_codes_in_response = set()
for domain, tasks in data["groups"].items():
for t in tasks:
all_codes_in_response.add(t["code"])
assert t["domain"] == domain
expected_codes = {t.code for t in ALL_TASKS}
assert all_codes_in_response == expected_codes
def test_registry_task_fields_complete(self):
"""每个任务项包含所有必要字段"""
resp = client.get("/api/tasks/registry")
data = resp.json()
required_fields = {"code", "name", "description", "domain", "layer",
"requires_window", "is_ods", "is_dimension", "default_enabled"}
for tasks in data["groups"].values():
for t in tasks:
assert required_fields.issubset(t.keys())
def test_registry_requires_auth(self):
"""未认证时返回 401"""
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.get("/api/tasks/registry")
assert resp.status_code == 401
finally:
app.dependency_overrides[get_current_user] = _override_auth
# ---------------------------------------------------------------------------
# GET /api/tasks/dwd-tables
# ---------------------------------------------------------------------------
class TestDwdTables:
def test_dwd_tables_returns_grouped(self):
resp = client.get("/api/tasks/dwd-tables")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
all_tables_in_response = set()
for domain, tables in data["groups"].items():
for t in tables:
all_tables_in_response.add(t["table_name"])
assert t["domain"] == domain
expected_tables = {t.table_name for t in DWD_TABLES}
assert all_tables_in_response == expected_tables
def test_dwd_tables_fields_complete(self):
resp = client.get("/api/tasks/dwd-tables")
data = resp.json()
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
for tables in data["groups"].values():
for t in tables:
assert required_fields.issubset(t.keys())
# ---------------------------------------------------------------------------
# GET /api/tasks/flows
# ---------------------------------------------------------------------------
class TestFlows:
def test_flows_returns_seven_flows(self):
resp = client.get("/api/tasks/flows")
assert resp.status_code == 200
data = resp.json()
assert len(data["flows"]) == 7
def test_flows_returns_three_processing_modes(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
assert len(data["processing_modes"]) == 3
def test_flow_ids_match_registry(self):
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
resp = client.get("/api/tasks/flows")
data = resp.json()
flow_ids = {f["id"] for f in data["flows"]}
assert flow_ids == set(FLOW_LAYER_MAP.keys())
def test_flow_layers_non_empty(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
for f in data["flows"]:
assert len(f["layers"]) > 0
def test_processing_mode_ids(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
mode_ids = {m["id"] for m in data["processing_modes"]}
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
# ---------------------------------------------------------------------------
# POST /api/tasks/validate
# ---------------------------------------------------------------------------
class TestValidate:
def test_validate_success(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert data["errors"] == []
assert len(data["command_args"]) > 0
assert "--store-id" in data["command"]
# store_id 应从 JWT 注入(测试用户 site_id=100
assert "100" in data["command"]
def test_validate_injects_store_id(self):
"""即使前端传了 store_id后端也用 JWT 中的值覆盖"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["DWD_LOAD_FROM_ODS"],
"pipeline": "ods_dwd",
"store_id": 999,
}
})
assert resp.status_code == 200
data = resp.json()
# 命令中应包含 JWT 的 site_id=100而非前端传的 999
assert "--store-id" in data["command"]
idx = data["command_args"].index("--store-id")
assert data["command_args"][idx + 1] == "100"
def test_validate_invalid_flow(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "nonexistent_flow",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("无效的执行流程" in e for e in data["errors"])
def test_validate_empty_tasks(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": [],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("任务列表不能为空" in e for e in data["errors"])
def test_validate_custom_window(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-01-01",
"window_end": "2024-01-31",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert "--window-start" in data["command"]
assert "--window-end" in data["command"]
def test_validate_window_end_before_start_rejected(self):
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-12-31",
"window_end": "2024-01-01",
}
})
assert resp.status_code == 422
def test_validate_dry_run_flag(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"dry_run": True,
}
})
assert resp.status_code == 200
data = resp.json()
assert "--dry-run" in data["command"]
# ---------------------------------------------------------------------------
# task_registry 服务层测试
# ---------------------------------------------------------------------------
class TestTaskRegistryService:
def test_all_tasks_have_unique_codes(self):
codes = [t.code for t in ALL_TASKS]
assert len(codes) == len(set(codes))
def test_grouped_tasks_cover_all(self):
grouped = get_tasks_grouped_by_domain()
all_codes = set()
for tasks in grouped.values():
for t in tasks:
all_codes.add(t.code)
assert all_codes == {t.code for t in ALL_TASKS}
def test_ods_tasks_marked_is_ods(self):
for t in ALL_TASKS:
if t.layer == "ODS":
assert t.is_ods is True
def test_flow_layer_map_covers_all_flows(self):
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
"dwd_dws", "dwd_dws_index", "dwd_index"}
assert set(FLOW_LAYER_MAP.keys()) == expected_flows