Files
ZQYY.FQ-ETL/orchestration/task_executor.py

498 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""任务执行器:封装单个 ETL 任务的完整执行生命周期。
从原 ETLScheduler 中提取的执行层,负责:
- 单任务执行(抓取/入库/ODS 录制+加载)
- 游标管理(成功后推进水位)
- 运行记录(创建/更新 etl_admin.etl_run
设计原则:
- data_source 作为显式参数传入,不依赖全局状态
- 工具类任务判断通过 TaskRegistry 元数据查询
- 所有依赖通过构造函数注入,不自行创建资源
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List
from zoneinfo import ZoneInfo
from api.recording_client import RecordingAPIClient
from api.local_json_client import LocalJsonClient
from orchestration.cursor_manager import CursorManager
from orchestration.run_tracker import RunTracker
from orchestration.task_registry import TaskRegistry
class DataSource(str, Enum):
"""数据源模式,取代原 pipeline.flow 全局状态。"""
ONLINE = "online" # 仅在线抓取(原 FETCH_ONLY
OFFLINE = "offline" # 仅本地入库(原 INGEST_ONLY
HYBRID = "hybrid" # 抓取 + 入库(原 FULL
class TaskExecutor:
"""任务执行器:封装单个 ETL 任务的完整执行生命周期。
通过构造函数注入所有依赖,不自行创建 DatabaseConnection 或 APIClient。
data_source 作为方法参数传入,替代原 self.pipeline_flow 全局状态。
"""
def __init__(
self,
config,
db_ops,
api_client,
cursor_mgr: CursorManager,
run_tracker: RunTracker,
task_registry: TaskRegistry,
logger: logging.Logger,
):
self.config = config
self.db_ops = db_ops
self.api_client = api_client
self.cursor_mgr = cursor_mgr
self.run_tracker = run_tracker
self.task_registry = task_registry
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
self.fetch_root = Path(
config.get("io.fetch_root")
or config.get("pipeline.fetch_root")
or config["io"]["export_root"]
)
self.ingest_source_dir = (
config.get("io.ingest_source_dir")
or config.get("pipeline.ingest_source_dir")
or ""
)
self.write_pretty_json = bool(config.get("io.write_pretty_json", False))
# ------------------------------------------------------------------ 公共接口
def run_tasks(
self,
task_codes: list[str],
data_source: str = "hybrid",
) -> list[dict[str, Any]]:
"""批量执行任务列表,返回每个任务的结果。"""
run_uuid = uuid.uuid4().hex
store_id = self.config.get("app.store_id")
results: list[dict[str, Any]] = []
file_handler = self._attach_run_file_logger(run_uuid)
try:
self.logger.info("开始运行任务: %s, run_uuid=%s", task_codes, run_uuid)
for task_code in task_codes:
try:
task_result = self.run_single_task(
task_code, run_uuid, store_id, data_source=data_source,
)
result_entry: dict[str, Any] = {
"task_code": task_code,
"status": "成功" if task_result else "完成",
"counts": task_result.get("counts", {}) if isinstance(task_result, dict) else {},
}
if isinstance(task_result, dict):
if task_result.get("dump_dir"):
result_entry["dump_dir"] = task_result["dump_dir"]
if task_result.get("last_dump"):
result_entry["last_dump"] = task_result["last_dump"]
results.append(result_entry)
except Exception as exc: # noqa: BLE001
self.logger.error("任务 %s 失败: %s", task_code, exc, exc_info=True)
results.append({
"task_code": task_code,
"status": "失败",
"error": str(exc),
"counts": {},
})
continue
self.logger.info("所有任务执行完成")
return results
finally:
if file_handler is not None:
try:
logging.getLogger().removeHandler(file_handler)
except Exception:
pass
try:
file_handler.close()
except Exception:
pass
def run_single_task(
self,
task_code: str,
run_uuid: str,
store_id: int,
data_source: str = "hybrid",
) -> dict[str, Any]:
"""执行单个任务的完整生命周期。
Args:
task_code: 任务代码
run_uuid: 本次运行的唯一标识
store_id: 门店 ID
data_source: 数据源模式online/offline/hybrid
"""
task_code_upper = task_code.upper()
# 工具类任务:通过 TaskRegistry 元数据判断,跳过游标和运行记录
if self.task_registry.is_utility_task(task_code_upper):
return self._run_utility_task(task_code_upper, store_id)
task_cfg = self._load_task_config(task_code, store_id)
if not task_cfg:
self.logger.warning("任务 %s 未启用或不存在", task_code)
return {"status": "SKIP", "counts": {}}
task_id = task_cfg["task_id"]
cursor_data = self.cursor_mgr.get_or_create(task_id, store_id)
# 创建运行记录
export_dir = Path(self.config["io"]["export_root"]) / datetime.now(self.tz).strftime("%Y%m%d")
log_path = str(Path(self.config["io"]["log_root"]) / f"{run_uuid}.log")
run_id = self.run_tracker.create_run(
task_id=task_id,
store_id=store_id,
run_uuid=run_uuid,
export_dir=str(export_dir),
log_path=log_path,
status=RunTracker.map_run_status("RUNNING"),
)
fetch_dir = self._build_fetch_dir(task_code, run_id)
fetch_stats = None
try:
# ODS 任务ODS_JSON_ARCHIVE 除外)走特殊路径
if self._is_ods_task(task_code):
if self._flow_includes_fetch(data_source):
result, last_dump = self._execute_ods_record_and_load(
task_code, cursor_data, fetch_dir, run_id,
)
if isinstance(result, dict):
result.setdefault("dump_dir", str(fetch_dir))
if last_dump:
result.setdefault("last_dump", last_dump)
else:
source_dir = self._resolve_ingest_source(fetch_dir, None)
result = self._execute_ingest(task_code, cursor_data, source_dir)
self.run_tracker.update_run(
run_id=run_id,
counts=result.get("counts") or {},
status=RunTracker.map_run_status(result.get("status")),
ended_at=datetime.now(self.tz),
window=result.get("window"),
request_params=result.get("request_params"),
overlap_seconds=self.config.get("run.overlap_seconds"),
)
if (result.get("status") or "").upper() == "SUCCESS":
window = result.get("window")
if isinstance(window, dict):
self.cursor_mgr.advance(
task_id=task_id,
store_id=store_id,
window_start=window.get("start"),
window_end=window.get("end"),
run_id=run_id,
)
self._maybe_run_integrity_check(task_code, window)
return result
# 非 ODS 任务:按 data_source 决定抓取/入库阶段
if self._flow_includes_fetch(data_source):
fetch_stats = self._execute_fetch(task_code, cursor_data, fetch_dir, run_id)
if data_source == DataSource.ONLINE or data_source == "online":
counts = self._counts_from_fetch(fetch_stats)
self.run_tracker.update_run(
run_id=run_id,
counts=counts,
status=RunTracker.map_run_status("SUCCESS"),
ended_at=datetime.now(self.tz),
)
return {"status": "SUCCESS", "counts": counts}
if self._flow_includes_ingest(data_source):
source_dir = self._resolve_ingest_source(fetch_dir, fetch_stats)
result = self._execute_ingest(task_code, cursor_data, source_dir)
self.run_tracker.update_run(
run_id=run_id,
counts=result["counts"],
status=RunTracker.map_run_status(result["status"]),
ended_at=datetime.now(self.tz),
window=result.get("window"),
request_params=result.get("request_params"),
overlap_seconds=self.config.get("run.overlap_seconds"),
)
if (result.get("status") or "").upper() == "SUCCESS":
window = result.get("window")
if window:
self.cursor_mgr.advance(
task_id=task_id,
store_id=store_id,
window_start=window.get("start"),
window_end=window.get("end"),
run_id=run_id,
)
self._maybe_run_integrity_check(task_code, window)
return result
except Exception as exc:
self.run_tracker.update_run(
run_id=run_id,
counts={},
status=RunTracker.map_run_status("FAIL"),
ended_at=datetime.now(self.tz),
error_message=str(exc),
)
raise
return {"status": "COMPLETE", "counts": {}}
# ------------------------------------------------------------------ 内部方法
def _execute_fetch(
self,
task_code: str,
cursor_data: dict | None,
fetch_dir: Path,
run_id: int,
):
"""在线抓取阶段:用 RecordingAPIClient 拉取并落盘,不做 Transform/Load。"""
recording_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, recording_client, self.logger,
)
context = task._build_context(cursor_data) # type: ignore[attr-defined]
self.logger.info("%s: 抓取阶段开始,目录=%s", task_code, fetch_dir)
extracted = task.extract(context)
stats = recording_client.last_dump or {}
extracted_count = 0
if isinstance(extracted, dict):
extracted_count = int(extracted.get("fetched") or 0) or len(extracted.get("records", []))
fetched_count = stats.get("records") or extracted_count or 0
self.logger.info(
"%s: 抓取完成,文件=%s,记录数=%s",
task_code,
stats.get("file"),
fetched_count,
)
return {"file": stats.get("file"), "records": fetched_count, "pages": stats.get("pages")}
@staticmethod
def _is_ods_task(task_code: str) -> bool:
"""判断是否为 ODS 任务ODS_JSON_ARCHIVE 除外)。"""
tc = str(task_code or "").upper()
return tc.startswith("ODS_") and tc != "ODS_JSON_ARCHIVE"
def _execute_ods_record_and_load(
self,
task_code: str,
cursor_data: dict | None,
fetch_dir: Path,
run_id: int,
) -> tuple[dict, dict]:
"""ODS 任务:在线抓取 + 直接入库ODS 任务在 execute() 内完成 DB upsert"""
recording_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, recording_client, self.logger,
)
self.logger.info("%s: ODS fetch+load start, dir=%s", task_code, fetch_dir)
result = task.execute(cursor_data)
return result, (recording_client.last_dump or {})
def _execute_ingest(
self,
task_code: str,
cursor_data: dict | None,
source_dir: Path,
):
"""本地清洗入库:使用 LocalJsonClient 回放 JSON走原有任务 ETL。"""
local_client = LocalJsonClient(source_dir)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, local_client, self.logger,
)
self.logger.info("%s: 本地清洗入库开始,源目录=%s", task_code, source_dir)
return task.execute(cursor_data)
def _build_fetch_dir(self, task_code: str, run_id: int) -> Path:
"""构建抓取输出目录路径。"""
ts = datetime.now(self.tz).strftime("%Y%m%d-%H%M%S")
task_code = str(task_code or "").upper()
return Path(self.fetch_root) / task_code / f"{task_code}-{run_id}-{ts}"
def _resolve_ingest_source(self, fetch_dir: Path, fetch_stats: dict | None) -> Path:
"""确定本地清洗入库的 JSON 源目录。"""
if fetch_stats and fetch_dir.exists():
return fetch_dir
if self.ingest_source_dir:
return Path(self.ingest_source_dir)
raise FileNotFoundError("未提供本地清洗入库所需的 JSON 目录")
def _counts_from_fetch(self, stats: dict | None) -> dict:
"""从抓取统计中构建计数字典。"""
fetched = (stats or {}).get("records") or 0
return {
"fetched": fetched,
"inserted": 0,
"updated": 0,
"skipped": 0,
"errors": 0,
}
@staticmethod
def _flow_includes_fetch(data_source: str) -> bool:
"""判断当前 data_source 是否包含抓取阶段。"""
ds = str(data_source).lower()
return ds in {"online", "hybrid"}
@staticmethod
def _flow_includes_ingest(data_source: str) -> bool:
"""判断当前 data_source 是否包含入库阶段。"""
ds = str(data_source).lower()
return ds in {"offline", "hybrid"}
def _run_utility_task(self, task_code: str, store_id: int) -> Dict[str, Any]:
"""执行工具类任务(不记录 cursor/run直接执行"""
self.logger.info("%s: 开始执行工具类任务", task_code)
try:
api_client = None
if task_code == "ODS_JSON_ARCHIVE":
run_id = int(datetime.now(self.tz).timestamp())
fetch_dir = self._build_fetch_dir(task_code, run_id)
api_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, api_client, self.logger,
)
result = task.execute(None)
status = (result.get("status") or "").upper() if isinstance(result, dict) else "SUCCESS"
counts = result.get("counts", {}) if isinstance(result, dict) else {}
if status == "SUCCESS":
self.logger.info("%s: 工具类任务执行成功", task_code)
if counts:
self.logger.info("%s: 结果统计: %s", task_code, counts)
else:
self.logger.warning("%s: 工具类任务执行结果: %s", task_code, status)
return {"status": status, "counts": counts}
except Exception as exc:
self.logger.error("%s: 工具类任务执行失败: %s", task_code, exc, exc_info=True)
raise
def _load_task_config(self, task_code: str, store_id: int) -> dict | None:
"""从数据库加载任务配置。"""
sql = """
SELECT task_id, task_code, store_id, enabled, cursor_field,
window_minutes_default, overlap_seconds, page_size, retry_max, params
FROM etl_admin.etl_task
WHERE store_id = %s AND task_code = %s AND enabled = TRUE
"""
rows = self.db_ops.query(sql, (store_id, task_code))
return rows[0] if rows else None
def _maybe_run_integrity_check(self, task_code: str, window: dict | None) -> None:
"""在 DWD_LOAD_FROM_ODS 成功后可选执行完整性校验。"""
if not self.config.get("integrity.auto_check", False):
return
if str(task_code or "").upper() != "DWD_LOAD_FROM_ODS":
return
if not isinstance(window, dict):
return
window_start = window.get("start")
window_end = window.get("end")
if not window_start or not window_end:
return
try:
from quality.integrity_checker import IntegrityWindow, run_integrity_window
include_dimensions = bool(self.config.get("integrity.include_dimensions", False))
task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip()
report = run_integrity_window(
cfg=self.config,
window=IntegrityWindow(
start=window_start,
end=window_end,
label="etl_window",
granularity="window",
),
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
write_report=True,
)
self.logger.info(
"Integrity check done: report=%s missing=%s errors=%s",
report.get("report_path"),
report.get("api_to_ods", {}).get("total_missing"),
report.get("api_to_ods", {}).get("total_errors"),
)
except Exception as exc: # noqa: BLE001
self.logger.warning("Integrity check failed: %s", exc, exc_info=True)
def _attach_run_file_logger(self, run_uuid: str) -> logging.Handler | None:
"""为本次 run_uuid 动态挂载文件日志处理器。"""
log_root = Path(self.config["io"]["log_root"])
try:
log_root.mkdir(parents=True, exist_ok=True)
except Exception as exc: # noqa: BLE001
self.logger.warning("创建日志目录失败:%s%s", log_root, exc)
return None
log_path = log_root / f"{run_uuid}.log"
try:
handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8")
except Exception as exc: # noqa: BLE001
self.logger.warning("创建文件日志失败:%s%s", log_path, exc)
return None
fmt = logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(fmt)
handler.setLevel(logging.INFO)
root_logger = logging.getLogger()
root_logger.addHandler(handler)
return handler