# -*- coding: utf-8 -*- """任务队列服务 基于 PostgreSQL task_queue 表实现 FIFO 队列,支持: - enqueue:入队,自动分配 position(当前最大 + 1) - dequeue:取出 position 最小的 pending 任务 - reorder:调整任务在队列中的位置 - delete:删除 pending 任务 - process_loop:后台协程,队列非空且无运行中任务时自动取出执行 所有操作按 site_id 过滤,实现门店隔离。 """ from __future__ import annotations import asyncio import json import logging import platform import uuid from dataclasses import dataclass, field from typing import Any from ..database import get_connection from ..schemas.tasks import TaskConfigSchema logger = logging.getLogger(__name__) # CHANGE 2026-03-07 | 实例标识:用于多后端实例共享同一 DB 时的任务隔离 # 背景:发现有另一台机器(宿主机 D 盘)的后端也在消费同一个 task_queue, # 导致任务被错误实例执行。通过 enqueued_by 列实现"谁入队谁消费"。 _INSTANCE_ID = platform.node() # 后台循环轮询间隔(秒) POLL_INTERVAL_SECONDS = 2 @dataclass class QueuedTask: """队列任务数据对象""" id: str site_id: int config: dict[str, Any] status: str position: int created_at: Any = None started_at: Any = None finished_at: Any = None exit_code: int | None = None error_message: str | None = None schedule_id: str | None = None class TaskQueue: """基于 PostgreSQL 的任务队列""" def __init__(self) -> None: self._running = False self._loop_task: asyncio.Task | None = None # ------------------------------------------------------------------ # 入队 # ------------------------------------------------------------------ def enqueue(self, config: TaskConfigSchema, site_id: int, schedule_id: str | None = None) -> str: """将任务配置入队,自动分配 position。 Args: config: 任务配置 site_id: 门店 ID(门店隔离) schedule_id: 关联的调度任务 ID(可选) Returns: 新创建的队列任务 ID(UUID 字符串) """ task_id = str(uuid.uuid4()) config_json = config.model_dump(mode="json") conn = get_connection() try: with conn.cursor() as cur: # 取当前该门店 pending 任务的最大 position,新任务排在末尾 cur.execute( """ SELECT COALESCE(MAX(position), 0) FROM task_queue WHERE site_id = %s AND status = 'pending' """, (site_id,), ) max_pos = cur.fetchone()[0] new_pos = max_pos + 1 # CHANGE 2026-03-07 | 写入 enqueued_by 实现多实例任务隔离 cur.execute( """ INSERT INTO task_queue (id, site_id, config, status, position, schedule_id, enqueued_by) VALUES (%s, %s, %s, 'pending', %s, %s, %s) """, (task_id, site_id, json.dumps(config_json), new_pos, schedule_id, _INSTANCE_ID), ) conn.commit() finally: conn.close() logger.info("任务入队 [%s] site_id=%s position=%s schedule_id=%s", task_id, site_id, new_pos, schedule_id) return task_id # ------------------------------------------------------------------ # 出队 # ------------------------------------------------------------------ def dequeue(self, site_id: int) -> QueuedTask | None: """取出 position 最小的 pending 任务,将其状态改为 running。 Args: site_id: 门店 ID Returns: QueuedTask 或 None(队列为空时) """ conn = get_connection() try: with conn.cursor() as cur: # CHANGE 2026-03-07 | 只消费本实例入队的任务(enqueued_by 匹配) # 背景:多后端实例共享同一 DB 时,防止 A 实例消费 B 实例入队的任务 cur.execute( """ SELECT id, site_id, config, status, position, created_at, started_at, finished_at, exit_code, error_message, schedule_id FROM task_queue WHERE site_id = %s AND status = 'pending' AND (enqueued_by = %s OR enqueued_by IS NULL) ORDER BY position ASC LIMIT 1 FOR UPDATE SKIP LOCKED """, (site_id, _INSTANCE_ID), ) row = cur.fetchone() if row is None: conn.commit() return None task = QueuedTask( id=str(row[0]), site_id=row[1], config=row[2] if isinstance(row[2], dict) else json.loads(row[2]), status=row[3], position=row[4], created_at=row[5], started_at=row[6], finished_at=row[7], exit_code=row[8], error_message=row[9], schedule_id=str(row[10]) if row[10] else None, ) # 更新状态为 running cur.execute( """ UPDATE task_queue SET status = 'running', started_at = NOW() WHERE id = %s """, (task.id,), ) conn.commit() finally: conn.close() task.status = "running" logger.info("任务出队 [%s] site_id=%s", task.id, site_id) return task # ------------------------------------------------------------------ # 重排 # ------------------------------------------------------------------ def reorder(self, task_id: str, new_position: int, site_id: int) -> None: """调整任务在队列中的位置。 仅允许对 pending 状态的任务重排。将目标任务移到 new_position, 其余 pending 任务按原有相对顺序重新编号。 Args: task_id: 要移动的任务 ID new_position: 目标位置(1-based) site_id: 门店 ID """ conn = get_connection() try: with conn.cursor() as cur: # 获取该门店所有 pending 任务,按 position 排序 cur.execute( """ SELECT id FROM task_queue WHERE site_id = %s AND status = 'pending' ORDER BY position ASC """, (site_id,), ) rows = cur.fetchall() task_ids = [str(r[0]) for r in rows] if task_id not in task_ids: conn.commit() return # 从列表中移除目标任务,再插入到新位置 task_ids.remove(task_id) # new_position 是 1-based,转为 0-based 索引并 clamp insert_idx = max(0, min(new_position - 1, len(task_ids))) task_ids.insert(insert_idx, task_id) # 按新顺序重新分配 position(1-based 连续编号) for idx, tid in enumerate(task_ids, start=1): cur.execute( "UPDATE task_queue SET position = %s WHERE id = %s", (idx, tid), ) conn.commit() finally: conn.close() logger.info( "任务重排 [%s] → position=%s site_id=%s", task_id, new_position, site_id, ) # ------------------------------------------------------------------ # 删除 # ------------------------------------------------------------------ def delete(self, task_id: str, site_id: int) -> bool: """删除 pending 状态的任务。 Args: task_id: 任务 ID site_id: 门店 ID Returns: True 表示成功删除,False 表示任务不存在或非 pending 状态。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ DELETE FROM task_queue WHERE id = %s AND site_id = %s AND status = 'pending' """, (task_id, site_id), ) deleted = cur.rowcount > 0 conn.commit() finally: conn.close() if deleted: logger.info("任务删除 [%s] site_id=%s", task_id, site_id) else: logger.warning( "任务删除失败 [%s] site_id=%s(不存在或非 pending)", task_id, site_id, ) return deleted # ------------------------------------------------------------------ # 查询 # ------------------------------------------------------------------ def list_pending(self, site_id: int) -> list[QueuedTask]: """列出指定门店的所有 pending 任务(仅限本实例入队的),按 position 升序。""" conn = get_connection() try: with conn.cursor() as cur: # CHANGE 2026-03-07 | 只列出本实例入队的 pending 任务 cur.execute( """ SELECT id, site_id, config, status, position, created_at, started_at, finished_at, exit_code, error_message FROM task_queue WHERE site_id = %s AND status = 'pending' AND (enqueued_by = %s OR enqueued_by IS NULL) ORDER BY position ASC """, (site_id, _INSTANCE_ID), ) rows = cur.fetchall() conn.commit() finally: conn.close() return [ QueuedTask( id=str(r[0]), site_id=r[1], config=r[2] if isinstance(r[2], dict) else json.loads(r[2]), status=r[3], position=r[4], created_at=r[5], started_at=r[6], finished_at=r[7], exit_code=r[8], error_message=r[9], ) for r in rows ] def has_running(self, site_id: int) -> bool: """检查指定门店是否有本实例的 running 状态任务。""" conn = get_connection() try: with conn.cursor() as cur: # CHANGE 2026-03-07 | 只检查本实例的 running 任务 cur.execute( """ SELECT EXISTS( SELECT 1 FROM task_queue WHERE site_id = %s AND status = 'running' AND (enqueued_by = %s OR enqueued_by IS NULL) ) """, (site_id, _INSTANCE_ID), ) result = cur.fetchone()[0] conn.commit() finally: conn.close() return result # ------------------------------------------------------------------ # 后台处理循环 # ------------------------------------------------------------------ async def process_loop(self) -> None: """后台协程:队列非空且无运行中任务时,自动取出并执行。 循环逻辑: 1. 查询所有有 pending 任务的 site_id 2. 对每个 site_id,若无 running 任务则 dequeue 并执行 3. 等待 POLL_INTERVAL_SECONDS 后重复 """ # 延迟导入避免循环依赖 from .task_executor import task_executor self._running = True logger.info( "TaskQueue process_loop 启动 (instance_id=%s,仅消费本实例入队的任务)", _INSTANCE_ID, ) while self._running: try: await self._process_once(task_executor) except Exception: logger.exception("process_loop 迭代异常") await asyncio.sleep(POLL_INTERVAL_SECONDS) logger.info("TaskQueue process_loop 停止") async def _process_once(self, executor: Any) -> None: """单次处理:扫描所有门店的 pending 队列并执行。""" # CHANGE 2026-03-09 | 每次轮询先回收僵尸 running 任务 self._recover_zombie_tasks() site_ids = self._get_pending_site_ids() for site_id in site_ids: if self.has_running(site_id): continue task = self.dequeue(site_id) if task is None: continue config = TaskConfigSchema(**task.config) execution_id = str(uuid.uuid4()) logger.info( "process_loop 自动执行 [%s] queue_id=%s site_id=%s", execution_id, task.id, site_id, ) # 异步启动执行(不阻塞循环) asyncio.create_task( self._execute_and_update( executor, config, execution_id, task.id, site_id, schedule_id=task.schedule_id, ) ) async def _execute_and_update( self, executor: Any, config: TaskConfigSchema, execution_id: str, queue_id: str, site_id: int, schedule_id: str | None = None, ) -> None: """执行任务并更新队列状态。""" try: await executor.execute( config=config, execution_id=execution_id, queue_id=queue_id, site_id=site_id, schedule_id=schedule_id, ) # 执行完成后根据 executor 的结果更新 task_queue 状态 self._update_queue_status_from_log(queue_id) except Exception: logger.exception("队列任务执行异常 [%s]", queue_id) self._mark_failed(queue_id, "执行过程中发生未捕获异常") finally: # CHANGE 2026-03-09 | 兜底:确保 task_queue 不会卡在 running # 背景:_update_execution_log 内部异常(如 duration_ms integer 溢出) # 被吞掉后,_update_queue_status_from_log 读到的 execution_log 仍是 # running,导致 task_queue 永远卡住,后续任务全部排队。 self._ensure_not_stuck_running(queue_id) def _get_pending_site_ids(self) -> list[int]: """获取所有有 pending 任务的 site_id 列表(仅限本实例入队的)。""" conn = get_connection() try: with conn.cursor() as cur: # CHANGE 2026-03-07 | 只查本实例入队的 pending 任务 cur.execute( """ SELECT DISTINCT site_id FROM task_queue WHERE status = 'pending' AND (enqueued_by = %s OR enqueued_by IS NULL) """, (_INSTANCE_ID,), ) rows = cur.fetchall() conn.commit() finally: conn.close() return [r[0] for r in rows] def _update_queue_status_from_log(self, queue_id: str) -> None: """从 task_execution_log 读取执行结果,同步到 task_queue 记录。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT status, finished_at, exit_code, error_log FROM task_execution_log WHERE queue_id = %s ORDER BY started_at DESC LIMIT 1 """, (queue_id,), ) row = cur.fetchone() if row: cur.execute( """ UPDATE task_queue SET status = %s, finished_at = %s, exit_code = %s, error_message = %s WHERE id = %s """, (row[0], row[1], row[2], row[3], queue_id), ) conn.commit() finally: conn.close() def _mark_failed(self, queue_id: str, error_message: str) -> None: """将队列任务标记为 failed。""" conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE task_queue SET status = 'failed', finished_at = NOW(), error_message = %s WHERE id = %s """, (error_message, queue_id), ) conn.commit() finally: conn.close() def _ensure_not_stuck_running(self, queue_id: str) -> None: """兜底检查:如果 task_queue 仍是 running,强制标记 failed。 CHANGE 2026-03-09 | 防止 _update_execution_log 内部异常导致 task_queue 永远卡在 running 状态。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( "SELECT status FROM task_queue WHERE id = %s", (queue_id,), ) row = cur.fetchone() if row and row[0] == "running": logger.warning( "兜底修正:task_queue [%s] 执行完毕但仍为 running," "强制标记 failed", queue_id, ) cur.execute( """ UPDATE task_queue SET status = 'failed', finished_at = NOW(), error_message = %s WHERE id = %s AND status = 'running' """, ( "[兜底修正] 执行流程结束但状态未同步," "可能因 execution_log 更新失败", queue_id, ), ) conn.commit() except Exception: logger.exception("_ensure_not_stuck_running 异常 [%s]", queue_id) finally: conn.close() def _recover_zombie_tasks(self, max_running_minutes: int = 180) -> None: """恢复僵尸 running 任务:超过阈值时间仍为 running 的任务强制标记 failed。 CHANGE 2026-03-09 | 在 process_loop 每次轮询时调用,作为最后防线。 场景:后端进程崩溃/重启后,之前的 running 任务永远不会被更新。 """ conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE task_queue SET status = 'failed', finished_at = NOW(), error_message = %s WHERE status = 'running' AND (enqueued_by = %s OR enqueued_by IS NULL) AND started_at < NOW() - INTERVAL '%s minutes' RETURNING id """, ( f"[僵尸回收] running 超过 {max_running_minutes} 分钟," "自动标记 failed", _INSTANCE_ID, max_running_minutes, ), ) recovered = cur.fetchall() if recovered: ids = [r[0] for r in recovered] logger.warning( "僵尸回收:%d 个 running 任务超时,已标记 failed: %s", len(ids), ids, ) conn.commit() except Exception: logger.exception("_recover_zombie_tasks 异常") finally: conn.close() # ------------------------------------------------------------------ # 生命周期 # ------------------------------------------------------------------ def start(self) -> None: """启动后台处理循环(在 FastAPI lifespan 中调用)。""" if self._loop_task is None or self._loop_task.done(): self._loop_task = asyncio.create_task(self.process_loop()) logger.info("TaskQueue 后台循环已启动") async def stop(self) -> None: """停止后台处理循环。""" self._running = False if self._loop_task and not self._loop_task.done(): self._loop_task.cancel() try: await self._loop_task except asyncio.CancelledError: pass self._loop_task = None logger.info("TaskQueue 后台循环已停止") # 全局单例 task_queue = TaskQueue()