Files
ZQYY.FQ-ETL/orchestration/pipeline_runner.py

380 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""管道运行器:管道定义、层→任务映射、校验编排。
从原 ETLScheduler 中提取管道编排逻辑,委托 TaskExecutor 执行具体任务。
所有依赖通过构造函数注入,不自行创建资源。
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from tasks.verification import filter_verify_tables
class PipelineRunner:
"""管道编排器:根据管道定义执行多层 ETL 任务并可选地运行后置校验。"""
# 管道定义:每个管道包含的层(从 scheduler.py 模块级常量迁移至此)
PIPELINE_LAYERS: dict[str, list[str]] = {
"api_ods": ["ODS"],
"api_ods_dwd": ["ODS", "DWD"],
"api_full": ["ODS", "DWD", "DWS", "INDEX"],
"ods_dwd": ["DWD"],
"dwd_dws": ["DWS"],
"dwd_dws_index": ["DWS", "INDEX"],
"dwd_index": ["INDEX"],
}
def __init__(
self,
config,
task_executor,
task_registry,
db_conn,
api_client,
logger: logging.Logger,
):
self.config = config
self.task_executor = task_executor
self.task_registry = task_registry
self.db_conn = db_conn
self.api_client = api_client
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
def run(
self,
pipeline: str,
processing_mode: str = "increment_only",
data_source: str = "hybrid",
window_start: datetime | None = None,
window_end: datetime | None = None,
window_split: str | None = None,
task_codes: list[str] | None = None,
fetch_before_verify: bool = False,
verify_tables: list[str] | None = None,
) -> dict[str, Any]:
"""执行管道,返回汇总结果。
Args:
pipeline: 管道类型 (api_ods, api_ods_dwd, api_full, ods_dwd, dwd_dws, dwd_dws_index, dwd_index)
processing_mode: 处理模式 (increment_only / verify_only / increment_verify)
data_source: 数据源模式 (online / offline / hybrid)
window_start: 时间窗口开始
window_end: 时间窗口结束
window_split: 时间窗口切分 (none / day / week / month)
task_codes: 要执行的任务代码列表(作为管道内的任务过滤器)
fetch_before_verify: 校验前是否先从 API 获取数据(仅在 verify_only 模式下有效)
verify_tables: 指定校验的表名列表(可用于单表验证)
Returns:
执行结果字典,包含 status / pipeline / layers / results / verification_summary
"""
from utils.task_logger import TaskLogger
if pipeline not in self.PIPELINE_LAYERS:
raise ValueError(f"无效的管道名称: {pipeline}")
run_uuid = uuid.uuid4().hex
pipeline_logger = TaskLogger(f"PIPELINE_{pipeline.upper()}", self.logger)
pipeline_logger.start(f"开始执行管道: {pipeline}")
layers = self.PIPELINE_LAYERS[pipeline]
results: list[dict[str, Any]] = []
verification_summary: dict[str, Any] | None = None
ods_dump_dirs: dict[str, str] = {}
use_local_json = bool(self.config.get("verification.ods_use_local_json", False))
# 设置默认时间窗口
if window_end is None:
window_end = datetime.now(self.tz)
if window_start is None:
window_start = window_end - timedelta(hours=24)
try:
if processing_mode == "verify_only":
# 仅校验模式
if fetch_before_verify:
self.logger.info("管道 %s: 校验模式(先获取 API 数据)", pipeline)
if task_codes:
ods_tasks = [t for t in task_codes if t.startswith("ODS_")]
if ods_tasks:
self.logger.info("从 API 获取数据: %s", ods_tasks)
results = self.task_executor.run_tasks(ods_tasks, data_source=data_source)
else:
auto_tasks = self._resolve_tasks(["ODS"])
if auto_tasks:
self.logger.info("从 API 获取数据: %s", auto_tasks)
results = self.task_executor.run_tasks(auto_tasks, data_source=data_source)
ods_dump_dirs = {
r.get("task_code"): r.get("dump_dir")
for r in results
if r.get("task_code") and r.get("dump_dir")
}
self.logger.info("API 数据获取完成,开始校验并修复")
else:
self.logger.info("管道 %s: 仅校验模式,跳过增量 ETL直接执行校验并修复", pipeline)
verification_summary = self._run_verification(
layers=layers,
window_start=window_start,
window_end=window_end,
window_split=window_split,
fetch_from_api=fetch_before_verify,
ods_dump_dirs=ods_dump_dirs,
use_local_json=use_local_json,
verify_tables=verify_tables,
)
pipeline_logger.set_verification_result(verification_summary)
else:
# 增量 ETLincrement_only 或 increment_verify
self.logger.info("管道 %s: 执行增量 ETL层=%s", pipeline, layers)
if task_codes:
results = self.task_executor.run_tasks(task_codes, data_source=data_source)
else:
auto_tasks = self._resolve_tasks(layers)
results = self.task_executor.run_tasks(auto_tasks, data_source=data_source)
# increment_verify 模式:增量后执行校验
if processing_mode == "increment_verify":
self.logger.info("管道 %s: 开始校验并修复", pipeline)
verification_summary = self._run_verification(
layers=layers,
window_start=window_start,
window_end=window_end,
window_split=window_split,
ods_dump_dirs=ods_dump_dirs,
use_local_json=use_local_json,
verify_tables=verify_tables,
)
pipeline_logger.set_verification_result(verification_summary)
# 汇总计数
pipeline_logger.set_counts(
fetched=sum(r.get("counts", {}).get("fetched", 0) for r in results),
inserted=sum(r.get("counts", {}).get("inserted", 0) for r in results),
updated=sum(r.get("counts", {}).get("updated", 0) for r in results),
errors=sum(r.get("counts", {}).get("errors", 0) for r in results),
)
summary_text = pipeline_logger.end(status="成功")
self.logger.info("\n%s", summary_text)
return {
"status": "SUCCESS",
"pipeline": pipeline,
"layers": layers,
"results": results,
"verification_summary": verification_summary,
}
except Exception as exc:
summary_text = pipeline_logger.end(status="失败", error_message=str(exc))
self.logger.error("\n%s", summary_text)
raise
def _resolve_tasks(self, layers: list[str]) -> list[str]:
"""根据层列表解析任务代码。
优先使用配置中的任务列表,回退到 task_registry.get_tasks_by_layer()。
DWD 层保持原有逻辑(默认 DWD_LOAD_FROM_ODS
"""
tasks: list[str] = []
for layer in layers:
layer_upper = layer.upper()
if layer_upper == "ODS":
ods_tasks = self.config.get("run.ods_tasks", [])
if ods_tasks:
tasks.extend(ods_tasks)
else:
registry_tasks = self.task_registry.get_tasks_by_layer("ODS")
if registry_tasks:
tasks.extend(registry_tasks)
else:
# 硬编码回退(与原 _get_tasks_for_layers 一致)
tasks.extend([
"ODS_MEMBER", "ODS_ASSISTANT", "ODS_TABLE",
"ODS_ORDER", "ODS_PAYMENT", "ODS_GOODS",
])
elif layer_upper == "DWD":
# DWD 层保持原有逻辑
tasks.append("DWD_LOAD_FROM_ODS")
elif layer_upper == "DWS":
dws_tasks = self.config.get("run.dws_tasks", [])
if dws_tasks:
tasks.extend(dws_tasks)
else:
registry_tasks = self.task_registry.get_tasks_by_layer("DWS")
if registry_tasks:
tasks.extend(registry_tasks)
else:
tasks.extend([
"DWS_BUILD_ORDER_SUMMARY",
"DWS_BUILD_MEMBER_SUMMARY",
])
elif layer_upper == "INDEX":
index_tasks = self.config.get("run.index_tasks", [])
if index_tasks:
tasks.extend(index_tasks)
else:
registry_tasks = self.task_registry.get_tasks_by_layer("INDEX")
if registry_tasks:
tasks.extend(registry_tasks)
else:
tasks.extend([
"DWS_WINBACK_INDEX",
"DWS_NEWCONV_INDEX",
"DWS_RELATION_INDEX",
])
return tasks
def _run_verification(
self,
layers: list[str],
window_start: datetime,
window_end: datetime,
window_split: str | None = None,
fetch_from_api: bool = False,
ods_dump_dirs: dict[str, str] | None = None,
use_local_json: bool = False,
verify_tables: list[str] | None = None,
) -> dict[str, Any]:
"""对指定层执行后置校验(从原 _run_layer_verification 迁移)。"""
try:
from tasks.verification import get_verifier_for_layer, build_window_segments
except ImportError:
self.logger.warning("校验框架未安装,跳过后置校验")
return {"status": "SKIPPED", "message": "校验框架未安装"}
total_tables = 0
consistent_tables = 0
total_backfilled = 0
total_error_tables = 0
layer_results: dict[str, Any] = {}
skip_ods_on_fetch = bool(self.config.get("verification.skip_ods_when_fetch_before_verify", True))
ods_dump_dirs = ods_dump_dirs or {}
segments = build_window_segments(window_start, window_end, window_split)
for layer in layers:
try:
if layer.upper() == "ODS" and fetch_from_api and skip_ods_on_fetch:
self.logger.info("ODS 层在 fetch_before_verify 下已完成入库,跳过二次校验")
layer_results[layer] = {
"status": "SKIPPED",
"reason": "fetch_before_verify",
}
continue
if layer.upper() == "ODS" and fetch_from_api:
if use_local_json:
if not ods_dump_dirs:
self.logger.warning("ODS 校验配置为使用本地 JSON但未找到 dump 目录,跳过 ODS 校验")
layer_results[layer] = {
"status": "SKIPPED",
"reason": "local_json_missing",
}
continue
verifier = get_verifier_for_layer(
layer,
self.db_conn,
self.logger,
api_client=self.api_client,
fetch_from_api=True,
local_dump_dirs=ods_dump_dirs,
use_local_json=True,
)
self.logger.info("ODS 层使用本地 JSON 校验(不请求 API")
else:
verifier = get_verifier_for_layer(
layer,
self.db_conn,
self.logger,
api_client=self.api_client,
fetch_from_api=True,
)
self.logger.info("ODS 层启用 API 数据校验")
else:
verifier_kwargs: dict[str, Any] = {}
if layer.upper() == "INDEX":
try:
lookback_days = int(self.config.get("run.index_lookback_days", 60))
except (TypeError, ValueError):
lookback_days = 60
verifier_kwargs = {
"lookback_days": lookback_days,
"config": self.config,
}
self.logger.info("INDEX 层校验使用回溯天数: %s", lookback_days)
if layer.upper() == "DWD":
verifier_kwargs["config"] = self.config
verifier = get_verifier_for_layer(
layer,
self.db_conn,
self.logger,
**verifier_kwargs,
)
# 使用 filter_verify_tables 替代原内联静态方法
layer_tables = filter_verify_tables(layer, verify_tables)
if verify_tables and not layer_tables:
self.logger.info("%s 无匹配表,跳过校验", layer)
layer_results[layer] = {
"status": "SKIPPED",
"reason": "table_filter",
}
continue
self.logger.info("开始校验层: %s,时间窗口: %s ~ %s", layer, window_start, window_end)
layer_summary = verifier.verify_and_backfill(
window_start=window_start,
window_end=window_end,
auto_backfill=True,
split_unit=window_split or "month",
tables=layer_tables,
)
layer_results[layer] = layer_summary.to_dict() if hasattr(layer_summary, 'to_dict') else {}
if hasattr(layer_summary, 'total_tables'):
total_tables += layer_summary.total_tables
consistent_tables += layer_summary.consistent_tables
total_backfilled += layer_summary.total_backfilled
total_error_tables += getattr(layer_summary, 'error_tables', 0)
self.logger.info(
"%s 校验完成: 表数=%d, 一致=%d, 错误=%d, 补齐=%d",
layer,
getattr(layer_summary, 'total_tables', 0),
getattr(layer_summary, 'consistent_tables', 0),
getattr(layer_summary, 'error_tables', 0),
getattr(layer_summary, 'total_backfilled', 0),
)
except Exception as exc:
self.logger.error("%s 校验失败: %s", layer, exc, exc_info=True)
layer_results[layer] = {"status": "ERROR", "error": str(exc)}
return {
"status": "COMPLETED",
"total_tables": total_tables,
"consistent_tables": consistent_tables,
"total_backfilled": total_backfilled,
"error_tables": total_error_tables,
"layers": layer_results,
}