# -*- coding: utf-8 -*- """执行与队列路由单元测试 覆盖 8 个端点:run / queue CRUD / cancel / history / logs 通过 mock 绕过数据库和服务层,专注路由逻辑验证。 """ import os os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests") from dataclasses import dataclass from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient from app.auth.dependencies import CurrentUser, get_current_user from app.main import app from app.services.task_queue import QueuedTask # 固定测试用户 _TEST_USER = CurrentUser(user_id=1, site_id=100) def _override_auth(): return _TEST_USER app.dependency_overrides[get_current_user] = _override_auth client = TestClient(app) _NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc) # 构造测试用的 TaskConfig payload _VALID_CONFIG = { "tasks": ["ODS_MEMBER"], "pipeline": "api_ods", } # --------------------------------------------------------------------------- # POST /api/execution/run # --------------------------------------------------------------------------- class TestRunTask: @patch("app.routers.execution.task_executor") def test_run_returns_execution_id(self, mock_executor): mock_executor.execute = AsyncMock() resp = client.post("/api/execution/run", json=_VALID_CONFIG) assert resp.status_code == 200 data = resp.json() assert "execution_id" in data assert data["message"] == "任务已提交执行" @patch("app.routers.execution.task_executor") def test_run_injects_store_id(self, mock_executor): """store_id 应从 JWT 注入""" mock_executor.execute = AsyncMock() resp = client.post("/api/execution/run", json={ **_VALID_CONFIG, "store_id": 999, # 前端传的值应被覆盖 }) assert resp.status_code == 200 def test_run_requires_auth(self): app.dependency_overrides.pop(get_current_user, None) try: resp = client.post("/api/execution/run", json=_VALID_CONFIG) assert resp.status_code in (401, 403) finally: app.dependency_overrides[get_current_user] = _override_auth def test_run_invalid_config_returns_422(self): """缺少必填字段 tasks 时返回 422""" resp = client.post("/api/execution/run", json={"pipeline": "api_ods"}) assert resp.status_code == 422 # --------------------------------------------------------------------------- # GET /api/execution/queue # --------------------------------------------------------------------------- class TestGetQueue: @patch("app.routers.execution.task_queue") def test_get_queue_returns_list(self, mock_queue): mock_queue.list_pending.return_value = [ QueuedTask( id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]}, status="pending", position=1, created_at=_NOW, ), ] resp = client.get("/api/execution/queue") assert resp.status_code == 200 data = resp.json() assert len(data) == 1 assert data[0]["id"] == "task-1" assert data[0]["status"] == "pending" @patch("app.routers.execution.task_queue") def test_get_queue_empty(self, mock_queue): mock_queue.list_pending.return_value = [] resp = client.get("/api/execution/queue") assert resp.status_code == 200 assert resp.json() == [] @patch("app.routers.execution.task_queue") def test_get_queue_filters_by_site_id(self, mock_queue): """确认调用 list_pending 时传入了正确的 site_id""" mock_queue.list_pending.return_value = [] client.get("/api/execution/queue") mock_queue.list_pending.assert_called_once_with(100) # --------------------------------------------------------------------------- # POST /api/execution/queue # --------------------------------------------------------------------------- class TestEnqueueTask: @patch("app.routers.execution.get_connection") @patch("app.routers.execution.task_queue") def test_enqueue_returns_201(self, mock_queue, mock_get_conn): mock_queue.enqueue.return_value = "new-task-id" # mock 数据库查询返回 mock_cursor = MagicMock() mock_cursor.fetchone.return_value = ( "new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}', "pending", 1, _NOW, None, None, None, None, ) mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = mock_conn resp = client.post("/api/execution/queue", json=_VALID_CONFIG) assert resp.status_code == 201 data = resp.json() assert data["id"] == "new-task-id" assert data["status"] == "pending" @patch("app.routers.execution.task_queue") def test_enqueue_calls_with_site_id(self, mock_queue): """确认 enqueue 时传入了 JWT 的 site_id""" mock_queue.enqueue.return_value = "id-1" # 让后续的 DB 查询抛异常来快速结束(enqueue 本身已验证) with patch("app.routers.execution.get_connection") as mock_conn: mock_cursor = MagicMock() mock_cursor.fetchone.return_value = ( "id-1", 100, '{"tasks": []}', "pending", 1, _NOW, None, None, None, None, ) conn = MagicMock() conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_conn.return_value = conn client.post("/api/execution/queue", json=_VALID_CONFIG) # 验证 enqueue 的第二个参数是 site_id=100 call_args = mock_queue.enqueue.call_args assert call_args[0][1] == 100 # site_id # --------------------------------------------------------------------------- # PUT /api/execution/queue/reorder # --------------------------------------------------------------------------- class TestReorderQueue: @patch("app.routers.execution.task_queue") def test_reorder_success(self, mock_queue): mock_queue.reorder.return_value = None resp = client.put("/api/execution/queue/reorder", json={ "task_id": "task-1", "new_position": 3, }) assert resp.status_code == 200 mock_queue.reorder.assert_called_once_with("task-1", 3, 100) def test_reorder_missing_fields_returns_422(self): resp = client.put("/api/execution/queue/reorder", json={}) assert resp.status_code == 422 # --------------------------------------------------------------------------- # DELETE /api/execution/queue/{id} # --------------------------------------------------------------------------- class TestDeleteQueueTask: @patch("app.routers.execution.task_queue") def test_delete_success(self, mock_queue): mock_queue.delete.return_value = True resp = client.delete("/api/execution/queue/task-1") assert resp.status_code == 200 mock_queue.delete.assert_called_once_with("task-1", 100) @patch("app.routers.execution.task_queue") def test_delete_nonexistent_returns_409(self, mock_queue): mock_queue.delete.return_value = False resp = client.delete("/api/execution/queue/nonexistent") assert resp.status_code == 409 # --------------------------------------------------------------------------- # POST /api/execution/{id}/cancel # --------------------------------------------------------------------------- class TestCancelExecution: @patch("app.routers.execution.task_executor") def test_cancel_success(self, mock_executor): mock_executor.cancel = AsyncMock(return_value=True) resp = client.post("/api/execution/some-id/cancel") assert resp.status_code == 200 assert "取消" in resp.json()["message"] @patch("app.routers.execution.task_executor") def test_cancel_not_found(self, mock_executor): mock_executor.cancel = AsyncMock(return_value=False) resp = client.post("/api/execution/nonexistent/cancel") assert resp.status_code == 404 # --------------------------------------------------------------------------- # GET /api/execution/history # --------------------------------------------------------------------------- class TestExecutionHistory: @patch("app.routers.execution.get_connection") def test_history_returns_list(self, mock_get_conn): mock_cursor = MagicMock() mock_cursor.fetchall.return_value = [ ( "exec-1", 100, ["ODS_MEMBER"], "success", _NOW, _NOW, 0, 1234, "python -m cli.main", None, ), ] mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = mock_conn resp = client.get("/api/execution/history") assert resp.status_code == 200 data = resp.json() assert len(data) == 1 assert data[0]["id"] == "exec-1" assert data[0]["status"] == "success" @patch("app.routers.execution.get_connection") def test_history_respects_limit(self, mock_get_conn): mock_cursor = MagicMock() mock_cursor.fetchall.return_value = [] mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = mock_conn resp = client.get("/api/execution/history?limit=10") assert resp.status_code == 200 # 验证 SQL 中传入了 limit=10 call_args = mock_cursor.execute.call_args assert call_args[0][1] == (100, 10) # (site_id, limit) def test_history_limit_validation(self): """limit 超出范围时返回 422""" resp = client.get("/api/execution/history?limit=0") assert resp.status_code == 422 resp = client.get("/api/execution/history?limit=999") assert resp.status_code == 422 @patch("app.routers.execution.get_connection") def test_history_empty(self, mock_get_conn): mock_cursor = MagicMock() mock_cursor.fetchall.return_value = [] mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = mock_conn resp = client.get("/api/execution/history") assert resp.status_code == 200 assert resp.json() == [] # --------------------------------------------------------------------------- # GET /api/execution/{id}/logs # --------------------------------------------------------------------------- class TestExecutionLogs: @patch("app.routers.execution.get_connection") @patch("app.routers.execution.task_executor") def test_logs_from_db(self, mock_executor, mock_get_conn): """已完成任务从数据库读取日志""" mock_executor.is_running.return_value = False mock_cursor = MagicMock() mock_cursor.fetchone.return_value = ("stdout output", "stderr output") mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = mock_conn resp = client.get("/api/execution/exec-1/logs") assert resp.status_code == 200 data = resp.json() assert data["execution_id"] == "exec-1" assert data["output_log"] == "stdout output" assert data["error_log"] == "stderr output" @patch("app.routers.execution.task_executor") def test_logs_from_memory(self, mock_executor): """执行中的任务从内存缓冲区读取""" mock_executor.is_running.return_value = True mock_executor.get_logs.return_value = ["line1", "line2"] resp = client.get("/api/execution/running-id/logs") assert resp.status_code == 200 data = resp.json() assert data["execution_id"] == "running-id" assert "line1" in data["output_log"] assert "line2" in data["output_log"] @patch("app.routers.execution.get_connection") @patch("app.routers.execution.task_executor") def test_logs_not_found(self, mock_executor, mock_get_conn): mock_executor.is_running.return_value = False mock_cursor = MagicMock() mock_cursor.fetchone.return_value = None mock_conn = MagicMock() mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = mock_conn resp = client.get("/api/execution/nonexistent/logs") assert resp.status_code == 404