在准备环境前提交次全部更改。
This commit is contained in:
482
apps/backend/tests/test_task_queue.py
Normal file
482
apps/backend/tests/test_task_queue.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskQueue 单元测试
|
||||
|
||||
覆盖:enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
|
||||
使用 mock 数据库操作,专注于业务逻辑验证。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from app.schemas.tasks import TaskConfigSchema
|
||||
from app.services.task_queue import TaskQueue, QueuedTask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue() -> TaskQueue:
|
||||
return TaskQueue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> TaskConfigSchema:
|
||||
return TaskConfigSchema(
|
||||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||||
pipeline="api_ods_dwd",
|
||||
store_id=42,
|
||||
)
|
||||
|
||||
|
||||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||||
"""构造 mock cursor,支持 context manager 协议。"""
|
||||
cur = MagicMock()
|
||||
cur.fetchone.return_value = fetchone_val
|
||||
cur.fetchall.return_value = fetchall_val or []
|
||||
cur.rowcount = rowcount
|
||||
cur.__enter__ = MagicMock(return_value=cur)
|
||||
cur.__exit__ = MagicMock(return_value=False)
|
||||
return cur
|
||||
|
||||
|
||||
def _mock_conn(cursor):
|
||||
"""构造 mock connection,支持 cursor() context manager。"""
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursor
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnqueue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
task_id = queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 返回有效 UUID
|
||||
uuid.UUID(task_id)
|
||||
conn.commit.assert_called_once()
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
|
||||
"""新任务 position = 当前最大 position + 1"""
|
||||
cur = _mock_cursor(fetchone_val=(5,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
# 检查 INSERT 调用中的 position 参数
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
# args = (task_id, site_id, config_json, new_pos)
|
||||
assert args[3] == 6 # 5 + 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
|
||||
"""空队列时 position = 1"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
args = insert_call[0][1]
|
||||
assert args[3] == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
|
||||
"""config 被序列化为 JSON 字符串"""
|
||||
cur = _mock_cursor(fetchone_val=(0,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.enqueue(sample_config, site_id=42)
|
||||
|
||||
insert_call = cur.execute.call_args_list[1]
|
||||
config_json_str = insert_call[0][1][2]
|
||||
parsed = json.loads(config_json_str)
|
||||
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
|
||||
assert parsed["pipeline"] == "api_ods_dwd"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dequeue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDequeue:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is None
|
||||
conn.commit.assert_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_returns_task(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.dequeue(site_id=42)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == task_id
|
||||
assert result.site_id == 42
|
||||
assert result.status == "running" # dequeue 后状态变为 running
|
||||
assert result.config["tasks"] == ["ODS_MEMBER"]
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}
|
||||
row = (
|
||||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.dequeue(site_id=42)
|
||||
|
||||
# 第二次 execute 调用应该是 UPDATE status = 'running'
|
||||
update_call = cur.execute.call_args_list[1]
|
||||
sql = update_call[0][0]
|
||||
assert "running" in sql
|
||||
assert task_id in update_call[0][1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reorder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReorder:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_moves_task(self, mock_get_conn, queue):
|
||||
"""将第 3 个任务移到第 1 位"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder(ids[2], new_position=1, site_id=42)
|
||||
|
||||
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
|
||||
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
assert positions[ids[2]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
assert positions[ids[1]] == 3
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
|
||||
"""重排不存在的任务不报错"""
|
||||
rows = [(str(uuid.uuid4()),)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.reorder("nonexistent-id", new_position=1, site_id=42)
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_reorder_clamps_position(self, mock_get_conn, queue):
|
||||
"""position 超出范围时 clamp 到有效范围"""
|
||||
ids = [str(uuid.uuid4()) for _ in range(2)]
|
||||
rows = [(i,) for i in ids]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
# new_position=100 超出范围,应 clamp 到末尾
|
||||
queue.reorder(ids[0], new_position=100, site_id=42)
|
||||
|
||||
update_calls = cur.execute.call_args_list[1:]
|
||||
positions = {}
|
||||
for c in update_calls:
|
||||
pos, tid = c[0][1]
|
||||
positions[tid] = pos
|
||||
# ids[0] 移到末尾
|
||||
assert positions[ids[1]] == 1
|
||||
assert positions[ids[0]] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDelete:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_pending_task(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=1)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("task-1", site_id=42)
|
||||
|
||||
assert result is True
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.delete("nonexistent", site_id=42)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_delete_only_affects_pending(self, mock_get_conn, queue):
|
||||
"""DELETE SQL 包含 status = 'pending' 条件"""
|
||||
cur = _mock_cursor(rowcount=0)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue.delete("task-1", site_id=42)
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "pending" in sql
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_pending / has_running
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuery:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_empty(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert result == []
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
|
||||
tid = str(uuid.uuid4())
|
||||
config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"})
|
||||
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
|
||||
cur = _mock_cursor(fetchall_val=rows)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
result = queue.list_pending(site_id=42)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == tid
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_true(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is True
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_has_running_false(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
assert queue.has_running(site_id=42) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_loop / _process_once
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessLoop:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
|
||||
"""有 running 任务时不 dequeue"""
|
||||
# _get_pending_site_ids 返回 [42]
|
||||
# has_running(42) 返回 True
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# has_running
|
||||
cur = _mock_cursor(fetchone_val=(True,))
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 不应调用 execute
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
|
||||
"""无 running 任务时 dequeue 并执行"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_dict = {
|
||||
"tasks": ["ODS_MEMBER"],
|
||||
"pipeline": "api_ods_dwd",
|
||||
"processing_mode": "increment_only",
|
||||
"dry_run": False,
|
||||
"window_mode": "lookback",
|
||||
"lookback_hours": 24,
|
||||
"overlap_seconds": 600,
|
||||
"fetch_before_verify": False,
|
||||
"skip_ods_when_fetch_before_verify": False,
|
||||
"ods_use_local_json": False,
|
||||
"extra_args": {},
|
||||
}
|
||||
config_json = json.dumps(config_dict)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def side_effect_conn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# _get_pending_site_ids
|
||||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||||
return _mock_conn(cur)
|
||||
elif call_count == 2:
|
||||
# has_running → False
|
||||
cur = _mock_cursor(fetchone_val=(False,))
|
||||
return _mock_conn(cur)
|
||||
else:
|
||||
# dequeue → 返回任务
|
||||
row = (
|
||||
task_id, 42, config_json, "pending", 1,
|
||||
None, None, None, None, None,
|
||||
)
|
||||
cur = _mock_cursor(fetchone_val=row)
|
||||
return _mock_conn(cur)
|
||||
|
||||
mock_get_conn.side_effect = side_effect_conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
mock_executor.execute = AsyncMock()
|
||||
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
# 给 create_task 一点时间启动
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_once_no_pending(self, mock_get_conn, queue):
|
||||
"""无 pending 任务时什么都不做"""
|
||||
cur = _mock_cursor(fetchall_val=[])
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
mock_executor = MagicMock()
|
||||
await queue._process_once(mock_executor)
|
||||
|
||||
mock_executor.execute.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sets_running_false(self, queue):
|
||||
queue._running = True
|
||||
queue._loop_task = None
|
||||
|
||||
await queue.stop()
|
||||
|
||||
assert queue._running is False
|
||||
|
||||
def test_start_creates_task(self, queue):
|
||||
"""start() 应创建 asyncio.Task(需要事件循环)"""
|
||||
# 仅验证 _running 初始状态
|
||||
assert queue._running is False
|
||||
assert queue._loop_task is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_failed / _update_queue_status_from_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInternalHelpers:
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_mark_failed(self, mock_get_conn, queue):
|
||||
cur = _mock_cursor()
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._mark_failed("queue-1", "测试错误")
|
||||
|
||||
sql = cur.execute.call_args[0][0]
|
||||
assert "failed" in sql
|
||||
args = cur.execute.call_args[0][1]
|
||||
assert args[0] == "测试错误"
|
||||
assert args[1] == "queue-1"
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_from_log(self, mock_get_conn, queue):
|
||||
"""从 execution_log 同步状态到 task_queue"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
finished = datetime.now(timezone.utc)
|
||||
# 第一次 fetchone 返回 execution_log 行
|
||||
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 应有 SELECT + UPDATE 两次 execute
|
||||
assert cur.execute.call_count == 2
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
@patch("app.services.task_queue.get_connection")
|
||||
def test_update_queue_status_no_log(self, mock_get_conn, queue):
|
||||
"""execution_log 无记录时不更新"""
|
||||
cur = _mock_cursor(fetchone_val=None)
|
||||
conn = _mock_conn(cur)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
queue._update_queue_status_from_log("queue-1")
|
||||
|
||||
# 只有 SELECT,没有 UPDATE
|
||||
assert cur.execute.call_count == 1
|
||||
Reference in New Issue
Block a user