Files
Neo-ZQYY/apps/backend/app/services/task_executor.py
2026-03-15 10:15:02 +08:00

481 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ETL 任务执行器
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
执行完成后将结果写入 task_execution_log 表。
设计要点:
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
- 日志行存储在内存缓冲区 _log_buffers 中
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
"""
from __future__ import annotations
import asyncio
import logging
import subprocess
import threading
import time
from datetime import datetime, timezone
from typing import Any
# CHANGE 2026-03-07 | 只保留模块引用execute() 中实时读取属性值
# 禁止 from ..config import ETL_PROJECT_PATH值拷贝reload 后过期)
from .. import config as _config_module
from ..database import get_connection
from ..schemas.tasks import TaskConfigSchema
from ..services.cli_builder import cli_builder
logger = logging.getLogger(__name__)
# 实例标识:用于区分多后端实例写入同一 DB 的记录
import platform as _platform
_INSTANCE_HOST = _platform.node() # hostname
class TaskExecutor:
"""管理 ETL CLI 子进程的生命周期"""
def __init__(self) -> None:
# execution_id → subprocess.Popen
self._processes: dict[str, subprocess.Popen] = {}
# execution_id → list[str]stdout + stderr 混合日志)
self._log_buffers: dict[str, list[str]] = {}
# execution_id → set[asyncio.Queue]WebSocket 订阅者)
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
# ------------------------------------------------------------------
# WebSocket 订阅管理
# ------------------------------------------------------------------
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
Queue 中推送 str 表示日志行None 表示执行结束。
"""
if execution_id not in self._subscribers:
self._subscribers[execution_id] = set()
queue: asyncio.Queue[str | None] = asyncio.Queue()
self._subscribers[execution_id].add(queue)
return queue
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
"""移除一个 WebSocket 订阅者。"""
subs = self._subscribers.get(execution_id)
if subs:
subs.discard(queue)
if not subs:
del self._subscribers[execution_id]
def _broadcast(self, execution_id: str, line: str) -> None:
"""向所有订阅者广播一行日志。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(line)
def _broadcast_end(self, execution_id: str) -> None:
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(None)
# ------------------------------------------------------------------
# 日志缓冲区
# ------------------------------------------------------------------
def get_logs(self, execution_id: str) -> list[str]:
"""获取指定执行的内存日志缓冲区(副本)。"""
return list(self._log_buffers.get(execution_id, []))
# ------------------------------------------------------------------
# 执行状态查询
# ------------------------------------------------------------------
def is_running(self, execution_id: str) -> bool:
"""判断指定执行是否仍在运行。"""
proc = self._processes.get(execution_id)
if proc is None:
return False
return proc.poll() is None
def get_running_ids(self) -> list[str]:
"""返回当前所有运行中的 execution_id 列表。"""
return [eid for eid, p in self._processes.items() if p.returncode is None]
# ------------------------------------------------------------------
# 核心执行
# ------------------------------------------------------------------
async def execute(
self,
config: TaskConfigSchema,
execution_id: str,
queue_id: str | None = None,
site_id: int | None = None,
schedule_id: str | None = None,
) -> None:
"""以子进程方式调用 ETL CLI。
使用 subprocess.Popen + 线程读取,兼容 Windows避免
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError
"""
# CHANGE 2026-03-07 | 实时从 config 模块读取,避免 import 时复制的值过期
etl_path = _config_module.ETL_PROJECT_PATH
etl_python = _config_module.ETL_PYTHON_EXECUTABLE
cmd = cli_builder.build_command(
config, etl_path, python_executable=etl_python
)
command_str = " ".join(cmd)
# CHANGE 2026-03-07 | 运行时防护:拒绝执行包含非预期路径的命令
# 检测两种异常:
# 1. D 盘路径junction 穿透)
# 2. 多环境子目录test/repo、prod/repo
_cmd_normalized = command_str.replace("/", "\\")
_bad_patterns = []
if "D:\\" in command_str or "D:/" in command_str:
_bad_patterns.append("D盘路径")
if "\\test\\repo" in _cmd_normalized or "\\prod\\repo" in _cmd_normalized:
_bad_patterns.append("多环境子目录(test/repo或prod/repo)")
if _bad_patterns:
_issues = " + ".join(_bad_patterns)
logger.error(
"路径防护触发:命令包含 %s,拒绝执行。"
" command=%s | ETL_PY=%s | ETL_PATH=%s"
" | NEOZQYY_ROOT=%s | config.__file__=%s",
_issues, command_str, etl_python, etl_path,
__import__('os').environ.get("NEOZQYY_ROOT", "<未设置>"),
_config_module.__file__,
)
raise RuntimeError(
f"ETL 命令包含异常路径({_issues}),拒绝执行。"
f" 请检查 .env 中 ETL_PYTHON_EXECUTABLE 和 ETL_PROJECT_PATH 配置。"
f" 当前值: ETL_PY={etl_python}, ETL_PATH={etl_path}"
)
effective_site_id = site_id or config.store_id
# CHANGE 2026-03-07 | 在 command 前缀中注入实例标识,
# 便于在多后端实例共享同一 DB 时区分记录来源
command_str_with_host = f"[{_INSTANCE_HOST}] {command_str}"
logger.info(
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
execution_id, command_str, etl_path,
)
self._log_buffers[execution_id] = []
started_at = datetime.now(timezone.utc)
t0 = time.monotonic()
self._write_execution_log(
execution_id=execution_id,
queue_id=queue_id,
site_id=effective_site_id,
task_codes=config.tasks,
status="running",
started_at=started_at,
command=command_str_with_host,
schedule_id=schedule_id,
)
exit_code: int | None = None
status = "running"
stdout_lines: list[str] = []
stderr_lines: list[str] = []
try:
# 构建额外环境变量DWD 表过滤通过环境变量注入)
extra_env: dict[str, str] = {}
if config.dwd_only_tables:
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
# 在线程池中运行子进程,兼容 Windows
exit_code = await asyncio.get_event_loop().run_in_executor(
None,
self._run_subprocess,
cmd,
execution_id,
stdout_lines,
stderr_lines,
extra_env or None,
)
if exit_code == 0:
status = "success"
else:
status = "failed"
logger.info(
"ETL 子进程 [%s] 退出exit_code=%s, status=%s",
execution_id, exit_code, status,
)
except asyncio.CancelledError:
status = "cancelled"
logger.info("ETL 子进程 [%s] 已取消", execution_id)
# 尝试终止子进程
proc = self._processes.get(execution_id)
if proc and proc.poll() is None:
proc.terminate()
except Exception as exc:
status = "failed"
import traceback
tb = traceback.format_exc()
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
stderr_lines.append(tb)
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
finally:
elapsed_ms = int((time.monotonic() - t0) * 1000)
finished_at = datetime.now(timezone.utc)
self._broadcast_end(execution_id)
self._processes.pop(execution_id, None)
self._update_execution_log(
execution_id=execution_id,
status=status,
finished_at=finished_at,
exit_code=exit_code,
duration_ms=elapsed_ms,
output_log="\n".join(stdout_lines),
error_log="\n".join(stderr_lines),
)
def _run_subprocess(
self,
cmd: list[str],
execution_id: str,
stdout_lines: list[str],
stderr_lines: list[str],
extra_env: dict[str, str] | None = None,
) -> int:
"""在线程中运行子进程并逐行读取输出。"""
import os
env = os.environ.copy()
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
env["PYTHONIOENCODING"] = "utf-8"
if extra_env:
env.update(extra_env)
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=_config_module.ETL_PROJECT_PATH,
env=env,
text=True,
encoding="utf-8",
errors="replace",
)
self._processes[execution_id] = proc
def read_stream(
stream, stream_name: str, collector: list[str],
) -> None:
"""逐行读取流并广播。"""
for raw_line in stream:
line = raw_line.rstrip("\n").rstrip("\r")
tagged = f"[{stream_name}] {line}"
buf = self._log_buffers.get(execution_id)
if buf is not None:
buf.append(tagged)
collector.append(line)
self._broadcast(execution_id, tagged)
t_out = threading.Thread(
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
daemon=True,
)
t_err = threading.Thread(
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
daemon=True,
)
t_out.start()
t_err.start()
proc.wait()
t_out.join(timeout=5)
t_err.join(timeout=5)
return proc.returncode
# ------------------------------------------------------------------
# 取消
# ------------------------------------------------------------------
async def cancel(self, execution_id: str) -> bool:
"""向子进程发送终止信号。
如果进程仍在内存中,发送 terminate 信号;
如果进程已不在内存中(如后端重启后),但数据库中仍为 running
则直接将数据库状态标记为 cancelled幽灵记录兜底
Returns:
True 表示成功取消False 表示任务不存在或已完成。
"""
proc = self._processes.get(execution_id)
if proc is not None:
# 进程仍在内存中
if proc.poll() is not None:
return False
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
try:
proc.terminate()
except ProcessLookupError:
return False
return True
# 进程不在内存中(后端重启等场景),尝试兜底修正数据库幽灵记录
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = 'cancelled',
finished_at = NOW(),
error_log = COALESCE(error_log, '')
|| E'\n[cancel 兜底] 进程已不在内存中,标记为 cancelled'
WHERE id = %s AND status = 'running'
""",
(execution_id,),
)
updated = cur.rowcount
conn.commit()
finally:
conn.close()
if updated:
logger.info(
"兜底取消 execution_log [%s]:数据库状态从 running → cancelled",
execution_id,
)
return True
except Exception:
logger.exception("兜底取消 execution_log [%s] 失败", execution_id)
return False
# ------------------------------------------------------------------
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
# ------------------------------------------------------------------
@staticmethod
def _write_execution_log(
*,
execution_id: str,
queue_id: str | None,
site_id: int | None,
task_codes: list[str],
status: str,
started_at: datetime,
command: str,
schedule_id: str | None = None,
) -> None:
"""插入一条执行日志记录running 状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
# 如果调用方未传 schedule_id尝试从 task_queue 回查
effective_schedule_id = schedule_id
if effective_schedule_id is None and queue_id is not None:
cur.execute(
"SELECT schedule_id FROM task_queue WHERE id = %s",
(queue_id,),
)
row = cur.fetchone()
if row and row[0]:
effective_schedule_id = str(row[0])
cur.execute(
"""
INSERT INTO task_execution_log
(id, queue_id, site_id, task_codes, status,
started_at, command, schedule_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""",
(
execution_id,
queue_id,
site_id or 0,
task_codes,
status,
started_at,
command,
effective_schedule_id,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("写入 execution_log 失败 [%s]", execution_id)
@staticmethod
def _update_execution_log(
*,
execution_id: str,
status: str,
finished_at: datetime,
exit_code: int | None,
duration_ms: int,
output_log: str,
error_log: str,
) -> None:
"""更新执行日志记录(完成状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = %s,
finished_at = %s,
exit_code = %s,
duration_ms = %s,
output_log = %s,
error_log = %s
WHERE id = %s
""",
(
status,
finished_at,
exit_code,
duration_ms,
output_log,
error_log,
execution_id,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("更新 execution_log 失败 [%s]", execution_id)
# ------------------------------------------------------------------
# 清理
# ------------------------------------------------------------------
def cleanup(self, execution_id: str) -> None:
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
通常在确认日志已持久化后调用。
"""
self._log_buffers.pop(execution_id, None)
self._subscribers.pop(execution_id, None)
# 全局单例
task_executor = TaskExecutor()