初始提交:飞球 ETL 系统全量代码

This commit is contained in:
Neo
2026-02-13 08:05:34 +08:00
commit 3c51f5485d
441 changed files with 117631 additions and 0 deletions

View File

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""游标管理器"""
from datetime import datetime
class CursorManager:
"""ETL游标管理"""
def __init__(self, db_connection):
self.db = db_connection
def get_or_create(self, task_id: int, store_id: int) -> dict:
"""获取或创建游标"""
rows = self.db.query(
"SELECT * FROM etl_admin.etl_cursor WHERE task_id=%s AND store_id=%s",
(task_id, store_id)
)
if rows:
return rows[0]
# 创建新游标
self.db.execute(
"""
INSERT INTO etl_admin.etl_cursor(task_id, store_id, last_start, last_end, last_id, extra)
VALUES(%s, %s, NULL, NULL, NULL, '{}'::jsonb)
""",
(task_id, store_id)
)
self.db.commit()
rows = self.db.query(
"SELECT * FROM etl_admin.etl_cursor WHERE task_id=%s AND store_id=%s",
(task_id, store_id)
)
return rows[0] if rows else None
def advance(self, task_id: int, store_id: int, window_start: datetime,
window_end: datetime, run_id: int, last_id: int = None):
"""推进游标"""
if last_id is not None:
sql = """
UPDATE etl_admin.etl_cursor
SET last_start = %s,
last_end = %s,
last_id = GREATEST(COALESCE(last_id, 0), %s),
last_run_id = %s,
updated_at = now()
WHERE task_id = %s AND store_id = %s
"""
self.db.execute(sql, (window_start, window_end, last_id, run_id, task_id, store_id))
else:
sql = """
UPDATE etl_admin.etl_cursor
SET last_start = %s,
last_end = %s,
last_run_id = %s,
updated_at = now()
WHERE task_id = %s AND store_id = %s
"""
self.db.execute(sql, (window_start, window_end, run_id, task_id, store_id))
self.db.commit()

View File

@@ -0,0 +1,379 @@
# -*- coding: utf-8 -*-
"""管道运行器:管道定义、层→任务映射、校验编排。
从原 ETLScheduler 中提取管道编排逻辑,委托 TaskExecutor 执行具体任务。
所有依赖通过构造函数注入,不自行创建资源。
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from tasks.verification import filter_verify_tables
class PipelineRunner:
"""管道编排器:根据管道定义执行多层 ETL 任务并可选地运行后置校验。"""
# 管道定义:每个管道包含的层(从 scheduler.py 模块级常量迁移至此)
PIPELINE_LAYERS: dict[str, list[str]] = {
"api_ods": ["ODS"],
"api_ods_dwd": ["ODS", "DWD"],
"api_full": ["ODS", "DWD", "DWS", "INDEX"],
"ods_dwd": ["DWD"],
"dwd_dws": ["DWS"],
"dwd_dws_index": ["DWS", "INDEX"],
"dwd_index": ["INDEX"],
}
def __init__(
self,
config,
task_executor,
task_registry,
db_conn,
api_client,
logger: logging.Logger,
):
self.config = config
self.task_executor = task_executor
self.task_registry = task_registry
self.db_conn = db_conn
self.api_client = api_client
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
def run(
self,
pipeline: str,
processing_mode: str = "increment_only",
data_source: str = "hybrid",
window_start: datetime | None = None,
window_end: datetime | None = None,
window_split: str | None = None,
task_codes: list[str] | None = None,
fetch_before_verify: bool = False,
verify_tables: list[str] | None = None,
) -> dict[str, Any]:
"""执行管道,返回汇总结果。
Args:
pipeline: 管道类型 (api_ods, api_ods_dwd, api_full, ods_dwd, dwd_dws, dwd_dws_index, dwd_index)
processing_mode: 处理模式 (increment_only / verify_only / increment_verify)
data_source: 数据源模式 (online / offline / hybrid)
window_start: 时间窗口开始
window_end: 时间窗口结束
window_split: 时间窗口切分 (none / day / week / month)
task_codes: 要执行的任务代码列表(作为管道内的任务过滤器)
fetch_before_verify: 校验前是否先从 API 获取数据(仅在 verify_only 模式下有效)
verify_tables: 指定校验的表名列表(可用于单表验证)
Returns:
执行结果字典,包含 status / pipeline / layers / results / verification_summary
"""
from utils.task_logger import TaskLogger
if pipeline not in self.PIPELINE_LAYERS:
raise ValueError(f"无效的管道名称: {pipeline}")
run_uuid = uuid.uuid4().hex
pipeline_logger = TaskLogger(f"PIPELINE_{pipeline.upper()}", self.logger)
pipeline_logger.start(f"开始执行管道: {pipeline}")
layers = self.PIPELINE_LAYERS[pipeline]
results: list[dict[str, Any]] = []
verification_summary: dict[str, Any] | None = None
ods_dump_dirs: dict[str, str] = {}
use_local_json = bool(self.config.get("verification.ods_use_local_json", False))
# 设置默认时间窗口
if window_end is None:
window_end = datetime.now(self.tz)
if window_start is None:
window_start = window_end - timedelta(hours=24)
try:
if processing_mode == "verify_only":
# 仅校验模式
if fetch_before_verify:
self.logger.info("管道 %s: 校验模式(先获取 API 数据)", pipeline)
if task_codes:
ods_tasks = [t for t in task_codes if t.startswith("ODS_")]
if ods_tasks:
self.logger.info("从 API 获取数据: %s", ods_tasks)
results = self.task_executor.run_tasks(ods_tasks, data_source=data_source)
else:
auto_tasks = self._resolve_tasks(["ODS"])
if auto_tasks:
self.logger.info("从 API 获取数据: %s", auto_tasks)
results = self.task_executor.run_tasks(auto_tasks, data_source=data_source)
ods_dump_dirs = {
r.get("task_code"): r.get("dump_dir")
for r in results
if r.get("task_code") and r.get("dump_dir")
}
self.logger.info("API 数据获取完成,开始校验并修复")
else:
self.logger.info("管道 %s: 仅校验模式,跳过增量 ETL直接执行校验并修复", pipeline)
verification_summary = self._run_verification(
layers=layers,
window_start=window_start,
window_end=window_end,
window_split=window_split,
fetch_from_api=fetch_before_verify,
ods_dump_dirs=ods_dump_dirs,
use_local_json=use_local_json,
verify_tables=verify_tables,
)
pipeline_logger.set_verification_result(verification_summary)
else:
# 增量 ETLincrement_only 或 increment_verify
self.logger.info("管道 %s: 执行增量 ETL层=%s", pipeline, layers)
if task_codes:
results = self.task_executor.run_tasks(task_codes, data_source=data_source)
else:
auto_tasks = self._resolve_tasks(layers)
results = self.task_executor.run_tasks(auto_tasks, data_source=data_source)
# increment_verify 模式:增量后执行校验
if processing_mode == "increment_verify":
self.logger.info("管道 %s: 开始校验并修复", pipeline)
verification_summary = self._run_verification(
layers=layers,
window_start=window_start,
window_end=window_end,
window_split=window_split,
ods_dump_dirs=ods_dump_dirs,
use_local_json=use_local_json,
verify_tables=verify_tables,
)
pipeline_logger.set_verification_result(verification_summary)
# 汇总计数
pipeline_logger.set_counts(
fetched=sum(r.get("counts", {}).get("fetched", 0) for r in results),
inserted=sum(r.get("counts", {}).get("inserted", 0) for r in results),
updated=sum(r.get("counts", {}).get("updated", 0) for r in results),
errors=sum(r.get("counts", {}).get("errors", 0) for r in results),
)
summary_text = pipeline_logger.end(status="成功")
self.logger.info("\n%s", summary_text)
return {
"status": "SUCCESS",
"pipeline": pipeline,
"layers": layers,
"results": results,
"verification_summary": verification_summary,
}
except Exception as exc:
summary_text = pipeline_logger.end(status="失败", error_message=str(exc))
self.logger.error("\n%s", summary_text)
raise
def _resolve_tasks(self, layers: list[str]) -> list[str]:
"""根据层列表解析任务代码。
优先使用配置中的任务列表,回退到 task_registry.get_tasks_by_layer()。
DWD 层保持原有逻辑(默认 DWD_LOAD_FROM_ODS
"""
tasks: list[str] = []
for layer in layers:
layer_upper = layer.upper()
if layer_upper == "ODS":
ods_tasks = self.config.get("run.ods_tasks", [])
if ods_tasks:
tasks.extend(ods_tasks)
else:
registry_tasks = self.task_registry.get_tasks_by_layer("ODS")
if registry_tasks:
tasks.extend(registry_tasks)
else:
# 硬编码回退(与原 _get_tasks_for_layers 一致)
tasks.extend([
"ODS_MEMBER", "ODS_ASSISTANT", "ODS_TABLE",
"ODS_ORDER", "ODS_PAYMENT", "ODS_GOODS",
])
elif layer_upper == "DWD":
# DWD 层保持原有逻辑
tasks.append("DWD_LOAD_FROM_ODS")
elif layer_upper == "DWS":
dws_tasks = self.config.get("run.dws_tasks", [])
if dws_tasks:
tasks.extend(dws_tasks)
else:
registry_tasks = self.task_registry.get_tasks_by_layer("DWS")
if registry_tasks:
tasks.extend(registry_tasks)
else:
tasks.extend([
"DWS_BUILD_ORDER_SUMMARY",
"DWS_BUILD_MEMBER_SUMMARY",
])
elif layer_upper == "INDEX":
index_tasks = self.config.get("run.index_tasks", [])
if index_tasks:
tasks.extend(index_tasks)
else:
registry_tasks = self.task_registry.get_tasks_by_layer("INDEX")
if registry_tasks:
tasks.extend(registry_tasks)
else:
tasks.extend([
"DWS_WINBACK_INDEX",
"DWS_NEWCONV_INDEX",
"DWS_RELATION_INDEX",
])
return tasks
def _run_verification(
self,
layers: list[str],
window_start: datetime,
window_end: datetime,
window_split: str | None = None,
fetch_from_api: bool = False,
ods_dump_dirs: dict[str, str] | None = None,
use_local_json: bool = False,
verify_tables: list[str] | None = None,
) -> dict[str, Any]:
"""对指定层执行后置校验(从原 _run_layer_verification 迁移)。"""
try:
from tasks.verification import get_verifier_for_layer, build_window_segments
except ImportError:
self.logger.warning("校验框架未安装,跳过后置校验")
return {"status": "SKIPPED", "message": "校验框架未安装"}
total_tables = 0
consistent_tables = 0
total_backfilled = 0
total_error_tables = 0
layer_results: dict[str, Any] = {}
skip_ods_on_fetch = bool(self.config.get("verification.skip_ods_when_fetch_before_verify", True))
ods_dump_dirs = ods_dump_dirs or {}
segments = build_window_segments(window_start, window_end, window_split)
for layer in layers:
try:
if layer.upper() == "ODS" and fetch_from_api and skip_ods_on_fetch:
self.logger.info("ODS 层在 fetch_before_verify 下已完成入库,跳过二次校验")
layer_results[layer] = {
"status": "SKIPPED",
"reason": "fetch_before_verify",
}
continue
if layer.upper() == "ODS" and fetch_from_api:
if use_local_json:
if not ods_dump_dirs:
self.logger.warning("ODS 校验配置为使用本地 JSON但未找到 dump 目录,跳过 ODS 校验")
layer_results[layer] = {
"status": "SKIPPED",
"reason": "local_json_missing",
}
continue
verifier = get_verifier_for_layer(
layer,
self.db_conn,
self.logger,
api_client=self.api_client,
fetch_from_api=True,
local_dump_dirs=ods_dump_dirs,
use_local_json=True,
)
self.logger.info("ODS 层使用本地 JSON 校验(不请求 API")
else:
verifier = get_verifier_for_layer(
layer,
self.db_conn,
self.logger,
api_client=self.api_client,
fetch_from_api=True,
)
self.logger.info("ODS 层启用 API 数据校验")
else:
verifier_kwargs: dict[str, Any] = {}
if layer.upper() == "INDEX":
try:
lookback_days = int(self.config.get("run.index_lookback_days", 60))
except (TypeError, ValueError):
lookback_days = 60
verifier_kwargs = {
"lookback_days": lookback_days,
"config": self.config,
}
self.logger.info("INDEX 层校验使用回溯天数: %s", lookback_days)
if layer.upper() == "DWD":
verifier_kwargs["config"] = self.config
verifier = get_verifier_for_layer(
layer,
self.db_conn,
self.logger,
**verifier_kwargs,
)
# 使用 filter_verify_tables 替代原内联静态方法
layer_tables = filter_verify_tables(layer, verify_tables)
if verify_tables and not layer_tables:
self.logger.info("%s 无匹配表,跳过校验", layer)
layer_results[layer] = {
"status": "SKIPPED",
"reason": "table_filter",
}
continue
self.logger.info("开始校验层: %s,时间窗口: %s ~ %s", layer, window_start, window_end)
layer_summary = verifier.verify_and_backfill(
window_start=window_start,
window_end=window_end,
auto_backfill=True,
split_unit=window_split or "month",
tables=layer_tables,
)
layer_results[layer] = layer_summary.to_dict() if hasattr(layer_summary, 'to_dict') else {}
if hasattr(layer_summary, 'total_tables'):
total_tables += layer_summary.total_tables
consistent_tables += layer_summary.consistent_tables
total_backfilled += layer_summary.total_backfilled
total_error_tables += getattr(layer_summary, 'error_tables', 0)
self.logger.info(
"%s 校验完成: 表数=%d, 一致=%d, 错误=%d, 补齐=%d",
layer,
getattr(layer_summary, 'total_tables', 0),
getattr(layer_summary, 'consistent_tables', 0),
getattr(layer_summary, 'error_tables', 0),
getattr(layer_summary, 'total_backfilled', 0),
)
except Exception as exc:
self.logger.error("%s 校验失败: %s", layer, exc, exc_info=True)
layer_results[layer] = {"status": "ERROR", "error": str(exc)}
return {
"status": "COMPLETED",
"total_tables": total_tables,
"consistent_tables": consistent_tables,
"total_backfilled": total_backfilled,
"error_tables": total_error_tables,
"layers": layer_results,
}

