在准备环境前提交次全部更改。
This commit is contained in:
486
apps/backend/app/services/task_queue.py
Normal file
486
apps/backend/app/services/task_queue.py
Normal file
@@ -0,0 +1,486 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务队列服务
|
||||
|
||||
基于 PostgreSQL task_queue 表实现 FIFO 队列,支持:
|
||||
- enqueue:入队,自动分配 position(当前最大 + 1)
|
||||
- dequeue:取出 position 最小的 pending 任务
|
||||
- reorder:调整任务在队列中的位置
|
||||
- delete:删除 pending 任务
|
||||
- process_loop:后台协程,队列非空且无运行中任务时自动取出执行
|
||||
|
||||
所有操作按 site_id 过滤,实现门店隔离。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..database import get_connection
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 后台循环轮询间隔(秒)
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedTask:
|
||||
"""队列任务数据对象"""
|
||||
|
||||
id: str
|
||||
site_id: int
|
||||
config: dict[str, Any]
|
||||
status: str
|
||||
position: int
|
||||
created_at: Any = None
|
||||
started_at: Any = None
|
||||
finished_at: Any = None
|
||||
exit_code: int | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""基于 PostgreSQL 的任务队列"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
self._loop_task: asyncio.Task | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 入队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def enqueue(self, config: TaskConfigSchema, site_id: int) -> str:
|
||||
"""将任务配置入队,自动分配 position。
|
||||
|
||||
Args:
|
||||
config: 任务配置
|
||||
site_id: 门店 ID(门店隔离)
|
||||
|
||||
Returns:
|
||||
新创建的队列任务 ID(UUID 字符串)
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
config_json = config.model_dump(mode="json")
|
||||
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 取当前该门店 pending 任务的最大 position,新任务排在末尾
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COALESCE(MAX(position), 0)
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
max_pos = cur.fetchone()[0]
|
||||
new_pos = max_pos + 1
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO task_queue (id, site_id, config, status, position)
|
||||
VALUES (%s, %s, %s, 'pending', %s)
|
||||
""",
|
||||
(task_id, site_id, json.dumps(config_json), new_pos),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info("任务入队 [%s] site_id=%s position=%s", task_id, site_id, new_pos)
|
||||
return task_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 出队
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def dequeue(self, site_id: int) -> QueuedTask | None:
|
||||
"""取出 position 最小的 pending 任务,将其状态改为 running。
|
||||
|
||||
Args:
|
||||
site_id: 门店 ID
|
||||
|
||||
Returns:
|
||||
QueuedTask 或 None(队列为空时)
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 选取 position 最小的 pending 任务并锁定
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
conn.commit()
|
||||
return None
|
||||
|
||||
task = QueuedTask(
|
||||
id=str(row[0]),
|
||||
site_id=row[1],
|
||||
config=row[2] if isinstance(row[2], dict) else json.loads(row[2]),
|
||||
status=row[3],
|
||||
position=row[4],
|
||||
created_at=row[5],
|
||||
started_at=row[6],
|
||||
finished_at=row[7],
|
||||
exit_code=row[8],
|
||||
error_message=row[9],
|
||||
)
|
||||
|
||||
# 更新状态为 running
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = 'running', started_at = NOW()
|
||||
WHERE id = %s
|
||||
""",
|
||||
(task.id,),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
task.status = "running"
|
||||
logger.info("任务出队 [%s] site_id=%s", task.id, site_id)
|
||||
return task
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 重排
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def reorder(self, task_id: str, new_position: int, site_id: int) -> None:
|
||||
"""调整任务在队列中的位置。
|
||||
|
||||
仅允许对 pending 状态的任务重排。将目标任务移到 new_position,
|
||||
其余 pending 任务按原有相对顺序重新编号。
|
||||
|
||||
Args:
|
||||
task_id: 要移动的任务 ID
|
||||
new_position: 目标位置(1-based)
|
||||
site_id: 门店 ID
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 获取该门店所有 pending 任务,按 position 排序
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
task_ids = [str(r[0]) for r in rows]
|
||||
|
||||
if task_id not in task_ids:
|
||||
conn.commit()
|
||||
return
|
||||
|
||||
# 从列表中移除目标任务,再插入到新位置
|
||||
task_ids.remove(task_id)
|
||||
# new_position 是 1-based,转为 0-based 索引并 clamp
|
||||
insert_idx = max(0, min(new_position - 1, len(task_ids)))
|
||||
task_ids.insert(insert_idx, task_id)
|
||||
|
||||
# 按新顺序重新分配 position(1-based 连续编号)
|
||||
for idx, tid in enumerate(task_ids, start=1):
|
||||
cur.execute(
|
||||
"UPDATE task_queue SET position = %s WHERE id = %s",
|
||||
(idx, tid),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.info(
|
||||
"任务重排 [%s] → position=%s site_id=%s",
|
||||
task_id, new_position, site_id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 删除
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def delete(self, task_id: str, site_id: int) -> bool:
|
||||
"""删除 pending 状态的任务。
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
site_id: 门店 ID
|
||||
|
||||
Returns:
|
||||
True 表示成功删除,False 表示任务不存在或非 pending 状态。
|
||||
"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM task_queue
|
||||
WHERE id = %s AND site_id = %s AND status = 'pending'
|
||||
""",
|
||||
(task_id, site_id),
|
||||
)
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
logger.info("任务删除 [%s] site_id=%s", task_id, site_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"任务删除失败 [%s] site_id=%s(不存在或非 pending)",
|
||||
task_id, site_id,
|
||||
)
|
||||
return deleted
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_pending(self, site_id: int) -> list[QueuedTask]:
|
||||
"""列出指定门店的所有 pending 任务,按 position 升序。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, site_id, config, status, position,
|
||||
created_at, started_at, finished_at,
|
||||
exit_code, error_message
|
||||
FROM task_queue
|
||||
WHERE site_id = %s AND status = 'pending'
|
||||
ORDER BY position ASC
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return [
|
||||
QueuedTask(
|
||||
id=str(r[0]),
|
||||
site_id=r[1],
|
||||
config=r[2] if isinstance(r[2], dict) else json.loads(r[2]),
|
||||
status=r[3],
|
||||
position=r[4],
|
||||
created_at=r[5],
|
||||
started_at=r[6],
|
||||
finished_at=r[7],
|
||||
exit_code=r[8],
|
||||
error_message=r[9],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def has_running(self, site_id: int) -> bool:
|
||||
"""检查指定门店是否有 running 状态的任务。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM task_queue
|
||||
WHERE site_id = %s AND status = 'running'
|
||||
)
|
||||
""",
|
||||
(site_id,),
|
||||
)
|
||||
result = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 后台处理循环
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_loop(self) -> None:
|
||||
"""后台协程:队列非空且无运行中任务时,自动取出并执行。
|
||||
|
||||
循环逻辑:
|
||||
1. 查询所有有 pending 任务的 site_id
|
||||
2. 对每个 site_id,若无 running 任务则 dequeue 并执行
|
||||
3. 等待 POLL_INTERVAL_SECONDS 后重复
|
||||
"""
|
||||
# 延迟导入避免循环依赖
|
||||
from .task_executor import task_executor
|
||||
|
||||
self._running = True
|
||||
logger.info("TaskQueue process_loop 启动")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._process_once(task_executor)
|
||||
except Exception:
|
||||
logger.exception("process_loop 迭代异常")
|
||||
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
logger.info("TaskQueue process_loop 停止")
|
||||
|
||||
async def _process_once(self, executor: Any) -> None:
|
||||
"""单次处理:扫描所有门店的 pending 队列并执行。"""
|
||||
site_ids = self._get_pending_site_ids()
|
||||
|
||||
for site_id in site_ids:
|
||||
if self.has_running(site_id):
|
||||
continue
|
||||
|
||||
task = self.dequeue(site_id)
|
||||
if task is None:
|
||||
continue
|
||||
|
||||
config = TaskConfigSchema(**task.config)
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
logger.info(
|
||||
"process_loop 自动执行 [%s] queue_id=%s site_id=%s",
|
||||
execution_id, task.id, site_id,
|
||||
)
|
||||
|
||||
# 异步启动执行(不阻塞循环)
|
||||
asyncio.create_task(
|
||||
self._execute_and_update(
|
||||
executor, config, execution_id, task.id, site_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _execute_and_update(
|
||||
self,
|
||||
executor: Any,
|
||||
config: TaskConfigSchema,
|
||||
execution_id: str,
|
||||
queue_id: str,
|
||||
site_id: int,
|
||||
) -> None:
|
||||
"""执行任务并更新队列状态。"""
|
||||
try:
|
||||
await executor.execute(
|
||||
config=config,
|
||||
execution_id=execution_id,
|
||||
queue_id=queue_id,
|
||||
site_id=site_id,
|
||||
)
|
||||
# 执行完成后根据 executor 的结果更新 task_queue 状态
|
||||
self._update_queue_status_from_log(queue_id)
|
||||
except Exception:
|
||||
logger.exception("队列任务执行异常 [%s]", queue_id)
|
||||
self._mark_failed(queue_id, "执行过程中发生未捕获异常")
|
||||
|
||||
def _get_pending_site_ids(self) -> list[int]:
|
||||
"""获取所有有 pending 任务的 site_id 列表。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT DISTINCT site_id FROM task_queue
|
||||
WHERE status = 'pending'
|
||||
"""
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def _update_queue_status_from_log(self, queue_id: str) -> None:
|
||||
"""从 task_execution_log 读取执行结果,同步到 task_queue 记录。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT status, finished_at, exit_code, error_log
|
||||
FROM task_execution_log
|
||||
WHERE queue_id = %s
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = %s, finished_at = %s,
|
||||
exit_code = %s, error_message = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(row[0], row[1], row[2], row[3], queue_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _mark_failed(self, queue_id: str, error_message: str) -> None:
|
||||
"""将队列任务标记为 failed。"""
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_queue
|
||||
SET status = 'failed', finished_at = NOW(),
|
||||
error_message = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(error_message, queue_id),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 生命周期
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self) -> None:
|
||||
"""启动后台处理循环(在 FastAPI lifespan 中调用)。"""
|
||||
if self._loop_task is None or self._loop_task.done():
|
||||
self._loop_task = asyncio.create_task(self.process_loop())
|
||||
logger.info("TaskQueue 后台循环已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止后台处理循环。"""
|
||||
self._running = False
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await self._loop_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._loop_task = None
|
||||
logger.info("TaskQueue 后台循环已停止")
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_queue = TaskQueue()
|
||||
Reference in New Issue
Block a user