600 lines
21 KiB
Python
600 lines
21 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""任务队列服务
|
||
|
||
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
|
||
- enqueue:入队,自动分配 position(当前最大 + 1)
|
||
- dequeue:取出 position 最小的 pending 任务
|
||
- reorder:调整任务在队列中的位置
|
||
- delete:删除 pending 任务
|
||
- process_loop:后台协程,队列非空且无运行中任务时自动取出执行
|
||
|
||
所有操作按 site_id 过滤,实现门店隔离。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import platform
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
from ..database import get_connection
|
||
from ..schemas.tasks import TaskConfigSchema
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# CHANGE 2026-03-07 | 实例标识:用于多后端实例共享同一 DB 时的任务隔离
|
||
# 背景:发现有另一台机器(宿主机 D 盘)的后端也在消费同一个 task_queue,
|
||
# 导致任务被错误实例执行。通过 enqueued_by 列实现"谁入队谁消费"。
|
||
_INSTANCE_ID = platform.node()
|
||
|
||
# 后台循环轮询间隔(秒)
|
||
POLL_INTERVAL_SECONDS = 2
|
||
|
||
|
||
@dataclass
|
||
class QueuedTask:
|
||
"""队列任务数据对象"""
|
||
|
||
id: str
|
||
site_id: int
|
||
config: dict[str, Any]
|
||
status: str
|
||
position: int
|
||
created_at: Any = None
|
||
started_at: Any = None
|
||
finished_at: Any = None
|
||
exit_code: int | None = None
|
||
error_message: str | None = None
|
||
schedule_id: str | None = None
|
||
|
||
|
||
class TaskQueue:
|
||
"""基于 PostgreSQL 的任务队列"""
|
||
|
||
def __init__(self) -> None:
|
||
self._running = False
|
||
self._loop_task: asyncio.Task | None = None
|
||
|
||
# ------------------------------------------------------------------
|
||
# 入队
|
||
# ------------------------------------------------------------------
|
||
|
||
def enqueue(self, config: TaskConfigSchema, site_id: int, schedule_id: str | None = None) -> str:
|
||
"""将任务配置入队,自动分配 position。
|
||
|
||
Args:
|
||
config: 任务配置
|
||
site_id: 门店 ID(门店隔离)
|
||
schedule_id: 关联的调度任务 ID(可选)
|
||
|
||
Returns:
|
||
新创建的队列任务 ID(UUID 字符串)
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
config_json = config.model_dump(mode="json")
|
||
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 取当前该门店 pending 任务的最大 position,新任务排在末尾
|
||
cur.execute(
|
||
"""
|
||
SELECT COALESCE(MAX(position), 0)
|
||
FROM task_queue
|
||
WHERE site_id = %s AND status = 'pending'
|
||
""",
|
||
(site_id,),
|
||
)
|
||
max_pos = cur.fetchone()[0]
|
||
new_pos = max_pos + 1
|
||
|
||
# CHANGE 2026-03-07 | 写入 enqueued_by 实现多实例任务隔离
|
||
cur.execute(
|
||
"""
|
||
INSERT INTO task_queue (id, site_id, config, status, position, schedule_id, enqueued_by)
|
||
VALUES (%s, %s, %s, 'pending', %s, %s, %s)
|
||
""",
|
||
(task_id, site_id, json.dumps(config_json), new_pos, schedule_id, _INSTANCE_ID),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
logger.info("任务入队 [%s] site_id=%s position=%s schedule_id=%s", task_id, site_id, new_pos, schedule_id)
|
||
return task_id
|
||
|
||
# ------------------------------------------------------------------
|
||
# 出队
|
||
# ------------------------------------------------------------------
|
||
|
||
def dequeue(self, site_id: int) -> QueuedTask | None:
|
||
"""取出 position 最小的 pending 任务,将其状态改为 running。
|
||
|
||
Args:
|
||
site_id: 门店 ID
|
||
|
||
Returns:
|
||
QueuedTask 或 None(队列为空时)
|
||
"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# CHANGE 2026-03-07 | 只消费本实例入队的任务(enqueued_by 匹配)
|
||
# 背景:多后端实例共享同一 DB 时,防止 A 实例消费 B 实例入队的任务
|
||
cur.execute(
|
||
"""
|
||
SELECT id, site_id, config, status, position,
|
||
created_at, started_at, finished_at,
|
||
exit_code, error_message, schedule_id
|
||
FROM task_queue
|
||
WHERE site_id = %s AND status = 'pending'
|
||
AND (enqueued_by = %s OR enqueued_by IS NULL)
|
||
ORDER BY position ASC
|
||
LIMIT 1
|
||
FOR UPDATE SKIP LOCKED
|
||
""",
|
||
(site_id, _INSTANCE_ID),
|
||
)
|
||
row = cur.fetchone()
|
||
if row is None:
|
||
conn.commit()
|
||
return None
|
||
|
||
task = QueuedTask(
|
||
id=str(row[0]),
|
||
site_id=row[1],
|
||
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
|
||
status=row[3],
|
||
position=row[4],
|
||
created_at=row[5],
|
||
started_at=row[6],
|
||
finished_at=row[7],
|
||
exit_code=row[8],
|
||
error_message=row[9],
|
||
schedule_id=str(row[10]) if row[10] else None,
|
||
)
|
||
|
||
# 更新状态为 running
|
||
cur.execute(
|
||
"""
|
||
UPDATE task_queue
|
||
SET status = 'running', started_at = NOW()
|
||
WHERE id = %s
|
||
""",
|
||
(task.id,),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
task.status = "running"
|
||
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
|
||
return task
|
||
|
||
# ------------------------------------------------------------------
|
||
# 重排
|
||
# ------------------------------------------------------------------
|
||
|
||
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
|
||
"""调整任务在队列中的位置。
|
||
|
||
仅允许对 pending 状态的任务重排。将目标任务移到 new_position,
|
||
其余 pending 任务按原有相对顺序重新编号。
|
||
|
||
Args:
|
||
task_id: 要移动的任务 ID
|
||
new_position: 目标位置(1-based)
|
||
site_id: 门店 ID
|
||
"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 获取该门店所有 pending 任务,按 position 排序
|
||
cur.execute(
|
||
"""
|
||
SELECT id FROM task_queue
|
||
WHERE site_id = %s AND status = 'pending'
|
||
ORDER BY position ASC
|
||
""",
|
||
(site_id,),
|
||
)
|
||
rows = cur.fetchall()
|
||
task_ids = [str(r[0]) for r in rows]
|
||
|
||
if task_id not in task_ids:
|
||
conn.commit()
|
||
return
|
||
|
||
# 从列表中移除目标任务,再插入到新位置
|
||
task_ids.remove(task_id)
|
||
# new_position 是 1-based,转为 0-based 索引并 clamp
|
||
insert_idx = max(0, min(new_position - 1, len(task_ids)))
|
||
task_ids.insert(insert_idx, task_id)
|
||
|
||
# 按新顺序重新分配 position(1-based 连续编号)
|
||
for idx, tid in enumerate(task_ids, start=1):
|
||
cur.execute(
|
||
"UPDATE task_queue SET position = %s WHERE id = %s",
|
||
(idx, tid),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
logger.info(
|
||
"任务重排 [%s] → position=%s site_id=%s",
|
||
task_id, new_position, site_id,
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 删除
|
||
# ------------------------------------------------------------------
|
||
|
||
def delete(self, task_id: str, site_id: int) -> bool:
|
||
"""删除 pending 状态的任务。
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
site_id: 门店 ID
|
||
|
||
Returns:
|
||
True 表示成功删除,False 表示任务不存在或非 pending 状态。
|
||
"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
DELETE FROM task_queue
|
||
WHERE id = %s AND site_id = %s AND status = 'pending'
|
||
""",
|
||
(task_id, site_id),
|
||
)
|
||
deleted = cur.rowcount > 0
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
if deleted:
|
||
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
|
||
else:
|
||
logger.warning(
|
||
"任务删除失败 [%s] site_id=%s(不存在或非 pending)",
|
||
task_id, site_id,
|
||
)
|
||
return deleted
|
||
|
||
# ------------------------------------------------------------------
|
||
# 查询
|
||
# ------------------------------------------------------------------
|
||
|
||
def list_pending(self, site_id: int) -> list[QueuedTask]:
|
||
"""列出指定门店的所有 pending 任务(仅限本实例入队的),按 position 升序。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# CHANGE 2026-03-07 | 只列出本实例入队的 pending 任务
|
||
cur.execute(
|
||
"""
|
||
SELECT id, site_id, config, status, position,
|
||
created_at, started_at, finished_at,
|
||
exit_code, error_message
|
||
FROM task_queue
|
||
WHERE site_id = %s AND status = 'pending'
|
||
AND (enqueued_by = %s OR enqueued_by IS NULL)
|
||
ORDER BY position ASC
|
||
""",
|
||
(site_id, _INSTANCE_ID),
|
||
)
|
||
rows = cur.fetchall()
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
return [
|
||
QueuedTask(
|
||
id=str(r[0]),
|
||
site_id=r[1],
|
||
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
|
||
status=r[3],
|
||
position=r[4],
|
||
created_at=r[5],
|
||
started_at=r[6],
|
||
finished_at=r[7],
|
||
exit_code=r[8],
|
||
error_message=r[9],
|
||
)
|
||
for r in rows
|
||
]
|
||
|
||
def has_running(self, site_id: int) -> bool:
|
||
"""检查指定门店是否有本实例的 running 状态任务。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# CHANGE 2026-03-07 | 只检查本实例的 running 任务
|
||
cur.execute(
|
||
"""
|
||
SELECT EXISTS(
|
||
SELECT 1 FROM task_queue
|
||
WHERE site_id = %s AND status = 'running'
|
||
AND (enqueued_by = %s OR enqueued_by IS NULL)
|
||
)
|
||
""",
|
||
(site_id, _INSTANCE_ID),
|
||
)
|
||
result = cur.fetchone()[0]
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
return result
|
||
|
||
# ------------------------------------------------------------------
|
||
# 后台处理循环
|
||
# ------------------------------------------------------------------
|
||
|
||
async def process_loop(self) -> None:
|
||
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
|
||
|
||
循环逻辑:
|
||
1. 查询所有有 pending 任务的 site_id
|
||
2. 对每个 site_id,若无 running 任务则 dequeue 并执行
|
||
3. 等待 POLL_INTERVAL_SECONDS 后重复
|
||
"""
|
||
# 延迟导入避免循环依赖
|
||
from .task_executor import task_executor
|
||
|
||
self._running = True
|
||
logger.info(
|
||
"TaskQueue process_loop 启动 (instance_id=%s,仅消费本实例入队的任务)",
|
||
_INSTANCE_ID,
|
||
)
|
||
|
||
while self._running:
|
||
try:
|
||
await self._process_once(task_executor)
|
||
except Exception:
|
||
logger.exception("process_loop 迭代异常")
|
||
|
||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||
|
||
logger.info("TaskQueue process_loop 停止")
|
||
|
||
async def _process_once(self, executor: Any) -> None:
|
||
"""单次处理:扫描所有门店的 pending 队列并执行。"""
|
||
# CHANGE 2026-03-09 | 每次轮询先回收僵尸 running 任务
|
||
self._recover_zombie_tasks()
|
||
|
||
site_ids = self._get_pending_site_ids()
|
||
|
||
for site_id in site_ids:
|
||
if self.has_running(site_id):
|
||
continue
|
||
|
||
task = self.dequeue(site_id)
|
||
if task is None:
|
||
continue
|
||
|
||
config = TaskConfigSchema(**task.config)
|
||
execution_id = str(uuid.uuid4())
|
||
|
||
logger.info(
|
||
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
|
||
execution_id, task.id, site_id,
|
||
)
|
||
|
||
# 异步启动执行(不阻塞循环)
|
||
asyncio.create_task(
|
||
self._execute_and_update(
|
||
executor, config, execution_id, task.id, site_id,
|
||
schedule_id=task.schedule_id,
|
||
)
|
||
)
|
||
|
||
async def _execute_and_update(
|
||
self,
|
||
executor: Any,
|
||
config: TaskConfigSchema,
|
||
execution_id: str,
|
||
queue_id: str,
|
||
site_id: int,
|
||
schedule_id: str | None = None,
|
||
) -> None:
|
||
"""执行任务并更新队列状态。"""
|
||
try:
|
||
await executor.execute(
|
||
config=config,
|
||
execution_id=execution_id,
|
||
queue_id=queue_id,
|
||
site_id=site_id,
|
||
schedule_id=schedule_id,
|
||
)
|
||
# 执行完成后根据 executor 的结果更新 task_queue 状态
|
||
self._update_queue_status_from_log(queue_id)
|
||
except Exception:
|
||
logger.exception("队列任务执行异常 [%s]", queue_id)
|
||
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
|
||
finally:
|
||
# CHANGE 2026-03-09 | 兜底:确保 task_queue 不会卡在 running
|
||
# 背景:_update_execution_log 内部异常(如 duration_ms integer 溢出)
|
||
# 被吞掉后,_update_queue_status_from_log 读到的 execution_log 仍是
|
||
# running,导致 task_queue 永远卡住,后续任务全部排队。
|
||
self._ensure_not_stuck_running(queue_id)
|
||
|
||
|
||
def _get_pending_site_ids(self) -> list[int]:
|
||
"""获取所有有 pending 任务的 site_id 列表(仅限本实例入队的)。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# CHANGE 2026-03-07 | 只查本实例入队的 pending 任务
|
||
cur.execute(
|
||
"""
|
||
SELECT DISTINCT site_id FROM task_queue
|
||
WHERE status = 'pending'
|
||
AND (enqueued_by = %s OR enqueued_by IS NULL)
|
||
""",
|
||
(_INSTANCE_ID,),
|
||
)
|
||
rows = cur.fetchall()
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
return [r[0] for r in rows]
|
||
|
||
def _update_queue_status_from_log(self, queue_id: str) -> None:
|
||
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
SELECT status, finished_at, exit_code, error_log
|
||
FROM task_execution_log
|
||
WHERE queue_id = %s
|
||
ORDER BY started_at DESC
|
||
LIMIT 1
|
||
""",
|
||
(queue_id,),
|
||
)
|
||
row = cur.fetchone()
|
||
if row:
|
||
cur.execute(
|
||
"""
|
||
UPDATE task_queue
|
||
SET status = %s, finished_at = %s,
|
||
exit_code = %s, error_message = %s
|
||
WHERE id = %s
|
||
""",
|
||
(row[0], row[1], row[2], row[3], queue_id),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
def _mark_failed(self, queue_id: str, error_message: str) -> None:
|
||
"""将队列任务标记为 failed。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
UPDATE task_queue
|
||
SET status = 'failed', finished_at = NOW(),
|
||
error_message = %s
|
||
WHERE id = %s
|
||
""",
|
||
(error_message, queue_id),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
def _ensure_not_stuck_running(self, queue_id: str) -> None:
|
||
"""兜底检查:如果 task_queue 仍是 running,强制标记 failed。
|
||
|
||
CHANGE 2026-03-09 | 防止 _update_execution_log 内部异常导致
|
||
task_queue 永远卡在 running 状态。
|
||
"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT status FROM task_queue WHERE id = %s",
|
||
(queue_id,),
|
||
)
|
||
row = cur.fetchone()
|
||
if row and row[0] == "running":
|
||
logger.warning(
|
||
"兜底修正:task_queue [%s] 执行完毕但仍为 running,"
|
||
"强制标记 failed",
|
||
queue_id,
|
||
)
|
||
cur.execute(
|
||
"""
|
||
UPDATE task_queue
|
||
SET status = 'failed', finished_at = NOW(),
|
||
error_message = %s
|
||
WHERE id = %s AND status = 'running'
|
||
""",
|
||
(
|
||
"[兜底修正] 执行流程结束但状态未同步,"
|
||
"可能因 execution_log 更新失败",
|
||
queue_id,
|
||
),
|
||
)
|
||
conn.commit()
|
||
except Exception:
|
||
logger.exception("_ensure_not_stuck_running 异常 [%s]", queue_id)
|
||
finally:
|
||
conn.close()
|
||
|
||
def _recover_zombie_tasks(self, max_running_minutes: int = 180) -> None:
|
||
"""恢复僵尸 running 任务:超过阈值时间仍为 running 的任务强制标记 failed。
|
||
|
||
CHANGE 2026-03-09 | 在 process_loop 每次轮询时调用,作为最后防线。
|
||
场景:后端进程崩溃/重启后,之前的 running 任务永远不会被更新。
|
||
"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""
|
||
UPDATE task_queue
|
||
SET status = 'failed', finished_at = NOW(),
|
||
error_message = %s
|
||
WHERE status = 'running'
|
||
AND (enqueued_by = %s OR enqueued_by IS NULL)
|
||
AND started_at < NOW() - INTERVAL '%s minutes'
|
||
RETURNING id
|
||
""",
|
||
(
|
||
f"[僵尸回收] running 超过 {max_running_minutes} 分钟,"
|
||
"自动标记 failed",
|
||
_INSTANCE_ID,
|
||
max_running_minutes,
|
||
),
|
||
)
|
||
recovered = cur.fetchall()
|
||
if recovered:
|
||
ids = [r[0] for r in recovered]
|
||
logger.warning(
|
||
"僵尸回收:%d 个 running 任务超时,已标记 failed: %s",
|
||
len(ids), ids,
|
||
)
|
||
conn.commit()
|
||
except Exception:
|
||
logger.exception("_recover_zombie_tasks 异常")
|
||
finally:
|
||
conn.close()
|
||
|
||
# ------------------------------------------------------------------
|
||
# 生命周期
|
||
# ------------------------------------------------------------------
|
||
|
||
def start(self) -> None:
|
||
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
|
||
if self._loop_task is None or self._loop_task.done():
|
||
self._loop_task = asyncio.create_task(self.process_loop())
|
||
logger.info("TaskQueue 后台循环已启动")
|
||
|
||
async def stop(self) -> None:
|
||
"""停止后台处理循环。"""
|
||
self._running = False
|
||
if self._loop_task and not self._loop_task.done():
|
||
self._loop_task.cancel()
|
||
try:
|
||
await self._loop_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._loop_task = None
|
||
logger.info("TaskQueue 后台循环已停止")
|
||
|
||
|
||
# 全局单例
|
||
task_queue = TaskQueue()
|