View File

@@ -0,0 +1,144 @@
# -*- coding: utf-8 -*-
"""运行记录追踪器"""
import json
from datetime import datetime
class RunTracker:
"""ETL运行记录管理"""
def __init__(self, db_connection):
self.db = db_connection
def create_run(self, task_id: int, store_id: int, run_uuid: str,
export_dir: str, log_path: str, status: str,
window_start: datetime = None, window_end: datetime = None,
window_minutes: int = None, overlap_seconds: int = None,
request_params: dict = None) -> int:
"""创建运行记录"""
sql = """
INSERT INTO etl_admin.etl_run(
run_uuid, task_id, store_id, status, started_at, window_start, window_end,
window_minutes, overlap_seconds, fetched_count, loaded_count, updated_count,
skipped_count, error_count, unknown_fields, export_dir, log_path,
request_params, manifest, error_message, extra
) VALUES (
%s, %s, %s, %s, now(), %s, %s, %s, %s, 0, 0, 0, 0, 0, 0, %s, %s, %s,
'{}'::jsonb, NULL, '{}'::jsonb
)
RETURNING run_id
"""
result = self.db.query(
sql,
(run_uuid, task_id, store_id, status, window_start, window_end,
window_minutes, overlap_seconds, export_dir, log_path,
json.dumps(request_params or {}, ensure_ascii=False))
)
run_id = result[0]["run_id"]
self.db.commit()
return run_id
def update_run(
self,
run_id: int,
counts: dict,
status: str,
ended_at: datetime = None,
manifest: dict = None,
error_message: str = None,
window: dict | None = None,
request_params: dict | None = None,
overlap_seconds: int | None = None,
):
"""更新运行记录"""
sql = """
UPDATE etl_admin.etl_run
SET fetched_count = %s,
loaded_count = %s,
updated_count = %s,
skipped_count = %s,
error_count = %s,
unknown_fields = %s,
status = %s,
ended_at = %s,
manifest = %s,
error_message = %s,
window_start = COALESCE(%s, window_start),
window_end = COALESCE(%s, window_end),
window_minutes = COALESCE(%s, window_minutes),
overlap_seconds = COALESCE(%s, overlap_seconds),
request_params = CASE WHEN %s IS NULL THEN request_params ELSE %s::jsonb END
WHERE run_id = %s
"""
def _count(v, default: int = 0) -> int:
if v is None:
return default
if isinstance(v, bool):
return int(v)
if isinstance(v, int):
return int(v)
if isinstance(v, str):
try:
return int(v)
except Exception:
return default
if isinstance(v, (list, tuple, set, dict)):
try:
return len(v)
except Exception:
return default
return default
safe_counts = counts or {}
window_start = None
window_end = None
window_minutes = None
if isinstance(window, dict):
window_start = window.get("start") or window.get("window_start")
window_end = window.get("end") or window.get("window_end")
window_minutes = window.get("minutes") or window.get("window_minutes")
request_json = None if request_params is None else json.dumps(request_params or {}, ensure_ascii=False)
self.db.execute(
sql,
(
_count(safe_counts.get("fetched", 0)),
_count(safe_counts.get("inserted", 0)),
_count(safe_counts.get("updated", 0)),
_count(safe_counts.get("skipped", 0)),
_count(safe_counts.get("errors", 0)),
_count(safe_counts.get("unknown_fields", 0)),
status,
ended_at,
json.dumps(manifest or {}, ensure_ascii=False),
error_message,
window_start,
window_end,
window_minutes,
overlap_seconds,
request_json,
request_json,
run_id,
),
)
self.db.commit()
@staticmethod
def map_run_status(status: str) -> str:
"""
将任务返回的状态转换为 etl_admin.run_status_enum
(SUCC / FAIL / PARTIAL)
"""
normalized = (status or "").upper()
if normalized in {"SUCCESS", "SUCC"}:
return "SUCC"
if normalized in {"FAIL", "FAILED", "ERROR"}:
return "FAIL"
if normalized in {"RUNNING", "PARTIAL", "PENDING", "IN_PROGRESS"}:
return "PARTIAL"
# 未知状态默认标记为 FAIL便于排查
return "FAIL"

