# -*- coding: utf-8 -*- """配置辅助工具""" import os import re from pathlib import Path from typing import Dict, List, Tuple, Optional, Any # 环境变量分组 ENV_GROUPS = { "database": { "title": "数据库配置", "keys": ["PG_DSN", "PG_HOST", "PG_PORT", "PG_NAME", "PG_USER", "PG_PASSWORD", "PG_CONNECT_TIMEOUT"], "sensitive": ["PG_PASSWORD"], }, "api": { "title": "API 配置", "keys": ["API_BASE", "API_TOKEN", "FICOO_TOKEN", "API_TIMEOUT", "API_PAGE_SIZE", "API_RETRY_MAX"], "sensitive": ["API_TOKEN", "FICOO_TOKEN"], }, "store": { "title": "门店配置", "keys": ["STORE_ID", "TIMEZONE", "SCHEMA_OLTP", "SCHEMA_ETL"], "sensitive": [], }, "paths": { "title": "路径配置", "keys": ["EXPORT_ROOT", "LOG_ROOT", "FETCH_ROOT", "INGEST_SOURCE_DIR", "JSON_FETCH_ROOT", "JSON_SOURCE_DIR"], "sensitive": [], }, "pipeline": { "title": "流水线配置", "keys": ["PIPELINE_FLOW", "RUN_TASKS", "OVERLAP_SECONDS"], "sensitive": [], }, "window": { "title": "时间窗口配置", "keys": ["WINDOW_START", "WINDOW_END", "WINDOW_BUSY_MIN", "WINDOW_IDLE_MIN", "IDLE_START", "IDLE_END"], "sensitive": [], }, "integrity": { "title": "数据完整性配置", "keys": ["INTEGRITY_MODE", "INTEGRITY_HISTORY_START", "INTEGRITY_HISTORY_END", "INTEGRITY_INCLUDE_DIMENSIONS", "INTEGRITY_AUTO_CHECK", "INTEGRITY_ODS_TASK_CODES"], "sensitive": [], }, } class ConfigHelper: """配置文件辅助类""" def __init__(self, env_path: Optional[Path] = None): """ 初始化配置辅助器 Args: env_path: .env 文件路径,默认使用 AppSettings 中的路径 """ if env_path is not None: self.env_path = Path(env_path) else: # 从 AppSettings 获取路径 from .app_settings import app_settings settings_path = app_settings.env_file_path if settings_path: self.env_path = Path(settings_path) else: # 回退到源码目录 self.env_path = Path(__file__).resolve().parents[2] / ".env" def load_env(self) -> Dict[str, str]: """ 加载 .env 文件内容 Returns: 环境变量字典 """ env_vars = {} if not self.env_path.exists(): return env_vars try: content = self.env_path.read_text(encoding="utf-8", errors="ignore") for line in content.splitlines(): parsed = self._parse_line(line) if parsed: key, value = parsed env_vars[key] = value except Exception: pass return env_vars def save_env(self, env_vars: Dict[str, str]) -> bool: """ 保存环境变量到 .env 文件 Args: env_vars: 环境变量字典 Returns: 是否保存成功 """ try: lines = [] # 按分组输出 written_keys = set() for group_id, group_info in ENV_GROUPS.items(): group_lines = [] for key in group_info["keys"]: if key in env_vars: value = env_vars[key] group_lines.append(self._format_line(key, value)) written_keys.add(key) if group_lines: lines.append(f"\n# {group_info['title']}") lines.extend(group_lines) # 写入未分组的变量 other_lines = [] for key, value in env_vars.items(): if key not in written_keys: other_lines.append(self._format_line(key, value)) if other_lines: lines.append("\n# 其他配置") lines.extend(other_lines) content = "\n".join(lines).strip() + "\n" self.env_path.write_text(content, encoding="utf-8") return True except Exception: return False def get_grouped_env(self) -> Dict[str, List[Tuple[str, str, bool]]]: """ 获取分组的环境变量 Returns: 分组字典 {group_id: [(key, value, is_sensitive), ...]} """ env_vars = self.load_env() result = {} used_keys = set() for group_id, group_info in ENV_GROUPS.items(): items = [] for key in group_info["keys"]: value = env_vars.get(key, "") is_sensitive = key in group_info.get("sensitive", []) items.append((key, value, is_sensitive)) if key in env_vars: used_keys.add(key) result[group_id] = items # 添加未分组的变量到 "other" 组 other_items = [] for key, value in env_vars.items(): if key not in used_keys: other_items.append((key, value, False)) if other_items: result["other"] = other_items return result def validate_env(self, env_vars: Dict[str, str]) -> List[str]: """ 验证环境变量 Args: env_vars: 环境变量字典 Returns: 错误消息列表 """ errors = [] # 验证 PG_DSN 格式 pg_dsn = env_vars.get("PG_DSN", "") if pg_dsn and not pg_dsn.startswith("postgresql://"): errors.append("PG_DSN 应以 'postgresql://' 开头") # 验证端口号 pg_port = env_vars.get("PG_PORT", "") if pg_port: try: port = int(pg_port) if port < 1 or port > 65535: errors.append("PG_PORT 应在 1-65535 范围内") except ValueError: errors.append("PG_PORT 应为数字") # 验证 STORE_ID store_id = env_vars.get("STORE_ID", "") if store_id: try: int(store_id) except ValueError: errors.append("STORE_ID 应为数字") # 验证路径存在性(可选) for key in ["EXPORT_ROOT", "LOG_ROOT", "FETCH_ROOT"]: path = env_vars.get(key, "") if path and not os.path.isabs(path): errors.append(f"{key} 建议使用绝对路径") return errors def mask_sensitive(self, value: str, visible_chars: int = 4) -> str: """ 脱敏敏感值 Args: value: 原始值 visible_chars: 可见字符数 Returns: 脱敏后的值 """ if not value or len(value) <= visible_chars: return "*" * len(value) if value else "" return value[:visible_chars] + "*" * (len(value) - visible_chars) def _parse_line(self, line: str) -> Optional[Tuple[str, str]]: """解析 .env 文件的一行""" stripped = line.strip() if not stripped or stripped.startswith("#"): return None if stripped.startswith("export "): stripped = stripped[7:].strip() if "=" not in stripped: return None key, value = stripped.split("=", 1) key = key.strip() value = self._unquote_value(value) return key, value def _unquote_value(self, value: str) -> str: """处理引号和注释""" # 去除内联注释 value = self._strip_inline_comment(value) value = value.rstrip(",").strip() if not value: return value # 去除引号 if len(value) >= 2 and value[0] in ("'", '"') and value[-1] == value[0]: return value[1:-1] if len(value) >= 3 and value[0] in ("r", "R") and value[1] in ("'", '"') and value[-1] == value[1]: return value[2:-1] return value def _strip_inline_comment(self, value: str) -> str: """去除内联注释""" result = [] in_quote = False quote_char = "" escape = False for ch in value: if escape: result.append(ch) escape = False continue if ch == "\\": escape = True result.append(ch) continue if ch in ("'", '"'): if not in_quote: in_quote = True quote_char = ch elif quote_char == ch: in_quote = False quote_char = "" result.append(ch) continue if ch == "#" and not in_quote: break result.append(ch) return "".join(result).rstrip() def _format_line(self, key: str, value: str) -> str: """格式化为 .env 行""" # 如果值包含特殊字符,使用引号包裹 if any(c in value for c in [' ', '"', "'", '#', '\n', '\r']): # 使用双引号,转义内部的双引号 escaped = value.replace('\\', '\\\\').replace('"', '\\"') return f'{key}="{escaped}"' return f"{key}={value}" @staticmethod def get_group_title(group_id: str) -> str: """获取分组标题""" if group_id in ENV_GROUPS: return ENV_GROUPS[group_id]["title"] return "其他配置" # 全局实例 config_helper = ConfigHelper()