128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""配置管理主类"""
|
||
import warnings
|
||
from copy import deepcopy
|
||
from .defaults import DEFAULTS
|
||
from .env_parser import load_env_overrides
|
||
|
||
# pipeline.flow → run.data_source 值映射
|
||
_FLOW_TO_DATA_SOURCE = {
|
||
"FULL": "hybrid",
|
||
"FETCH_ONLY": "online",
|
||
"INGEST_ONLY": "offline",
|
||
}
|
||
|
||
class AppConfig:
|
||
"""应用配置管理器"""
|
||
|
||
def __init__(self, config_dict: dict):
|
||
self.config = config_dict
|
||
|
||
@classmethod
|
||
def load(cls, cli_overrides: dict = None):
|
||
"""加载配置: DEFAULTS < ENV < CLI"""
|
||
cfg = load_env_overrides(DEFAULTS)
|
||
|
||
if cli_overrides:
|
||
cls._deep_merge(cfg, cli_overrides)
|
||
|
||
# 规范化
|
||
cls._normalize(cfg)
|
||
cls._validate(cfg)
|
||
|
||
return cls(cfg)
|
||
|
||
@staticmethod
|
||
def _deep_merge(dst, src):
|
||
"""深度合并字典"""
|
||
for k, v in src.items():
|
||
if isinstance(v, dict) and isinstance(dst.get(k), dict):
|
||
AppConfig._deep_merge(dst[k], v)
|
||
else:
|
||
dst[k] = v
|
||
|
||
@staticmethod
|
||
def _normalize(cfg):
|
||
"""规范化配置"""
|
||
# 转换 store_id 为整数
|
||
try:
|
||
cfg["app"]["store_id"] = int(str(cfg["app"]["store_id"]).strip())
|
||
except Exception:
|
||
raise SystemExit("app.store_id 必须为整数")
|
||
|
||
# DSN 组装
|
||
if not cfg["db"]["dsn"]:
|
||
cfg["db"]["dsn"] = (
|
||
f"postgresql://{cfg['db']['user']}:{cfg['db']['password']}"
|
||
f"@{cfg['db']['host']}:{cfg['db']['port']}/{cfg['db']['name']}"
|
||
)
|
||
|
||
# connect_timeout 限定 1-20 秒
|
||
try:
|
||
timeout_sec = int(cfg["db"].get("connect_timeout_sec") or 5)
|
||
except Exception:
|
||
raise SystemExit("db.connect_timeout_sec 必须为整数")
|
||
cfg["db"]["connect_timeout_sec"] = max(1, min(timeout_sec, 20))
|
||
|
||
# 会话参数
|
||
cfg["db"].setdefault("session", {})
|
||
sess = cfg["db"]["session"]
|
||
sess.setdefault("timezone", cfg["app"]["timezone"])
|
||
|
||
for k in ("statement_timeout_ms", "lock_timeout_ms", "idle_in_tx_timeout_ms"):
|
||
if k in sess and sess[k] is not None:
|
||
try:
|
||
sess[k] = int(sess[k])
|
||
except Exception:
|
||
raise SystemExit(f"db.session.{k} 需为整数毫秒")
|
||
|
||
# ── 旧键 → 新键 兼容映射 ──
|
||
pipeline = cfg.get("pipeline", {})
|
||
run = cfg.setdefault("run", {})
|
||
io = cfg.setdefault("io", {})
|
||
|
||
# 1. pipeline.flow → run.data_source
|
||
# 仅当新键未被显式设置(缺失或仍为默认值 hybrid)时,才用旧键覆盖
|
||
old_flow = str(pipeline.get("flow", "")).upper()
|
||
if old_flow in _FLOW_TO_DATA_SOURCE:
|
||
mapped = _FLOW_TO_DATA_SOURCE[old_flow]
|
||
if run.get("data_source", "hybrid") == "hybrid" and mapped != "hybrid":
|
||
run["data_source"] = mapped
|
||
warnings.warn(
|
||
f"配置键 pipeline.flow={old_flow} 已弃用,"
|
||
f"已映射为 run.data_source={mapped}",
|
||
DeprecationWarning,
|
||
stacklevel=2,
|
||
)
|
||
|
||
# 2. pipeline.fetch_root → io.fetch_root(新键优先)
|
||
if pipeline.get("fetch_root") and not io.get("fetch_root"):
|
||
io["fetch_root"] = pipeline["fetch_root"]
|
||
|
||
# 3. pipeline.ingest_source_dir → io.ingest_source_dir(新键优先)
|
||
if pipeline.get("ingest_source_dir") and not io.get("ingest_source_dir"):
|
||
io["ingest_source_dir"] = pipeline["ingest_source_dir"]
|
||
|
||
@staticmethod
|
||
def _validate(cfg):
|
||
"""验证必填配置"""
|
||
missing = []
|
||
if not cfg["app"]["store_id"]:
|
||
missing.append("app.store_id")
|
||
if missing:
|
||
raise SystemExit("缺少必需配置: " + ", ".join(missing))
|
||
|
||
def get(self, key: str, default=None):
|
||
"""获取配置值(支持点号路径)"""
|
||
keys = key.split(".")
|
||
val = self.config
|
||
for k in keys:
|
||
if isinstance(val, dict):
|
||
val = val.get(k)
|
||
else:
|
||
return default
|
||
return val if val is not None else default
|
||
|
||
def __getitem__(self, key):
|
||
return self.config[key]
|