# -*- coding: utf-8 -*- """任务执行器:封装单个 ETL 任务的完整执行生命周期。 从原 ETLScheduler 中提取的执行层,负责: - 单任务执行(抓取/入库/ODS 录制+加载) - 游标管理(成功后推进水位) - 运行记录(创建/更新 etl_admin.etl_run) 设计原则: - data_source 作为显式参数传入,不依赖全局状态 - 工具类任务判断通过 TaskRegistry 元数据查询 - 所有依赖通过构造函数注入,不自行创建资源 """ from __future__ import annotations import logging import uuid from datetime import datetime from enum import Enum from pathlib import Path from typing import Any, Dict, List from zoneinfo import ZoneInfo from api.recording_client import RecordingAPIClient from api.local_json_client import LocalJsonClient from orchestration.cursor_manager import CursorManager from orchestration.run_tracker import RunTracker from orchestration.task_registry import TaskRegistry class DataSource(str, Enum): """数据源模式,取代原 pipeline.flow 全局状态。""" ONLINE = "online" # 仅在线抓取(原 FETCH_ONLY) OFFLINE = "offline" # 仅本地入库(原 INGEST_ONLY) HYBRID = "hybrid" # 抓取 + 入库(原 FULL) class TaskExecutor: """任务执行器:封装单个 ETL 任务的完整执行生命周期。 通过构造函数注入所有依赖,不自行创建 DatabaseConnection 或 APIClient。 data_source 作为方法参数传入,替代原 self.pipeline_flow 全局状态。 """ def __init__( self, config, db_ops, api_client, cursor_mgr: CursorManager, run_tracker: RunTracker, task_registry: TaskRegistry, logger: logging.Logger, ): self.config = config self.db_ops = db_ops self.api_client = api_client self.cursor_mgr = cursor_mgr self.run_tracker = run_tracker self.task_registry = task_registry self.logger = logger self.tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai")) self.fetch_root = Path( config.get("io.fetch_root") or config.get("pipeline.fetch_root") or config["io"]["export_root"] ) self.ingest_source_dir = ( config.get("io.ingest_source_dir") or config.get("pipeline.ingest_source_dir") or "" ) self.write_pretty_json = bool(config.get("io.write_pretty_json", False)) # ------------------------------------------------------------------ 公共接口 def run_tasks( self, task_codes: list[str], data_source: str = "hybrid", ) -> list[dict[str, Any]]: """批量执行任务列表,返回每个任务的结果。""" run_uuid = uuid.uuid4().hex store_id = self.config.get("app.store_id") results: list[dict[str, Any]] = [] file_handler = self._attach_run_file_logger(run_uuid) try: self.logger.info("开始运行任务: %s, run_uuid=%s", task_codes, run_uuid) for task_code in task_codes: try: task_result = self.run_single_task( task_code, run_uuid, store_id, data_source=data_source, ) result_entry: dict[str, Any] = { "task_code": task_code, "status": "成功" if task_result else "完成", "counts": task_result.get("counts", {}) if isinstance(task_result, dict) else {}, } if isinstance(task_result, dict): if task_result.get("dump_dir"): result_entry["dump_dir"] = task_result["dump_dir"] if task_result.get("last_dump"): result_entry["last_dump"] = task_result["last_dump"] results.append(result_entry) except Exception as exc: # noqa: BLE001 self.logger.error("任务 %s 失败: %s", task_code, exc, exc_info=True) results.append({ "task_code": task_code, "status": "失败", "error": str(exc), "counts": {}, }) continue self.logger.info("所有任务执行完成") return results finally: if file_handler is not None: try: logging.getLogger().removeHandler(file_handler) except Exception: pass try: file_handler.close() except Exception: pass def run_single_task( self, task_code: str, run_uuid: str, store_id: int, data_source: str = "hybrid", ) -> dict[str, Any]: """执行单个任务的完整生命周期。 Args: task_code: 任务代码 run_uuid: 本次运行的唯一标识 store_id: 门店 ID data_source: 数据源模式(online/offline/hybrid) """ task_code_upper = task_code.upper() # 工具类任务:通过 TaskRegistry 元数据判断,跳过游标和运行记录 if self.task_registry.is_utility_task(task_code_upper): return self._run_utility_task(task_code_upper, store_id) task_cfg = self._load_task_config(task_code, store_id) if not task_cfg: self.logger.warning("任务 %s 未启用或不存在", task_code) return {"status": "SKIP", "counts": {}} task_id = task_cfg["task_id"] cursor_data = self.cursor_mgr.get_or_create(task_id, store_id) # 创建运行记录 export_dir = Path(self.config["io"]["export_root"]) / datetime.now(self.tz).strftime("%Y%m%d") log_path = str(Path(self.config["io"]["log_root"]) / f"{run_uuid}.log") run_id = self.run_tracker.create_run( task_id=task_id, store_id=store_id, run_uuid=run_uuid, export_dir=str(export_dir), log_path=log_path, status=RunTracker.map_run_status("RUNNING"), ) fetch_dir = self._build_fetch_dir(task_code, run_id) fetch_stats = None try: # ODS 任务(ODS_JSON_ARCHIVE 除外)走特殊路径 if self._is_ods_task(task_code): if self._flow_includes_fetch(data_source): result, last_dump = self._execute_ods_record_and_load( task_code, cursor_data, fetch_dir, run_id, ) if isinstance(result, dict): result.setdefault("dump_dir", str(fetch_dir)) if last_dump: result.setdefault("last_dump", last_dump) else: source_dir = self._resolve_ingest_source(fetch_dir, None) result = self._execute_ingest(task_code, cursor_data, source_dir) self.run_tracker.update_run( run_id=run_id, counts=result.get("counts") or {}, status=RunTracker.map_run_status(result.get("status")), ended_at=datetime.now(self.tz), window=result.get("window"), request_params=result.get("request_params"), overlap_seconds=self.config.get("run.overlap_seconds"), ) if (result.get("status") or "").upper() == "SUCCESS": window = result.get("window") if isinstance(window, dict): self.cursor_mgr.advance( task_id=task_id, store_id=store_id, window_start=window.get("start"), window_end=window.get("end"), run_id=run_id, ) self._maybe_run_integrity_check(task_code, window) return result # 非 ODS 任务:按 data_source 决定抓取/入库阶段 if self._flow_includes_fetch(data_source): fetch_stats = self._execute_fetch(task_code, cursor_data, fetch_dir, run_id) if data_source == DataSource.ONLINE or data_source == "online": counts = self._counts_from_fetch(fetch_stats) self.run_tracker.update_run( run_id=run_id, counts=counts, status=RunTracker.map_run_status("SUCCESS"), ended_at=datetime.now(self.tz), ) return {"status": "SUCCESS", "counts": counts} if self._flow_includes_ingest(data_source): source_dir = self._resolve_ingest_source(fetch_dir, fetch_stats) result = self._execute_ingest(task_code, cursor_data, source_dir) self.run_tracker.update_run( run_id=run_id, counts=result["counts"], status=RunTracker.map_run_status(result["status"]), ended_at=datetime.now(self.tz), window=result.get("window"), request_params=result.get("request_params"), overlap_seconds=self.config.get("run.overlap_seconds"), ) if (result.get("status") or "").upper() == "SUCCESS": window = result.get("window") if window: self.cursor_mgr.advance( task_id=task_id, store_id=store_id, window_start=window.get("start"), window_end=window.get("end"), run_id=run_id, ) self._maybe_run_integrity_check(task_code, window) return result except Exception as exc: self.run_tracker.update_run( run_id=run_id, counts={}, status=RunTracker.map_run_status("FAIL"), ended_at=datetime.now(self.tz), error_message=str(exc), ) raise return {"status": "COMPLETE", "counts": {}} # ------------------------------------------------------------------ 内部方法 def _execute_fetch( self, task_code: str, cursor_data: dict | None, fetch_dir: Path, run_id: int, ): """在线抓取阶段:用 RecordingAPIClient 拉取并落盘,不做 Transform/Load。""" recording_client = RecordingAPIClient( base_client=self.api_client, output_dir=fetch_dir, task_code=task_code, run_id=run_id, write_pretty=self.write_pretty_json, ) task = self.task_registry.create_task( task_code, self.config, self.db_ops, recording_client, self.logger, ) context = task._build_context(cursor_data) # type: ignore[attr-defined] self.logger.info("%s: 抓取阶段开始,目录=%s", task_code, fetch_dir) extracted = task.extract(context) stats = recording_client.last_dump or {} extracted_count = 0 if isinstance(extracted, dict): extracted_count = int(extracted.get("fetched") or 0) or len(extracted.get("records", [])) fetched_count = stats.get("records") or extracted_count or 0 self.logger.info( "%s: 抓取完成,文件=%s,记录数=%s", task_code, stats.get("file"), fetched_count, ) return {"file": stats.get("file"), "records": fetched_count, "pages": stats.get("pages")} @staticmethod def _is_ods_task(task_code: str) -> bool: """判断是否为 ODS 任务(ODS_JSON_ARCHIVE 除外)。""" tc = str(task_code or "").upper() return tc.startswith("ODS_") and tc != "ODS_JSON_ARCHIVE" def _execute_ods_record_and_load( self, task_code: str, cursor_data: dict | None, fetch_dir: Path, run_id: int, ) -> tuple[dict, dict]: """ODS 任务:在线抓取 + 直接入库(ODS 任务在 execute() 内完成 DB upsert)。""" recording_client = RecordingAPIClient( base_client=self.api_client, output_dir=fetch_dir, task_code=task_code, run_id=run_id, write_pretty=self.write_pretty_json, ) task = self.task_registry.create_task( task_code, self.config, self.db_ops, recording_client, self.logger, ) self.logger.info("%s: ODS fetch+load start, dir=%s", task_code, fetch_dir) result = task.execute(cursor_data) return result, (recording_client.last_dump or {}) def _execute_ingest( self, task_code: str, cursor_data: dict | None, source_dir: Path, ): """本地清洗入库:使用 LocalJsonClient 回放 JSON,走原有任务 ETL。""" local_client = LocalJsonClient(source_dir) task = self.task_registry.create_task( task_code, self.config, self.db_ops, local_client, self.logger, ) self.logger.info("%s: 本地清洗入库开始,源目录=%s", task_code, source_dir) return task.execute(cursor_data) def _build_fetch_dir(self, task_code: str, run_id: int) -> Path: """构建抓取输出目录路径。""" ts = datetime.now(self.tz).strftime("%Y%m%d-%H%M%S") task_code = str(task_code or "").upper() return Path(self.fetch_root) / task_code / f"{task_code}-{run_id}-{ts}" def _resolve_ingest_source(self, fetch_dir: Path, fetch_stats: dict | None) -> Path: """确定本地清洗入库的 JSON 源目录。""" if fetch_stats and fetch_dir.exists(): return fetch_dir if self.ingest_source_dir: return Path(self.ingest_source_dir) raise FileNotFoundError("未提供本地清洗入库所需的 JSON 目录") def _counts_from_fetch(self, stats: dict | None) -> dict: """从抓取统计中构建计数字典。""" fetched = (stats or {}).get("records") or 0 return { "fetched": fetched, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0, } @staticmethod def _flow_includes_fetch(data_source: str) -> bool: """判断当前 data_source 是否包含抓取阶段。""" ds = str(data_source).lower() return ds in {"online", "hybrid"} @staticmethod def _flow_includes_ingest(data_source: str) -> bool: """判断当前 data_source 是否包含入库阶段。""" ds = str(data_source).lower() return ds in {"offline", "hybrid"} def _run_utility_task(self, task_code: str, store_id: int) -> Dict[str, Any]: """执行工具类任务(不记录 cursor/run,直接执行)。""" self.logger.info("%s: 开始执行工具类任务", task_code) try: api_client = None if task_code == "ODS_JSON_ARCHIVE": run_id = int(datetime.now(self.tz).timestamp()) fetch_dir = self._build_fetch_dir(task_code, run_id) api_client = RecordingAPIClient( base_client=self.api_client, output_dir=fetch_dir, task_code=task_code, run_id=run_id, write_pretty=self.write_pretty_json, ) task = self.task_registry.create_task( task_code, self.config, self.db_ops, api_client, self.logger, ) result = task.execute(None) status = (result.get("status") or "").upper() if isinstance(result, dict) else "SUCCESS" counts = result.get("counts", {}) if isinstance(result, dict) else {} if status == "SUCCESS": self.logger.info("%s: 工具类任务执行成功", task_code) if counts: self.logger.info("%s: 结果统计: %s", task_code, counts) else: self.logger.warning("%s: 工具类任务执行结果: %s", task_code, status) return {"status": status, "counts": counts} except Exception as exc: self.logger.error("%s: 工具类任务执行失败: %s", task_code, exc, exc_info=True) raise def _load_task_config(self, task_code: str, store_id: int) -> dict | None: """从数据库加载任务配置。""" sql = """ SELECT task_id, task_code, store_id, enabled, cursor_field, window_minutes_default, overlap_seconds, page_size, retry_max, params FROM etl_admin.etl_task WHERE store_id = %s AND task_code = %s AND enabled = TRUE """ rows = self.db_ops.query(sql, (store_id, task_code)) return rows[0] if rows else None def _maybe_run_integrity_check(self, task_code: str, window: dict | None) -> None: """在 DWD_LOAD_FROM_ODS 成功后可选执行完整性校验。""" if not self.config.get("integrity.auto_check", False): return if str(task_code or "").upper() != "DWD_LOAD_FROM_ODS": return if not isinstance(window, dict): return window_start = window.get("start") window_end = window.get("end") if not window_start or not window_end: return try: from quality.integrity_checker import IntegrityWindow, run_integrity_window include_dimensions = bool(self.config.get("integrity.include_dimensions", False)) task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip() report = run_integrity_window( cfg=self.config, window=IntegrityWindow( start=window_start, end=window_end, label="etl_window", granularity="window", ), include_dimensions=include_dimensions, task_codes=task_codes, logger=self.logger, write_report=True, ) self.logger.info( "Integrity check done: report=%s missing=%s errors=%s", report.get("report_path"), report.get("api_to_ods", {}).get("total_missing"), report.get("api_to_ods", {}).get("total_errors"), ) except Exception as exc: # noqa: BLE001 self.logger.warning("Integrity check failed: %s", exc, exc_info=True) def _attach_run_file_logger(self, run_uuid: str) -> logging.Handler | None: """为本次 run_uuid 动态挂载文件日志处理器。""" log_root = Path(self.config["io"]["log_root"]) try: log_root.mkdir(parents=True, exist_ok=True) except Exception as exc: # noqa: BLE001 self.logger.warning("创建日志目录失败:%s(%s)", log_root, exc) return None log_path = log_root / f"{run_uuid}.log" try: handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8") except Exception as exc: # noqa: BLE001 self.logger.warning("创建文件日志失败:%s(%s)", log_path, exc) return None fmt = logging.Formatter( fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(fmt) handler.setLevel(logging.INFO) root_logger = logging.getLogger() root_logger.addHandler(handler) return handler