View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""ETL 调度器(薄包装层)
已弃用:请直接使用 TaskExecutor 和 PipelineRunner。
保留此类以兼容 GUI 层、run_update.py 等现有调用方。
"""
from __future__ import annotations
import logging
import warnings
from typing import Any, Dict, List, Optional
from api.client import APIClient
from database.connection import DatabaseConnection
from database.operations import DatabaseOperations
from orchestration.cursor_manager import CursorManager
from orchestration.run_tracker import RunTracker
from orchestration.task_registry import default_registry
from orchestration.task_executor import TaskExecutor
from orchestration.pipeline_runner import PipelineRunner
# 保留模块级常量以兼容外部引用
PIPELINE_LAYERS = PipelineRunner.PIPELINE_LAYERS
class ETLScheduler:
"""调度器薄包装层(已弃用)。
内部委托 TaskExecutor 和 PipelineRunner 执行。
保留公共接口以兼容现有调用方run_update.py、GUI 等)。
"""
def __init__(self, config, logger):
warnings.warn(
"ETLScheduler 已弃用,请直接使用 TaskExecutor 和 PipelineRunner",
DeprecationWarning,
stacklevel=2,
)
self.config = config
self.logger = logger
# 创建资源(与原实现一致)
self.db_conn = DatabaseConnection(
dsn=config["db"]["dsn"],
session=config["db"].get("session"),
connect_timeout=config["db"].get("connect_timeout_sec"),
)
self.db_ops = DatabaseOperations(self.db_conn)
self.api_client = APIClient(
base_url=config["api"]["base_url"],
token=config["api"]["token"],
timeout=config["api"]["timeout_sec"],
retry_max=config["api"]["retries"]["max_attempts"],
headers_extra=config["api"].get("headers_extra"),
)
cursor_mgr = CursorManager(self.db_conn)
run_tracker = RunTracker(self.db_conn)
self.task_registry = default_registry
# 内部组件
self.task_executor = TaskExecutor(
config, self.db_ops, self.api_client,
cursor_mgr, run_tracker, self.task_registry, logger,
)
self.pipeline_runner = PipelineRunner(
config, self.task_executor, self.task_registry,
self.db_conn, self.api_client, logger,
)
def run_tasks(self, task_codes=None) -> list:
"""执行任务列表(委托 TaskExecutor"""
if not task_codes:
task_codes = self.config.get("run.tasks", [])
data_source = str(self.config.get("run.data_source", "hybrid") or "hybrid")
return self.task_executor.run_tasks(task_codes, data_source=data_source)
def run_pipeline_with_verification(self, **kwargs) -> dict:
"""执行管道(委托 PipelineRunner"""
# 从配置读取 data_source如果调用方未传入
if "data_source" not in kwargs:
kwargs["data_source"] = str(
self.config.get("run.data_source", "hybrid") or "hybrid"
)
return self.pipeline_runner.run(**kwargs)
def close(self):
"""关闭数据库连接。"""
self.db_conn.close()

View File

@@ -0,0 +1,497 @@
# -*- coding: utf-8 -*-
"""任务执行器:封装单个 ETL 任务的完整执行生命周期。
从原 ETLScheduler 中提取的执行层,负责:
- 单任务执行(抓取/入库/ODS 录制+加载)
- 游标管理(成功后推进水位)
- 运行记录(创建/更新 etl_admin.etl_run
设计原则:
- data_source 作为显式参数传入,不依赖全局状态
- 工具类任务判断通过 TaskRegistry 元数据查询
- 所有依赖通过构造函数注入,不自行创建资源
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List
from zoneinfo import ZoneInfo
from api.recording_client import RecordingAPIClient
from api.local_json_client import LocalJsonClient
from orchestration.cursor_manager import CursorManager
from orchestration.run_tracker import RunTracker
from orchestration.task_registry import TaskRegistry
class DataSource(str, Enum):
"""数据源模式,取代原 pipeline.flow 全局状态。"""
ONLINE = "online" # 仅在线抓取(原 FETCH_ONLY
OFFLINE = "offline" # 仅本地入库(原 INGEST_ONLY
HYBRID = "hybrid" # 抓取 + 入库(原 FULL
class TaskExecutor:
"""任务执行器:封装单个 ETL 任务的完整执行生命周期。
通过构造函数注入所有依赖,不自行创建 DatabaseConnection 或 APIClient。
data_source 作为方法参数传入,替代原 self.pipeline_flow 全局状态。
"""
def __init__(
self,
config,
db_ops,
api_client,
cursor_mgr: CursorManager,
run_tracker: RunTracker,
task_registry: TaskRegistry,
logger: logging.Logger,
):
self.config = config
self.db_ops = db_ops
self.api_client = api_client
self.cursor_mgr = cursor_mgr
self.run_tracker = run_tracker
self.task_registry = task_registry
self.logger = logger
self.tz = ZoneInfo(config.get("app.timezone", "Asia/Shanghai"))
self.fetch_root = Path(
config.get("io.fetch_root")
or config.get("pipeline.fetch_root")
or config["io"]["export_root"]
)
self.ingest_source_dir = (
config.get("io.ingest_source_dir")
or config.get("pipeline.ingest_source_dir")
or ""
)
self.write_pretty_json = bool(config.get("io.write_pretty_json", False))
# ------------------------------------------------------------------ 公共接口
def run_tasks(
self,
task_codes: list[str],
data_source: str = "hybrid",
) -> list[dict[str, Any]]:
"""批量执行任务列表,返回每个任务的结果。"""
run_uuid = uuid.uuid4().hex
store_id = self.config.get("app.store_id")
results: list[dict[str, Any]] = []
file_handler = self._attach_run_file_logger(run_uuid)
try:
self.logger.info("开始运行任务: %s, run_uuid=%s", task_codes, run_uuid)
for task_code in task_codes:
try:
task_result = self.run_single_task(
task_code, run_uuid, store_id, data_source=data_source,
)
result_entry: dict[str, Any] = {
"task_code": task_code,
"status": "成功" if task_result else "完成",
"counts": task_result.get("counts", {}) if isinstance(task_result, dict) else {},
}
if isinstance(task_result, dict):
if task_result.get("dump_dir"):
result_entry["dump_dir"] = task_result["dump_dir"]
if task_result.get("last_dump"):
result_entry["last_dump"] = task_result["last_dump"]
results.append(result_entry)
except Exception as exc: # noqa: BLE001
self.logger.error("任务 %s 失败: %s", task_code, exc, exc_info=True)
results.append({
"task_code": task_code,
"status": "失败",
"error": str(exc),
"counts": {},
})
continue
self.logger.info("所有任务执行完成")
return results
finally:
if file_handler is not None:
try:
logging.getLogger().removeHandler(file_handler)
except Exception:
pass
try:
file_handler.close()
except Exception:
pass
def run_single_task(
self,
task_code: str,
run_uuid: str,
store_id: int,
data_source: str = "hybrid",
) -> dict[str, Any]:
"""执行单个任务的完整生命周期。
Args:
task_code: 任务代码
run_uuid: 本次运行的唯一标识
store_id: 门店 ID
data_source: 数据源模式online/offline/hybrid
"""
task_code_upper = task_code.upper()
# 工具类任务:通过 TaskRegistry 元数据判断,跳过游标和运行记录
if self.task_registry.is_utility_task(task_code_upper):
return self._run_utility_task(task_code_upper, store_id)
task_cfg = self._load_task_config(task_code, store_id)
if not task_cfg:
self.logger.warning("任务 %s 未启用或不存在", task_code)
return {"status": "SKIP", "counts": {}}
task_id = task_cfg["task_id"]
cursor_data = self.cursor_mgr.get_or_create(task_id, store_id)
# 创建运行记录
export_dir = Path(self.config["io"]["export_root"]) / datetime.now(self.tz).strftime("%Y%m%d")
log_path = str(Path(self.config["io"]["log_root"]) / f"{run_uuid}.log")
run_id = self.run_tracker.create_run(
task_id=task_id,
store_id=store_id,
run_uuid=run_uuid,
export_dir=str(export_dir),
log_path=log_path,
status=RunTracker.map_run_status("RUNNING"),
)
fetch_dir = self._build_fetch_dir(task_code, run_id)
fetch_stats = None
try:
# ODS 任务ODS_JSON_ARCHIVE 除外)走特殊路径
if self._is_ods_task(task_code):
if self._flow_includes_fetch(data_source):
result, last_dump = self._execute_ods_record_and_load(
task_code, cursor_data, fetch_dir, run_id,
)
if isinstance(result, dict):
result.setdefault("dump_dir", str(fetch_dir))
if last_dump:
result.setdefault("last_dump", last_dump)
else:
source_dir = self._resolve_ingest_source(fetch_dir, None)
result = self._execute_ingest(task_code, cursor_data, source_dir)
self.run_tracker.update_run(
run_id=run_id,
counts=result.get("counts") or {},
status=RunTracker.map_run_status(result.get("status")),
ended_at=datetime.now(self.tz),
window=result.get("window"),
request_params=result.get("request_params"),
overlap_seconds=self.config.get("run.overlap_seconds"),
)
if (result.get("status") or "").upper() == "SUCCESS":
window = result.get("window")
if isinstance(window, dict):
self.cursor_mgr.advance(
task_id=task_id,
store_id=store_id,
window_start=window.get("start"),
window_end=window.get("end"),
run_id=run_id,
)
self._maybe_run_integrity_check(task_code, window)
return result
# 非 ODS 任务:按 data_source 决定抓取/入库阶段
if self._flow_includes_fetch(data_source):
fetch_stats = self._execute_fetch(task_code, cursor_data, fetch_dir, run_id)
if data_source == DataSource.ONLINE or data_source == "online":
counts = self._counts_from_fetch(fetch_stats)
self.run_tracker.update_run(
run_id=run_id,
counts=counts,
status=RunTracker.map_run_status("SUCCESS"),
ended_at=datetime.now(self.tz),
)
return {"status": "SUCCESS", "counts": counts}
if self._flow_includes_ingest(data_source):
source_dir = self._resolve_ingest_source(fetch_dir, fetch_stats)
result = self._execute_ingest(task_code, cursor_data, source_dir)
self.run_tracker.update_run(
run_id=run_id,
counts=result["counts"],
status=RunTracker.map_run_status(result["status"]),
ended_at=datetime.now(self.tz),
window=result.get("window"),
request_params=result.get("request_params"),
overlap_seconds=self.config.get("run.overlap_seconds"),
)
if (result.get("status") or "").upper() == "SUCCESS":
window = result.get("window")
if window:
self.cursor_mgr.advance(
task_id=task_id,
store_id=store_id,
window_start=window.get("start"),
window_end=window.get("end"),
run_id=run_id,
)
self._maybe_run_integrity_check(task_code, window)
return result
except Exception as exc:
self.run_tracker.update_run(
run_id=run_id,
counts={},
status=RunTracker.map_run_status("FAIL"),
ended_at=datetime.now(self.tz),
error_message=str(exc),
)
raise
return {"status": "COMPLETE", "counts": {}}
# ------------------------------------------------------------------ 内部方法
def _execute_fetch(
self,
task_code: str,
cursor_data: dict | None,
fetch_dir: Path,
run_id: int,
):
"""在线抓取阶段:用 RecordingAPIClient 拉取并落盘,不做 Transform/Load。"""
recording_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, recording_client, self.logger,
)
context = task._build_context(cursor_data) # type: ignore[attr-defined]
self.logger.info("%s: 抓取阶段开始,目录=%s", task_code, fetch_dir)
extracted = task.extract(context)
stats = recording_client.last_dump or {}
extracted_count = 0
if isinstance(extracted, dict):
extracted_count = int(extracted.get("fetched") or 0) or len(extracted.get("records", []))
fetched_count = stats.get("records") or extracted_count or 0
self.logger.info(
"%s: 抓取完成,文件=%s,记录数=%s",
task_code,
stats.get("file"),
fetched_count,
)
return {"file": stats.get("file"), "records": fetched_count, "pages": stats.get("pages")}
@staticmethod
def _is_ods_task(task_code: str) -> bool:
"""判断是否为 ODS 任务ODS_JSON_ARCHIVE 除外)。"""
tc = str(task_code or "").upper()
return tc.startswith("ODS_") and tc != "ODS_JSON_ARCHIVE"
def _execute_ods_record_and_load(
self,
task_code: str,
cursor_data: dict | None,
fetch_dir: Path,
run_id: int,
) -> tuple[dict, dict]:
"""ODS 任务:在线抓取 + 直接入库ODS 任务在 execute() 内完成 DB upsert"""
recording_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, recording_client, self.logger,
)
self.logger.info("%s: ODS fetch+load start, dir=%s", task_code, fetch_dir)
result = task.execute(cursor_data)
return result, (recording_client.last_dump or {})
def _execute_ingest(
self,
task_code: str,
cursor_data: dict | None,
source_dir: Path,
):
"""本地清洗入库:使用 LocalJsonClient 回放 JSON走原有任务 ETL。"""
local_client = LocalJsonClient(source_dir)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, local_client, self.logger,
)
self.logger.info("%s: 本地清洗入库开始,源目录=%s", task_code, source_dir)
return task.execute(cursor_data)
def _build_fetch_dir(self, task_code: str, run_id: int) -> Path:
"""构建抓取输出目录路径。"""
ts = datetime.now(self.tz).strftime("%Y%m%d-%H%M%S")
task_code = str(task_code or "").upper()
return Path(self.fetch_root) / task_code / f"{task_code}-{run_id}-{ts}"
def _resolve_ingest_source(self, fetch_dir: Path, fetch_stats: dict | None) -> Path:
"""确定本地清洗入库的 JSON 源目录。"""
if fetch_stats and fetch_dir.exists():
return fetch_dir
if self.ingest_source_dir:
return Path(self.ingest_source_dir)
raise FileNotFoundError("未提供本地清洗入库所需的 JSON 目录")
def _counts_from_fetch(self, stats: dict | None) -> dict:
"""从抓取统计中构建计数字典。"""
fetched = (stats or {}).get("records") or 0
return {
"fetched": fetched,
"inserted": 0,
"updated": 0,
"skipped": 0,
"errors": 0,
}
@staticmethod
def _flow_includes_fetch(data_source: str) -> bool:
"""判断当前 data_source 是否包含抓取阶段。"""
ds = str(data_source).lower()
return ds in {"online", "hybrid"}
@staticmethod
def _flow_includes_ingest(data_source: str) -> bool:
"""判断当前 data_source 是否包含入库阶段。"""
ds = str(data_source).lower()
return ds in {"offline", "hybrid"}
def _run_utility_task(self, task_code: str, store_id: int) -> Dict[str, Any]:
"""执行工具类任务(不记录 cursor/run直接执行"""
self.logger.info("%s: 开始执行工具类任务", task_code)
try:
api_client = None
if task_code == "ODS_JSON_ARCHIVE":
run_id = int(datetime.now(self.tz).timestamp())
fetch_dir = self._build_fetch_dir(task_code, run_id)
api_client = RecordingAPIClient(
base_client=self.api_client,
output_dir=fetch_dir,
task_code=task_code,
run_id=run_id,
write_pretty=self.write_pretty_json,
)
task = self.task_registry.create_task(
task_code, self.config, self.db_ops, api_client, self.logger,
)
result = task.execute(None)
status = (result.get("status") or "").upper() if isinstance(result, dict) else "SUCCESS"
counts = result.get("counts", {}) if isinstance(result, dict) else {}
if status == "SUCCESS":
self.logger.info("%s: 工具类任务执行成功", task_code)
if counts:
self.logger.info("%s: 结果统计: %s", task_code, counts)
else:
self.logger.warning("%s: 工具类任务执行结果: %s", task_code, status)
return {"status": status, "counts": counts}
except Exception as exc:
self.logger.error("%s: 工具类任务执行失败: %s", task_code, exc, exc_info=True)
raise
def _load_task_config(self, task_code: str, store_id: int) -> dict | None:
"""从数据库加载任务配置。"""
sql = """
SELECT task_id, task_code, store_id, enabled, cursor_field,
window_minutes_default, overlap_seconds, page_size, retry_max, params
FROM etl_admin.etl_task
WHERE store_id = %s AND task_code = %s AND enabled = TRUE
"""
rows = self.db_ops.query(sql, (store_id, task_code))
return rows[0] if rows else None
def _maybe_run_integrity_check(self, task_code: str, window: dict | None) -> None:
"""在 DWD_LOAD_FROM_ODS 成功后可选执行完整性校验。"""
if not self.config.get("integrity.auto_check", False):
return
if str(task_code or "").upper() != "DWD_LOAD_FROM_ODS":
return
if not isinstance(window, dict):
return
window_start = window.get("start")
window_end = window.get("end")
if not window_start or not window_end:
return
try:
from quality.integrity_checker import IntegrityWindow, run_integrity_window
include_dimensions = bool(self.config.get("integrity.include_dimensions", False))
task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip()
report = run_integrity_window(
cfg=self.config,
window=IntegrityWindow(
start=window_start,
end=window_end,
label="etl_window",
granularity="window",
),
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
write_report=True,
)
self.logger.info(
"Integrity check done: report=%s missing=%s errors=%s",
report.get("report_path"),
report.get("api_to_ods", {}).get("total_missing"),
report.get("api_to_ods", {}).get("total_errors"),
)
except Exception as exc: # noqa: BLE001
self.logger.warning("Integrity check failed: %s", exc, exc_info=True)
def _attach_run_file_logger(self, run_uuid: str) -> logging.Handler | None:
"""为本次 run_uuid 动态挂载文件日志处理器。"""
log_root = Path(self.config["io"]["log_root"])
try:
log_root.mkdir(parents=True, exist_ok=True)
except Exception as exc: # noqa: BLE001
self.logger.warning("创建日志目录失败:%s%s", log_root, exc)
return None
log_path = log_root / f"{run_uuid}.log"
try:
handler: logging.Handler = logging.FileHandler(log_path, encoding="utf-8")
except Exception as exc: # noqa: BLE001
self.logger.warning("创建文件日志失败:%s%s", log_path, exc)
return None
fmt = logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(fmt)
handler.setLevel(logging.INFO)
root_logger = logging.getLogger()
root_logger.addHandler(handler)
return handler

