235 lines
9.5 KiB
Python
235 lines
9.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""ETL 调度:支持在线抓取、离线清洗入库、全流程三种模式。"""
|
||
from __future__ import annotations
|
||
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from zoneinfo import ZoneInfo
|
||
|
||
from api.client import APIClient
|
||
from api.local_json_client import LocalJsonClient
|
||
from api.recording_client import RecordingAPIClient
|
||
from database.connection import DatabaseConnection
|
||
from database.operations import DatabaseOperations
|
||
from orchestration.cursor_manager import CursorManager
|
||
from orchestration.run_tracker import RunTracker
|
||
from orchestration.task_registry import default_registry
|
||
|
||
|
||
class ETLScheduler:
|
||
"""调度多个任务,按 pipeline.flow 执行抓取/清洗入库。"""
|
||
|
||
def __init__(self, config, logger):
|
||
self.config = config
|
||
self.logger = logger
|
||
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Taipei"))
|
||
|
||
self.pipeline_flow = str(config.get("pipeline.flow", "FULL") or "FULL").upper()
|
||
self.fetch_root = Path(config.get("pipeline.fetch_root") or config["io"]["export_root"])
|
||
self.ingest_source_dir = config.get("pipeline.ingest_source_dir") or ""
|
||
self.write_pretty_json = bool(config.get("io.write_pretty_json", False))
|
||
|
||
# 组件
|
||
self.db_conn = DatabaseConnection(
|
||
dsn=config["db"]["dsn"],
|
||
session=config["db"].get("session"),
|
||
connect_timeout=config["db"].get("connect_timeout_sec"),
|
||
)
|
||
self.db_ops = DatabaseOperations(self.db_conn)
|
||
|
||
self.api_client = APIClient(
|
||
base_url=config["api"]["base_url"],
|
||
token=config["api"]["token"],
|
||
timeout=config["api"]["timeout_sec"],
|
||
retry_max=config["api"]["retries"]["max_attempts"],
|
||
headers_extra=config["api"].get("headers_extra"),
|
||
)
|
||
|
||
self.cursor_mgr = CursorManager(self.db_conn)
|
||
self.run_tracker = RunTracker(self.db_conn)
|
||
self.task_registry = default_registry
|
||
|
||
# ------------------------------------------------------------------ public
|
||
def run_tasks(self, task_codes: list | None = None):
|
||
"""按配置或传入列表执行任务。"""
|
||
run_uuid = uuid.uuid4().hex
|
||
store_id = self.config.get("app.store_id")
|
||
|
||
if not task_codes:
|
||
task_codes = self.config.get("run.tasks", [])
|
||
|
||
self.logger.info("开始运行任务: %s, run_uuid=%s", task_codes, run_uuid)
|
||
|
||
for task_code in task_codes:
|
||
try:
|
||
self._run_single_task(task_code, run_uuid, store_id)
|
||
except Exception as exc: # noqa: BLE001
|
||
self.logger.error("任务 %s 失败: %s", task_code, exc, exc_info=True)
|
||
continue
|
||
|
||
self.logger.info("所有任务执行完成")
|
||
|
||
# ------------------------------------------------------------------ internals
|
||
def _run_single_task(self, task_code: str, run_uuid: str, store_id: int):
|
||
"""单个任务的抓取/清洗编排。"""
|
||
task_cfg = self._load_task_config(task_code, store_id)
|
||
if not task_cfg:
|
||
self.logger.warning("任务 %s 未启用或不存在", task_code)
|
||
return
|
||
|
||
task_id = task_cfg["task_id"]
|
||
cursor_data = self.cursor_mgr.get_or_create(task_id, store_id)
|
||
|
||
# run 记录
|
||
export_dir = Path(self.config["io"]["export_root"]) / datetime.now(self.tz).strftime("%Y%m%d")
|
||
log_path = str(Path(self.config["io"]["log_root"]) / f"{run_uuid}.log")
|
||
run_id = self.run_tracker.create_run(
|
||
task_id=task_id,
|
||
store_id=store_id,
|
||
run_uuid=run_uuid,
|
||
export_dir=str(export_dir),
|
||
log_path=log_path,
|
||
status=self._map_run_status("RUNNING"),
|
||
)
|
||
|
||
# 为抓取阶段准备目录
|
||
fetch_dir = self._build_fetch_dir(task_code, run_id)
|
||
fetch_stats = None
|
||
|
||
try:
|
||
if self._flow_includes_fetch():
|
||
fetch_stats = self._execute_fetch(task_code, cursor_data, fetch_dir, run_id)
|
||
if self.pipeline_flow == "FETCH_ONLY":
|
||
counts = self._counts_from_fetch(fetch_stats)
|
||
self.run_tracker.update_run(
|
||
run_id=run_id,
|
||
counts=counts,
|
||
status=self._map_run_status("SUCCESS"),
|
||
ended_at=datetime.now(self.tz),
|
||
)
|
||
return
|
||
|
||
if self._flow_includes_ingest():
|
||
source_dir = self._resolve_ingest_source(fetch_dir, fetch_stats)
|
||
result = self._execute_ingest(task_code, cursor_data, source_dir)
|
||
|
||
self.run_tracker.update_run(
|
||
run_id=run_id,
|
||
counts=result["counts"],
|
||
status=self._map_run_status(result["status"]),
|
||
ended_at=datetime.now(self.tz),
|
||
)
|
||
|
||
if (result.get("status") or "").upper() == "SUCCESS":
|
||
window = result.get("window")
|
||
if window:
|
||
self.cursor_mgr.advance(
|
||
task_id=task_id,
|
||
store_id=store_id,
|
||
window_start=window.get("start"),
|
||
window_end=window.get("end"),
|
||
run_id=run_id,
|
||
)
|
||
|
||
except Exception as exc: # noqa: BLE001
|
||
self.run_tracker.update_run(
|
||
run_id=run_id,
|
||
counts={},
|
||
status=self._map_run_status("FAIL"),
|
||
ended_at=datetime.now(self.tz),
|
||
error_message=str(exc),
|
||
)
|
||
raise
|
||
|
||
def _execute_fetch(self, task_code: str, cursor_data: dict | None, fetch_dir: Path, run_id: int):
|
||
"""在线抓取阶段:用 RecordingAPIClient 拉取并落盘,不做 Transform/Load。"""
|
||
recording_client = RecordingAPIClient(
|
||
base_client=self.api_client,
|
||
output_dir=fetch_dir,
|
||
task_code=task_code,
|
||
run_id=run_id,
|
||
write_pretty=self.write_pretty_json,
|
||
)
|
||
task = self.task_registry.create_task(task_code, self.config, self.db_ops, recording_client, self.logger)
|
||
context = task._build_context(cursor_data) # type: ignore[attr-defined]
|
||
self.logger.info("%s: 抓取阶段开始,目录=%s", task_code, fetch_dir)
|
||
|
||
extracted = task.extract(context)
|
||
# 抓取结束,不执行 transform/load
|
||
stats = recording_client.last_dump or {}
|
||
fetched_count = stats.get("records") or len(extracted.get("records", [])) if isinstance(extracted, dict) else 0
|
||
self.logger.info(
|
||
"%s: 抓取完成,文件=%s,记录数=%s",
|
||
task_code,
|
||
stats.get("file"),
|
||
fetched_count,
|
||
)
|
||
return {"file": stats.get("file"), "records": fetched_count, "pages": stats.get("pages")}
|
||
|
||
def _execute_ingest(self, task_code: str, cursor_data: dict | None, source_dir: Path):
|
||
"""本地清洗入库:使用 LocalJsonClient 回放 JSON,走原有任务 ETL。"""
|
||
local_client = LocalJsonClient(source_dir)
|
||
task = self.task_registry.create_task(task_code, self.config, self.db_ops, local_client, self.logger)
|
||
self.logger.info("%s: 本地清洗入库开始,源目录=%s", task_code, source_dir)
|
||
return task.execute(cursor_data)
|
||
|
||
def _build_fetch_dir(self, task_code: str, run_id: int) -> Path:
|
||
ts = datetime.now(self.tz).strftime("%Y%m%d-%H%M%S")
|
||
return Path(self.fetch_root) / f"{task_code.upper()}-{run_id}-{ts}"
|
||
|
||
def _resolve_ingest_source(self, fetch_dir: Path, fetch_stats: dict | None) -> Path:
|
||
if fetch_stats and fetch_dir.exists():
|
||
return fetch_dir
|
||
if self.ingest_source_dir:
|
||
return Path(self.ingest_source_dir)
|
||
raise FileNotFoundError("未提供本地清洗入库所需的 JSON 目录")
|
||
|
||
def _counts_from_fetch(self, stats: dict | None) -> dict:
|
||
fetched = (stats or {}).get("records") or 0
|
||
return {
|
||
"fetched": fetched,
|
||
"inserted": 0,
|
||
"updated": 0,
|
||
"skipped": 0,
|
||
"errors": 0,
|
||
}
|
||
|
||
def _flow_includes_fetch(self) -> bool:
|
||
return self.pipeline_flow in {"FETCH_ONLY", "FULL"}
|
||
|
||
def _flow_includes_ingest(self) -> bool:
|
||
return self.pipeline_flow in {"INGEST_ONLY", "FULL"}
|
||
|
||
def _load_task_config(self, task_code: str, store_id: int) -> dict | None:
|
||
"""从数据库加载任务配置。"""
|
||
sql = """
|
||
SELECT task_id, task_code, store_id, enabled, cursor_field,
|
||
window_minutes_default, overlap_seconds, page_size, retry_max, params
|
||
FROM etl_admin.etl_task
|
||
WHERE store_id = %s AND task_code = %s AND enabled = TRUE
|
||
"""
|
||
|
||
rows = self.db_conn.query(sql, (store_id, task_code))
|
||
return rows[0] if rows else None
|
||
|
||
def close(self):
|
||
"""关闭连接。"""
|
||
self.db_conn.close()
|
||
|
||
@staticmethod
|
||
def _map_run_status(status: str) -> str:
|
||
"""
|
||
将任务返回的状态转换为 etl_admin.run_status_enum
|
||
(SUCC / FAIL / PARTIAL)
|
||
"""
|
||
normalized = (status or "").upper()
|
||
if normalized in {"SUCCESS", "SUCC"}:
|
||
return "SUCC"
|
||
if normalized in {"FAIL", "FAILED", "ERROR"}:
|
||
return "FAIL"
|
||
if normalized in {"RUNNING", "PARTIAL", "PENDING", "IN_PROGRESS"}:
|
||
return "PARTIAL"
|
||
# 未知状态默认标记为 FAIL,便于排查
|
||
return "FAIL"
|