Files
Neo-ZQYY/apps/backend/tests/test_task_queue.py

483 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""TaskQueue 单元测试
覆盖enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
使用 mock 数据库操作,专注于业务逻辑验证。
"""
import asyncio
import json
import uuid
from unittest.mock import MagicMock, AsyncMock, patch, call
import pytest
from app.schemas.tasks import TaskConfigSchema
from app.services.task_queue import TaskQueue, QueuedTask
@pytest.fixture
def queue() -> TaskQueue:
return TaskQueue()
@pytest.fixture
def sample_config() -> TaskConfigSchema:
return TaskConfigSchema(
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
pipeline="api_ods_dwd",
store_id=42,
)
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
"""构造 mock cursor支持 context manager 协议。"""
cur = MagicMock()
cur.fetchone.return_value = fetchone_val
cur.fetchall.return_value = fetchall_val or []
cur.rowcount = rowcount
cur.__enter__ = MagicMock(return_value=cur)
cur.__exit__ = MagicMock(return_value=False)
return cur
def _mock_conn(cursor):
"""构造 mock connection支持 cursor() context manager。"""
conn = MagicMock()
conn.cursor.return_value = cursor
return conn
# ---------------------------------------------------------------------------
# enqueue
# ---------------------------------------------------------------------------
class TestEnqueue:
@patch("app.services.task_queue.get_connection")
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
task_id = queue.enqueue(sample_config, site_id=42)
# 返回有效 UUID
uuid.UUID(task_id)
conn.commit.assert_called_once()
conn.close.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
"""新任务 position = 当前最大 position + 1"""
cur = _mock_cursor(fetchone_val=(5,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
# 检查 INSERT 调用中的 position 参数
insert_call = cur.execute.call_args_list[1]
args = insert_call[0][1]
# args = (task_id, site_id, config_json, new_pos)
assert args[3] == 6 # 5 + 1
@patch("app.services.task_queue.get_connection")
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
"""空队列时 position = 1"""
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
insert_call = cur.execute.call_args_list[1]
args = insert_call[0][1]
assert args[3] == 1
@patch("app.services.task_queue.get_connection")
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
"""config 被序列化为 JSON 字符串"""
cur = _mock_cursor(fetchone_val=(0,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.enqueue(sample_config, site_id=42)
insert_call = cur.execute.call_args_list[1]
config_json_str = insert_call[0][1][2]
parsed = json.loads(config_json_str)
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
assert parsed["pipeline"] == "api_ods_dwd"
# ---------------------------------------------------------------------------
# dequeue
# ---------------------------------------------------------------------------
class TestDequeue:
@patch("app.services.task_queue.get_connection")
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=None)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.dequeue(site_id=42)
assert result is None
conn.commit.assert_called()
@patch("app.services.task_queue.get_connection")
def test_dequeue_returns_task(self, mock_get_conn, queue):
task_id = str(uuid.uuid4())
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.dequeue(site_id=42)
assert result is not None
assert result.id == task_id
assert result.site_id == 42
assert result.status == "running" # dequeue 后状态变为 running
assert result.config["tasks"] == ["ODS_MEMBER"]
@patch("app.services.task_queue.get_connection")
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
task_id = str(uuid.uuid4())
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
row = (
task_id, 42, json.dumps(config_dict), "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.dequeue(site_id=42)
# 第二次 execute 调用应该是 UPDATE status = 'running'
update_call = cur.execute.call_args_list[1]
sql = update_call[0][0]
assert "running" in sql
assert task_id in update_call[0][1]
# ---------------------------------------------------------------------------
# reorder
# ---------------------------------------------------------------------------
class TestReorder:
@patch("app.services.task_queue.get_connection")
def test_reorder_moves_task(self, mock_get_conn, queue):
"""将第 3 个任务移到第 1 位"""
ids = [str(uuid.uuid4()) for _ in range(3)]
rows = [(i,) for i in ids]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.reorder(ids[2], new_position=1, site_id=42)
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
positions = {}
for c in update_calls:
pos, tid = c[0][1]
positions[tid] = pos
assert positions[ids[2]] == 1
assert positions[ids[0]] == 2
assert positions[ids[1]] == 3
@patch("app.services.task_queue.get_connection")
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
"""重排不存在的任务不报错"""
rows = [(str(uuid.uuid4()),)]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.reorder("nonexistent-id", new_position=1, site_id=42)
# 只有 SELECT没有 UPDATE
assert cur.execute.call_count == 1
@patch("app.services.task_queue.get_connection")
def test_reorder_clamps_position(self, mock_get_conn, queue):
"""position 超出范围时 clamp 到有效范围"""
ids = [str(uuid.uuid4()) for _ in range(2)]
rows = [(i,) for i in ids]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
# new_position=100 超出范围,应 clamp 到末尾
queue.reorder(ids[0], new_position=100, site_id=42)
update_calls = cur.execute.call_args_list[1:]
positions = {}
for c in update_calls:
pos, tid = c[0][1]
positions[tid] = pos
# ids[0] 移到末尾
assert positions[ids[1]] == 1
assert positions[ids[0]] == 2
# ---------------------------------------------------------------------------
# delete
# ---------------------------------------------------------------------------
class TestDelete:
@patch("app.services.task_queue.get_connection")
def test_delete_pending_task(self, mock_get_conn, queue):
cur = _mock_cursor(rowcount=1)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.delete("task-1", site_id=42)
assert result is True
conn.commit.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
cur = _mock_cursor(rowcount=0)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.delete("nonexistent", site_id=42)
assert result is False
@patch("app.services.task_queue.get_connection")
def test_delete_only_affects_pending(self, mock_get_conn, queue):
"""DELETE SQL 包含 status = 'pending' 条件"""
cur = _mock_cursor(rowcount=0)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue.delete("task-1", site_id=42)
sql = cur.execute.call_args[0][0]
assert "pending" in sql
# ---------------------------------------------------------------------------
# list_pending / has_running
# ---------------------------------------------------------------------------
class TestQuery:
@patch("app.services.task_queue.get_connection")
def test_list_pending_empty(self, mock_get_conn, queue):
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.list_pending(site_id=42)
assert result == []
@patch("app.services.task_queue.get_connection")
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
tid = str(uuid.uuid4())
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
cur = _mock_cursor(fetchall_val=rows)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
result = queue.list_pending(site_id=42)
assert len(result) == 1
assert result[0].id == tid
@patch("app.services.task_queue.get_connection")
def test_has_running_true(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=(True,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
assert queue.has_running(site_id=42) is True
@patch("app.services.task_queue.get_connection")
def test_has_running_false(self, mock_get_conn, queue):
cur = _mock_cursor(fetchone_val=(False,))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
assert queue.has_running(site_id=42) is False
# ---------------------------------------------------------------------------
# process_loop / _process_once
# ---------------------------------------------------------------------------
class TestProcessLoop:
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
"""有 running 任务时不 dequeue"""
# _get_pending_site_ids 返回 [42]
# has_running(42) 返回 True
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
else:
# has_running
cur = _mock_cursor(fetchone_val=(True,))
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)
# 不应调用 execute
mock_executor.execute.assert_not_called()
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
"""无 running 任务时 dequeue 并执行"""
task_id = str(uuid.uuid4())
config_dict = {
"tasks": ["ODS_MEMBER"],
"pipeline": "api_ods_dwd",
"processing_mode": "increment_only",
"dry_run": False,
"window_mode": "lookback",
"lookback_hours": 24,
"overlap_seconds": 600,
"fetch_before_verify": False,
"skip_ods_when_fetch_before_verify": False,
"ods_use_local_json": False,
"extra_args": {},
}
config_json = json.dumps(config_dict)
call_count = 0
def side_effect_conn():
nonlocal call_count
call_count += 1
if call_count == 1:
# _get_pending_site_ids
cur = _mock_cursor(fetchall_val=[(42,)])
return _mock_conn(cur)
elif call_count == 2:
# has_running → False
cur = _mock_cursor(fetchone_val=(False,))
return _mock_conn(cur)
else:
# dequeue → 返回任务
row = (
task_id, 42, config_json, "pending", 1,
None, None, None, None, None,
)
cur = _mock_cursor(fetchone_val=row)
return _mock_conn(cur)
mock_get_conn.side_effect = side_effect_conn
mock_executor = MagicMock()
mock_executor.execute = AsyncMock()
await queue._process_once(mock_executor)
# 给 create_task 一点时间启动
await asyncio.sleep(0.1)
@patch("app.services.task_queue.get_connection")
@pytest.mark.asyncio
async def test_process_once_no_pending(self, mock_get_conn, queue):
"""无 pending 任务时什么都不做"""
cur = _mock_cursor(fetchall_val=[])
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
mock_executor = MagicMock()
await queue._process_once(mock_executor)
mock_executor.execute.assert_not_called()
# ---------------------------------------------------------------------------
# 生命周期
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_stop_sets_running_false(self, queue):
queue._running = True
queue._loop_task = None
await queue.stop()
assert queue._running is False
def test_start_creates_task(self, queue):
"""start() 应创建 asyncio.Task需要事件循环"""
# 仅验证 _running 初始状态
assert queue._running is False
assert queue._loop_task is None
# ---------------------------------------------------------------------------
# _mark_failed / _update_queue_status_from_log
# ---------------------------------------------------------------------------
class TestInternalHelpers:
@patch("app.services.task_queue.get_connection")
def test_mark_failed(self, mock_get_conn, queue):
cur = _mock_cursor()
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._mark_failed("queue-1", "测试错误")
sql = cur.execute.call_args[0][0]
assert "failed" in sql
args = cur.execute.call_args[0][1]
assert args[0] == "测试错误"
assert args[1] == "queue-1"
@patch("app.services.task_queue.get_connection")
def test_update_queue_status_from_log(self, mock_get_conn, queue):
"""从 execution_log 同步状态到 task_queue"""
from datetime import datetime, timezone
finished = datetime.now(timezone.utc)
# 第一次 fetchone 返回 execution_log 行
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._update_queue_status_from_log("queue-1")
# 应有 SELECT + UPDATE 两次 execute
assert cur.execute.call_count == 2
conn.commit.assert_called_once()
@patch("app.services.task_queue.get_connection")
def test_update_queue_status_no_log(self, mock_get_conn, queue):
"""execution_log 无记录时不更新"""
cur = _mock_cursor(fetchone_val=None)
conn = _mock_conn(cur)
mock_get_conn.return_value = conn
queue._update_queue_status_from_log("queue-1")
# 只有 SELECT没有 UPDATE
assert cur.execute.call_count == 1