Files
Neo-ZQYY/apps/backend/tests/test_execution_router.py

340 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""执行与队列路由单元测试
覆盖 8 个端点run / queue CRUD / cancel / history / logs
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from dataclasses import dataclass
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
from app.services.task_queue import QueuedTask
# 固定测试用户
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
# 构造测试用的 TaskConfig payload
_VALID_CONFIG = {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods",
}
# ---------------------------------------------------------------------------
# POST /api/execution/run
# ---------------------------------------------------------------------------
class TestRunTask:
@patch("app.routers.execution.task_executor")
def test_run_returns_execution_id(self, mock_executor):
mock_executor.execute = AsyncMock()
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
assert resp.status_code == 200
data = resp.json()
assert "execution_id" in data
assert data["message"] == "任务已提交执行"
@patch("app.routers.execution.task_executor")
def test_run_injects_store_id(self, mock_executor):
"""store_id 应从 JWT 注入"""
mock_executor.execute = AsyncMock()
resp = client.post("/api/execution/run", json={
**_VALID_CONFIG,
"store_id": 999, # 前端传的值应被覆盖
})
assert resp.status_code == 200
def test_run_requires_auth(self):
app.dependency_overrides.pop(get_current_user, None)
try:
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
assert resp.status_code in (401, 403)
finally:
app.dependency_overrides[get_current_user] = _override_auth
def test_run_invalid_config_returns_422(self):
"""缺少必填字段 tasks 时返回 422"""
resp = client.post("/api/execution/run", json={"pipeline": "api_ods"})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# GET /api/execution/queue
# ---------------------------------------------------------------------------
class TestGetQueue:
@patch("app.routers.execution.task_queue")
def test_get_queue_returns_list(self, mock_queue):
mock_queue.list_pending.return_value = [
QueuedTask(
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
status="pending", position=1, created_at=_NOW,
),
]
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "task-1"
assert data[0]["status"] == "pending"
@patch("app.routers.execution.task_queue")
def test_get_queue_empty(self, mock_queue):
mock_queue.list_pending.return_value = []
resp = client.get("/api/execution/queue")
assert resp.status_code == 200
assert resp.json() == []
@patch("app.routers.execution.task_queue")
def test_get_queue_filters_by_site_id(self, mock_queue):
"""确认调用 list_pending 时传入了正确的 site_id"""
mock_queue.list_pending.return_value = []
client.get("/api/execution/queue")
mock_queue.list_pending.assert_called_once_with(100)
# ---------------------------------------------------------------------------
# POST /api/execution/queue
# ---------------------------------------------------------------------------
class TestEnqueueTask:
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_queue")
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
mock_queue.enqueue.return_value = "new-task-id"
# mock 数据库查询返回
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = (
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
"pending", 1, _NOW, None, None, None, None,
)
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
assert resp.status_code == 201
data = resp.json()
assert data["id"] == "new-task-id"
assert data["status"] == "pending"
@patch("app.routers.execution.task_queue")
def test_enqueue_calls_with_site_id(self, mock_queue):
"""确认 enqueue 时传入了 JWT 的 site_id"""
mock_queue.enqueue.return_value = "id-1"
# 让后续的 DB 查询抛异常来快速结束enqueue 本身已验证)
with patch("app.routers.execution.get_connection") as mock_conn:
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = (
"id-1", 100, '{"tasks": []}', "pending", 1,
_NOW, None, None, None, None,
)
conn = MagicMock()
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_conn.return_value = conn
client.post("/api/execution/queue", json=_VALID_CONFIG)
# 验证 enqueue 的第二个参数是 site_id=100
call_args = mock_queue.enqueue.call_args
assert call_args[0][1] == 100 # site_id
# ---------------------------------------------------------------------------
# PUT /api/execution/queue/reorder
# ---------------------------------------------------------------------------
class TestReorderQueue:
@patch("app.routers.execution.task_queue")
def test_reorder_success(self, mock_queue):
mock_queue.reorder.return_value = None
resp = client.put("/api/execution/queue/reorder", json={
"task_id": "task-1",
"new_position": 3,
})
assert resp.status_code == 200
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
def test_reorder_missing_fields_returns_422(self):
resp = client.put("/api/execution/queue/reorder", json={})
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# DELETE /api/execution/queue/{id}
# ---------------------------------------------------------------------------
class TestDeleteQueueTask:
@patch("app.routers.execution.task_queue")
def test_delete_success(self, mock_queue):
mock_queue.delete.return_value = True
resp = client.delete("/api/execution/queue/task-1")
assert resp.status_code == 200
mock_queue.delete.assert_called_once_with("task-1", 100)
@patch("app.routers.execution.task_queue")
def test_delete_nonexistent_returns_409(self, mock_queue):
mock_queue.delete.return_value = False
resp = client.delete("/api/execution/queue/nonexistent")
assert resp.status_code == 409
# ---------------------------------------------------------------------------
# POST /api/execution/{id}/cancel
# ---------------------------------------------------------------------------
class TestCancelExecution:
@patch("app.routers.execution.task_executor")
def test_cancel_success(self, mock_executor):
mock_executor.cancel = AsyncMock(return_value=True)
resp = client.post("/api/execution/some-id/cancel")
assert resp.status_code == 200
assert "取消" in resp.json()["message"]
@patch("app.routers.execution.task_executor")
def test_cancel_not_found(self, mock_executor):
mock_executor.cancel = AsyncMock(return_value=False)
resp = client.post("/api/execution/nonexistent/cancel")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# GET /api/execution/history
# ---------------------------------------------------------------------------
class TestExecutionHistory:
@patch("app.routers.execution.get_connection")
def test_history_returns_list(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = [
(
"exec-1", 100, ["ODS_MEMBER"], "success",
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
),
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["id"] == "exec-1"
assert data[0]["status"] == "success"
@patch("app.routers.execution.get_connection")
def test_history_respects_limit(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history?limit=10")
assert resp.status_code == 200
# 验证 SQL 中传入了 limit=10
call_args = mock_cursor.execute.call_args
assert call_args[0][1] == (100, 10) # (site_id, limit)
def test_history_limit_validation(self):
"""limit 超出范围时返回 422"""
resp = client.get("/api/execution/history?limit=0")
assert resp.status_code == 422
resp = client.get("/api/execution/history?limit=999")
assert resp.status_code == 422
@patch("app.routers.execution.get_connection")
def test_history_empty(self, mock_get_conn):
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = []
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/history")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/execution/{id}/logs
# ---------------------------------------------------------------------------
class TestExecutionLogs:
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_executor")
def test_logs_from_db(self, mock_executor, mock_get_conn):
"""已完成任务从数据库读取日志"""
mock_executor.is_running.return_value = False
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/exec-1/logs")
assert resp.status_code == 200
data = resp.json()
assert data["execution_id"] == "exec-1"
assert data["output_log"] == "stdout output"
assert data["error_log"] == "stderr output"
@patch("app.routers.execution.task_executor")
def test_logs_from_memory(self, mock_executor):
"""执行中的任务从内存缓冲区读取"""
mock_executor.is_running.return_value = True
mock_executor.get_logs.return_value = ["line1", "line2"]
resp = client.get("/api/execution/running-id/logs")
assert resp.status_code == 200
data = resp.json()
assert data["execution_id"] == "running-id"
assert "line1" in data["output_log"]
assert "line2" in data["output_log"]
@patch("app.routers.execution.get_connection")
@patch("app.routers.execution.task_executor")
def test_logs_not_found(self, mock_executor, mock_get_conn):
mock_executor.is_running.return_value = False
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = None
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = mock_conn
resp = client.get("/api/execution/nonexistent/logs")
assert resp.status_code == 404