# -*- coding: utf-8 -*- """任务注册表 API 单元测试 覆盖 4 个端点:registry / dwd-tables / flows / validate 通过 JWT mock 绕过认证依赖。 """ import os os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests") import pytest from fastapi.testclient import TestClient from app.auth.dependencies import CurrentUser, get_current_user from app.main import app from app.services.task_registry import ( ALL_TASKS, DWD_TABLES, FLOW_LAYER_MAP, get_tasks_grouped_by_domain, ) # 固定测试用户 _TEST_USER = CurrentUser(user_id=1, site_id=100) def _override_auth(): return _TEST_USER app.dependency_overrides[get_current_user] = _override_auth client = TestClient(app) # --------------------------------------------------------------------------- # GET /api/tasks/registry # --------------------------------------------------------------------------- class TestTaskRegistry: def setup_method(self): """每个测试方法前重新设置 auth 覆盖,防止其他测试文件的 clear/pop 导致状态泄漏""" app.dependency_overrides[get_current_user] = _override_auth def test_registry_returns_grouped_tasks(self): resp = client.get("/api/tasks/registry") assert resp.status_code == 200 data = resp.json() assert "groups" in data # 所有任务都应出现在某个分组中 all_codes_in_response = set() for domain, tasks in data["groups"].items(): for t in tasks: all_codes_in_response.add(t["code"]) assert t["domain"] == domain expected_codes = {t.code for t in ALL_TASKS} assert all_codes_in_response == expected_codes def test_registry_task_fields_complete(self): """每个任务项包含所有必要字段""" resp = client.get("/api/tasks/registry") data = resp.json() required_fields = {"code", "name", "description", "domain", "layer", "requires_window", "is_ods", "is_dimension", "default_enabled"} for tasks in data["groups"].values(): for t in tasks: assert required_fields.issubset(t.keys()) def test_registry_requires_auth(self): """未认证时返回 401""" app.dependency_overrides.pop(get_current_user, None) try: resp = client.get("/api/tasks/registry") assert resp.status_code == 401 finally: app.dependency_overrides[get_current_user] = _override_auth # --------------------------------------------------------------------------- # GET /api/tasks/dwd-tables # --------------------------------------------------------------------------- class TestDwdTables: def test_dwd_tables_returns_grouped(self): resp = client.get("/api/tasks/dwd-tables") assert resp.status_code == 200 data = resp.json() assert "groups" in data all_tables_in_response = set() for domain, tables in data["groups"].items(): for t in tables: all_tables_in_response.add(t["table_name"]) assert t["domain"] == domain expected_tables = {t.table_name for t in DWD_TABLES} assert all_tables_in_response == expected_tables def test_dwd_tables_fields_complete(self): resp = client.get("/api/tasks/dwd-tables") data = resp.json() required_fields = {"table_name", "display_name", "domain", "ods_source", "is_dimension"} for tables in data["groups"].values(): for t in tables: assert required_fields.issubset(t.keys()) # --------------------------------------------------------------------------- # GET /api/tasks/flows # --------------------------------------------------------------------------- class TestFlows: def test_flows_returns_seven_flows(self): resp = client.get("/api/tasks/flows") assert resp.status_code == 200 data = resp.json() assert len(data["flows"]) == 7 def test_flows_returns_three_processing_modes(self): resp = client.get("/api/tasks/flows") data = resp.json() assert len(data["processing_modes"]) == 3 def test_flow_ids_match_registry(self): """Flow ID 与 FLOW_LAYER_MAP 一致""" resp = client.get("/api/tasks/flows") data = resp.json() flow_ids = {f["id"] for f in data["flows"]} assert flow_ids == set(FLOW_LAYER_MAP.keys()) def test_flow_layers_non_empty(self): resp = client.get("/api/tasks/flows") data = resp.json() for f in data["flows"]: assert len(f["layers"]) > 0 def test_processing_mode_ids(self): resp = client.get("/api/tasks/flows") data = resp.json() mode_ids = {m["id"] for m in data["processing_modes"]} assert mode_ids == {"increment_only", "verify_only", "increment_verify"} # --------------------------------------------------------------------------- # POST /api/tasks/validate # --------------------------------------------------------------------------- class TestValidate: def test_validate_success(self): resp = client.post("/api/tasks/validate", json={ "config": { "tasks": ["ODS_MEMBER", "ODS_PAYMENT"], "flow": "api_ods", } }) assert resp.status_code == 200 data = resp.json() assert data["valid"] is True assert data["errors"] == [] assert len(data["command_args"]) > 0 assert "--store-id" in data["command"] # store_id 应从 JWT 注入(测试用户 site_id=100) assert "100" in data["command"] def test_validate_injects_store_id(self): """即使前端传了 store_id,后端也用 JWT 中的值覆盖""" resp = client.post("/api/tasks/validate", json={ "config": { "tasks": ["DWD_LOAD_FROM_ODS"], "flow": "ods_dwd", "store_id": 999, } }) assert resp.status_code == 200 data = resp.json() # 命令中应包含 JWT 的 site_id=100,而非前端传的 999 assert "--store-id" in data["command"] idx = data["command_args"].index("--store-id") assert data["command_args"][idx + 1] == "100" def test_validate_invalid_flow(self): resp = client.post("/api/tasks/validate", json={ "config": { "tasks": ["ODS_MEMBER"], "flow": "nonexistent_flow", } }) assert resp.status_code == 200 data = resp.json() assert data["valid"] is False assert any("无效的执行流程" in e for e in data["errors"]) def test_validate_empty_tasks(self): resp = client.post("/api/tasks/validate", json={ "config": { "tasks": [], "flow": "api_ods", } }) assert resp.status_code == 200 data = resp.json() assert data["valid"] is False assert any("任务列表不能为空" in e for e in data["errors"]) def test_validate_custom_window(self): resp = client.post("/api/tasks/validate", json={ "config": { "tasks": ["ODS_MEMBER"], "flow": "api_ods", "window_mode": "custom", "window_start": "2024-01-01", "window_end": "2024-01-31", } }) assert resp.status_code == 200 data = resp.json() assert data["valid"] is True assert "--window-start" in data["command"] assert "--window-end" in data["command"] def test_validate_window_end_before_start_rejected(self): """window_end 早于 window_start 时 Pydantic 验证失败 → 422""" resp = client.post("/api/tasks/validate", json={ "config": { "tasks": ["ODS_MEMBER"], "flow": "api_ods", "window_mode": "custom", "window_start": "2024-12-31", "window_end": "2024-01-01", } }) assert resp.status_code == 422 def test_validate_dry_run_flag(self): resp = client.post("/api/tasks/validate", json={ "config": { "tasks": ["ODS_MEMBER"], "flow": "api_ods", "dry_run": True, } }) assert resp.status_code == 200 data = resp.json() assert "--dry-run" in data["command"] # --------------------------------------------------------------------------- # task_registry 服务层测试 # --------------------------------------------------------------------------- class TestTaskRegistryService: def test_all_tasks_have_unique_codes(self): codes = [t.code for t in ALL_TASKS] assert len(codes) == len(set(codes)) def test_grouped_tasks_cover_all(self): grouped = get_tasks_grouped_by_domain() all_codes = set() for tasks in grouped.values(): for t in tasks: all_codes.add(t.code) assert all_codes == {t.code for t in ALL_TASKS} def test_ods_tasks_marked_is_ods(self): for t in ALL_TASKS: if t.layer == "ODS": assert t.is_ods is True def test_flow_layer_map_covers_all_flows(self): expected_flows = {"api_ods", "api_ods_dwd", "api_full", "ods_dwd", "dwd_dws", "dwd_dws_index", "dwd_index"} assert set(FLOW_LAYER_MAP.keys()) == expected_flows