# -*- coding: utf-8 -*- """任务队列服务 基于 PostgreSQL task_queue 表实现 FIFO 队列,支持: - enqueue:入队,自动分配 position(当前最大 + 1) - dequeue:取出 position 最小的 pending 任务 - reorder:调整任务在队列中的位置 - delete:删除 pending 任务 - process_loop:后台协程,队列非空且无运行中任务时自动取出执行 所有操作按 site_id 过滤,实现门店隔离。 """ from __future__ import annotations import asyncio import json import logging import uuid from dataclasses import dataclass, field from typing import Any from ..database import get_connection from ..schemas.tasks import TaskConfigSchema logger = logging.getLogger(__name__) # 后台循环轮询间隔(秒) POLL_INTERVAL_SECONDS = 2 @dataclass class QueuedTask: """队列任务数据对象""" id: str site_id: int config: dict[str, Any] status: str position: int created_at: Any = None started_at: Any = None finished_at: Any = None exit_code: int | None = None error_message: str | None = None class TaskQueue: """基于 PostgreSQL 的任务队列""" def __init__(self) -> None: self._running = False self._loop_task: asyncio.Task | None = None # ------------------------------------------------------------------ # 入队 # ------------------------------------------------------------------ def enqueue(self, config: TaskConfigSchema, site_id: int) -> str: """将任务配置入队,自动分配 position。 Args: config: 任务配置 site_id: 门店 ID(门店隔离) Returns: 新创建的队列任务 ID(UUID 字符串) """ task_id = str(uuid.uuid4()) config_json = config.model_dump(mode="json") conn = get_connection() try: with conn.cursor() as cur: # 取当前该门店 pending 任务的最大 position,新任务排在末尾 cur.execute( """ SELECT COALESCE(MAX(position), 0) FROM task_queue WHERE site_id = %s AND status = 'pending' """, (site_id,), ) max_pos = cur.fetchone()[0] new_pos = max_pos + 1 cur.execute( """ INSERT INTO task_queue (id, site_id, config, status, position) VALUES (%s, %s, %s, 'pending', %s) """, (task_id, site_id, json.dumps(config_json), new_pos), ) conn.commit() finally: conn.close() logger.info("任务入队 [%s] site_id=%s position=%s", task_id, site_id, new_pos) return task_id # ------------------------------------------------------------------ # 出队 # ------------------------------------------------------------------ def dequeue(self, site_id: int) -> QueuedTask | None: """取出 position 最小的 pending 任务,将其状态改为 running。 Args: site_id: 门店 ID Returns: QueuedTask 或 None(队列为空时) """ conn = get_connection() try: with conn.cursor() as cur: # 选取 position 最小的 pending 任务并锁定 cur.execute( """ SELECT id, site_id, config, status, position, created_at, started_at, finished_at, exit_code, error_message FROM task_queue WHERE site_id = %s AND status = 'pending' ORDER BY position ASC LIMIT 1 FOR UPDATE SKIP LOCKED """, (site_id,), ) row = cur.fetchone() if row is None: conn.commit() return None task = QueuedTask( id=str(row[0]), site_id=row[1], config=row[2] if isinstance(row[2], dict) else json.loads(row[2]), status=row[3], position=row[4], created_at=row[5], started_at=row[6], finished_at=row[7], exit_code=row[8], error_message=row[9], ) # 更新状态为 running cur.execute( """ UPDATE task_queue SET status = 'running', started_at = NOW() WHERE id = %s """, (task.id,), ) conn.commit() finally: conn.close() task.status = "running" logger.info("任务出队 [%s] site_id=%s", task.id, site_id) return task # ------------------------------------------------------------------ # 重排 # ------------------------------------------------------------------ def reorder(self, task_id: str, new_position: int, site_id: int) -> None: """调整任务在队列中的位置。 仅允许对 pending 状态的任务重排。将目标任务移到 new_position, 其余 pending 任务按原有相对顺序重新编号。 Args: task_id: 要移动的任务 ID new_position: 目标位置(1-based) site_id: 门店 ID """ conn = get_connection() try: with conn.cursor() as cur: # 获取该门店所有 pending 任务,按 position 排序 cur.execute( """ SELECT id FROM task_queue WHERE site_id = %s AND status = 'pending' ORDER BY position ASC """, (site_id,), ) rows = cur.fetchall() task_ids = [str(r[0]) for r in rows] if task_id not in task_ids: conn.commit() return # 从列表中移除目标任务,再插入到新位置 task_ids.remove(task_id) # new_position 是 1-based,转为 0-based 索引并 clamp insert_idx = max(0, min(new_position - 1, len(task_ids))) task_ids.insert(insert_idx, task_id) # 按新顺序重新分配 position(1-based 连续编号) for idx, tid in enumerate(task_ids, start=1): cur.execute( "UPDATE task_queue SET position = %s WHERE id = %s", (idx, tid), ) conn.commit() finally: conn.close() logger.info( "任务重排 [%s] → position=%s site_id=%s", task_id, new_position, site_id, ) # ------------------------------------------------------------------ # 删除 # ------------------------------------------------------------------ def delete(self, task_id: str, site_id: int) -> bool: """删除 pending 状态的任务。 Args: task_id: 任务 ID site_id: 门店 ID Returns: True 表示成功删除,False 表示任务不存在或非 pending 状态。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ DELETE FROM task_queue WHERE id = %s AND site_id = %s AND status = 'pending' """, (task_id, site_id), ) deleted = cur.rowcount > 0 conn.commit() finally: conn.close() if deleted: logger.info("任务删除 [%s] site_id=%s", task_id, site_id) else: logger.warning( "任务删除失败 [%s] site_id=%s(不存在或非 pending)", task_id, site_id, ) return deleted # ------------------------------------------------------------------ # 查询 # ------------------------------------------------------------------ def list_pending(self, site_id: int) -> list[QueuedTask]: """列出指定门店的所有 pending 任务,按 position 升序。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT id, site_id, config, status, position, created_at, started_at, finished_at, exit_code, error_message FROM task_queue WHERE site_id = %s AND status = 'pending' ORDER BY position ASC """, (site_id,), ) rows = cur.fetchall() conn.commit() finally: conn.close() return [ QueuedTask( id=str(r[0]), site_id=r[1], config=r[2] if isinstance(r[2], dict) else json.loads(r[2]), status=r[3], position=r[4], created_at=r[5], started_at=r[6], finished_at=r[7], exit_code=r[8], error_message=r[9], ) for r in rows ] def has_running(self, site_id: int) -> bool: """检查指定门店是否有 running 状态的任务。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT EXISTS( SELECT 1 FROM task_queue WHERE site_id = %s AND status = 'running' ) """, (site_id,), ) result = cur.fetchone()[0] conn.commit() finally: conn.close() return result # ------------------------------------------------------------------ # 后台处理循环 # ------------------------------------------------------------------ async def process_loop(self) -> None: """后台协程:队列非空且无运行中任务时,自动取出并执行。 循环逻辑: 1. 查询所有有 pending 任务的 site_id 2. 对每个 site_id,若无 running 任务则 dequeue 并执行 3. 等待 POLL_INTERVAL_SECONDS 后重复 """ # 延迟导入避免循环依赖 from .task_executor import task_executor self._running = True logger.info("TaskQueue process_loop 启动") while self._running: try: await self._process_once(task_executor) except Exception: logger.exception("process_loop 迭代异常") await asyncio.sleep(POLL_INTERVAL_SECONDS) logger.info("TaskQueue process_loop 停止") async def _process_once(self, executor: Any) -> None: """单次处理:扫描所有门店的 pending 队列并执行。""" site_ids = self._get_pending_site_ids() for site_id in site_ids: if self.has_running(site_id): continue task = self.dequeue(site_id) if task is None: continue config = TaskConfigSchema(**task.config) execution_id = str(uuid.uuid4()) logger.info( "process_loop 自动执行 [%s] queue_id=%s site_id=%s", execution_id, task.id, site_id, ) # 异步启动执行(不阻塞循环) asyncio.create_task( self._execute_and_update( executor, config, execution_id, task.id, site_id, ) ) async def _execute_and_update( self, executor: Any, config: TaskConfigSchema, execution_id: str, queue_id: str, site_id: int, ) -> None: """执行任务并更新队列状态。""" try: await executor.execute( config=config, execution_id=execution_id, queue_id=queue_id, site_id=site_id, ) # 执行完成后根据 executor 的结果更新 task_queue 状态 self._update_queue_status_from_log(queue_id) except Exception: logger.exception("队列任务执行异常 [%s]", queue_id) self._mark_failed(queue_id, "执行过程中发生未捕获异常") def _get_pending_site_ids(self) -> list[int]: """获取所有有 pending 任务的 site_id 列表。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT DISTINCT site_id FROM task_queue WHERE status = 'pending' """ ) rows = cur.fetchall() conn.commit() finally: conn.close() return [r[0] for r in rows] def _update_queue_status_from_log(self, queue_id: str) -> None: """从 task_execution_log 读取执行结果,同步到 task_queue 记录。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT status, finished_at, exit_code, error_log FROM task_execution_log WHERE queue_id = %s ORDER BY started_at DESC LIMIT 1 """, (queue_id,), ) row = cur.fetchone() if row: cur.execute( """ UPDATE task_queue SET status = %s, finished_at = %s, exit_code = %s, error_message = %s WHERE id = %s """, (row[0], row[1], row[2], row[3], queue_id), ) conn.commit() finally: conn.close() def _mark_failed(self, queue_id: str, error_message: str) -> None: """将队列任务标记为 failed。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE task_queue SET status = 'failed', finished_at = NOW(), error_message = %s WHERE id = %s """, (error_message, queue_id), ) conn.commit() finally: conn.close() # ------------------------------------------------------------------ # 生命周期 # ------------------------------------------------------------------ def start(self) -> None: """启动后台处理循环(在 FastAPI lifespan 中调用)。""" if self._loop_task is None or self._loop_task.done(): self._loop_task = asyncio.create_task(self.process_loop()) logger.info("TaskQueue 后台循环已启动") async def stop(self) -> None: """停止后台处理循环。""" self._running = False if self._loop_task and not self._loop_task.done(): self._loop_task.cancel() try: await self._loop_task except asyncio.CancelledError: pass self._loop_task = None logger.info("TaskQueue 后台循环已停止") # 全局单例 task_queue = TaskQueue()