初始提交:飞球 ETL 系统全量代码
This commit is contained in:
252
tasks/base_task.py
Normal file
252
tasks/base_task.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ETL任务基类(引入 Extract/Transform/Load 模板方法)"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
from dateutil import parser as dtparser
|
||||
|
||||
from utils.windowing import build_window_segments, calc_window_minutes, calc_window_days, format_window_days
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskContext:
|
||||
"""统一透传给 Extract/Transform/Load 的运行期信息。"""
|
||||
|
||||
store_id: int
|
||||
window_start: datetime
|
||||
window_end: datetime
|
||||
window_minutes: int
|
||||
cursor: dict | None = None
|
||||
|
||||
|
||||
class BaseTask:
|
||||
"""提供 E/T/L 模板的任务基类。"""
|
||||
|
||||
def __init__(self, config, db_connection, api_client, logger):
|
||||
self.config = config
|
||||
self.db = db_connection
|
||||
self.api = api_client
|
||||
self.logger = logger
|
||||
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Taipei"))
|
||||
|
||||
# ------------------------------------------------------------------ 基本信息
|
||||
def get_task_code(self) -> str:
|
||||
"""获取任务代码"""
|
||||
raise NotImplementedError("子类需实现 get_task_code 方法")
|
||||
|
||||
# ------------------------------------------------------------------ E/T/L 钩子
|
||||
def extract(self, context: TaskContext):
|
||||
"""提取数据"""
|
||||
raise NotImplementedError("子类需实现 extract 方法")
|
||||
|
||||
def transform(self, extracted, context: TaskContext):
|
||||
"""转换数据"""
|
||||
return extracted
|
||||
|
||||
def load(self, transformed, context: TaskContext) -> dict:
|
||||
"""加载数据并返回统计信息"""
|
||||
raise NotImplementedError("子类需实现 load 方法")
|
||||
|
||||
# ------------------------------------------------------------------ 主流程
|
||||
def execute(self, cursor_data: dict | None = None) -> dict:
|
||||
"""统一 orchestrate Extract → Transform → Load"""
|
||||
base_context = self._build_context(cursor_data)
|
||||
task_code = self.get_task_code()
|
||||
segments = build_window_segments(
|
||||
self.config,
|
||||
base_context.window_start,
|
||||
base_context.window_end,
|
||||
tz=self.tz,
|
||||
override_only=True,
|
||||
)
|
||||
if not segments:
|
||||
segments = [(base_context.window_start, base_context.window_end)]
|
||||
|
||||
total_segments = len(segments)
|
||||
total_days = sum(calc_window_days(s, e) for s, e in segments) if segments else 0.0
|
||||
processed_days = 0.0
|
||||
if total_segments > 1:
|
||||
self.logger.info(
|
||||
"%s: 窗口拆分为 %s 段(共 %s 天)",
|
||||
task_code,
|
||||
total_segments,
|
||||
format_window_days(total_days),
|
||||
)
|
||||
|
||||
total_counts: dict = {}
|
||||
segment_results: list[dict] = []
|
||||
|
||||
for idx, (window_start, window_end) in enumerate(segments, start=1):
|
||||
context = self._build_context_for_window(window_start, window_end, cursor_data)
|
||||
self.logger.info(
|
||||
"%s: 开始执行(%s/%s),窗口[%s ~ %s]",
|
||||
task_code,
|
||||
idx,
|
||||
total_segments,
|
||||
context.window_start,
|
||||
context.window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
extracted = self.extract(context)
|
||||
transformed = self.transform(extracted, context)
|
||||
counts = self.load(transformed, context) or {}
|
||||
self.db.commit()
|
||||
except Exception:
|
||||
self.db.rollback()
|
||||
self.logger.error("%s: 执行失败", task_code, exc_info=True)
|
||||
raise
|
||||
|
||||
self._accumulate_counts(total_counts, counts)
|
||||
segment_days = calc_window_days(context.window_start, context.window_end)
|
||||
processed_days += segment_days
|
||||
if total_segments > 1:
|
||||
self.logger.info(
|
||||
"%s: 完成(%s/%s),已处理 %s/%s 天",
|
||||
task_code,
|
||||
idx,
|
||||
total_segments,
|
||||
format_window_days(processed_days),
|
||||
format_window_days(total_days),
|
||||
)
|
||||
if total_segments > 1:
|
||||
segment_results.append(
|
||||
{
|
||||
"window": {
|
||||
"start": context.window_start,
|
||||
"end": context.window_end,
|
||||
"minutes": context.window_minutes,
|
||||
},
|
||||
"counts": counts,
|
||||
}
|
||||
)
|
||||
|
||||
overall_start = segments[0][0]
|
||||
overall_end = segments[-1][1]
|
||||
result = self._build_result("SUCCESS", total_counts)
|
||||
result["window"] = {
|
||||
"start": overall_start,
|
||||
"end": overall_end,
|
||||
"minutes": calc_window_minutes(overall_start, overall_end),
|
||||
}
|
||||
if segment_results:
|
||||
result["segments"] = segment_results
|
||||
self.logger.info("%s: 完成,统计=%s", task_code, result["counts"])
|
||||
return result
|
||||
|
||||
def _build_context(self, cursor_data: dict | None) -> TaskContext:
|
||||
window_start, window_end, window_minutes = self._get_time_window(cursor_data)
|
||||
return TaskContext(
|
||||
store_id=self.config.get("app.store_id"),
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
window_minutes=window_minutes,
|
||||
cursor=cursor_data,
|
||||
)
|
||||
|
||||
def _build_context_for_window(
|
||||
self,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
cursor_data: dict | None,
|
||||
) -> TaskContext:
|
||||
return TaskContext(
|
||||
store_id=self.config.get("app.store_id"),
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
window_minutes=calc_window_minutes(window_start, window_end),
|
||||
cursor=cursor_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _accumulate_counts(total: dict, current: dict) -> dict:
|
||||
for key, value in (current or {}).items():
|
||||
if isinstance(value, (int, float)):
|
||||
total[key] = (total.get(key) or 0) + value
|
||||
else:
|
||||
total.setdefault(key, value)
|
||||
return total
|
||||
|
||||
def _get_time_window(self, cursor_data: dict = None) -> tuple:
|
||||
"""计算时间窗口"""
|
||||
now = datetime.now(self.tz)
|
||||
|
||||
override_start = self.config.get("run.window_override.start")
|
||||
override_end = self.config.get("run.window_override.end")
|
||||
if override_start or override_end:
|
||||
if not (override_start and override_end):
|
||||
raise ValueError("run.window_override.start/end 需要同时提供")
|
||||
|
||||
window_start = override_start
|
||||
if isinstance(window_start, str):
|
||||
window_start = dtparser.parse(window_start)
|
||||
if isinstance(window_start, datetime) and window_start.tzinfo is None:
|
||||
window_start = window_start.replace(tzinfo=self.tz)
|
||||
elif isinstance(window_start, datetime):
|
||||
window_start = window_start.astimezone(self.tz)
|
||||
|
||||
window_end = override_end
|
||||
if isinstance(window_end, str):
|
||||
window_end = dtparser.parse(window_end)
|
||||
if isinstance(window_end, datetime) and window_end.tzinfo is None:
|
||||
window_end = window_end.replace(tzinfo=self.tz)
|
||||
elif isinstance(window_end, datetime):
|
||||
window_end = window_end.astimezone(self.tz)
|
||||
|
||||
if not isinstance(window_start, datetime) or not isinstance(window_end, datetime):
|
||||
raise ValueError("run.window_override.start/end 解析失败")
|
||||
if window_end <= window_start:
|
||||
raise ValueError("run.window_override.end 必须大于 start")
|
||||
|
||||
window_minutes = max(1, int((window_end - window_start).total_seconds() // 60))
|
||||
return window_start, window_end, window_minutes
|
||||
|
||||
idle_start = self.config.get("run.idle_window.start", "04:00")
|
||||
idle_end = self.config.get("run.idle_window.end", "16:00")
|
||||
is_idle = self._is_in_idle_window(now, idle_start, idle_end)
|
||||
|
||||
if is_idle:
|
||||
window_minutes = self.config.get("run.window_minutes.default_idle", 180)
|
||||
else:
|
||||
window_minutes = self.config.get("run.window_minutes.default_busy", 30)
|
||||
|
||||
overlap_seconds = self.config.get("run.overlap_seconds", 600)
|
||||
|
||||
if cursor_data and cursor_data.get("last_end"):
|
||||
window_start = cursor_data["last_end"] - timedelta(seconds=overlap_seconds)
|
||||
else:
|
||||
window_start = now - timedelta(minutes=window_minutes)
|
||||
|
||||
window_end = now
|
||||
return window_start, window_end, window_minutes
|
||||
|
||||
def _is_in_idle_window(self, dt: datetime, start_time: str, end_time: str) -> bool:
|
||||
"""判断是否在闲时窗口"""
|
||||
current_time = dt.strftime("%H:%M")
|
||||
return start_time <= current_time <= end_time
|
||||
|
||||
def _merge_common_params(self, base: dict) -> dict:
|
||||
"""
|
||||
合并全局/任务级参数池,便于在配置中统一覆<E4B880>?/追加过滤条件。
|
||||
支持:
|
||||
- api.params 下的通用键<E794A8>?
|
||||
- api.params.<task_code_lower> 下的任务级键<E7BAA7>?
|
||||
"""
|
||||
merged: dict = {}
|
||||
common = self.config.get("api.params", {}) or {}
|
||||
if isinstance(common, dict):
|
||||
merged.update(common)
|
||||
|
||||
task_key = f"api.params.{self.get_task_code().lower()}"
|
||||
scoped = self.config.get(task_key, {}) or {}
|
||||
if isinstance(scoped, dict):
|
||||
merged.update(scoped)
|
||||
|
||||
merged.update(base)
|
||||
return merged
|
||||
|
||||
def _build_result(self, status: str, counts: dict) -> dict:
|
||||
"""构建结果字典"""
|
||||
return {"status": status, "counts": counts}
|
||||
Reference in New Issue
Block a user