View File

@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
"""任务注册表"""
from dataclasses import dataclass
# ODS 层任务
from tasks.ods.orders_task import OrdersTask
from tasks.ods.payments_task import PaymentsTask
from tasks.ods.members_task import MembersTask
from tasks.ods.products_task import ProductsTask
from tasks.ods.tables_task import TablesTask
from tasks.ods.assistants_task import AssistantsTask
from tasks.ods.packages_task import PackagesDefTask
from tasks.ods.refunds_task import RefundsTask
from tasks.ods.coupon_usage_task import CouponUsageTask
from tasks.ods.inventory_change_task import InventoryChangeTask
from tasks.ods.topups_task import TopupsTask
from tasks.ods.table_discount_task import TableDiscountTask
from tasks.ods.assistant_abolish_task import AssistantAbolishTask
from tasks.ods.ledger_task import LedgerTask
from tasks.ods.ods_tasks import ODS_TASK_CLASSES
from tasks.ods.ods_json_archive_task import OdsJsonArchiveTask
# DWD 层任务
from tasks.dwd.payments_dwd_task import PaymentsDwdTask
from tasks.dwd.members_dwd_task import MembersDwdTask
from tasks.dwd.dwd_load_task import DwdLoadTask
from tasks.dwd.ticket_dwd_task import TicketDwdTask
from tasks.dwd.dwd_quality_task import DwdQualityTask
# 工具类任务
from tasks.utility.manual_ingest_task import ManualIngestTask
from tasks.utility.init_schema_task import InitOdsSchemaTask
from tasks.utility.init_dwd_schema_task import InitDwdSchemaTask
from tasks.utility.init_dws_schema_task import InitDwsSchemaTask
from tasks.utility.check_cutoff_task import CheckCutoffTask
from tasks.utility.dws_build_order_summary_task import DwsBuildOrderSummaryTask
from tasks.utility.data_integrity_task import DataIntegrityTask
from tasks.utility.seed_dws_config_task import SeedDwsConfigTask
# DWS 层任务导入
from tasks.dws import (
AssistantDailyTask,
AssistantMonthlyTask,
AssistantCustomerTask,
AssistantSalaryTask,
AssistantFinanceTask,
MemberConsumptionTask,
MemberVisitTask,
FinanceDailyTask,
FinanceRechargeTask,
FinanceIncomeStructureTask,
FinanceDiscountDetailTask,
DwsRetentionCleanupTask,
DwsMvRefreshFinanceDailyTask,
DwsMvRefreshAssistantDailyTask,
# 指数算法任务
RecallIndexTask,
IntimacyIndexTask,
WinbackIndexTask,
NewconvIndexTask,
MlManualImportTask,
RelationIndexTask,
)
@dataclass
class TaskMeta:
"""任务元数据"""
task_class: type
requires_db_config: bool = True
layer: str | None = None # "ODS" / "DWD" / "DWS" / "INDEX" / None
task_type: str = "etl" # "etl" / "utility" / "verification"
class TaskRegistry:
"""任务注册和工厂"""
def __init__(self):
self._tasks: dict[str, TaskMeta] = {}
def register(
self,
task_code: str,
task_class: type,
requires_db_config: bool = True,
layer: str | None = None,
task_type: str = "etl",
):
"""注册任务类及其元数据。向后兼容:仅传 task_code 和 task_class 时使用默认值。"""
self._tasks[task_code.upper()] = TaskMeta(
task_class=task_class,
requires_db_config=requires_db_config,
layer=layer,
task_type=task_type,
)
def create_task(self, task_code: str, config, db_connection, api_client, logger):
"""创建任务实例"""
task_code = task_code.upper()
if task_code not in self._tasks:
raise ValueError(f"未知的任务类型: {task_code}")
task_class = self._tasks[task_code].task_class
return task_class(config, db_connection, api_client, logger)
def get_metadata(self, task_code: str) -> TaskMeta | None:
"""查询任务元数据。"""
return self._tasks.get(task_code.upper())
def get_tasks_by_layer(self, layer: str) -> list[str]:
"""获取指定层的所有任务代码。"""
return [
code for code, meta in self._tasks.items()
if meta.layer and meta.layer.upper() == layer.upper()
]
def is_utility_task(self, task_code: str) -> bool:
"""判断是否为工具类任务(不需要游标/运行记录)。"""
meta = self.get_metadata(task_code)
return meta is not None and not meta.requires_db_config
def get_all_task_codes(self) -> list[str]:
"""获取所有已注册的任务代码"""
return list(self._tasks.keys())
# 默认注册表
default_registry = TaskRegistry()
# ── ODS 层:基础抓取任务 ──────────────────────────────────────
default_registry.register("PRODUCTS", ProductsTask, layer="ODS")
default_registry.register("TABLES", TablesTask, layer="ODS")
default_registry.register("MEMBERS", MembersTask, layer="ODS")
default_registry.register("ASSISTANTS", AssistantsTask, layer="ODS")
default_registry.register("PACKAGES_DEF", PackagesDefTask, layer="ODS")
default_registry.register("ORDERS", OrdersTask, layer="ODS")
default_registry.register("PAYMENTS", PaymentsTask, layer="ODS")
default_registry.register("REFUNDS", RefundsTask, layer="ODS")
default_registry.register("COUPON_USAGE", CouponUsageTask, layer="ODS")
default_registry.register("INVENTORY_CHANGE", InventoryChangeTask, layer="ODS")
default_registry.register("TOPUPS", TopupsTask, layer="ODS")
default_registry.register("TABLE_DISCOUNT", TableDiscountTask, layer="ODS")
default_registry.register("ASSISTANT_ABOLISH", AssistantAbolishTask, layer="ODS")
default_registry.register("LEDGER", LedgerTask, layer="ODS")
# ── DWD 层任务 ────────────────────────────────────────────────
default_registry.register("TICKET_DWD", TicketDwdTask, layer="DWD")
default_registry.register("PAYMENTS_DWD", PaymentsDwdTask, layer="DWD")
default_registry.register("MEMBERS_DWD", MembersDwdTask, layer="DWD")
default_registry.register("DWD_LOAD_FROM_ODS", DwdLoadTask, layer="DWD")
default_registry.register("DWD_QUALITY_CHECK", DwdQualityTask, requires_db_config=False, layer="DWD", task_type="verification")
# ── 工具类任务 ────────────────────────────────────────────────
default_registry.register("MANUAL_INGEST", ManualIngestTask, requires_db_config=False, task_type="utility")
default_registry.register("INIT_ODS_SCHEMA", InitOdsSchemaTask, requires_db_config=False, task_type="utility")
default_registry.register("INIT_DWD_SCHEMA", InitDwdSchemaTask, requires_db_config=False, task_type="utility")
default_registry.register("INIT_DWS_SCHEMA", InitDwsSchemaTask, requires_db_config=False, task_type="utility")
default_registry.register("ODS_JSON_ARCHIVE", OdsJsonArchiveTask, requires_db_config=False, task_type="utility")
default_registry.register("CHECK_CUTOFF", CheckCutoffTask, requires_db_config=False, task_type="utility")
default_registry.register("SEED_DWS_CONFIG", SeedDwsConfigTask, task_type="utility")
# ── 校验类任务 ────────────────────────────────────────────────
default_registry.register("DATA_INTEGRITY_CHECK", DataIntegrityTask, requires_db_config=False, task_type="verification")
# ── DWS 层业务任务 ────────────────────────────────────────────
default_registry.register("DWS_BUILD_ORDER_SUMMARY", DwsBuildOrderSummaryTask, requires_db_config=False, layer="DWS")
default_registry.register("DWS_ASSISTANT_DAILY", AssistantDailyTask, layer="DWS")
default_registry.register("DWS_ASSISTANT_MONTHLY", AssistantMonthlyTask, layer="DWS")
default_registry.register("DWS_ASSISTANT_CUSTOMER", AssistantCustomerTask, layer="DWS")
default_registry.register("DWS_ASSISTANT_SALARY", AssistantSalaryTask, layer="DWS")
default_registry.register("DWS_ASSISTANT_FINANCE", AssistantFinanceTask, layer="DWS")
default_registry.register("DWS_MEMBER_CONSUMPTION", MemberConsumptionTask, layer="DWS")
default_registry.register("DWS_MEMBER_VISIT", MemberVisitTask, layer="DWS")
default_registry.register("DWS_FINANCE_DAILY", FinanceDailyTask, layer="DWS")
default_registry.register("DWS_FINANCE_RECHARGE", FinanceRechargeTask, layer="DWS")
default_registry.register("DWS_FINANCE_INCOME_STRUCTURE", FinanceIncomeStructureTask, layer="DWS")
default_registry.register("DWS_FINANCE_DISCOUNT_DETAIL", FinanceDiscountDetailTask, layer="DWS")
default_registry.register("DWS_RETENTION_CLEANUP", DwsRetentionCleanupTask, layer="DWS")
default_registry.register("DWS_MV_REFRESH_FINANCE_DAILY", DwsMvRefreshFinanceDailyTask, layer="DWS")
default_registry.register("DWS_MV_REFRESH_ASSISTANT_DAILY", DwsMvRefreshAssistantDailyTask, layer="DWS")
# ── INDEX 层:指数算法任务 ────────────────────────────────────
default_registry.register("DWS_RECALL_INDEX", RecallIndexTask, layer="INDEX")
default_registry.register("DWS_WINBACK_INDEX", WinbackIndexTask, requires_db_config=False, layer="INDEX")
default_registry.register("DWS_NEWCONV_INDEX", NewconvIndexTask, requires_db_config=False, layer="INDEX")
default_registry.register("DWS_INTIMACY_INDEX", IntimacyIndexTask, requires_db_config=False, layer="INDEX")
default_registry.register("DWS_ML_MANUAL_IMPORT", MlManualImportTask, requires_db_config=False, layer="INDEX")
default_registry.register("DWS_RELATION_INDEX", RelationIndexTask, requires_db_config=False, layer="INDEX")
# ── ODS 层:通用 ODS 任务(由 ODS_TASK_CLASSES 动态生成)─────
for code, task_cls in ODS_TASK_CLASSES.items():
default_registry.register(code, task_cls, layer="ODS")