502 lines
17 KiB
Python
502 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""TaskQueue 单元测试
|
||
|
||
覆盖:enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。
|
||
使用 mock 数据库操作,专注于业务逻辑验证。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import uuid
|
||
from unittest.mock import MagicMock, AsyncMock, patch, call
|
||
|
||
import pytest
|
||
|
||
from app.schemas.tasks import TaskConfigSchema
|
||
from app.services.task_queue import TaskQueue, QueuedTask
|
||
|
||
|
||
@pytest.fixture
|
||
def queue() -> TaskQueue:
|
||
return TaskQueue()
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_config() -> TaskConfigSchema:
|
||
return TaskConfigSchema(
|
||
tasks=["ODS_MEMBER", "ODS_PAYMENT"],
|
||
flow="api_ods_dwd",
|
||
store_id=42,
|
||
)
|
||
|
||
|
||
def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1):
|
||
"""构造 mock cursor,支持 context manager 协议。"""
|
||
cur = MagicMock()
|
||
cur.fetchone.return_value = fetchone_val
|
||
cur.fetchall.return_value = fetchall_val or []
|
||
cur.rowcount = rowcount
|
||
cur.__enter__ = MagicMock(return_value=cur)
|
||
cur.__exit__ = MagicMock(return_value=False)
|
||
return cur
|
||
|
||
|
||
def _mock_conn(cursor):
|
||
"""构造 mock connection,支持 cursor() context manager。"""
|
||
conn = MagicMock()
|
||
conn.cursor.return_value = cursor
|
||
return conn
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# enqueue
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestEnqueue:
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config):
|
||
cur = _mock_cursor(fetchone_val=(0,))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
task_id = queue.enqueue(sample_config, site_id=42)
|
||
|
||
# 返回有效 UUID
|
||
uuid.UUID(task_id)
|
||
conn.commit.assert_called_once()
|
||
conn.close.assert_called_once()
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config):
|
||
"""新任务 position = 当前最大 position + 1"""
|
||
cur = _mock_cursor(fetchone_val=(5,))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.enqueue(sample_config, site_id=42)
|
||
|
||
# 检查 INSERT 调用中的 position 参数
|
||
insert_call = cur.execute.call_args_list[1]
|
||
args = insert_call[0][1]
|
||
# args = (task_id, site_id, config_json, new_pos)
|
||
assert args[3] == 6 # 5 + 1
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config):
|
||
"""空队列时 position = 1"""
|
||
cur = _mock_cursor(fetchone_val=(0,))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.enqueue(sample_config, site_id=42)
|
||
|
||
insert_call = cur.execute.call_args_list[1]
|
||
args = insert_call[0][1]
|
||
assert args[3] == 1
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config):
|
||
"""config 被序列化为 JSON 字符串"""
|
||
cur = _mock_cursor(fetchone_val=(0,))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.enqueue(sample_config, site_id=42)
|
||
|
||
insert_call = cur.execute.call_args_list[1]
|
||
config_json_str = insert_call[0][1][2]
|
||
parsed = json.loads(config_json_str)
|
||
assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"]
|
||
assert parsed["flow"] == "api_ods_dwd"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# dequeue
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDequeue:
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue):
|
||
cur = _mock_cursor(fetchone_val=None)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
result = queue.dequeue(site_id=42)
|
||
|
||
assert result is None
|
||
conn.commit.assert_called()
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_dequeue_returns_task(self, mock_get_conn, queue):
|
||
task_id = str(uuid.uuid4())
|
||
config_dict = {"tasks": ["ODS_MEMBER"], "flow": "api_ods"}
|
||
row = (
|
||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||
None, None, None, None, None, None,
|
||
)
|
||
cur = _mock_cursor(fetchone_val=row)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
result = queue.dequeue(site_id=42)
|
||
|
||
assert result is not None
|
||
assert result.id == task_id
|
||
assert result.site_id == 42
|
||
assert result.status == "running" # dequeue 后状态变为 running
|
||
assert result.config["tasks"] == ["ODS_MEMBER"]
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_dequeue_updates_status_to_running(self, mock_get_conn, queue):
|
||
task_id = str(uuid.uuid4())
|
||
config_dict = {"tasks": ["ODS_MEMBER"], "flow": "api_ods"}
|
||
row = (
|
||
task_id, 42, json.dumps(config_dict), "pending", 1,
|
||
None, None, None, None, None, None,
|
||
)
|
||
cur = _mock_cursor(fetchone_val=row)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.dequeue(site_id=42)
|
||
|
||
# 第二次 execute 调用应该是 UPDATE status = 'running'
|
||
update_call = cur.execute.call_args_list[1]
|
||
sql = update_call[0][0]
|
||
assert "running" in sql
|
||
assert task_id in update_call[0][1]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# reorder
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestReorder:
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_reorder_moves_task(self, mock_get_conn, queue):
|
||
"""将第 3 个任务移到第 1 位"""
|
||
ids = [str(uuid.uuid4()) for _ in range(3)]
|
||
rows = [(i,) for i in ids]
|
||
cur = _mock_cursor(fetchall_val=rows)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.reorder(ids[2], new_position=1, site_id=42)
|
||
|
||
# 重排后顺序应为 [ids[2], ids[0], ids[1]]
|
||
update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT
|
||
positions = {}
|
||
for c in update_calls:
|
||
pos, tid = c[0][1]
|
||
positions[tid] = pos
|
||
assert positions[ids[2]] == 1
|
||
assert positions[ids[0]] == 2
|
||
assert positions[ids[1]] == 3
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue):
|
||
"""重排不存在的任务不报错"""
|
||
rows = [(str(uuid.uuid4()),)]
|
||
cur = _mock_cursor(fetchall_val=rows)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.reorder("nonexistent-id", new_position=1, site_id=42)
|
||
|
||
# 只有 SELECT,没有 UPDATE
|
||
assert cur.execute.call_count == 1
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_reorder_clamps_position(self, mock_get_conn, queue):
|
||
"""position 超出范围时 clamp 到有效范围"""
|
||
ids = [str(uuid.uuid4()) for _ in range(2)]
|
||
rows = [(i,) for i in ids]
|
||
cur = _mock_cursor(fetchall_val=rows)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
# new_position=100 超出范围,应 clamp 到末尾
|
||
queue.reorder(ids[0], new_position=100, site_id=42)
|
||
|
||
update_calls = cur.execute.call_args_list[1:]
|
||
positions = {}
|
||
for c in update_calls:
|
||
pos, tid = c[0][1]
|
||
positions[tid] = pos
|
||
# ids[0] 移到末尾
|
||
assert positions[ids[1]] == 1
|
||
assert positions[ids[0]] == 2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# delete
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDelete:
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_delete_pending_task(self, mock_get_conn, queue):
|
||
cur = _mock_cursor(rowcount=1)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
result = queue.delete("task-1", site_id=42)
|
||
|
||
assert result is True
|
||
conn.commit.assert_called_once()
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_delete_nonexistent_returns_false(self, mock_get_conn, queue):
|
||
cur = _mock_cursor(rowcount=0)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
result = queue.delete("nonexistent", site_id=42)
|
||
|
||
assert result is False
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_delete_only_affects_pending(self, mock_get_conn, queue):
|
||
"""DELETE SQL 包含 status = 'pending' 条件"""
|
||
cur = _mock_cursor(rowcount=0)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue.delete("task-1", site_id=42)
|
||
|
||
sql = cur.execute.call_args[0][0]
|
||
assert "pending" in sql
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# list_pending / has_running
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestQuery:
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_list_pending_empty(self, mock_get_conn, queue):
|
||
cur = _mock_cursor(fetchall_val=[])
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
result = queue.list_pending(site_id=42)
|
||
|
||
assert result == []
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_list_pending_returns_tasks(self, mock_get_conn, queue):
|
||
tid = str(uuid.uuid4())
|
||
config = json.dumps({"tasks": ["ODS_MEMBER"], "flow": "api_ods"})
|
||
rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)]
|
||
cur = _mock_cursor(fetchall_val=rows)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
result = queue.list_pending(site_id=42)
|
||
|
||
assert len(result) == 1
|
||
assert result[0].id == tid
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_has_running_true(self, mock_get_conn, queue):
|
||
cur = _mock_cursor(fetchone_val=(True,))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
assert queue.has_running(site_id=42) is True
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_has_running_false(self, mock_get_conn, queue):
|
||
cur = _mock_cursor(fetchone_val=(False,))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
assert queue.has_running(site_id=42) is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# process_loop / _process_once
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestProcessLoop:
|
||
@patch("app.services.task_queue.get_connection")
|
||
@pytest.mark.asyncio
|
||
async def test_process_once_skips_when_running(self, mock_get_conn, queue):
|
||
"""有 running 任务时不 dequeue"""
|
||
# 调用顺序:_recover_zombie_tasks → _get_pending_site_ids → has_running
|
||
call_count = 0
|
||
|
||
def side_effect_conn():
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
# _recover_zombie_tasks(无僵尸任务)
|
||
cur = _mock_cursor()
|
||
return _mock_conn(cur)
|
||
elif call_count == 2:
|
||
# _get_pending_site_ids
|
||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||
return _mock_conn(cur)
|
||
else:
|
||
# has_running
|
||
cur = _mock_cursor(fetchone_val=(True,))
|
||
return _mock_conn(cur)
|
||
|
||
mock_get_conn.side_effect = side_effect_conn
|
||
|
||
mock_executor = MagicMock()
|
||
await queue._process_once(mock_executor)
|
||
|
||
# 不应调用 execute
|
||
mock_executor.execute.assert_not_called()
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
@pytest.mark.asyncio
|
||
async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue):
|
||
"""无 running 任务时 dequeue 并执行"""
|
||
task_id = str(uuid.uuid4())
|
||
config_dict = {
|
||
"tasks": ["ODS_MEMBER"],
|
||
"flow": "api_ods_dwd",
|
||
"processing_mode": "increment_only",
|
||
"dry_run": False,
|
||
"window_mode": "lookback",
|
||
"lookback_hours": 24,
|
||
"overlap_seconds": 600,
|
||
"fetch_before_verify": False,
|
||
"skip_ods_when_fetch_before_verify": False,
|
||
"ods_use_local_json": False,
|
||
"extra_args": {},
|
||
}
|
||
config_json = json.dumps(config_dict)
|
||
|
||
call_count = 0
|
||
|
||
def side_effect_conn():
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
# _recover_zombie_tasks(无僵尸任务)
|
||
cur = _mock_cursor()
|
||
return _mock_conn(cur)
|
||
elif call_count == 2:
|
||
# _get_pending_site_ids
|
||
cur = _mock_cursor(fetchall_val=[(42,)])
|
||
return _mock_conn(cur)
|
||
elif call_count == 3:
|
||
# has_running → False
|
||
cur = _mock_cursor(fetchone_val=(False,))
|
||
return _mock_conn(cur)
|
||
else:
|
||
# dequeue → 返回任务
|
||
row = (
|
||
task_id, 42, config_json, "pending", 1,
|
||
None, None, None, None, None, None,
|
||
)
|
||
cur = _mock_cursor(fetchone_val=row)
|
||
return _mock_conn(cur)
|
||
|
||
mock_get_conn.side_effect = side_effect_conn
|
||
|
||
mock_executor = MagicMock()
|
||
mock_executor.execute = AsyncMock()
|
||
|
||
await queue._process_once(mock_executor)
|
||
|
||
# 给 create_task 一点时间启动
|
||
await asyncio.sleep(0.1)
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
@pytest.mark.asyncio
|
||
async def test_process_once_no_pending(self, mock_get_conn, queue):
|
||
"""无 pending 任务时什么都不做"""
|
||
call_count = 0
|
||
|
||
def side_effect_conn():
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
# _recover_zombie_tasks(无僵尸任务)
|
||
cur = _mock_cursor()
|
||
return _mock_conn(cur)
|
||
else:
|
||
# _get_pending_site_ids → 空
|
||
cur = _mock_cursor(fetchall_val=[])
|
||
return _mock_conn(cur)
|
||
|
||
mock_get_conn.side_effect = side_effect_conn
|
||
|
||
mock_executor = MagicMock()
|
||
await queue._process_once(mock_executor)
|
||
|
||
mock_executor.execute.assert_not_called()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 生命周期
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestLifecycle:
|
||
@pytest.mark.asyncio
|
||
async def test_stop_sets_running_false(self, queue):
|
||
queue._running = True
|
||
queue._loop_task = None
|
||
|
||
await queue.stop()
|
||
|
||
assert queue._running is False
|
||
|
||
def test_start_creates_task(self, queue):
|
||
"""start() 应创建 asyncio.Task(需要事件循环)"""
|
||
# 仅验证 _running 初始状态
|
||
assert queue._running is False
|
||
assert queue._loop_task is None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _mark_failed / _update_queue_status_from_log
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestInternalHelpers:
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_mark_failed(self, mock_get_conn, queue):
|
||
cur = _mock_cursor()
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue._mark_failed("queue-1", "测试错误")
|
||
|
||
sql = cur.execute.call_args[0][0]
|
||
assert "failed" in sql
|
||
args = cur.execute.call_args[0][1]
|
||
assert args[0] == "测试错误"
|
||
assert args[1] == "queue-1"
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_update_queue_status_from_log(self, mock_get_conn, queue):
|
||
"""从 execution_log 同步状态到 task_queue"""
|
||
from datetime import datetime, timezone
|
||
|
||
finished = datetime.now(timezone.utc)
|
||
# 第一次 fetchone 返回 execution_log 行
|
||
cur = _mock_cursor(fetchone_val=("success", finished, 0, None))
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue._update_queue_status_from_log("queue-1")
|
||
|
||
# 应有 SELECT + UPDATE 两次 execute
|
||
assert cur.execute.call_count == 2
|
||
conn.commit.assert_called_once()
|
||
|
||
@patch("app.services.task_queue.get_connection")
|
||
def test_update_queue_status_no_log(self, mock_get_conn, queue):
|
||
"""execution_log 无记录时不更新"""
|
||
cur = _mock_cursor(fetchone_val=None)
|
||
conn = _mock_conn(cur)
|
||
mock_get_conn.return_value = conn
|
||
|
||
queue._update_queue_status_from_log("queue-1")
|
||
|
||
# 只有 SELECT,没有 UPDATE
|
||
assert cur.execute.call_count == 1
|