在准备环境前提交次全部更改。
This commit is contained in:
391
apps/backend/app/services/task_executor.py
Normal file
391
apps/backend/app/services/task_executor.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL 任务执行器
|
||||
|
||||
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
|
||||
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
|
||||
执行完成后将结果写入 task_execution_log 表。
|
||||
|
||||
设计要点:
|
||||
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
|
||||
- 日志行存储在内存缓冲区 _log_buffers 中
|
||||
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
|
||||
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..config import ETL_PROJECT_PATH
|
||||
from ..database import get_connection
|
||||
from ..schemas.tasks import TaskConfigSchema
|
||||
from ..services.cli_builder import cli_builder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""管理 ETL CLI 子进程的生命周期"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# execution_id → subprocess.Popen
|
||||
self._processes: dict[str, subprocess.Popen] = {}
|
||||
# execution_id → list[str](stdout + stderr 混合日志)
|
||||
self._log_buffers: dict[str, list[str]] = {}
|
||||
# execution_id → set[asyncio.Queue](WebSocket 订阅者)
|
||||
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket 订阅管理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
|
||||
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
|
||||
|
||||
Queue 中推送 str 表示日志行,None 表示执行结束。
|
||||
"""
|
||||
if execution_id not in self._subscribers:
|
||||
self._subscribers[execution_id] = set()
|
||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._subscribers[execution_id].add(queue)
|
||||
return queue
|
||||
|
||||
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
|
||||
"""移除一个 WebSocket 订阅者。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
subs.discard(queue)
|
||||
if not subs:
|
||||
del self._subscribers[execution_id]
|
||||
|
||||
def _broadcast(self, execution_id: str, line: str) -> None:
|
||||
"""向所有订阅者广播一行日志。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
for q in subs:
|
||||
q.put_nowait(line)
|
||||
|
||||
def _broadcast_end(self, execution_id: str) -> None:
|
||||
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
|
||||
subs = self._subscribers.get(execution_id)
|
||||
if subs:
|
||||
for q in subs:
|
||||
q.put_nowait(None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 日志缓冲区
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_logs(self, execution_id: str) -> list[str]:
|
||||
"""获取指定执行的内存日志缓冲区(副本)。"""
|
||||
return list(self._log_buffers.get(execution_id, []))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 执行状态查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_running(self, execution_id: str) -> bool:
|
||||
"""判断指定执行是否仍在运行。"""
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc is None:
|
||||
return False
|
||||
return proc.poll() is None
|
||||
|
||||
def get_running_ids(self) -> list[str]:
|
||||
"""返回当前所有运行中的 execution_id 列表。"""
|
||||
return [eid for eid, p in self._processes.items() if p.returncode is None]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心执行
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
config: TaskConfigSchema,
|
||||
execution_id: str,
|
||||
queue_id: str | None = None,
|
||||
site_id: int | None = None,
|
||||
) -> None:
|
||||
"""以子进程方式调用 ETL CLI。
|
||||
|
||||
使用 subprocess.Popen + 线程读取,兼容 Windows(避免
|
||||
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError)。
|
||||
"""
|
||||
cmd = cli_builder.build_command(
|
||||
config, ETL_PROJECT_PATH, python_executable=sys.executable
|
||||
)
|
||||
command_str = " ".join(cmd)
|
||||
effective_site_id = site_id or config.store_id
|
||||
|
||||
logger.info(
|
||||
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
|
||||
execution_id, command_str, ETL_PROJECT_PATH,
|
||||
)
|
||||
|
||||
self._log_buffers[execution_id] = []
|
||||
started_at = datetime.now(timezone.utc)
|
||||
t0 = time.monotonic()
|
||||
|
||||
self._write_execution_log(
|
||||
execution_id=execution_id,
|
||||
queue_id=queue_id,
|
||||
site_id=effective_site_id,
|
||||
task_codes=config.tasks,
|
||||
status="running",
|
||||
started_at=started_at,
|
||||
command=command_str,
|
||||
)
|
||||
|
||||
exit_code: int | None = None
|
||||
status = "running"
|
||||
stdout_lines: list[str] = []
|
||||
stderr_lines: list[str] = []
|
||||
|
||||
try:
|
||||
# 构建额外环境变量(DWD 表过滤通过环境变量注入)
|
||||
extra_env: dict[str, str] = {}
|
||||
if config.dwd_only_tables:
|
||||
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
|
||||
|
||||
# 在线程池中运行子进程,兼容 Windows
|
||||
exit_code = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._run_subprocess,
|
||||
cmd,
|
||||
execution_id,
|
||||
stdout_lines,
|
||||
stderr_lines,
|
||||
extra_env or None,
|
||||
)
|
||||
|
||||
if exit_code == 0:
|
||||
status = "success"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
logger.info(
|
||||
"ETL 子进程 [%s] 退出,exit_code=%s, status=%s",
|
||||
execution_id, exit_code, status,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
status = "cancelled"
|
||||
logger.info("ETL 子进程 [%s] 已取消", execution_id)
|
||||
# 尝试终止子进程
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc and proc.poll() is None:
|
||||
proc.terminate()
|
||||
except Exception as exc:
|
||||
status = "failed"
|
||||
import traceback
|
||||
tb = traceback.format_exc()
|
||||
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
|
||||
stderr_lines.append(tb)
|
||||
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
|
||||
finally:
|
||||
elapsed_ms = int((time.monotonic() - t0) * 1000)
|
||||
finished_at = datetime.now(timezone.utc)
|
||||
|
||||
self._broadcast_end(execution_id)
|
||||
self._processes.pop(execution_id, None)
|
||||
|
||||
self._update_execution_log(
|
||||
execution_id=execution_id,
|
||||
status=status,
|
||||
finished_at=finished_at,
|
||||
exit_code=exit_code,
|
||||
duration_ms=elapsed_ms,
|
||||
output_log="\n".join(stdout_lines),
|
||||
error_log="\n".join(stderr_lines),
|
||||
)
|
||||
|
||||
def _run_subprocess(
|
||||
self,
|
||||
cmd: list[str],
|
||||
execution_id: str,
|
||||
stdout_lines: list[str],
|
||||
stderr_lines: list[str],
|
||||
extra_env: dict[str, str] | None = None,
|
||||
) -> int:
|
||||
"""在线程中运行子进程并逐行读取输出。"""
|
||||
import os
|
||||
env = os.environ.copy()
|
||||
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=ETL_PROJECT_PATH,
|
||||
env=env,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
)
|
||||
self._processes[execution_id] = proc
|
||||
|
||||
def read_stream(
|
||||
stream, stream_name: str, collector: list[str],
|
||||
) -> None:
|
||||
"""逐行读取流并广播。"""
|
||||
for raw_line in stream:
|
||||
line = raw_line.rstrip("\n").rstrip("\r")
|
||||
tagged = f"[{stream_name}] {line}"
|
||||
buf = self._log_buffers.get(execution_id)
|
||||
if buf is not None:
|
||||
buf.append(tagged)
|
||||
collector.append(line)
|
||||
self._broadcast(execution_id, tagged)
|
||||
|
||||
t_out = threading.Thread(
|
||||
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
|
||||
daemon=True,
|
||||
)
|
||||
t_err = threading.Thread(
|
||||
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
|
||||
daemon=True,
|
||||
)
|
||||
t_out.start()
|
||||
t_err.start()
|
||||
|
||||
proc.wait()
|
||||
t_out.join(timeout=5)
|
||||
t_err.join(timeout=5)
|
||||
|
||||
return proc.returncode
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 取消
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def cancel(self, execution_id: str) -> bool:
|
||||
"""向子进程发送终止信号。
|
||||
|
||||
Returns:
|
||||
True 表示成功发送终止信号,False 表示进程不存在或已退出。
|
||||
"""
|
||||
proc = self._processes.get(execution_id)
|
||||
if proc is None:
|
||||
return False
|
||||
# subprocess.Popen: poll() 返回 None 表示仍在运行
|
||||
if proc.poll() is not None:
|
||||
return False
|
||||
|
||||
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
|
||||
try:
|
||||
proc.terminate()
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _write_execution_log(
|
||||
*,
|
||||
execution_id: str,
|
||||
queue_id: str | None,
|
||||
site_id: int | None,
|
||||
task_codes: list[str],
|
||||
status: str,
|
||||
started_at: datetime,
|
||||
command: str,
|
||||
) -> None:
|
||||
"""插入一条执行日志记录(running 状态)。"""
|
||||
try:
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO task_execution_log
|
||||
(id, queue_id, site_id, task_codes, status,
|
||||
started_at, command)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
execution_id,
|
||||
queue_id,
|
||||
site_id or 0,
|
||||
task_codes,
|
||||
status,
|
||||
started_at,
|
||||
command,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.exception("写入 execution_log 失败 [%s]", execution_id)
|
||||
|
||||
@staticmethod
|
||||
def _update_execution_log(
|
||||
*,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
finished_at: datetime,
|
||||
exit_code: int | None,
|
||||
duration_ms: int,
|
||||
output_log: str,
|
||||
error_log: str,
|
||||
) -> None:
|
||||
"""更新执行日志记录(完成状态)。"""
|
||||
try:
|
||||
conn = get_connection()
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE task_execution_log
|
||||
SET status = %s,
|
||||
finished_at = %s,
|
||||
exit_code = %s,
|
||||
duration_ms = %s,
|
||||
output_log = %s,
|
||||
error_log = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(
|
||||
status,
|
||||
finished_at,
|
||||
exit_code,
|
||||
duration_ms,
|
||||
output_log,
|
||||
error_log,
|
||||
execution_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.exception("更新 execution_log 失败 [%s]", execution_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 清理
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def cleanup(self, execution_id: str) -> None:
|
||||
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
|
||||
|
||||
通常在确认日志已持久化后调用。
|
||||
"""
|
||||
self._log_buffers.pop(execution_id, None)
|
||||
self._subscribers.pop(execution_id, None)
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_executor = TaskExecutor()
|
||||
Reference in New Issue
Block a user