在准备环境前提交次全部更改。

This commit is contained in:
Neo
2026-02-19 08:35:13 +08:00
parent ded6dfb9d8
commit 4eac07da47
1387 changed files with 6107191 additions and 33002 deletions

View File

@@ -0,0 +1,274 @@
# -*- coding: utf-8 -*-
"""任务注册表 API 单元测试
覆盖 4 个端点registry / dwd-tables / flows / validate
通过 JWT mock 绕过认证依赖。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_registry import (
ALL_TASKS,
DWD_TABLES,
FLOW_LAYER_MAP,
get_tasks_grouped_by_domain,
)
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
# ---------------------------------------------------------------------------
# GET /api/tasks/registry
# ---------------------------------------------------------------------------
class TestTaskRegistry:
def setup_method(self):
"""每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏"""
app.dependency_overrides[get_current_user] = _override_auth
def test_registry_returns_grouped_tasks(self):
resp = client.get("/api/tasks/registry")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
# 所有任务都应出现在某个分组中
all_codes_in_response = set()
for domain, tasks in data["groups"].items():
for t in tasks:
all_codes_in_response.add(t["code"])
assert t["domain"] == domain
expected_codes = {t.code for t in ALL_TASKS}
assert all_codes_in_response == expected_codes
def test_registry_task_fields_complete(self):
"""每个任务项包含所有必要字段"""
resp = client.get("/api/tasks/registry")
data = resp.json()
required_fields = {"code", "name", "description", "domain", "layer",
"requires_window", "is_ods", "is_dimension", "default_enabled"}
for tasks in data["groups"].values():
for t in tasks:
assert required_fields.issubset(t.keys())
def test_registry_requires_auth(self):
"""未认证时返回 401"""
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.get("/api/tasks/registry")
assert resp.status_code == 401
finally:
app.dependency_overrides[get_current_user] = _override_auth
# ---------------------------------------------------------------------------
# GET /api/tasks/dwd-tables
# ---------------------------------------------------------------------------
class TestDwdTables:
def test_dwd_tables_returns_grouped(self):
resp = client.get("/api/tasks/dwd-tables")
assert resp.status_code == 200
data = resp.json()
assert "groups" in data
all_tables_in_response = set()
for domain, tables in data["groups"].items():
for t in tables:
all_tables_in_response.add(t["table_name"])
assert t["domain"] == domain
expected_tables = {t.table_name for t in DWD_TABLES}
assert all_tables_in_response == expected_tables
def test_dwd_tables_fields_complete(self):
resp = client.get("/api/tasks/dwd-tables")
data = resp.json()
required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"}
for tables in data["groups"].values():
for t in tables:
assert required_fields.issubset(t.keys())
# ---------------------------------------------------------------------------
# GET /api/tasks/flows
# ---------------------------------------------------------------------------
class TestFlows:
def test_flows_returns_seven_flows(self):
resp = client.get("/api/tasks/flows")
assert resp.status_code == 200
data = resp.json()
assert len(data["flows"]) == 7
def test_flows_returns_three_processing_modes(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
assert len(data["processing_modes"]) == 3
def test_flow_ids_match_registry(self):
"""Flow ID 与 FLOW_LAYER_MAP 一致"""
resp = client.get("/api/tasks/flows")
data = resp.json()
flow_ids = {f["id"] for f in data["flows"]}
assert flow_ids == set(FLOW_LAYER_MAP.keys())
def test_flow_layers_non_empty(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
for f in data["flows"]:
assert len(f["layers"]) > 0
def test_processing_mode_ids(self):
resp = client.get("/api/tasks/flows")
data = resp.json()
mode_ids = {m["id"] for m in data["processing_modes"]}
assert mode_ids == {"increment_only", "verify_only", "increment_verify"}
# ---------------------------------------------------------------------------
# POST /api/tasks/validate
# ---------------------------------------------------------------------------
class TestValidate:
def test_validate_success(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER", "ODS_PAYMENT"],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert data["errors"] == []
assert len(data["command_args"]) > 0
assert "--store-id" in data["command"]
# store_id 应从 JWT 注入(测试用户 site_id=100
assert "100" in data["command"]
def test_validate_injects_store_id(self):
"""即使前端传了 store_id后端也用 JWT 中的值覆盖"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["DWD_LOAD_FROM_ODS"],
"pipeline": "ods_dwd",
"store_id": 999,
}
})
assert resp.status_code == 200
data = resp.json()
# 命令中应包含 JWT 的 site_id=100而非前端传的 999
assert "--store-id" in data["command"]
idx = data["command_args"].index("--store-id")
assert data["command_args"][idx + 1] == "100"
def test_validate_invalid_flow(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "nonexistent_flow",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("无效的执行流程" in e for e in data["errors"])
def test_validate_empty_tasks(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": [],
"pipeline": "api_ods",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert any("任务列表不能为空" in e for e in data["errors"])
def test_validate_custom_window(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-01-01",
"window_end": "2024-01-31",
}
})
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert "--window-start" in data["command"]
assert "--window-end" in data["command"]
def test_validate_window_end_before_start_rejected(self):
"""window_end 早于 window_start 时 Pydantic 验证失败 → 422"""
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"window_mode": "custom",
"window_start": "2024-12-31",
"window_end": "2024-01-01",
}
})
assert resp.status_code == 422
def test_validate_dry_run_flag(self):
resp = client.post("/api/tasks/validate", json={
"config": {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
"dry_run": True,
}
})
assert resp.status_code == 200
data = resp.json()
assert "--dry-run" in data["command"]
# ---------------------------------------------------------------------------
# task_registry 服务层测试
# ---------------------------------------------------------------------------
class TestTaskRegistryService:
def test_all_tasks_have_unique_codes(self):
codes = [t.code for t in ALL_TASKS]
assert len(codes) == len(set(codes))
def test_grouped_tasks_cover_all(self):
grouped = get_tasks_grouped_by_domain()
all_codes = set()
for tasks in grouped.values():
for t in tasks:
all_codes.add(t.code)
assert all_codes == {t.code for t in ALL_TASKS}
def test_ods_tasks_marked_is_ods(self):
for t in ALL_TASKS:
if t.layer == "ODS":
assert t.is_ods is True
def test_flow_layer_map_covers_all_flows(self):
expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd",
"dwd_dws", "dwd_dws_index", "dwd_index"}
assert set(FLOW_LAYER_MAP.keys()) == expected_flows