Files
Neo-ZQYY/apps/backend/tests/test_task_executor.py

374 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""TaskExecutor 单元测试
覆盖子进程启动、stdout/stderr 读取、日志广播、取消、数据库记录。
使用 asyncio 测试mock 子进程和数据库连接避免外部依赖。
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.task_executor import TaskExecutor
@pytest.fixture
def executor() -> TaskExecutor:
return TaskExecutor()
@pytest.fixture
def sample_config() -> TaskConfigSchema:
return TaskConfigSchema(
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
flow="api_ods_dwd",
store_id=42,
)
def _make_stream(lines: list[bytes]) -> AsyncMock:
"""构造一个模拟的 asyncio.StreamReader按行返回数据。"""
stream = AsyncMock()
# readline 依次返回每行,最后返回 b"" 表示 EOF
stream.readline = AsyncMock(side_effect=[*lines, b""])
return stream
# ---------------------------------------------------------------------------
# 订阅 / 取消订阅
# ---------------------------------------------------------------------------
class TestSubscription:
def test_subscribe_returns_queue(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
assert isinstance(q, asyncio.Queue)
def test_subscribe_multiple(self, executor: TaskExecutor):
q1 = executor.subscribe("exec-1")
q2 = executor.subscribe("exec-1")
assert q1 is not q2
assert len(executor._subscribers["exec-1"]) == 2
def test_unsubscribe_removes_queue(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
executor.unsubscribe("exec-1", q)
# 最后一个订阅者移除后,键也被清理
assert "exec-1" not in executor._subscribers
def test_unsubscribe_nonexistent_is_safe(self, executor: TaskExecutor):
"""对不存在的 execution_id 取消订阅不应报错"""
q: asyncio.Queue = asyncio.Queue()
executor.unsubscribe("nonexistent", q)
# ---------------------------------------------------------------------------
# 广播
# ---------------------------------------------------------------------------
class TestBroadcast:
def test_broadcast_to_subscribers(self, executor: TaskExecutor):
q1 = executor.subscribe("exec-1")
q2 = executor.subscribe("exec-1")
executor._broadcast("exec-1", "hello")
assert q1.get_nowait() == "hello"
assert q2.get_nowait() == "hello"
def test_broadcast_no_subscribers_is_safe(self, executor: TaskExecutor):
"""无订阅者时广播不应报错"""
executor._broadcast("nonexistent", "hello")
def test_broadcast_end_sends_none(self, executor: TaskExecutor):
q = executor.subscribe("exec-1")
executor._broadcast_end("exec-1")
assert q.get_nowait() is None
# ---------------------------------------------------------------------------
# 日志缓冲区
# ---------------------------------------------------------------------------
class TestLogBuffer:
def test_get_logs_empty(self, executor: TaskExecutor):
assert executor.get_logs("nonexistent") == []
def test_get_logs_returns_copy(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = ["line1", "line2"]
logs = executor.get_logs("exec-1")
assert logs == ["line1", "line2"]
# 修改副本不影响原始
logs.append("line3")
assert len(executor._log_buffers["exec-1"]) == 2
# ---------------------------------------------------------------------------
# 执行状态查询
# ---------------------------------------------------------------------------
class TestRunningState:
def test_is_running_false_when_no_process(self, executor: TaskExecutor):
assert executor.is_running("nonexistent") is False
def test_is_running_true_when_process_active(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = None
executor._processes["exec-1"] = proc
assert executor.is_running("exec-1") is True
def test_is_running_false_when_process_exited(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = 0
executor._processes["exec-1"] = proc
assert executor.is_running("exec-1") is False
def test_get_running_ids(self, executor: TaskExecutor):
running = MagicMock()
running.returncode = None
exited = MagicMock()
exited.returncode = 0
executor._processes["a"] = running
executor._processes["b"] = exited
assert executor.get_running_ids() == ["a"]
# ---------------------------------------------------------------------------
# _read_stream
# ---------------------------------------------------------------------------
class TestReadStream:
@pytest.mark.asyncio
async def test_read_stdout_lines(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
stream = _make_stream([b"line1\n", b"line2\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stdout", collector)
assert collector == ["line1", "line2"]
assert executor._log_buffers["exec-1"] == [
"[stdout] line1",
"[stdout] line2",
]
@pytest.mark.asyncio
async def test_read_stderr_lines(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
stream = _make_stream([b"err1\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stderr", collector)
assert collector == ["err1"]
assert executor._log_buffers["exec-1"] == ["[stderr] err1"]
@pytest.mark.asyncio
async def test_read_stream_none_is_safe(self, executor: TaskExecutor):
"""stream 为 None 时不应报错"""
collector: list[str] = []
await executor._read_stream("exec-1", None, "stdout", collector)
assert collector == []
@pytest.mark.asyncio
async def test_broadcast_during_read(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = []
q = executor.subscribe("exec-1")
stream = _make_stream([b"hello\n"])
collector: list[str] = []
await executor._read_stream("exec-1", stream, "stdout", collector)
assert q.get_nowait() == "[stdout] hello"
# ---------------------------------------------------------------------------
# execute集成级mock 子进程和数据库)
# ---------------------------------------------------------------------------
class TestExecute:
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_successful_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
# 模拟子进程
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([b"processing...\n", b"done\n"])
proc.stderr = _make_stream([])
proc.wait = AsyncMock(return_value=0)
# wait 调用后设置 returncode
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-1", site_id=42)
# 验证写入了 running 状态
mock_write_log.assert_called_once()
call_kwargs = mock_write_log.call_args[1]
assert call_kwargs["status"] == "running"
assert call_kwargs["execution_id"] == "exec-1"
# 验证更新了 success 状态
mock_update_log.assert_called_once()
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "success"
assert update_kwargs["exit_code"] == 0
assert "processing..." in update_kwargs["output_log"]
assert "done" in update_kwargs["output_log"]
# 进程已从跟踪表移除
assert "exec-1" not in executor._processes
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_failed_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([])
proc.stderr = _make_stream([b"error occurred\n"])
async def _wait():
proc.returncode = 1
return 1
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-2", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "failed"
assert update_kwargs["exit_code"] == 1
assert "error occurred" in update_kwargs["error_log"]
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_exception_during_execution(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
"""子进程创建失败时应记录 failed 状态"""
mock_create.side_effect = OSError("command not found")
await executor.execute(sample_config, "exec-3", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert update_kwargs["status"] == "failed"
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_subscribers_notified_on_completion(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([b"line\n"])
proc.stderr = _make_stream([])
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
q = executor.subscribe("exec-4")
await executor.execute(sample_config, "exec-4", site_id=42)
# 应收到日志行 + None 哨兵
messages = []
while not q.empty():
messages.append(q.get_nowait())
assert "[stdout] line" in messages
assert None in messages
@pytest.mark.asyncio
@patch("app.services.task_executor.TaskExecutor._update_execution_log")
@patch("app.services.task_executor.TaskExecutor._write_execution_log")
@patch("asyncio.create_subprocess_exec")
async def test_duration_ms_recorded(
self, mock_create, mock_write_log, mock_update_log,
executor: TaskExecutor, sample_config: TaskConfigSchema,
):
proc = AsyncMock()
proc.returncode = None
proc.stdout = _make_stream([])
proc.stderr = _make_stream([])
async def _wait():
proc.returncode = 0
return 0
proc.wait = _wait
mock_create.return_value = proc
await executor.execute(sample_config, "exec-5", site_id=42)
update_kwargs = mock_update_log.call_args[1]
assert isinstance(update_kwargs["duration_ms"], int)
assert update_kwargs["duration_ms"] >= 0
# ---------------------------------------------------------------------------
# cancel
# ---------------------------------------------------------------------------
class TestCancel:
@pytest.mark.asyncio
async def test_cancel_running_process(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = None
proc.terminate = MagicMock()
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is True
proc.terminate.assert_called_once()
@pytest.mark.asyncio
async def test_cancel_nonexistent_returns_false(self, executor: TaskExecutor):
result = await executor.cancel("nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_cancel_already_exited_returns_false(self, executor: TaskExecutor):
proc = MagicMock()
proc.returncode = 0
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is False
@pytest.mark.asyncio
async def test_cancel_process_lookup_error(self, executor: TaskExecutor):
"""进程已消失时 terminate 抛出 ProcessLookupError"""
proc = MagicMock()
proc.returncode = None
proc.terminate = MagicMock(side_effect=ProcessLookupError)
executor._processes["exec-1"] = proc
result = await executor.cancel("exec-1")
assert result is False
# ---------------------------------------------------------------------------
# cleanup
# ---------------------------------------------------------------------------
class TestCleanup:
def test_cleanup_removes_buffers_and_subscribers(self, executor: TaskExecutor):
executor._log_buffers["exec-1"] = ["line"]
executor.subscribe("exec-1")
executor.cleanup("exec-1")
assert "exec-1" not in executor._log_buffers
assert "exec-1" not in executor._subscribers
def test_cleanup_nonexistent_is_safe(self, executor: TaskExecutor):
executor.cleanup("nonexistent")