Files
2025-11-30 07:19:05 +08:00

235 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ETL 调度:支持在线抓取、离线清洗入库、全流程三种模式。"""
from __future__ import annotations
import uuid
from datetime import datetime
from pathlib import Path
from zoneinfo import ZoneInfo
from api.client import APIClient
from api.local_json_client import LocalJsonClient
from api.recording_client import RecordingAPIClient
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from orchestration.cursor_manager import CursorManager
from orchestration.run_tracker import RunTracker
from orchestration.task_registry import default_registry
class ETLScheduler:
"""调度多个任务,按 pipeline.flow 执行抓取/清洗入库。"""
def __init__(self, config, logger):
self.config = config
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Taipei"))
self.pipeline_flow = str(config.get("pipeline.flow", "FULL") or "FULL").upper()
self.fetch_root = Path(config.get("pipeline.fetch_root") or config["io"]["export_root"])
self.ingest_source_dir = config.get("pipeline.ingest_source_dir") or ""
self.write_pretty_json = bool(config.get("io.write_pretty_json", False))
# 组件
self.db_conn = DatabaseConnection(
dsn=config["db"]["dsn"],
session=config["db"].get("session"),
connect_timeout=config["db"].get("connect_timeout_sec"),
)
self.db_ops = DatabaseOperations(self.db_conn)
self.api_client = APIClient(
base_url=config["api"]["base_url"],
token=config["api"]["token"],
timeout=config["api"]["timeout_sec"],
retry_max=config["api"]["retries"]["max_attempts"],
headers_extra=config["api"].get("headers_extra"),
)
self.cursor_mgr = CursorManager(self.db_conn)
self.run_tracker = RunTracker(self.db_conn)
self.task_registry = default_registry
# ------------------------------------------------------------------ public
def run_tasks(self, task_codes: list | None = None):
"""按配置或传入列表执行任务。"""
run_uuid = uuid.uuid4().hex
store_id = self.config.get("app.store_id")
if not task_codes:
task_codes = self.config.get("run.tasks", [])
self.logger.info("开始运行任务: %s, run_uuid=%s", task_codes, run_uuid)
for task_code in task_codes:
try:
self._run_single_task(task_code, run_uuid, store_id)
except Exception as exc: # noqa: BLE001
self.logger.error("任务 %s 失败: %s", task_code, exc, exc_info=True)
continue
self.logger.info("所有任务执行完成")
# ------------------------------------------------------------------ internals
def _run_single_task(self, task_code: str, run_uuid: str, store_id: int):
"""单个任务的抓取/清洗编排。"""
task_cfg = self._load_task_config(task_code, store_id)
if not task_cfg:
self.logger.warning("任务 %s 未启用或不存在", task_code)
return
task_id = task_cfg["task_id"]
cursor_data = self.cursor_mgr.get_or_create(task_id, store_id)
# run 记录
export_dir = Path(self.config["io"]["export_root"]) / datetime.now(self.tz).strftime("%Y%m%d")
log_path = str(Path(self.config["io"]["log_root"]) / f"{run_uuid}.log")
run_id = self.run_tracker.create_run(
task_id=task_id,
store_id=store_id,
run_uuid=run_uuid,
export_dir=str(export_dir),
log_path=log_path,
status=self._map_run_status("RUNNING"),
)
# 为抓取阶段准备目录
fetch_dir = self._build_fetch_dir(task_code, run_id)
fetch_stats = None
try:
if self._flow_includes_fetch():
fetch_stats = self._execute_fetch(task_code, cursor_data, fetch_dir, run_id)
if self.pipeline_flow == "FETCH_ONLY":
counts = self._counts_from_fetch(fetch_stats)
self.run_tracker.update_run(
run_id=run_id,
counts=counts,
status=self._map_run_status("SUCCESS"),
ended_at=datetime.now(self.tz),
)
return
if self._flow_includes_ingest():
source_dir = self._resolve_ingest_source(fetch_dir, fetch_stats)
result = self._execute_ingest(task_code, cursor_data, source_dir)
self.run_tracker.update_run(
run_id=run_id,
counts=result["counts"],
status=self._map_run_status(result["status"]),
ended_at=datetime.now(self.tz),
)
if (result.get("status") or "").upper() == "SUCCESS":
window = result.get("window")
if window:
self.cursor_mgr.advance(
task_id=task_id,
store_id=store_id,
window_start=window.get("start"),
window_end=window.get("end"),
run_id=run_id,
)
except Exception as exc: # noqa: BLE001
self.run_tracker.update_run(
run_id=run_id,
counts={},
status=self._map_run_status("FAIL"),
ended_at=datetime.now(self.tz),
error_message=str(exc),
)
raise
def _execute_fetch(self, task_code: str, cursor_data: dict | None, fetch_dir: Path, run_id: int):
"""在线抓取阶段:用 RecordingAPIClient 拉取并落盘,不做 Transform/Load。"""
recording_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(task_code, self.config, self.db_ops, recording_client, self.logger)
context = task._build_context(cursor_data) # type: ignore[attr-defined]
self.logger.info("%s: 抓取阶段开始,目录=%s", task_code, fetch_dir)
extracted = task.extract(context)
# 抓取结束,不执行 transform/load
stats = recording_client.last_dump or {}
fetched_count = stats.get("records") or len(extracted.get("records", [])) if isinstance(extracted, dict) else 0
self.logger.info(
"%s: 抓取完成,文件=%s,记录数=%s",
task_code,
stats.get("file"),
fetched_count,
)
return {"file": stats.get("file"), "records": fetched_count, "pages": stats.get("pages")}
def _execute_ingest(self, task_code: str, cursor_data: dict | None, source_dir: Path):
"""本地清洗入库:使用 LocalJsonClient 回放 JSON走原有任务 ETL。"""
local_client = LocalJsonClient(source_dir)
task = self.task_registry.create_task(task_code, self.config, self.db_ops, local_client, self.logger)
self.logger.info("%s: 本地清洗入库开始,源目录=%s", task_code, source_dir)
return task.execute(cursor_data)
def _build_fetch_dir(self, task_code: str, run_id: int) -> Path:
ts = datetime.now(self.tz).strftime("%Y%m%d-%H%M%S")
return Path(self.fetch_root) / f"{task_code.upper()}-{run_id}-{ts}"
def _resolve_ingest_source(self, fetch_dir: Path, fetch_stats: dict | None) -> Path:
if fetch_stats and fetch_dir.exists():
return fetch_dir
if self.ingest_source_dir:
return Path(self.ingest_source_dir)
raise FileNotFoundError("未提供本地清洗入库所需的 JSON 目录")
def _counts_from_fetch(self, stats: dict | None) -> dict:
fetched = (stats or {}).get("records") or 0
return {
"fetched": fetched,
"inserted": 0,
"updated": 0,
"skipped": 0,
"errors": 0,
}
def _flow_includes_fetch(self) -> bool:
return self.pipeline_flow in {"FETCH_ONLY", "FULL"}
def _flow_includes_ingest(self) -> bool:
return self.pipeline_flow in {"INGEST_ONLY", "FULL"}
def _load_task_config(self, task_code: str, store_id: int) -> dict | None:
"""从数据库加载任务配置。"""
sql = """
SELECT task_id, task_code, store_id, enabled, cursor_field,
window_minutes_default, overlap_seconds, page_size, retry_max, params
FROM etl_admin.etl_task
WHERE store_id = %s AND task_code = %s AND enabled = TRUE
"""
rows = self.db_conn.query(sql, (store_id, task_code))
return rows[0] if rows else None
def close(self):
"""关闭连接。"""
self.db_conn.close()
@staticmethod
def _map_run_status(status: str) -> str:
"""
将任务返回的状态转换为 etl_admin.run_status_enum
(SUCC / FAIL / PARTIAL)
"""
normalized = (status or "").upper()
if normalized in {"SUCCESS", "SUCC"}:
return "SUCC"
if normalized in {"FAIL", "FAILED", "ERROR"}:
return "FAIL"
if normalized in {"RUNNING", "PARTIAL", "PENDING", "IN_PROGRESS"}:
return "PARTIAL"
# 未知状态默认标记为 FAIL便于排查
return "FAIL"