# -*- coding: utf-8 -*- """ETL 任务执行器 通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程, 逐行读取 stdout/stderr 并广播到 WebSocket 订阅者, 执行完成后将结果写入 task_execution_log 表。 设计要点: - 每个 execution_id 对应一个子进程,存储在 _processes 字典中 - 日志行存储在内存缓冲区 _log_buffers 中 - WebSocket 订阅者通过 asyncio.Queue 接收实时日志 - Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM """ from __future__ import annotations import asyncio import logging import subprocess import threading import time from datetime import datetime, timezone from typing import Any # CHANGE 2026-03-07 | 只保留模块引用,execute() 中实时读取属性值 # 禁止 from ..config import ETL_PROJECT_PATH(值拷贝,reload 后过期) from .. import config as _config_module from ..database import get_connection from psycopg2.extras import Json from ..schemas.tasks import TaskConfigSchema from ..services.cli_builder import cli_builder logger = logging.getLogger(__name__) # 实例标识:用于区分多后端实例写入同一 DB 的记录 import platform as _platform _INSTANCE_HOST = _platform.node() # hostname class TaskExecutor: """管理 ETL CLI 子进程的生命周期""" def __init__(self) -> None: # execution_id → subprocess.Popen self._processes: dict[str, subprocess.Popen] = {} # execution_id → list[str](stdout + stderr 混合日志) self._log_buffers: dict[str, list[str]] = {} # execution_id → set[asyncio.Queue](WebSocket 订阅者) self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {} # ------------------------------------------------------------------ # WebSocket 订阅管理 # ------------------------------------------------------------------ def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]: """注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。 Queue 中推送 str 表示日志行,None 表示执行结束。 """ if execution_id not in self._subscribers: self._subscribers[execution_id] = set() queue: asyncio.Queue[str | None] = asyncio.Queue() self._subscribers[execution_id].add(queue) return queue def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None: """移除一个 WebSocket 订阅者。""" subs = self._subscribers.get(execution_id) if subs: subs.discard(queue) if not subs: del self._subscribers[execution_id] def _broadcast(self, execution_id: str, line: str) -> None: """向所有订阅者广播一行日志。""" subs = self._subscribers.get(execution_id) if subs: for q in subs: q.put_nowait(line) def _broadcast_end(self, execution_id: str) -> None: """通知所有订阅者执行已结束(发送 None 哨兵)。""" subs = self._subscribers.get(execution_id) if subs: for q in subs: q.put_nowait(None) # ------------------------------------------------------------------ # 日志缓冲区 # ------------------------------------------------------------------ def get_logs(self, execution_id: str) -> list[str]: """获取指定执行的内存日志缓冲区(副本)。""" return list(self._log_buffers.get(execution_id, [])) # ------------------------------------------------------------------ # 执行状态查询 # ------------------------------------------------------------------ def is_running(self, execution_id: str) -> bool: """判断指定执行是否仍在运行。""" proc = self._processes.get(execution_id) if proc is None: return False return proc.poll() is None def get_running_ids(self) -> list[str]: """返回当前所有运行中的 execution_id 列表。""" return [eid for eid, p in self._processes.items() if p.returncode is None] # ------------------------------------------------------------------ # 核心执行 # ------------------------------------------------------------------ async def execute( self, config: TaskConfigSchema, execution_id: str, queue_id: str | None = None, site_id: int | None = None, schedule_id: str | None = None, ) -> None: """以子进程方式调用 ETL CLI。 使用 subprocess.Popen + 线程读取,兼容 Windows(避免 asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError)。 """ # CHANGE 2026-03-07 | 实时从 config 模块读取,避免 import 时复制的值过期 etl_path = _config_module.ETL_PROJECT_PATH etl_python = _config_module.ETL_PYTHON_EXECUTABLE cmd = cli_builder.build_command( config, etl_path, python_executable=etl_python ) command_str = " ".join(cmd) # CHANGE 2026-03-07 | 运行时防护:拒绝执行包含非预期路径的命令 # 检测两种异常: # 1. D 盘路径(junction 穿透) # 2. 多环境子目录(test/repo、prod/repo) _cmd_normalized = command_str.replace("/", "\\") _bad_patterns = [] if "D:\\" in command_str or "D:/" in command_str: _bad_patterns.append("D盘路径") if "\\test\\repo" in _cmd_normalized or "\\prod\\repo" in _cmd_normalized: _bad_patterns.append("多环境子目录(test/repo或prod/repo)") if _bad_patterns: _issues = " + ".join(_bad_patterns) logger.error( "路径防护触发:命令包含 %s,拒绝执行。" " command=%s | ETL_PY=%s | ETL_PATH=%s" " | NEOZQYY_ROOT=%s | config.__file__=%s", _issues, command_str, etl_python, etl_path, __import__('os').environ.get("NEOZQYY_ROOT", "<未设置>"), _config_module.__file__, ) raise RuntimeError( f"ETL 命令包含异常路径({_issues}),拒绝执行。" f" 请检查 .env 中 ETL_PYTHON_EXECUTABLE 和 ETL_PROJECT_PATH 配置。" f" 当前值: ETL_PY={etl_python}, ETL_PATH={etl_path}" ) effective_site_id = site_id or config.store_id # CHANGE 2026-03-07 | 在 command 前缀中注入实例标识, # 便于在多后端实例共享同一 DB 时区分记录来源 command_str_with_host = f"[{_INSTANCE_HOST}] {command_str}" logger.info( "启动 ETL 子进程 [%s]: %s (cwd=%s)", execution_id, command_str, etl_path, ) self._log_buffers[execution_id] = [] started_at = datetime.now(timezone.utc) t0 = time.monotonic() self._write_execution_log( execution_id=execution_id, queue_id=queue_id, site_id=effective_site_id, task_codes=config.tasks, status="running", started_at=started_at, command=command_str_with_host, schedule_id=schedule_id, config_json=config.model_dump(mode="json"), ) exit_code: int | None = None status = "running" stdout_lines: list[str] = [] stderr_lines: list[str] = [] try: # 构建额外环境变量(DWD 表过滤通过环境变量注入) extra_env: dict[str, str] = {} if config.dwd_only_tables: extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables) # 在线程池中运行子进程,兼容 Windows exit_code = await asyncio.get_event_loop().run_in_executor( None, self._run_subprocess, cmd, execution_id, stdout_lines, stderr_lines, extra_env or None, ) if exit_code == 0: status = "success" else: status = "failed" logger.info( "ETL 子进程 [%s] 退出,exit_code=%s, status=%s", execution_id, exit_code, status, ) except asyncio.CancelledError: status = "cancelled" logger.info("ETL 子进程 [%s] 已取消", execution_id) # 尝试终止子进程 proc = self._processes.get(execution_id) if proc and proc.poll() is None: proc.terminate() except Exception as exc: status = "failed" import traceback tb = traceback.format_exc() stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}") stderr_lines.append(tb) logger.exception("ETL 子进程 [%s] 执行异常", execution_id) finally: elapsed_ms = int((time.monotonic() - t0) * 1000) finished_at = datetime.now(timezone.utc) self._broadcast_end(execution_id) self._processes.pop(execution_id, None) self._update_execution_log( execution_id=execution_id, status=status, finished_at=finished_at, exit_code=exit_code, duration_ms=elapsed_ms, output_log="\n".join(stdout_lines), error_log="\n".join(stderr_lines), ) # CHANGE 2026-03-22 | 释放内存缓冲区,防止长期运行内存泄漏 self.cleanup(execution_id) def _run_subprocess( self, cmd: list[str], execution_id: str, stdout_lines: list[str], stderr_lines: list[str], extra_env: dict[str, str] | None = None, ) -> int: """在线程中运行子进程并逐行读取输出。""" import os env = os.environ.copy() # 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码 env["PYTHONIOENCODING"] = "utf-8" if extra_env: env.update(extra_env) proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=_config_module.ETL_PROJECT_PATH, env=env, text=True, encoding="utf-8", errors="replace", ) self._processes[execution_id] = proc def read_stream( stream, stream_name: str, collector: list[str], ) -> None: """逐行读取流并广播。""" for raw_line in stream: line = raw_line.rstrip("\n").rstrip("\r") tagged = f"[{stream_name}] {line}" buf = self._log_buffers.get(execution_id) if buf is not None: buf.append(tagged) collector.append(line) self._broadcast(execution_id, tagged) t_out = threading.Thread( target=read_stream, args=(proc.stdout, "stdout", stdout_lines), daemon=True, ) t_err = threading.Thread( target=read_stream, args=(proc.stderr, "stderr", stderr_lines), daemon=True, ) t_out.start() t_err.start() proc.wait() t_out.join(timeout=5) t_err.join(timeout=5) return proc.returncode # ------------------------------------------------------------------ # 取消 # ------------------------------------------------------------------ async def cancel(self, execution_id: str) -> bool: """向子进程发送终止信号。 如果进程仍在内存中,发送 terminate 信号; 如果进程已不在内存中(如后端重启后),但数据库中仍为 running, 则直接将数据库状态标记为 cancelled(幽灵记录兜底)。 Returns: True 表示成功取消,False 表示任务不存在或已完成。 """ proc = self._processes.get(execution_id) if proc is not None: # 进程仍在内存中 if proc.poll() is not None: return False logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid) try: proc.terminate() except ProcessLookupError: return False return True # 进程不在内存中(后端重启等场景),尝试兜底修正数据库幽灵记录 try: conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE task_execution_log SET status = 'cancelled', finished_at = NOW(), error_log = COALESCE(error_log, '') || E'\n[cancel 兜底] 进程已不在内存中,标记为 cancelled' WHERE id = %s AND status = 'running' """, (execution_id,), ) updated = cur.rowcount conn.commit() finally: conn.close() if updated: logger.info( "兜底取消 execution_log [%s]:数据库状态从 running → cancelled", execution_id, ) return True except Exception: logger.exception("兜底取消 execution_log [%s] 失败", execution_id) return False # ------------------------------------------------------------------ # 数据库操作(同步,在线程池中执行也可,此处简单直连) # ------------------------------------------------------------------ @staticmethod def _write_execution_log( *, execution_id: str, queue_id: str | None, site_id: int | None, task_codes: list[str], status: str, started_at: datetime, command: str, schedule_id: str | None = None, config_json: dict | None = None, ) -> None: """插入一条执行日志记录(running 状态)。""" try: conn = get_connection() try: with conn.cursor() as cur: # 如果调用方未传 schedule_id,尝试从 task_queue 回查 effective_schedule_id = schedule_id if effective_schedule_id is None and queue_id is not None: cur.execute( "SELECT schedule_id FROM task_queue WHERE id = %s", (queue_id,), ) row = cur.fetchone() if row and row[0]: effective_schedule_id = str(row[0]) # CHANGE 2026-03-22 | 存储完整 TaskConfig JSON,供 rerun 还原原始参数 cur.execute( """ INSERT INTO task_execution_log (id, queue_id, site_id, task_codes, status, started_at, command, schedule_id, config) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( execution_id, queue_id, site_id or 0, task_codes, status, started_at, command, effective_schedule_id, Json(config_json) if config_json else None, ), ) conn.commit() finally: conn.close() except Exception: logger.exception("写入 execution_log 失败 [%s]", execution_id) @staticmethod def _update_execution_log( *, execution_id: str, status: str, finished_at: datetime, exit_code: int | None, duration_ms: int, output_log: str, error_log: str, ) -> None: """更新执行日志记录(完成状态)。""" try: conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE task_execution_log SET status = %s, finished_at = %s, exit_code = %s, duration_ms = %s, output_log = %s, error_log = %s WHERE id = %s """, ( status, finished_at, exit_code, duration_ms, output_log, error_log, execution_id, ), ) conn.commit() finally: conn.close() except Exception: logger.exception("更新 execution_log 失败 [%s]", execution_id) # ------------------------------------------------------------------ # 清理 # ------------------------------------------------------------------ def cleanup(self, execution_id: str) -> None: """清理指定执行的内存资源(日志缓冲区和订阅者)。 通常在确认日志已持久化后调用。 """ self._log_buffers.pop(execution_id, None) self._subscribers.pop(execution_id, None) # ------------------------------------------------------------------ # 优雅关闭:终止所有子进程并回写状态 # ------------------------------------------------------------------ async def shutdown(self, timeout: float = 3.0) -> int: """优雅关闭:终止所有正在运行的子进程,等待回写完成。 Args: timeout: 等待子进程退出的超时秒数,超时后强制 kill。 Returns: 被终止的进程数量。 """ running_ids = list(self._processes.keys()) if not running_ids: return 0 logger.info( "优雅关闭:终止 %d 个运行中的子进程,超时 %.1fs", len(running_ids), timeout, ) # 先发 terminate 信号 for eid, proc in list(self._processes.items()): if proc.poll() is None: try: proc.terminate() logger.info("已发送 terminate 信号: %s (pid=%s)", eid, proc.pid) except ProcessLookupError: pass # 等待子进程退出(给 finally 块执行的机会) import time deadline = time.monotonic() + timeout for eid, proc in list(self._processes.items()): remaining = deadline - time.monotonic() if remaining > 0 and proc.poll() is None: try: proc.wait(timeout=remaining) except Exception: pass # 超时后强制 kill 仍存活的进程 for eid, proc in list(self._processes.items()): if proc.poll() is None: try: proc.kill() logger.warning("强制 kill: %s (pid=%s)", eid, proc.pid) except ProcessLookupError: pass # 注意:execute() 的 finally 块会在 run_in_executor 返回后执行, # 此处不需要手动回写——asyncio 事件循环关闭前会处理。 # 但如果 finally 来不及执行,recover_stale() 会在下次启动时兜底。 count = len(running_ids) logger.info("优雅关闭完成,已终止 %d 个子进程", count) return count # ------------------------------------------------------------------ # 启动时僵尸任务清理 # ------------------------------------------------------------------ def recover_stale(self) -> int: """启动时清理本机的僵尸任务(status=running 但进程已不存在)。 仅清理 command 中包含本机主机名标识 [hostname] 的记录。 Returns: 被标记为 interrupted 的记录数量。 """ # CHANGE 2026-03-22 | 启动时僵尸清理,仅限本机 host_tag = f"[{_INSTANCE_HOST}]" try: conn = get_connection() try: with conn.cursor() as cur: cur.execute( """ UPDATE task_execution_log SET status = 'interrupted', finished_at = NOW(), error_log = COALESCE(error_log, '') || E'\n[recover_stale] 后端重启,进程已丢失,标记为 interrupted' WHERE status = 'running' AND command LIKE %s RETURNING id """, (f"{host_tag}%",), ) rows = cur.fetchall() count = len(rows) conn.commit() finally: conn.close() if count > 0: ids = [str(r[0]) for r in rows] logger.warning( "启动清理:%d 条僵尸任务标记为 interrupted: %s", count, ", ".join(ids), ) else: logger.info("启动清理:无僵尸任务") return count except Exception: logger.exception("启动清理僵尸任务失败") return 0 # 全局单例 task_executor = TaskExecutor()