340 lines
13 KiB
Python
340 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""执行与队列路由单元测试
|
||
|
||
覆盖 8 个端点:run / queue CRUD / cancel / history / logs
|
||
通过 mock 绕过数据库和服务层,专注路由逻辑验证。
|
||
"""
|
||
|
||
import os
|
||
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.main import app
|
||
from app.services.task_queue import QueuedTask
|
||
|
||
# 固定测试用户
|
||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||
|
||
|
||
def _override_auth():
|
||
return _TEST_USER
|
||
|
||
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
client = TestClient(app)
|
||
|
||
_NOW = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||
|
||
# 构造测试用的 TaskConfig payload
|
||
_VALID_CONFIG = {
|
||
"tasks": ["ODS_MEMBER"],
|
||
"flow": "api_ods",
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/execution/run
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestRunTask:
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_run_returns_execution_id(self, mock_executor):
|
||
mock_executor.execute = AsyncMock()
|
||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert "execution_id" in data
|
||
assert data["message"] == "任务已提交执行"
|
||
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_run_injects_store_id(self, mock_executor):
|
||
"""store_id 应从 JWT 注入"""
|
||
mock_executor.execute = AsyncMock()
|
||
resp = client.post("/api/execution/run", json={
|
||
**_VALID_CONFIG,
|
||
"store_id": 999, # 前端传的值应被覆盖
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
def test_run_requires_auth(self):
|
||
app.dependency_overrides.pop(get_current_user, None)
|
||
try:
|
||
resp = client.post("/api/execution/run", json=_VALID_CONFIG)
|
||
assert resp.status_code in (401, 403)
|
||
finally:
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
|
||
def test_run_invalid_config_returns_422(self):
|
||
"""缺少必填字段 tasks 时返回 422"""
|
||
resp = client.post("/api/execution/run", json={"flow": "api_ods"})
|
||
assert resp.status_code == 422
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/execution/queue
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestGetQueue:
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_get_queue_returns_list(self, mock_queue):
|
||
mock_queue.list_pending.return_value = [
|
||
QueuedTask(
|
||
id="task-1", site_id=100, config={"tasks": ["ODS_MEMBER"]},
|
||
status="pending", position=1, created_at=_NOW,
|
||
),
|
||
]
|
||
resp = client.get("/api/execution/queue")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data) == 1
|
||
assert data[0]["id"] == "task-1"
|
||
assert data[0]["status"] == "pending"
|
||
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_get_queue_empty(self, mock_queue):
|
||
mock_queue.list_pending.return_value = []
|
||
resp = client.get("/api/execution/queue")
|
||
assert resp.status_code == 200
|
||
assert resp.json() == []
|
||
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_get_queue_filters_by_site_id(self, mock_queue):
|
||
"""确认调用 list_pending 时传入了正确的 site_id"""
|
||
mock_queue.list_pending.return_value = []
|
||
client.get("/api/execution/queue")
|
||
mock_queue.list_pending.assert_called_once_with(100)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/execution/queue
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestEnqueueTask:
|
||
@patch("app.routers.execution.get_connection")
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_enqueue_returns_201(self, mock_queue, mock_get_conn):
|
||
mock_queue.enqueue.return_value = "new-task-id"
|
||
|
||
# mock 数据库查询返回
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchone.return_value = (
|
||
"new-task-id", 100, '{"tasks": ["ODS_MEMBER"]}',
|
||
"pending", 1, _NOW, None, None, None, None,
|
||
)
|
||
mock_conn = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = mock_conn
|
||
|
||
resp = client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||
assert resp.status_code == 201
|
||
data = resp.json()
|
||
assert data["id"] == "new-task-id"
|
||
assert data["status"] == "pending"
|
||
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_enqueue_calls_with_site_id(self, mock_queue):
|
||
"""确认 enqueue 时传入了 JWT 的 site_id"""
|
||
mock_queue.enqueue.return_value = "id-1"
|
||
|
||
# 让后续的 DB 查询抛异常来快速结束(enqueue 本身已验证)
|
||
with patch("app.routers.execution.get_connection") as mock_conn:
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchone.return_value = (
|
||
"id-1", 100, '{"tasks": []}', "pending", 1,
|
||
_NOW, None, None, None, None,
|
||
)
|
||
conn = MagicMock()
|
||
conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_conn.return_value = conn
|
||
|
||
client.post("/api/execution/queue", json=_VALID_CONFIG)
|
||
|
||
# 验证 enqueue 的第二个参数是 site_id=100
|
||
call_args = mock_queue.enqueue.call_args
|
||
assert call_args[0][1] == 100 # site_id
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# PUT /api/execution/queue/reorder
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestReorderQueue:
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_reorder_success(self, mock_queue):
|
||
mock_queue.reorder.return_value = None
|
||
resp = client.put("/api/execution/queue/reorder", json={
|
||
"task_id": "task-1",
|
||
"new_position": 3,
|
||
})
|
||
assert resp.status_code == 200
|
||
mock_queue.reorder.assert_called_once_with("task-1", 3, 100)
|
||
|
||
def test_reorder_missing_fields_returns_422(self):
|
||
resp = client.put("/api/execution/queue/reorder", json={})
|
||
assert resp.status_code == 422
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DELETE /api/execution/queue/{id}
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDeleteQueueTask:
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_delete_success(self, mock_queue):
|
||
mock_queue.delete.return_value = True
|
||
resp = client.delete("/api/execution/queue/task-1")
|
||
assert resp.status_code == 200
|
||
mock_queue.delete.assert_called_once_with("task-1", 100)
|
||
|
||
@patch("app.routers.execution.task_queue")
|
||
def test_delete_nonexistent_returns_409(self, mock_queue):
|
||
mock_queue.delete.return_value = False
|
||
resp = client.delete("/api/execution/queue/nonexistent")
|
||
assert resp.status_code == 409
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/execution/{id}/cancel
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCancelExecution:
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_cancel_success(self, mock_executor):
|
||
mock_executor.cancel = AsyncMock(return_value=True)
|
||
resp = client.post("/api/execution/some-id/cancel")
|
||
assert resp.status_code == 200
|
||
assert "取消" in resp.json()["message"]
|
||
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_cancel_not_found(self, mock_executor):
|
||
mock_executor.cancel = AsyncMock(return_value=False)
|
||
resp = client.post("/api/execution/nonexistent/cancel")
|
||
assert resp.status_code == 404
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/execution/history
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestExecutionHistory:
|
||
@patch("app.routers.execution.get_connection")
|
||
def test_history_returns_list(self, mock_get_conn):
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchall.return_value = [
|
||
(
|
||
"exec-1", 100, ["ODS_MEMBER"], "success",
|
||
_NOW, _NOW, 0, 1234, "python -m cli.main", None,
|
||
),
|
||
]
|
||
mock_conn = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = mock_conn
|
||
|
||
resp = client.get("/api/execution/history")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data) == 1
|
||
assert data[0]["id"] == "exec-1"
|
||
assert data[0]["status"] == "success"
|
||
|
||
@patch("app.routers.execution.get_connection")
|
||
def test_history_respects_limit(self, mock_get_conn):
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchall.return_value = []
|
||
mock_conn = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = mock_conn
|
||
|
||
resp = client.get("/api/execution/history?limit=10")
|
||
assert resp.status_code == 200
|
||
|
||
# 验证 SQL 中传入了 limit=10
|
||
call_args = mock_cursor.execute.call_args
|
||
assert call_args[0][1] == (100, 10) # (site_id, limit)
|
||
|
||
def test_history_limit_validation(self):
|
||
"""limit 超出范围时返回 422"""
|
||
resp = client.get("/api/execution/history?limit=0")
|
||
assert resp.status_code == 422
|
||
resp = client.get("/api/execution/history?limit=999")
|
||
assert resp.status_code == 422
|
||
|
||
@patch("app.routers.execution.get_connection")
|
||
def test_history_empty(self, mock_get_conn):
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchall.return_value = []
|
||
mock_conn = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = mock_conn
|
||
|
||
resp = client.get("/api/execution/history")
|
||
assert resp.status_code == 200
|
||
assert resp.json() == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/execution/{id}/logs
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestExecutionLogs:
|
||
@patch("app.routers.execution.get_connection")
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_logs_from_db(self, mock_executor, mock_get_conn):
|
||
"""已完成任务从数据库读取日志"""
|
||
mock_executor.is_running.return_value = False
|
||
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchone.return_value = ("stdout output", "stderr output")
|
||
mock_conn = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = mock_conn
|
||
|
||
resp = client.get("/api/execution/exec-1/logs")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["execution_id"] == "exec-1"
|
||
assert data["output_log"] == "stdout output"
|
||
assert data["error_log"] == "stderr output"
|
||
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_logs_from_memory(self, mock_executor):
|
||
"""执行中的任务从内存缓冲区读取"""
|
||
mock_executor.is_running.return_value = True
|
||
mock_executor.get_logs.return_value = ["line1", "line2"]
|
||
|
||
resp = client.get("/api/execution/running-id/logs")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["execution_id"] == "running-id"
|
||
assert "line1" in data["output_log"]
|
||
assert "line2" in data["output_log"]
|
||
|
||
@patch("app.routers.execution.get_connection")
|
||
@patch("app.routers.execution.task_executor")
|
||
def test_logs_not_found(self, mock_executor, mock_get_conn):
|
||
mock_executor.is_running.return_value = False
|
||
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchone.return_value = None
|
||
mock_conn = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = mock_conn
|
||
|
||
resp = client.get("/api/execution/nonexistent/logs")
|
||
assert resp.status_code == 404
|