Files
Neo-ZQYY/apps/backend/app/services/task_executor.py

392 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ETL 任务执行器
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
执行完成后将结果写入 task_execution_log 表。
设计要点:
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
- 日志行存储在内存缓冲区 _log_buffers 中
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
"""
from __future__ import annotations
import asyncio
import logging
import subprocess
import sys
import threading
import time
from datetime import datetime, timezone
from typing import Any
from ..config import ETL_PROJECT_PATH
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
from ..services.cli_builder import cli_builder
logger = logging.getLogger(__name__)
class TaskExecutor:
"""管理 ETL CLI 子进程的生命周期"""
def __init__(self) -> None:
# execution_id → subprocess.Popen
self._processes: dict[str, subprocess.Popen] = {}
# execution_id → list[str]stdout + stderr 混合日志)
self._log_buffers: dict[str, list[str]] = {}
# execution_id → set[asyncio.Queue]WebSocket 订阅者)
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
# ------------------------------------------------------------------
# WebSocket 订阅管理
# ------------------------------------------------------------------
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
Queue 中推送 str 表示日志行None 表示执行结束。
"""
if execution_id not in self._subscribers:
self._subscribers[execution_id] = set()
queue: asyncio.Queue[str | None] = asyncio.Queue()
self._subscribers[execution_id].add(queue)
return queue
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
"""移除一个 WebSocket 订阅者。"""
subs = self._subscribers.get(execution_id)
if subs:
subs.discard(queue)
if not subs:
del self._subscribers[execution_id]
def _broadcast(self, execution_id: str, line: str) -> None:
"""向所有订阅者广播一行日志。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(line)
def _broadcast_end(self, execution_id: str) -> None:
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(None)
# ------------------------------------------------------------------
# 日志缓冲区
# ------------------------------------------------------------------
def get_logs(self, execution_id: str) -> list[str]:
"""获取指定执行的内存日志缓冲区(副本)。"""
return list(self._log_buffers.get(execution_id, []))
# ------------------------------------------------------------------
# 执行状态查询
# ------------------------------------------------------------------
def is_running(self, execution_id: str) -> bool:
"""判断指定执行是否仍在运行。"""
proc = self._processes.get(execution_id)
if proc is None:
return False
return proc.poll() is None
def get_running_ids(self) -> list[str]:
"""返回当前所有运行中的 execution_id 列表。"""
return [eid for eid, p in self._processes.items() if p.returncode is None]
# ------------------------------------------------------------------
# 核心执行
# ------------------------------------------------------------------
async def execute(
self,
config: TaskConfigSchema,
execution_id: str,
queue_id: str | None = None,
site_id: int | None = None,
) -> None:
"""以子进程方式调用 ETL CLI。
使用 subprocess.Popen + 线程读取,兼容 Windows避免
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError
"""
cmd = cli_builder.build_command(
config, ETL_PROJECT_PATH, python_executable=sys.executable
)
command_str = " ".join(cmd)
effective_site_id = site_id or config.store_id
logger.info(
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
execution_id, command_str, ETL_PROJECT_PATH,
)
self._log_buffers[execution_id] = []
started_at = datetime.now(timezone.utc)
t0 = time.monotonic()
self._write_execution_log(
execution_id=execution_id,
queue_id=queue_id,
site_id=effective_site_id,
task_codes=config.tasks,
status="running",
started_at=started_at,
command=command_str,
)
exit_code: int | None = None
status = "running"
stdout_lines: list[str] = []
stderr_lines: list[str] = []
try:
# 构建额外环境变量DWD 表过滤通过环境变量注入)
extra_env: dict[str, str] = {}
if config.dwd_only_tables:
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
# 在线程池中运行子进程,兼容 Windows
exit_code = await asyncio.get_event_loop().run_in_executor(
None,
self._run_subprocess,
cmd,
execution_id,
stdout_lines,
stderr_lines,
extra_env or None,
)
if exit_code == 0:
status = "success"
else:
status = "failed"
logger.info(
"ETL 子进程 [%s] 退出exit_code=%s, status=%s",
execution_id, exit_code, status,
)
except asyncio.CancelledError:
status = "cancelled"
logger.info("ETL 子进程 [%s] 已取消", execution_id)
# 尝试终止子进程
proc = self._processes.get(execution_id)
if proc and proc.poll() is None:
proc.terminate()
except Exception as exc:
status = "failed"
import traceback
tb = traceback.format_exc()
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
stderr_lines.append(tb)
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
finally:
elapsed_ms = int((time.monotonic() - t0) * 1000)
finished_at = datetime.now(timezone.utc)
self._broadcast_end(execution_id)
self._processes.pop(execution_id, None)
self._update_execution_log(
execution_id=execution_id,
status=status,
finished_at=finished_at,
exit_code=exit_code,
duration_ms=elapsed_ms,
output_log="\n".join(stdout_lines),
error_log="\n".join(stderr_lines),
)
def _run_subprocess(
self,
cmd: list[str],
execution_id: str,
stdout_lines: list[str],
stderr_lines: list[str],
extra_env: dict[str, str] | None = None,
) -> int:
"""在线程中运行子进程并逐行读取输出。"""
import os
env = os.environ.copy()
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
env["PYTHONIOENCODING"] = "utf-8"
if extra_env:
env.update(extra_env)
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=ETL_PROJECT_PATH,
env=env,
text=True,
encoding="utf-8",
errors="replace",
)
self._processes[execution_id] = proc
def read_stream(
stream, stream_name: str, collector: list[str],
) -> None:
"""逐行读取流并广播。"""
for raw_line in stream:
line = raw_line.rstrip("\n").rstrip("\r")
tagged = f"[{stream_name}] {line}"
buf = self._log_buffers.get(execution_id)
if buf is not None:
buf.append(tagged)
collector.append(line)
self._broadcast(execution_id, tagged)
t_out = threading.Thread(
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
daemon=True,
)
t_err = threading.Thread(
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
daemon=True,
)
t_out.start()
t_err.start()
proc.wait()
t_out.join(timeout=5)
t_err.join(timeout=5)
return proc.returncode
# ------------------------------------------------------------------
# 取消
# ------------------------------------------------------------------
async def cancel(self, execution_id: str) -> bool:
"""向子进程发送终止信号。
Returns:
True 表示成功发送终止信号False 表示进程不存在或已退出。
"""
proc = self._processes.get(execution_id)
if proc is None:
return False
# subprocess.Popen: poll() 返回 None 表示仍在运行
if proc.poll() is not None:
return False
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
try:
proc.terminate()
except ProcessLookupError:
return False
return True
# ------------------------------------------------------------------
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
# ------------------------------------------------------------------
@staticmethod
def _write_execution_log(
*,
execution_id: str,
queue_id: str | None,
site_id: int | None,
task_codes: list[str],
status: str,
started_at: datetime,
command: str,
) -> None:
"""插入一条执行日志记录running 状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO task_execution_log
(id, queue_id, site_id, task_codes, status,
started_at, command)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
execution_id,
queue_id,
site_id or 0,
task_codes,
status,
started_at,
command,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("写入 execution_log 失败 [%s]", execution_id)
@staticmethod
def _update_execution_log(
*,
execution_id: str,
status: str,
finished_at: datetime,
exit_code: int | None,
duration_ms: int,
output_log: str,
error_log: str,
) -> None:
"""更新执行日志记录(完成状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = %s,
finished_at = %s,
exit_code = %s,
duration_ms = %s,
output_log = %s,
error_log = %s
WHERE id = %s
""",
(
status,
finished_at,
exit_code,
duration_ms,
output_log,
error_log,
execution_id,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("更新 execution_log 失败 [%s]", execution_id)
# ------------------------------------------------------------------
# 清理
# ------------------------------------------------------------------
def cleanup(self, execution_id: str) -> None:
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
通常在确认日志已持久化后调用。
"""
self._log_buffers.pop(execution_id, None)
self._subscribers.pop(execution_id, None)
# 全局单例
task_executor = TaskExecutor()