# -*- coding: utf-8 -*- """统一管道配置数据类。 支持全局默认值 + 任务级覆盖的三级回退: pipeline..* → pipeline.* → 硬编码默认值 """ from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from .settings import AppConfig @dataclass(frozen=True) class PipelineConfig: """统一管道配置,支持全局默认 + 任务级覆盖。""" workers: int = 2 # ProcessingPool 工作线程数 queue_size: int = 100 # 处理队列容量 batch_size: int = 100 # WriteWorker 批量写入阈值 batch_timeout: float = 5.0 # WriteWorker 等待超时(秒) rate_min: float = 0.1 # RateLimiter 最小间隔(秒) rate_max: float = 2.0 # RateLimiter 最大间隔(秒) max_consecutive_failures: int = 10 # 连续失败中断阈值 def __post_init__(self) -> None: if self.workers < 1: raise ValueError(f"workers 必须 >= 1,当前值: {self.workers}") if self.queue_size < 1: raise ValueError(f"queue_size 必须 >= 1,当前值: {self.queue_size}") if self.batch_size < 1: raise ValueError(f"batch_size 必须 >= 1,当前值: {self.batch_size}") if self.rate_min > self.rate_max: raise ValueError( f"rate_min({self.rate_min}) 不能大于 rate_max({self.rate_max})" ) @classmethod def from_app_config( cls, config: AppConfig, task_code: str | None = None, ) -> PipelineConfig: """从 AppConfig 加载,支持 pipeline..* 任务级覆盖。 回退优先级: 1. pipeline.. (任务级,仅 task_code 非空时查找) 2. pipeline. (全局级) 3. 字段硬编码默认值 """ def _get(key: str, default): # noqa: ANN001 # 任务级覆盖 if task_code: val = config.get(f"pipeline.{task_code.lower()}.{key}") if val is not None: return type(default)(val) # 全局级 val = config.get(f"pipeline.{key}") if val is not None: return type(default)(val) # 硬编码默认值 return default return cls( workers=_get("workers", 2), queue_size=_get("queue_size", 100), batch_size=_get("batch_size", 100), batch_timeout=_get("batch_timeout", 5.0), rate_min=_get("rate_min", 5.0), rate_max=_get("rate_max", 20.0), max_consecutive_failures=_get("max_consecutive_failures", 10), )