Files
Neo-ZQYY/apps/backend/app/services/task_queue.py

512 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""任务队列服务
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
- enqueue入队自动分配 position当前最大 + 1
- dequeue取出 position 最小的 pending 任务
- reorder调整任务在队列中的位置
- delete删除 pending 任务
- process_loop后台协程队列非空且无运行中任务时自动取出执行
所有操作按 site_id 过滤,实现门店隔离。
"""
from __future__ import annotations
import asyncio
import json
import logging
import platform
import uuid
from dataclasses import dataclass, field
from typing import Any
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
logger = logging.getLogger(__name__)
# CHANGE 2026-03-07 | 实例标识:用于多后端实例共享同一 DB 时的任务隔离
# 背景:发现有另一台机器(宿主机 D 盘)的后端也在消费同一个 task_queue
# 导致任务被错误实例执行。通过 enqueued_by 列实现"谁入队谁消费"。
_INSTANCE_ID = platform.node()
# 后台循环轮询间隔(秒)
POLL_INTERVAL_SECONDS = 2
@dataclass
class QueuedTask:
"""队列任务数据对象"""
id: str
site_id: int
config: dict[str, Any]
status: str
position: int
created_at: Any = None
started_at: Any = None
finished_at: Any = None
exit_code: int | None = None
error_message: str | None = None
schedule_id: str | None = None
class TaskQueue:
"""基于 PostgreSQL 的任务队列"""
def __init__(self) -> None:
self._running = False
self._loop_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# 入队
# ------------------------------------------------------------------
def enqueue(self, config: TaskConfigSchema, site_id: int, schedule_id: str | None = None) -> str:
"""将任务配置入队,自动分配 position。
Args:
config: 任务配置
site_id: 门店 ID门店隔离
schedule_id: 关联的调度任务 ID可选
Returns:
新创建的队列任务 IDUUID 字符串)
"""
task_id = str(uuid.uuid4())
config_json = config.model_dump(mode="json")
conn = get_connection()
try:
with conn.cursor() as cur:
# 取当前该门店 pending 任务的最大 position新任务排在末尾
cur.execute(
"""
SELECT COALESCE(MAX(position), 0)
FROM task_queue
WHERE site_id = %s AND status = 'pending'
""",
(site_id,),
)
max_pos = cur.fetchone()[0]
new_pos = max_pos + 1
# CHANGE 2026-03-07 | 写入 enqueued_by 实现多实例任务隔离
cur.execute(
"""
INSERT INTO task_queue (id, site_id, config, status, position, schedule_id, enqueued_by)
VALUES (%s, %s, %s, 'pending', %s, %s, %s)
""",
(task_id, site_id, json.dumps(config_json), new_pos, schedule_id, _INSTANCE_ID),
)
conn.commit()
finally:
conn.close()
logger.info("任务入队 [%s] site_id=%s position=%s schedule_id=%s", task_id, site_id, new_pos, schedule_id)
return task_id
# ------------------------------------------------------------------
# 出队
# ------------------------------------------------------------------
def dequeue(self, site_id: int) -> QueuedTask | None:
"""取出 position 最小的 pending 任务,将其状态改为 running。
Args:
site_id: 门店 ID
Returns:
QueuedTask 或 None队列为空时
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-07 | 只消费本实例入队的任务enqueued_by 匹配)
# 背景:多后端实例共享同一 DB 时,防止 A 实例消费 B 实例入队的任务
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message, schedule_id
FROM task_queue
WHERE site_id = %s AND status = 'pending'
AND (enqueued_by = %s OR enqueued_by IS NULL)
ORDER BY position ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
""",
(site_id, _INSTANCE_ID),
)
row = cur.fetchone()
if row is None:
conn.commit()
return None
task = QueuedTask(
id=str(row[0]),
site_id=row[1],
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
status=row[3],
position=row[4],
created_at=row[5],
started_at=row[6],
finished_at=row[7],
exit_code=row[8],
error_message=row[9],
schedule_id=str(row[10]) if row[10] else None,
)
# 更新状态为 running
cur.execute(
"""
UPDATE task_queue
SET status = 'running', started_at = NOW()
WHERE id = %s
""",
(task.id,),
)
conn.commit()
finally:
conn.close()
task.status = "running"
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
return task
# ------------------------------------------------------------------
# 重排
# ------------------------------------------------------------------
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
"""调整任务在队列中的位置。
仅允许对 pending 状态的任务重排。将目标任务移到 new_position
其余 pending 任务按原有相对顺序重新编号。
Args:
task_id: 要移动的任务 ID
new_position: 目标位置1-based
site_id: 门店 ID
"""
conn = get_connection()
try:
with conn.cursor() as cur:
# 获取该门店所有 pending 任务,按 position 排序
cur.execute(
"""
SELECT id FROM task_queue
WHERE site_id = %s AND status = 'pending'
ORDER BY position ASC
""",
(site_id,),
)
rows = cur.fetchall()
task_ids = [str(r[0]) for r in rows]
if task_id not in task_ids:
conn.commit()
return
# 从列表中移除目标任务,再插入到新位置
task_ids.remove(task_id)
# new_position 是 1-based转为 0-based 索引并 clamp
insert_idx = max(0, min(new_position - 1, len(task_ids)))
task_ids.insert(insert_idx, task_id)
# 按新顺序重新分配 position1-based 连续编号)
for idx, tid in enumerate(task_ids, start=1):
cur.execute(
"UPDATE task_queue SET position = %s WHERE id = %s",
(idx, tid),
)
conn.commit()
finally:
conn.close()
logger.info(
"任务重排 [%s] → position=%s site_id=%s",
task_id, new_position, site_id,
)
# ------------------------------------------------------------------
# 删除
# ------------------------------------------------------------------
def delete(self, task_id: str, site_id: int) -> bool:
"""删除 pending 状态的任务。
Args:
task_id: 任务 ID
site_id: 门店 ID
Returns:
True 表示成功删除False 表示任务不存在或非 pending 状态。
"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
DELETE FROM task_queue
WHERE id = %s AND site_id = %s AND status = 'pending'
""",
(task_id, site_id),
)
deleted = cur.rowcount > 0
conn.commit()
finally:
conn.close()
if deleted:
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
else:
logger.warning(
"任务删除失败 [%s] site_id=%s(不存在或非 pending",
task_id, site_id,
)
return deleted
# ------------------------------------------------------------------
# 查询
# ------------------------------------------------------------------
def list_pending(self, site_id: int) -> list[QueuedTask]:
"""列出指定门店的所有 pending 任务(仅限本实例入队的),按 position 升序。"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-07 | 只列出本实例入队的 pending 任务
cur.execute(
"""
SELECT id, site_id, config, status, position,
created_at, started_at, finished_at,
exit_code, error_message
FROM task_queue
WHERE site_id = %s AND status = 'pending'
AND (enqueued_by = %s OR enqueued_by IS NULL)
ORDER BY position ASC
""",
(site_id, _INSTANCE_ID),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [
QueuedTask(
id=str(r[0]),
site_id=r[1],
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
status=r[3],
position=r[4],
created_at=r[5],
started_at=r[6],
finished_at=r[7],
exit_code=r[8],
error_message=r[9],
)
for r in rows
]
def has_running(self, site_id: int) -> bool:
"""检查指定门店是否有本实例的 running 状态任务。"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-07 | 只检查本实例的 running 任务
cur.execute(
"""
SELECT EXISTS(
SELECT 1 FROM task_queue
WHERE site_id = %s AND status = 'running'
AND (enqueued_by = %s OR enqueued_by IS NULL)
)
""",
(site_id, _INSTANCE_ID),
)
result = cur.fetchone()[0]
conn.commit()
finally:
conn.close()
return result
# ------------------------------------------------------------------
# 后台处理循环
# ------------------------------------------------------------------
async def process_loop(self) -> None:
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
循环逻辑:
1. 查询所有有 pending 任务的 site_id
2. 对每个 site_id若无 running 任务则 dequeue 并执行
3. 等待 POLL_INTERVAL_SECONDS 后重复
"""
# 延迟导入避免循环依赖
from .task_executor import task_executor
self._running = True
logger.info(
"TaskQueue process_loop 启动 (instance_id=%s,仅消费本实例入队的任务)",
_INSTANCE_ID,
)
while self._running:
try:
await self._process_once(task_executor)
except Exception:
logger.exception("process_loop 迭代异常")
await asyncio.sleep(POLL_INTERVAL_SECONDS)
logger.info("TaskQueue process_loop 停止")
async def _process_once(self, executor: Any) -> None:
"""单次处理:扫描所有门店的 pending 队列并执行。"""
site_ids = self._get_pending_site_ids()
for site_id in site_ids:
if self.has_running(site_id):
continue
task = self.dequeue(site_id)
if task is None:
continue
config = TaskConfigSchema(**task.config)
execution_id = str(uuid.uuid4())
logger.info(
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
execution_id, task.id, site_id,
)
# 异步启动执行(不阻塞循环)
asyncio.create_task(
self._execute_and_update(
executor, config, execution_id, task.id, site_id,
schedule_id=task.schedule_id,
)
)
async def _execute_and_update(
self,
executor: Any,
config: TaskConfigSchema,
execution_id: str,
queue_id: str,
site_id: int,
schedule_id: str | None = None,
) -> None:
"""执行任务并更新队列状态。"""
try:
await executor.execute(
config=config,
execution_id=execution_id,
queue_id=queue_id,
site_id=site_id,
schedule_id=schedule_id,
)
# 执行完成后根据 executor 的结果更新 task_queue 状态
self._update_queue_status_from_log(queue_id)
except Exception:
logger.exception("队列任务执行异常 [%s]", queue_id)
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
def _get_pending_site_ids(self) -> list[int]:
"""获取所有有 pending 任务的 site_id 列表(仅限本实例入队的)。"""
conn = get_connection()
try:
with conn.cursor() as cur:
# CHANGE 2026-03-07 | 只查本实例入队的 pending 任务
cur.execute(
"""
SELECT DISTINCT site_id FROM task_queue
WHERE status = 'pending'
AND (enqueued_by = %s OR enqueued_by IS NULL)
""",
(_INSTANCE_ID,),
)
rows = cur.fetchall()
conn.commit()
finally:
conn.close()
return [r[0] for r in rows]
def _update_queue_status_from_log(self, queue_id: str) -> None:
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT status, finished_at, exit_code, error_log
FROM task_execution_log
WHERE queue_id = %s
ORDER BY started_at DESC
LIMIT 1
""",
(queue_id,),
)
row = cur.fetchone()
if row:
cur.execute(
"""
UPDATE task_queue
SET status = %s, finished_at = %s,
exit_code = %s, error_message = %s
WHERE id = %s
""",
(row[0], row[1], row[2], row[3], queue_id),
)
conn.commit()
finally:
conn.close()
def _mark_failed(self, queue_id: str, error_message: str) -> None:
"""将队列任务标记为 failed。"""
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_queue
SET status = 'failed', finished_at = NOW(),
error_message = %s
WHERE id = %s
""",
(error_message, queue_id),
)
conn.commit()
finally:
conn.close()
# ------------------------------------------------------------------
# 生命周期
# ------------------------------------------------------------------
def start(self) -> None:
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
if self._loop_task is None or self._loop_task.done():
self._loop_task = asyncio.create_task(self.process_loop())
logger.info("TaskQueue 后台循环已启动")
async def stop(self) -> None:
"""停止后台处理循环。"""
self._running = False
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
try:
await self._loop_task
except asyncio.CancelledError:
pass
self._loop_task = None
logger.info("TaskQueue 后台循环已停止")
# 全局单例
task_queue = TaskQueue()