Files
Neo-ZQYY/apps/backend/app/services/task_executor.py
Neo 6f8f12314f feat: 累积功能变更 — 聊天集成、租户管理、小程序更新、ETL 增强、迁移脚本
包含多个会话的累积代码变更:
- backend: AI 聊天服务、触发器调度、认证增强、WebSocket、调度器最小间隔
- admin-web: ETL 状态页、任务管理、调度配置、登录优化
- miniprogram: 看板页面、聊天集成、UI 组件、导航更新
- etl: DWS 新任务(finance_area_daily/board_cache)、连接器增强
- tenant-admin: 项目初始化
- db: 19 个迁移脚本(etl_feiqiu 11 + zqyy_app 8)
- packages/shared: 枚举和工具函数更新
- tools: 数据库工具、报表生成、健康检查
- docs: PRD/架构/部署/合约文档更新

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-06 00:03:48 +08:00

598 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ETL 任务执行器
通过 asyncio.create_subprocess_exec 启动 ETL CLI 子进程,
逐行读取 stdout/stderr 并广播到 WebSocket 订阅者,
执行完成后将结果写入 task_execution_log 表。
设计要点:
- 每个 execution_id 对应一个子进程,存储在 _processes 字典中
- 日志行存储在内存缓冲区 _log_buffers 中
- WebSocket 订阅者通过 asyncio.Queue 接收实时日志
- Windows 兼容:取消时使用 process.terminate() 而非 SIGTERM
"""
from __future__ import annotations
import asyncio
import logging
import subprocess
import threading
import time
from datetime import datetime, timezone
from typing import Any
# CHANGE 2026-03-07 | 只保留模块引用execute() 中实时读取属性值
# 禁止 from ..config import ETL_PROJECT_PATH值拷贝reload 后过期)
from .. import config as _config_module
from ..database import get_connection
from psycopg2.extras import Json
from ..schemas.tasks import TaskConfigSchema
from ..services.cli_builder import cli_builder
logger = logging.getLogger(__name__)
# 实例标识:用于区分多后端实例写入同一 DB 的记录
import platform as _platform
_INSTANCE_HOST = _platform.node() # hostname
class TaskExecutor:
"""管理 ETL CLI 子进程的生命周期"""
def __init__(self) -> None:
# execution_id → subprocess.Popen
self._processes: dict[str, subprocess.Popen] = {}
# execution_id → list[str]stdout + stderr 混合日志)
self._log_buffers: dict[str, list[str]] = {}
# execution_id → set[asyncio.Queue]WebSocket 订阅者)
self._subscribers: dict[str, set[asyncio.Queue[str | None]]] = {}
# ------------------------------------------------------------------
# WebSocket 订阅管理
# ------------------------------------------------------------------
def subscribe(self, execution_id: str) -> asyncio.Queue[str | None]:
"""注册一个 WebSocket 订阅者,返回用于读取日志行的 Queue。
Queue 中推送 str 表示日志行None 表示执行结束。
"""
if execution_id not in self._subscribers:
self._subscribers[execution_id] = set()
queue: asyncio.Queue[str | None] = asyncio.Queue()
self._subscribers[execution_id].add(queue)
return queue
def unsubscribe(self, execution_id: str, queue: asyncio.Queue[str | None]) -> None:
"""移除一个 WebSocket 订阅者。"""
subs = self._subscribers.get(execution_id)
if subs:
subs.discard(queue)
if not subs:
del self._subscribers[execution_id]
def _broadcast(self, execution_id: str, line: str) -> None:
"""向所有订阅者广播一行日志。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(line)
def _broadcast_end(self, execution_id: str) -> None:
"""通知所有订阅者执行已结束(发送 None 哨兵)。"""
subs = self._subscribers.get(execution_id)
if subs:
for q in subs:
q.put_nowait(None)
# ------------------------------------------------------------------
# 日志缓冲区
# ------------------------------------------------------------------
def get_logs(self, execution_id: str) -> list[str]:
"""获取指定执行的内存日志缓冲区(副本)。"""
return list(self._log_buffers.get(execution_id, []))
# ------------------------------------------------------------------
# 执行状态查询
# ------------------------------------------------------------------
def is_running(self, execution_id: str) -> bool:
"""判断指定执行是否仍在运行。"""
proc = self._processes.get(execution_id)
if proc is None:
return False
return proc.poll() is None
def get_running_ids(self) -> list[str]:
"""返回当前所有运行中的 execution_id 列表。"""
return [eid for eid, p in self._processes.items() if p.returncode is None]
# ------------------------------------------------------------------
# 核心执行
# ------------------------------------------------------------------
async def execute(
self,
config: TaskConfigSchema,
execution_id: str,
queue_id: str | None = None,
site_id: int | None = None,
schedule_id: str | None = None,
) -> None:
"""以子进程方式调用 ETL CLI。
使用 subprocess.Popen + 线程读取,兼容 Windows避免
asyncio.create_subprocess_exec 在 Windows 上的 NotImplementedError
"""
# CHANGE 2026-03-07 | 实时从 config 模块读取,避免 import 时复制的值过期
etl_path = _config_module.ETL_PROJECT_PATH
etl_python = _config_module.ETL_PYTHON_EXECUTABLE
cmd = cli_builder.build_command(
config, etl_path, python_executable=etl_python
)
command_str = " ".join(cmd)
# CHANGE 2026-03-07 | 运行时防护:拒绝执行包含非预期路径的命令
# 检测两种异常:
# 1. D 盘路径junction 穿透)
# 2. 多环境子目录test/repo、prod/repo
_cmd_normalized = command_str.replace("/", "\\")
_bad_patterns = []
if "D:\\" in command_str or "D:/" in command_str:
_bad_patterns.append("D盘路径")
if "\\test\\repo" in _cmd_normalized or "\\prod\\repo" in _cmd_normalized:
_bad_patterns.append("多环境子目录(test/repo或prod/repo)")
if _bad_patterns:
_issues = " + ".join(_bad_patterns)
logger.error(
"路径防护触发:命令包含 %s,拒绝执行。"
" command=%s | ETL_PY=%s | ETL_PATH=%s"
" | NEOZQYY_ROOT=%s | config.__file__=%s",
_issues, command_str, etl_python, etl_path,
__import__('os').environ.get("NEOZQYY_ROOT", "<未设置>"),
_config_module.__file__,
)
raise RuntimeError(
f"ETL 命令包含异常路径({_issues}),拒绝执行。"
f" 请检查 .env 中 ETL_PYTHON_EXECUTABLE 和 ETL_PROJECT_PATH 配置。"
f" 当前值: ETL_PY={etl_python}, ETL_PATH={etl_path}"
)
effective_site_id = site_id or config.store_id
# CHANGE 2026-03-07 | 在 command 前缀中注入实例标识,
# 便于在多后端实例共享同一 DB 时区分记录来源
command_str_with_host = f"[{_INSTANCE_HOST}] {command_str}"
logger.info(
"启动 ETL 子进程 [%s]: %s (cwd=%s)",
execution_id, command_str, etl_path,
)
self._log_buffers[execution_id] = []
started_at = datetime.now(timezone.utc)
t0 = time.monotonic()
self._write_execution_log(
execution_id=execution_id,
queue_id=queue_id,
site_id=effective_site_id,
task_codes=config.tasks,
status="running",
started_at=started_at,
command=command_str_with_host,
schedule_id=schedule_id,
config_json=config.model_dump(mode="json"),
)
exit_code: int | None = None
status = "running"
stdout_lines: list[str] = []
stderr_lines: list[str] = []
try:
# 构建额外环境变量DWD 表过滤通过环境变量注入)
extra_env: dict[str, str] = {}
if config.dwd_only_tables:
extra_env["DWD_ONLY_TABLES"] = ",".join(config.dwd_only_tables)
# 在线程池中运行子进程,兼容 Windows
exit_code = await asyncio.get_event_loop().run_in_executor(
None,
self._run_subprocess,
cmd,
execution_id,
stdout_lines,
stderr_lines,
extra_env or None,
)
if exit_code == 0:
status = "success"
else:
status = "failed"
logger.info(
"ETL 子进程 [%s] 退出exit_code=%s, status=%s",
execution_id, exit_code, status,
)
except asyncio.CancelledError:
status = "cancelled"
logger.info("ETL 子进程 [%s] 已取消", execution_id)
# 尝试终止子进程
proc = self._processes.get(execution_id)
if proc and proc.poll() is None:
proc.terminate()
except Exception as exc:
status = "failed"
import traceback
tb = traceback.format_exc()
stderr_lines.append(f"[task_executor] 子进程启动/执行异常: {exc}")
stderr_lines.append(tb)
logger.exception("ETL 子进程 [%s] 执行异常", execution_id)
finally:
elapsed_ms = int((time.monotonic() - t0) * 1000)
finished_at = datetime.now(timezone.utc)
self._broadcast_end(execution_id)
self._processes.pop(execution_id, None)
self._update_execution_log(
execution_id=execution_id,
status=status,
finished_at=finished_at,
exit_code=exit_code,
duration_ms=elapsed_ms,
output_log="\n".join(stdout_lines),
error_log="\n".join(stderr_lines),
)
# CHANGE 2026-03-22 | 释放内存缓冲区,防止长期运行内存泄漏
self.cleanup(execution_id)
def _run_subprocess(
self,
cmd: list[str],
execution_id: str,
stdout_lines: list[str],
stderr_lines: list[str],
extra_env: dict[str, str] | None = None,
) -> int:
"""在线程中运行子进程并逐行读取输出。"""
import os
env = os.environ.copy()
# 强制子进程使用 UTF-8 输出,避免 Windows GBK 乱码
env["PYTHONIOENCODING"] = "utf-8"
if extra_env:
env.update(extra_env)
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=_config_module.ETL_PROJECT_PATH,
env=env,
text=True,
encoding="utf-8",
errors="replace",
)
self._processes[execution_id] = proc
def read_stream(
stream, stream_name: str, collector: list[str],
) -> None:
"""逐行读取流并广播。"""
for raw_line in stream:
line = raw_line.rstrip("\n").rstrip("\r")
tagged = f"[{stream_name}] {line}"
buf = self._log_buffers.get(execution_id)
if buf is not None:
buf.append(tagged)
collector.append(line)
self._broadcast(execution_id, tagged)
t_out = threading.Thread(
target=read_stream, args=(proc.stdout, "stdout", stdout_lines),
daemon=True,
)
t_err = threading.Thread(
target=read_stream, args=(proc.stderr, "stderr", stderr_lines),
daemon=True,
)
t_out.start()
t_err.start()
proc.wait()
t_out.join(timeout=5)
t_err.join(timeout=5)
return proc.returncode
# ------------------------------------------------------------------
# 取消
# ------------------------------------------------------------------
async def cancel(self, execution_id: str) -> bool:
"""向子进程发送终止信号。
如果进程仍在内存中,发送 terminate 信号;
如果进程已不在内存中(如后端重启后),但数据库中仍为 running
则直接将数据库状态标记为 cancelled幽灵记录兜底
Returns:
True 表示成功取消False 表示任务不存在或已完成。
"""
proc = self._processes.get(execution_id)
if proc is not None:
# 进程仍在内存中
if proc.poll() is not None:
return False
logger.info("取消 ETL 子进程 [%s], pid=%s", execution_id, proc.pid)
try:
proc.terminate()
except ProcessLookupError:
return False
return True
# 进程不在内存中(后端重启等场景),尝试兜底修正数据库幽灵记录
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = 'cancelled',
finished_at = NOW(),
error_log = COALESCE(error_log, '')
|| E'\n[cancel 兜底] 进程已不在内存中,标记为 cancelled'
WHERE id = %s AND status = 'running'
""",
(execution_id,),
)
updated = cur.rowcount
conn.commit()
finally:
conn.close()
if updated:
logger.info(
"兜底取消 execution_log [%s]:数据库状态从 running → cancelled",
execution_id,
)
return True
except Exception:
logger.exception("兜底取消 execution_log [%s] 失败", execution_id)
return False
# ------------------------------------------------------------------
# 数据库操作(同步,在线程池中执行也可,此处简单直连)
# ------------------------------------------------------------------
@staticmethod
def _write_execution_log(
*,
execution_id: str,
queue_id: str | None,
site_id: int | None,
task_codes: list[str],
status: str,
started_at: datetime,
command: str,
schedule_id: str | None = None,
config_json: dict | None = None,
) -> None:
"""插入一条执行日志记录running 状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
# 如果调用方未传 schedule_id尝试从 task_queue 回查
effective_schedule_id = schedule_id
if effective_schedule_id is None and queue_id is not None:
cur.execute(
"SELECT schedule_id FROM task_queue WHERE id = %s",
(queue_id,),
)
row = cur.fetchone()
if row and row[0]:
effective_schedule_id = str(row[0])
# CHANGE 2026-03-22 | 存储完整 TaskConfig JSON供 rerun 还原原始参数
cur.execute(
"""
INSERT INTO task_execution_log
(id, queue_id, site_id, task_codes, status,
started_at, command, schedule_id, config)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(
execution_id,
queue_id,
site_id or 0,
task_codes,
status,
started_at,
command,
effective_schedule_id,
Json(config_json) if config_json else None,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("写入 execution_log 失败 [%s]", execution_id)
@staticmethod
def _update_execution_log(
*,
execution_id: str,
status: str,
finished_at: datetime,
exit_code: int | None,
duration_ms: int,
output_log: str,
error_log: str,
) -> None:
"""更新执行日志记录(完成状态)。"""
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = %s,
finished_at = %s,
exit_code = %s,
duration_ms = %s,
output_log = %s,
error_log = %s
WHERE id = %s
""",
(
status,
finished_at,
exit_code,
duration_ms,
output_log,
error_log,
execution_id,
),
)
conn.commit()
finally:
conn.close()
except Exception:
logger.exception("更新 execution_log 失败 [%s]", execution_id)
# ------------------------------------------------------------------
# 清理
# ------------------------------------------------------------------
def cleanup(self, execution_id: str) -> None:
"""清理指定执行的内存资源(日志缓冲区和订阅者)。
通常在确认日志已持久化后调用。
"""
self._log_buffers.pop(execution_id, None)
self._subscribers.pop(execution_id, None)
# ------------------------------------------------------------------
# 优雅关闭:终止所有子进程并回写状态
# ------------------------------------------------------------------
async def shutdown(self, timeout: float = 3.0) -> int:
"""优雅关闭:终止所有正在运行的子进程,等待回写完成。
Args:
timeout: 等待子进程退出的超时秒数,超时后强制 kill。
Returns:
被终止的进程数量。
"""
running_ids = list(self._processes.keys())
if not running_ids:
return 0
logger.info(
"优雅关闭:终止 %d 个运行中的子进程,超时 %.1fs",
len(running_ids), timeout,
)
# 先发 terminate 信号
for eid, proc in list(self._processes.items()):
if proc.poll() is None:
try:
proc.terminate()
logger.info("已发送 terminate 信号: %s (pid=%s)", eid, proc.pid)
except ProcessLookupError:
pass
# 等待子进程退出(给 finally 块执行的机会)
import time
deadline = time.monotonic() + timeout
for eid, proc in list(self._processes.items()):
remaining = deadline - time.monotonic()
if remaining > 0 and proc.poll() is None:
try:
proc.wait(timeout=remaining)
except Exception:
pass
# 超时后强制 kill 仍存活的进程
for eid, proc in list(self._processes.items()):
if proc.poll() is None:
try:
proc.kill()
logger.warning("强制 kill: %s (pid=%s)", eid, proc.pid)
except ProcessLookupError:
pass
# 注意execute() 的 finally 块会在 run_in_executor 返回后执行,
# 此处不需要手动回写——asyncio 事件循环关闭前会处理。
# 但如果 finally 来不及执行recover_stale() 会在下次启动时兜底。
count = len(running_ids)
logger.info("优雅关闭完成,已终止 %d 个子进程", count)
return count
# ------------------------------------------------------------------
# 启动时僵尸任务清理
# ------------------------------------------------------------------
def recover_stale(self) -> int:
"""启动时清理本机的僵尸任务status=running 但进程已不存在)。
仅清理 command 中包含本机主机名标识 [hostname] 的记录。
Returns:
被标记为 interrupted 的记录数量。
"""
# CHANGE 2026-03-22 | 启动时僵尸清理,仅限本机
host_tag = f"[{_INSTANCE_HOST}]"
try:
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
UPDATE task_execution_log
SET status = 'interrupted',
finished_at = NOW(),
error_log = COALESCE(error_log, '')
|| E'\n[recover_stale] 后端重启,进程已丢失,标记为 interrupted'
WHERE status = 'running'
AND command LIKE %s
RETURNING id
""",
(f"{host_tag}%",),
)
rows = cur.fetchall()
count = len(rows)
conn.commit()
finally:
conn.close()
if count > 0:
ids = [str(r[0]) for r in rows]
logger.warning(
"启动清理:%d 条僵尸任务标记为 interrupted: %s",
count, ", ".join(ids),
)
else:
logger.info("启动清理:无僵尸任务")
return count
except Exception:
logger.exception("启动清理僵尸任务失败")
return 0
# 全局单例
task_executor = TaskExecutor()