Files
feiqiu-ETL/etl_billiards/tasks/base_task.py
2026-01-27 22:14:01 +08:00

235 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""ETL任务基类引入 Extract/Transform/Load 模板方法)"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from utils.windowing import build_window_segments, calc_window_minutes
@dataclass(frozen=True)
class TaskContext:
"""统一透传给 Extract/Transform/Load 的运行期信息。"""
store_id: int
window_start: datetime
window_end: datetime
window_minutes: int
cursor: dict | None = None
class BaseTask:
"""提供 E/T/L 模板的任务基类。"""
def __init__(self, config, db_connection, api_client, logger):
self.config = config
self.db = db_connection
self.api = api_client
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Taipei"))
# ------------------------------------------------------------------ 基本信息
def get_task_code(self) -> str:
"""获取任务代码"""
raise NotImplementedError("子类需实现 get_task_code 方法")
# ------------------------------------------------------------------ E/T/L 钩子
def extract(self, context: TaskContext):
"""提取数据"""
raise NotImplementedError("子类需实现 extract 方法")
def transform(self, extracted, context: TaskContext):
"""转换数据"""
return extracted
def load(self, transformed, context: TaskContext) -> dict:
"""加载数据并返回统计信息"""
raise NotImplementedError("子类需实现 load 方法")
# ------------------------------------------------------------------ 主流程
def execute(self, cursor_data: dict | None = None) -> dict:
"""统一 orchestrate Extract → Transform → Load"""
base_context = self._build_context(cursor_data)
task_code = self.get_task_code()
segments = build_window_segments(
self.config,
base_context.window_start,
base_context.window_end,
tz=self.tz,
override_only=True,
)
if not segments:
segments = [(base_context.window_start, base_context.window_end)]
total_segments = len(segments)
if total_segments > 1:
self.logger.info("%s: 窗口拆分为 %s", task_code, total_segments)
total_counts: dict = {}
segment_results: list[dict] = []
for idx, (window_start, window_end) in enumerate(segments, start=1):
context = self._build_context_for_window(window_start, window_end, cursor_data)
self.logger.info(
"%s: 开始执行(%s/%s),窗口[%s ~ %s]",
task_code,
idx,
total_segments,
context.window_start,
context.window_end,
)
try:
extracted = self.extract(context)
transformed = self.transform(extracted, context)
counts = self.load(transformed, context) or {}
self.db.commit()
except Exception:
self.db.rollback()
self.logger.error("%s: 执行失败", task_code, exc_info=True)
raise
self._accumulate_counts(total_counts, counts)
if total_segments > 1:
segment_results.append(
{
"window": {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
},
"counts": counts,
}
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
result = self._build_result("SUCCESS", total_counts)
result["window"] = {
"start": overall_start,
"end": overall_end,
"minutes": calc_window_minutes(overall_start, overall_end),
}
if segment_results:
result["segments"] = segment_results
self.logger.info("%s: 完成,统计=%s", task_code, result["counts"])
return result
def _build_context(self, cursor_data: dict | None) -> TaskContext:
window_start, window_end, window_minutes = self._get_time_window(cursor_data)
return TaskContext(
store_id=self.config.get("app.store_id"),
window_start=window_start,
window_end=window_end,
window_minutes=window_minutes,
cursor=cursor_data,
)
def _build_context_for_window(
self,
window_start: datetime,
window_end: datetime,
cursor_data: dict | None,
) -> TaskContext:
return TaskContext(
store_id=self.config.get("app.store_id"),
window_start=window_start,
window_end=window_end,
window_minutes=calc_window_minutes(window_start, window_end),
cursor=cursor_data,
)
@staticmethod
def _accumulate_counts(total: dict, current: dict) -> dict:
for key, value in (current or {}).items():
if isinstance(value, (int, float)):
total[key] = (total.get(key) or 0) + value
else:
total.setdefault(key, value)
return total
def _get_time_window(self, cursor_data: dict = None) -> tuple:
"""计算时间窗口"""
now = datetime.now(self.tz)
override_start = self.config.get("run.window_override.start")
override_end = self.config.get("run.window_override.end")
if override_start or override_end:
if not (override_start and override_end):
raise ValueError("run.window_override.start/end 需要同时提供")
window_start = override_start
if isinstance(window_start, str):
window_start = dtparser.parse(window_start)
if isinstance(window_start, datetime) and window_start.tzinfo is None:
window_start = window_start.replace(tzinfo=self.tz)
elif isinstance(window_start, datetime):
window_start = window_start.astimezone(self.tz)
window_end = override_end
if isinstance(window_end, str):
window_end = dtparser.parse(window_end)
if isinstance(window_end, datetime) and window_end.tzinfo is None:
window_end = window_end.replace(tzinfo=self.tz)
elif isinstance(window_end, datetime):
window_end = window_end.astimezone(self.tz)
if not isinstance(window_start, datetime) or not isinstance(window_end, datetime):
raise ValueError("run.window_override.start/end 解析失败")
if window_end <= window_start:
raise ValueError("run.window_override.end 必须大于 start")
window_minutes = max(1, int((window_end - window_start).total_seconds() // 60))
return window_start, window_end, window_minutes
idle_start = self.config.get("run.idle_window.start", "04:00")
idle_end = self.config.get("run.idle_window.end", "16:00")
is_idle = self._is_in_idle_window(now, idle_start, idle_end)
if is_idle:
window_minutes = self.config.get("run.window_minutes.default_idle", 180)
else:
window_minutes = self.config.get("run.window_minutes.default_busy", 30)
overlap_seconds = self.config.get("run.overlap_seconds", 120)
if cursor_data and cursor_data.get("last_end"):
window_start = cursor_data["last_end"] - timedelta(seconds=overlap_seconds)
else:
window_start = now - timedelta(minutes=window_minutes)
window_end = now
return window_start, window_end, window_minutes
def _is_in_idle_window(self, dt: datetime, start_time: str, end_time: str) -> bool:
"""判断是否在闲时窗口"""
current_time = dt.strftime("%H:%M")
return start_time <= current_time <= end_time
def _merge_common_params(self, base: dict) -> dict:
"""
合并全局/任务级参数池便于在配置中统一覆<E4B880>?/追加过滤条件。
支持:
- api.params 下的通用键<E794A8>?
- api.params.<task_code_lower> 下的任务级键<E7BAA7>?
"""
merged: dict = {}
common = self.config.get("api.params", {}) or {}
if isinstance(common, dict):
merged.update(common)
task_key = f"api.params.{self.get_task_code().lower()}"
scoped = self.config.get(task_key, {}) or {}
if isinstance(scoped, dict):
merged.update(scoped)
merged.update(base)
return merged
def _build_result(self, status: str, counts: dict) -> dict:
"""构建结果字典"""
return {"status": status, "counts": counts}