# -*- coding: utf-8 -*- """ETL 调度:支持在线抓取、离线清洗入库、全流程三种模式。 说明: 为了便于排障与审计,调度器默认会在每次运行时将日志写入文件: `io.log_root/.log`。 - 该文件路径会同步写入 `etl_admin.etl_run.log_path` 字段(由 RunTracker 记录)。 - 文件日志通过给 root logger 动态挂载 FileHandler 实现,保证即便子模块使用 `logging.getLogger(__name__)` 也能写入同一份日志文件。 """ from __future__ import annotations import logging import uuid from datetime import datetime from pathlib import Path from zoneinfo import ZoneInfo from api.client import APIClient from api.local_json_client import LocalJsonClient from api.recording_client import RecordingAPIClient from database.connection import DatabaseConnection from database.operations import DatabaseOperations from orchestration.cursor_manager import CursorManager from orchestration.run_tracker import RunTracker from orchestration.task_registry import default_registry class ETLScheduler: """调度多个任务,按 pipeline.flow 执行抓取/清洗入库。""" def __init__(self, config, logger): self.config = config self.logger = logger self.tz = ZoneInfo(config.get("app.timezone", "Asia/Taipei")) self.pipeline_flow = str(config.get("pipeline.flow", "FULL") or "FULL").upper() self.fetch_root = Path(config.get("pipeline.fetch_root") or config["io"]["export_root"]) self.ingest_source_dir = config.get("pipeline.ingest_source_dir") or "" self.write_pretty_json = bool(config.get("io.write_pretty_json", False)) # 组件 self.db_conn = DatabaseConnection( dsn=config["db"]["dsn"], session=config["db"].get("session"), connect_timeout=config["db"].get("connect_timeout_sec"), ) self.db_ops = DatabaseOperations(self.db_conn) self.api_client = APIClient( base_url=config["api"]["base_url"], token=config["api"]["token"], timeout=config["api"]["timeout_sec"], retry_max=config["api"]["retries"]["max_attempts"], headers_extra=config["api"].get("headers_extra"), ) self.cursor_mgr = CursorManager(self.db_conn) self.run_tracker = RunTracker(self.db_conn) self.task_registry = default_registry def _attach_run_file_logger(self, run_uuid: str) -> logging.Handler | None: """ 为本次 run_uuid 动态挂载文件日志处理器。 返回值: - 成功:返回 FileHandler(调用方负责 removeHandler/close) - 失败:返回 None(不中断主流程) """ log_root = Path(self.config["io"]["log_root"]) try: log_root.mkdir(parents=True, exist_ok=True) except Exception as exc: # noqa: BLE001 self.logger.warning("创建日志目录失败:%s(%s)", log_root, exc) return None log_path = log_root / f"{run_uuid}.log" try: handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8") except Exception as exc: # noqa: BLE001 self.logger.warning("创建文件日志失败:%s(%s)", log_path, exc) return None fmt = logging.Formatter( fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(fmt) handler.setLevel(logging.INFO) # 挂到 root logger,保证各模块 logger 都能写入同一文件。 root_logger = logging.getLogger() root_logger.addHandler(handler) return handler # ------------------------------------------------------------------ public def run_tasks(self, task_codes: list | None = None): """按配置或传入列表执行任务。""" run_uuid = uuid.uuid4().hex store_id = self.config.get("app.store_id") if not task_codes: task_codes = self.config.get("run.tasks", []) file_handler = self._attach_run_file_logger(run_uuid) try: self.logger.info("开始运行任务: %s, run_uuid=%s", task_codes, run_uuid) for task_code in task_codes: try: self._run_single_task(task_code, run_uuid, store_id) except Exception as exc: # noqa: BLE001 self.logger.error("任务 %s 失败: %s", task_code, exc, exc_info=True) continue self.logger.info("所有任务执行完成") finally: if file_handler is not None: try: logging.getLogger().removeHandler(file_handler) except Exception: pass try: file_handler.close() except Exception: pass # ------------------------------------------------------------------ internals def _run_single_task(self, task_code: str, run_uuid: str, store_id: int): """单个任务的抓取/清洗编排。""" task_code_upper = task_code.upper() # 工具类任务:直接执行,不记录 cursor/run if task_code_upper in self.NO_DB_CONFIG_TASKS: self._run_utility_task(task_code_upper, store_id) return task_cfg = self._load_task_config(task_code, store_id) if not task_cfg: self.logger.warning("任务 %s 未启用或不存在", task_code) return task_id = task_cfg["task_id"] cursor_data = self.cursor_mgr.get_or_create(task_id, store_id) # run 记录 export_dir = Path(self.config["io"]["export_root"]) / datetime.now(self.tz).strftime("%Y%m%d") log_path = str(Path(self.config["io"]["log_root"]) / f"{run_uuid}.log") run_id = self.run_tracker.create_run( task_id=task_id, store_id=store_id, run_uuid=run_uuid, export_dir=str(export_dir), log_path=log_path, status=self._map_run_status("RUNNING"), ) # 为抓取阶段准备目录 fetch_dir = self._build_fetch_dir(task_code, run_id) fetch_stats = None try: # ODS_* tasks (except ODS_JSON_ARCHIVE) don't implement extract/transform/load stages in this repo # version, so we execute them as a single step with the appropriate API client. if self._is_ods_task(task_code): if self.pipeline_flow in {"FULL", "FETCH_ONLY"}: result, _ = self._execute_ods_record_and_load(task_code, cursor_data, fetch_dir, run_id) else: source_dir = self._resolve_ingest_source(fetch_dir, None) result = self._execute_ingest(task_code, cursor_data, source_dir) self.run_tracker.update_run( run_id=run_id, counts=result.get("counts") or {}, status=self._map_run_status(result.get("status")), ended_at=datetime.now(self.tz), window=result.get("window"), request_params=result.get("request_params"), overlap_seconds=self.config.get("run.overlap_seconds"), ) if (result.get("status") or "").upper() == "SUCCESS": window = result.get("window") if isinstance(window, dict): self.cursor_mgr.advance( task_id=task_id, store_id=store_id, window_start=window.get("start"), window_end=window.get("end"), run_id=run_id, ) self._maybe_run_integrity_check(task_code, window) return if self._flow_includes_fetch(): fetch_stats = self._execute_fetch(task_code, cursor_data, fetch_dir, run_id) if self.pipeline_flow == "FETCH_ONLY": counts = self._counts_from_fetch(fetch_stats) self.run_tracker.update_run( run_id=run_id, counts=counts, status=self._map_run_status("SUCCESS"), ended_at=datetime.now(self.tz), ) return if self._flow_includes_ingest(): source_dir = self._resolve_ingest_source(fetch_dir, fetch_stats) result = self._execute_ingest(task_code, cursor_data, source_dir) self.run_tracker.update_run( run_id=run_id, counts=result["counts"], status=self._map_run_status(result["status"]), ended_at=datetime.now(self.tz), window=result.get("window"), request_params=result.get("request_params"), overlap_seconds=self.config.get("run.overlap_seconds"), ) if (result.get("status") or "").upper() == "SUCCESS": window = result.get("window") if window: self.cursor_mgr.advance( task_id=task_id, store_id=store_id, window_start=window.get("start"), window_end=window.get("end"), run_id=run_id, ) self._maybe_run_integrity_check(task_code, window) except Exception as exc: # noqa: BLE001 self.run_tracker.update_run( run_id=run_id, counts={}, status=self._map_run_status("FAIL"), ended_at=datetime.now(self.tz), error_message=str(exc), ) raise def _execute_fetch(self, task_code: str, cursor_data: dict | None, fetch_dir: Path, run_id: int): """在线抓取阶段:用 RecordingAPIClient 拉取并落盘,不做 Transform/Load。""" recording_client = RecordingAPIClient( base_client=self.api_client, output_dir=fetch_dir, task_code=task_code, run_id=run_id, write_pretty=self.write_pretty_json, ) task = self.task_registry.create_task(task_code, self.config, self.db_ops, recording_client, self.logger) context = task._build_context(cursor_data) # type: ignore[attr-defined] self.logger.info("%s: 抓取阶段开始,目录=%s", task_code, fetch_dir) extracted = task.extract(context) # 抓取结束,不执行 transform/load stats = recording_client.last_dump or {} extracted_count = 0 if isinstance(extracted, dict): extracted_count = int(extracted.get("fetched") or 0) or len(extracted.get("records", [])) fetched_count = stats.get("records") or extracted_count or 0 self.logger.info( "%s: 抓取完成,文件=%s,记录数=%s", task_code, stats.get("file"), fetched_count, ) return {"file": stats.get("file"), "records": fetched_count, "pages": stats.get("pages")} @staticmethod def _is_ods_task(task_code: str) -> bool: tc = str(task_code or "").upper() return tc.startswith("ODS_") and tc != "ODS_JSON_ARCHIVE" def _execute_ods_record_and_load( self, task_code: str, cursor_data: dict | None, fetch_dir: Path, run_id: int, ) -> tuple[dict, dict]: """ Execute an ODS task with RecordingAPIClient so it fetches online and writes JSON dumps. (ODS tasks in this repo perform DB upsert inside execute(); there is no staged extract/load.) """ recording_client = RecordingAPIClient( base_client=self.api_client, output_dir=fetch_dir, task_code=task_code, run_id=run_id, write_pretty=self.write_pretty_json, ) task = self.task_registry.create_task(task_code, self.config, self.db_ops, recording_client, self.logger) self.logger.info("%s: ODS fetch+load start, dir=%s", task_code, fetch_dir) result = task.execute(cursor_data) return result, (recording_client.last_dump or {}) def _execute_ingest(self, task_code: str, cursor_data: dict | None, source_dir: Path): """本地清洗入库:使用 LocalJsonClient 回放 JSON,走原有任务 ETL。""" local_client = LocalJsonClient(source_dir) task = self.task_registry.create_task(task_code, self.config, self.db_ops, local_client, self.logger) self.logger.info("%s: 本地清洗入库开始,源目录=%s", task_code, source_dir) return task.execute(cursor_data) def _build_fetch_dir(self, task_code: str, run_id: int) -> Path: ts = datetime.now(self.tz).strftime("%Y%m%d-%H%M%S") return Path(self.fetch_root) / f"{task_code.upper()}-{run_id}-{ts}" def _resolve_ingest_source(self, fetch_dir: Path, fetch_stats: dict | None) -> Path: if fetch_stats and fetch_dir.exists(): return fetch_dir if self.ingest_source_dir: return Path(self.ingest_source_dir) raise FileNotFoundError("未提供本地清洗入库所需的 JSON 目录") def _counts_from_fetch(self, stats: dict | None) -> dict: fetched = (stats or {}).get("records") or 0 return { "fetched": fetched, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0, } def _flow_includes_fetch(self) -> bool: return self.pipeline_flow in {"FETCH_ONLY", "FULL"} def _flow_includes_ingest(self) -> bool: return self.pipeline_flow in {"INGEST_ONLY", "FULL"} # 不需要数据库配置即可运行的任务(工具类/初始化类) NO_DB_CONFIG_TASKS = { # Schema 初始化任务 "INIT_ODS_SCHEMA", "INIT_DWD_SCHEMA", "INIT_DWS_SCHEMA", # 质量检查任务 "DATA_INTEGRITY_CHECK", "DWD_QUALITY_CHECK", # 工具任务 "CHECK_CUTOFF", "MANUAL_INGEST", "ODS_JSON_ARCHIVE", # DWS 汇总任务 "DWS_BUILD_ORDER_SUMMARY", } def _run_utility_task(self, task_code: str, store_id: int): """ 执行工具类任务(不记录 cursor/run,直接执行)。 这些任务不需要游标管理和运行跟踪。 """ self.logger.info("%s: 开始执行工具类任务", task_code) try: # 创建任务实例(不需要 API client,使用 None) task = self.task_registry.create_task( task_code, self.config, self.db_ops, None, self.logger ) # 执行任务(工具类任务通常不需要 cursor_data) result = task.execute(None) status = (result.get("status") or "").upper() if isinstance(result, dict) else "SUCCESS" if status == "SUCCESS": self.logger.info("%s: 工具类任务执行成功", task_code) if isinstance(result, dict): counts = result.get("counts", {}) if counts: self.logger.info("%s: 结果统计: %s", task_code, counts) else: self.logger.warning("%s: 工具类任务执行结果: %s", task_code, status) except Exception as exc: self.logger.error("%s: 工具类任务执行失败: %s", task_code, exc, exc_info=True) raise def _load_task_config(self, task_code: str, store_id: int) -> dict | None: """从数据库加载任务配置。""" sql = """ SELECT task_id, task_code, store_id, enabled, cursor_field, window_minutes_default, overlap_seconds, page_size, retry_max, params FROM etl_admin.etl_task WHERE store_id = %s AND task_code = %s AND enabled = TRUE """ rows = self.db_conn.query(sql, (store_id, task_code)) return rows[0] if rows else None def _maybe_run_integrity_check(self, task_code: str, window: dict | None) -> None: if not self.config.get("integrity.auto_check", False): return if str(task_code or "").upper() != "DWD_LOAD_FROM_ODS": return if not isinstance(window, dict): return window_start = window.get("start") window_end = window.get("end") if not window_start or not window_end: return try: from quality.integrity_checker import IntegrityWindow, run_integrity_window include_dimensions = bool(self.config.get("integrity.include_dimensions", False)) task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip() report = run_integrity_window( cfg=self.config, window=IntegrityWindow( start=window_start, end=window_end, label="etl_window", granularity="window", ), include_dimensions=include_dimensions, task_codes=task_codes, logger=self.logger, write_report=True, ) self.logger.info( "Integrity check done: report=%s missing=%s errors=%s", report.get("report_path"), report.get("api_to_ods", {}).get("total_missing"), report.get("api_to_ods", {}).get("total_errors"), ) except Exception as exc: # noqa: BLE001 self.logger.warning("Integrity check failed: %s", exc, exc_info=True) def close(self): """关闭连接。""" self.db_conn.close() @staticmethod def _map_run_status(status: str) -> str: """ 将任务返回的状态转换为 etl_admin.run_status_enum (SUCC / FAIL / PARTIAL) """ normalized = (status or "").upper() if normalized in {"SUCCESS", "SUCC"}: return "SUCC" if normalized in {"FAIL", "FAILED", "ERROR"}: return "FAIL" if normalized in {"RUNNING", "PARTIAL", "PENDING", "IN_PROGRESS"}: return "PARTIAL" # 未知状态默认标记为 FAIL,便于排查 return "FAIL"