# -*- coding: utf-8 -*- """TaskQueue 单元测试 覆盖:enqueue、dequeue、reorder、delete、process_loop 的核心逻辑。 使用 mock 数据库操作,专注于业务逻辑验证。 """ import asyncio import json import uuid from unittest.mock import MagicMock, AsyncMock, patch, call import pytest from app.schemas.tasks import TaskConfigSchema from app.services.task_queue import TaskQueue, QueuedTask @pytest.fixture def queue() -> TaskQueue: return TaskQueue() @pytest.fixture def sample_config() -> TaskConfigSchema: return TaskConfigSchema( tasks=["ODS_MEMBER", "ODS_PAYMENT"], pipeline="api_ods_dwd", store_id=42, ) def _mock_cursor(fetchone_val=None, fetchall_val=None, rowcount=1): """构造 mock cursor,支持 context manager 协议。""" cur = MagicMock() cur.fetchone.return_value = fetchone_val cur.fetchall.return_value = fetchall_val or [] cur.rowcount = rowcount cur.__enter__ = MagicMock(return_value=cur) cur.__exit__ = MagicMock(return_value=False) return cur def _mock_conn(cursor): """构造 mock connection,支持 cursor() context manager。""" conn = MagicMock() conn.cursor.return_value = cursor return conn # --------------------------------------------------------------------------- # enqueue # --------------------------------------------------------------------------- class TestEnqueue: @patch("app.services.task_queue.get_connection") def test_enqueue_returns_uuid(self, mock_get_conn, queue, sample_config): cur = _mock_cursor(fetchone_val=(0,)) conn = _mock_conn(cur) mock_get_conn.return_value = conn task_id = queue.enqueue(sample_config, site_id=42) # 返回有效 UUID uuid.UUID(task_id) conn.commit.assert_called_once() conn.close.assert_called_once() @patch("app.services.task_queue.get_connection") def test_enqueue_position_increments(self, mock_get_conn, queue, sample_config): """新任务 position = 当前最大 position + 1""" cur = _mock_cursor(fetchone_val=(5,)) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.enqueue(sample_config, site_id=42) # 检查 INSERT 调用中的 position 参数 insert_call = cur.execute.call_args_list[1] args = insert_call[0][1] # args = (task_id, site_id, config_json, new_pos) assert args[3] == 6 # 5 + 1 @patch("app.services.task_queue.get_connection") def test_enqueue_empty_queue_position_is_one(self, mock_get_conn, queue, sample_config): """空队列时 position = 1""" cur = _mock_cursor(fetchone_val=(0,)) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.enqueue(sample_config, site_id=42) insert_call = cur.execute.call_args_list[1] args = insert_call[0][1] assert args[3] == 1 @patch("app.services.task_queue.get_connection") def test_enqueue_serializes_config(self, mock_get_conn, queue, sample_config): """config 被序列化为 JSON 字符串""" cur = _mock_cursor(fetchone_val=(0,)) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.enqueue(sample_config, site_id=42) insert_call = cur.execute.call_args_list[1] config_json_str = insert_call[0][1][2] parsed = json.loads(config_json_str) assert parsed["tasks"] == ["ODS_MEMBER", "ODS_PAYMENT"] assert parsed["pipeline"] == "api_ods_dwd" # --------------------------------------------------------------------------- # dequeue # --------------------------------------------------------------------------- class TestDequeue: @patch("app.services.task_queue.get_connection") def test_dequeue_returns_none_when_empty(self, mock_get_conn, queue): cur = _mock_cursor(fetchone_val=None) conn = _mock_conn(cur) mock_get_conn.return_value = conn result = queue.dequeue(site_id=42) assert result is None conn.commit.assert_called() @patch("app.services.task_queue.get_connection") def test_dequeue_returns_task(self, mock_get_conn, queue): task_id = str(uuid.uuid4()) config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"} row = ( task_id, 42, json.dumps(config_dict), "pending", 1, None, None, None, None, None, ) cur = _mock_cursor(fetchone_val=row) conn = _mock_conn(cur) mock_get_conn.return_value = conn result = queue.dequeue(site_id=42) assert result is not None assert result.id == task_id assert result.site_id == 42 assert result.status == "running" # dequeue 后状态变为 running assert result.config["tasks"] == ["ODS_MEMBER"] @patch("app.services.task_queue.get_connection") def test_dequeue_updates_status_to_running(self, mock_get_conn, queue): task_id = str(uuid.uuid4()) config_dict = {"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"} row = ( task_id, 42, json.dumps(config_dict), "pending", 1, None, None, None, None, None, ) cur = _mock_cursor(fetchone_val=row) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.dequeue(site_id=42) # 第二次 execute 调用应该是 UPDATE status = 'running' update_call = cur.execute.call_args_list[1] sql = update_call[0][0] assert "running" in sql assert task_id in update_call[0][1] # --------------------------------------------------------------------------- # reorder # --------------------------------------------------------------------------- class TestReorder: @patch("app.services.task_queue.get_connection") def test_reorder_moves_task(self, mock_get_conn, queue): """将第 3 个任务移到第 1 位""" ids = [str(uuid.uuid4()) for _ in range(3)] rows = [(i,) for i in ids] cur = _mock_cursor(fetchall_val=rows) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.reorder(ids[2], new_position=1, site_id=42) # 重排后顺序应为 [ids[2], ids[0], ids[1]] update_calls = cur.execute.call_args_list[1:] # 跳过 SELECT positions = {} for c in update_calls: pos, tid = c[0][1] positions[tid] = pos assert positions[ids[2]] == 1 assert positions[ids[0]] == 2 assert positions[ids[1]] == 3 @patch("app.services.task_queue.get_connection") def test_reorder_nonexistent_task_is_noop(self, mock_get_conn, queue): """重排不存在的任务不报错""" rows = [(str(uuid.uuid4()),)] cur = _mock_cursor(fetchall_val=rows) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.reorder("nonexistent-id", new_position=1, site_id=42) # 只有 SELECT,没有 UPDATE assert cur.execute.call_count == 1 @patch("app.services.task_queue.get_connection") def test_reorder_clamps_position(self, mock_get_conn, queue): """position 超出范围时 clamp 到有效范围""" ids = [str(uuid.uuid4()) for _ in range(2)] rows = [(i,) for i in ids] cur = _mock_cursor(fetchall_val=rows) conn = _mock_conn(cur) mock_get_conn.return_value = conn # new_position=100 超出范围,应 clamp 到末尾 queue.reorder(ids[0], new_position=100, site_id=42) update_calls = cur.execute.call_args_list[1:] positions = {} for c in update_calls: pos, tid = c[0][1] positions[tid] = pos # ids[0] 移到末尾 assert positions[ids[1]] == 1 assert positions[ids[0]] == 2 # --------------------------------------------------------------------------- # delete # --------------------------------------------------------------------------- class TestDelete: @patch("app.services.task_queue.get_connection") def test_delete_pending_task(self, mock_get_conn, queue): cur = _mock_cursor(rowcount=1) conn = _mock_conn(cur) mock_get_conn.return_value = conn result = queue.delete("task-1", site_id=42) assert result is True conn.commit.assert_called_once() @patch("app.services.task_queue.get_connection") def test_delete_nonexistent_returns_false(self, mock_get_conn, queue): cur = _mock_cursor(rowcount=0) conn = _mock_conn(cur) mock_get_conn.return_value = conn result = queue.delete("nonexistent", site_id=42) assert result is False @patch("app.services.task_queue.get_connection") def test_delete_only_affects_pending(self, mock_get_conn, queue): """DELETE SQL 包含 status = 'pending' 条件""" cur = _mock_cursor(rowcount=0) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue.delete("task-1", site_id=42) sql = cur.execute.call_args[0][0] assert "pending" in sql # --------------------------------------------------------------------------- # list_pending / has_running # --------------------------------------------------------------------------- class TestQuery: @patch("app.services.task_queue.get_connection") def test_list_pending_empty(self, mock_get_conn, queue): cur = _mock_cursor(fetchall_val=[]) conn = _mock_conn(cur) mock_get_conn.return_value = conn result = queue.list_pending(site_id=42) assert result == [] @patch("app.services.task_queue.get_connection") def test_list_pending_returns_tasks(self, mock_get_conn, queue): tid = str(uuid.uuid4()) config = json.dumps({"tasks": ["ODS_MEMBER"], "pipeline": "api_ods"}) rows = [(tid, 42, config, "pending", 1, None, None, None, None, None)] cur = _mock_cursor(fetchall_val=rows) conn = _mock_conn(cur) mock_get_conn.return_value = conn result = queue.list_pending(site_id=42) assert len(result) == 1 assert result[0].id == tid @patch("app.services.task_queue.get_connection") def test_has_running_true(self, mock_get_conn, queue): cur = _mock_cursor(fetchone_val=(True,)) conn = _mock_conn(cur) mock_get_conn.return_value = conn assert queue.has_running(site_id=42) is True @patch("app.services.task_queue.get_connection") def test_has_running_false(self, mock_get_conn, queue): cur = _mock_cursor(fetchone_val=(False,)) conn = _mock_conn(cur) mock_get_conn.return_value = conn assert queue.has_running(site_id=42) is False # --------------------------------------------------------------------------- # process_loop / _process_once # --------------------------------------------------------------------------- class TestProcessLoop: @patch("app.services.task_queue.get_connection") @pytest.mark.asyncio async def test_process_once_skips_when_running(self, mock_get_conn, queue): """有 running 任务时不 dequeue""" # _get_pending_site_ids 返回 [42] # has_running(42) 返回 True call_count = 0 def side_effect_conn(): nonlocal call_count call_count += 1 if call_count == 1: # _get_pending_site_ids cur = _mock_cursor(fetchall_val=[(42,)]) return _mock_conn(cur) else: # has_running cur = _mock_cursor(fetchone_val=(True,)) return _mock_conn(cur) mock_get_conn.side_effect = side_effect_conn mock_executor = MagicMock() await queue._process_once(mock_executor) # 不应调用 execute mock_executor.execute.assert_not_called() @patch("app.services.task_queue.get_connection") @pytest.mark.asyncio async def test_process_once_dequeues_and_executes(self, mock_get_conn, queue): """无 running 任务时 dequeue 并执行""" task_id = str(uuid.uuid4()) config_dict = { "tasks": ["ODS_MEMBER"], "pipeline": "api_ods_dwd", "processing_mode": "increment_only", "dry_run": False, "window_mode": "lookback", "lookback_hours": 24, "overlap_seconds": 600, "fetch_before_verify": False, "skip_ods_when_fetch_before_verify": False, "ods_use_local_json": False, "extra_args": {}, } config_json = json.dumps(config_dict) call_count = 0 def side_effect_conn(): nonlocal call_count call_count += 1 if call_count == 1: # _get_pending_site_ids cur = _mock_cursor(fetchall_val=[(42,)]) return _mock_conn(cur) elif call_count == 2: # has_running → False cur = _mock_cursor(fetchone_val=(False,)) return _mock_conn(cur) else: # dequeue → 返回任务 row = ( task_id, 42, config_json, "pending", 1, None, None, None, None, None, ) cur = _mock_cursor(fetchone_val=row) return _mock_conn(cur) mock_get_conn.side_effect = side_effect_conn mock_executor = MagicMock() mock_executor.execute = AsyncMock() await queue._process_once(mock_executor) # 给 create_task 一点时间启动 await asyncio.sleep(0.1) @patch("app.services.task_queue.get_connection") @pytest.mark.asyncio async def test_process_once_no_pending(self, mock_get_conn, queue): """无 pending 任务时什么都不做""" cur = _mock_cursor(fetchall_val=[]) conn = _mock_conn(cur) mock_get_conn.return_value = conn mock_executor = MagicMock() await queue._process_once(mock_executor) mock_executor.execute.assert_not_called() # --------------------------------------------------------------------------- # 生命周期 # --------------------------------------------------------------------------- class TestLifecycle: @pytest.mark.asyncio async def test_stop_sets_running_false(self, queue): queue._running = True queue._loop_task = None await queue.stop() assert queue._running is False def test_start_creates_task(self, queue): """start() 应创建 asyncio.Task(需要事件循环)""" # 仅验证 _running 初始状态 assert queue._running is False assert queue._loop_task is None # --------------------------------------------------------------------------- # _mark_failed / _update_queue_status_from_log # --------------------------------------------------------------------------- class TestInternalHelpers: @patch("app.services.task_queue.get_connection") def test_mark_failed(self, mock_get_conn, queue): cur = _mock_cursor() conn = _mock_conn(cur) mock_get_conn.return_value = conn queue._mark_failed("queue-1", "测试错误") sql = cur.execute.call_args[0][0] assert "failed" in sql args = cur.execute.call_args[0][1] assert args[0] == "测试错误" assert args[1] == "queue-1" @patch("app.services.task_queue.get_connection") def test_update_queue_status_from_log(self, mock_get_conn, queue): """从 execution_log 同步状态到 task_queue""" from datetime import datetime, timezone finished = datetime.now(timezone.utc) # 第一次 fetchone 返回 execution_log 行 cur = _mock_cursor(fetchone_val=("success", finished, 0, None)) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue._update_queue_status_from_log("queue-1") # 应有 SELECT + UPDATE 两次 execute assert cur.execute.call_count == 2 conn.commit.assert_called_once() @patch("app.services.task_queue.get_connection") def test_update_queue_status_no_log(self, mock_get_conn, queue): """execution_log 无记录时不更新""" cur = _mock_cursor(fetchone_val=None) conn = _mock_conn(cur) mock_get_conn.return_value = conn queue._update_queue_status_from_log("queue-1") # 只有 SELECT,没有 UPDATE assert cur.execute.call_count == 1