# -*- coding: utf-8 -*- """ 指数算法任务基类 功能说明: - 提供半衰期时间衰减函数 - 提供分位数计算和分位截断 - 提供0-10映射方法 - 提供算法参数加载 - 提供分位点历史记录(用于EWMA平滑) 算法原理: 1. 时间衰减函数(半衰期模型):decay(d; h) = exp(-ln(2) * d / h) 当 d=h 时权重衰减到 0.5,越近权重越大 2. 0-10映射流程: Raw Score → Winsorize(P5, P95) → [可选Log压缩] → MinMax(0, 10) 作者:ETL团队 创建日期:2026-02-03 """ from __future__ import annotations import math from abc import abstractmethod from dataclasses import dataclass from datetime import date, datetime from decimal import Decimal from typing import Any, Dict, List, Optional, Tuple from ..base_dws_task import BaseDwsTask, TaskContext # ============================================================================= # 数据类定义 # ============================================================================= @dataclass class IndexParameters: """指数算法参数数据类""" params: Dict[str, float] loaded_at: datetime @dataclass class PercentileHistory: """分位点历史记录""" percentile_5: float percentile_95: float percentile_5_smoothed: float percentile_95_smoothed: float record_count: int calc_time: datetime # ============================================================================= # 指数任务基类 # ============================================================================= class BaseIndexTask(BaseDwsTask): """ 指数算法任务基类 提供指数计算通用功能: 1. 半衰期时间衰减函数 2. 分位数计算与截断 3. 0-10归一化映射 4. 算法参数加载 5. 分位点历史管理(EWMA平滑) """ # 子类需要定义的指数类型 INDEX_TYPE: str = "" # 参数缓存 _index_params_cache: Optional[IndexParameters] = None _index_params_ttl: int = 300 # 缓存有效期(秒) # 默认参数 DEFAULT_LOOKBACK_DAYS = 60 DEFAULT_PERCENTILE_LOWER = 5 DEFAULT_PERCENTILE_UPPER = 95 DEFAULT_EWMA_ALPHA = 0.2 # ========================================================================== # 抽象方法(子类需实现) # ========================================================================== @abstractmethod def get_index_type(self) -> str: """获取指数类型(RECALL/INTIMACY)""" raise NotImplementedError # ========================================================================== # 时间衰减函数 # ========================================================================== def decay(self, days: float, halflife: float) -> float: """ 半衰期衰减函数 公式: decay(d; h) = exp(-ln(2) * d / h) 解释:当 d=h 时权重衰减到 0.5;越近权重越大,符合"近期更重要"的直觉 Args: days: 事件距今天数 (d >= 0) halflife: 半衰期 (h > 0),单位:天 Returns: 衰减后的权重,范围 (0, 1] Examples: >>> decay(0, 7) # 今天,权重=1.0 1.0 >>> decay(7, 7) # 7天前,半衰期=7,权重=0.5 0.5 >>> decay(14, 7) # 14天前,权重=0.25 0.25 """ if halflife <= 0: raise ValueError("半衰期必须大于0") if days < 0: days = 0 return math.exp(-math.log(2) * days / halflife) # ========================================================================== # 分位数计算 # ========================================================================== def calculate_percentiles( self, scores: List[float], lower: int = 5, upper: int = 95 ) -> Tuple[float, float]: """ 计算分位点 Args: scores: 分数列表 lower: 下分位点百分比(默认5) upper: 上分位点百分比(默认95) Returns: (下分位值, 上分位值) 元组 """ if not scores: return 0.0, 0.0 sorted_scores = sorted(scores) n = len(sorted_scores) # 计算分位点索引 lower_idx = max(0, int(n * lower / 100) - 1) upper_idx = min(n - 1, int(n * upper / 100)) return sorted_scores[lower_idx], sorted_scores[upper_idx] def winsorize(self, value: float, lower: float, upper: float) -> float: """ 分位截断(Winsorize) 将值限制在 [lower, upper] 范围内 Args: value: 原始值 lower: 下限(P5分位) upper: 上限(P95分位) Returns: 截断后的值 """ return min(max(value, lower), upper) # ========================================================================== # 0-10映射 # ========================================================================== def normalize_to_display( self, value: float, min_val: float, max_val: float, use_log: bool = False, epsilon: float = 1e-6 ) -> float: """ 归一化到0-10分 映射流程: 1. [可选] 对数压缩:y = ln(1 + x) 2. MinMax映射:score = 10 * (y - min) / (max - min) Args: value: 原始值(已Winsorize) min_val: 最小值(通常为P5) max_val: 最大值(通常为P95) use_log: 是否使用对数压缩(亲密指数建议启用) epsilon: 防除零小量 Returns: 0-10范围的分数 """ if use_log: value = math.log1p(value) min_val = math.log1p(min_val) max_val = math.log1p(max_val) # 防止分母为0 range_val = max_val - min_val if range_val < epsilon: return 5.0 # 几乎全员相同时返回中间值 score = 10.0 * (value - min_val) / range_val # 确保在0-10范围内 return max(0.0, min(10.0, score)) def batch_normalize_to_display( self, raw_scores: List[Tuple[Any, float]], # [(entity_id, raw_score), ...] use_log: bool = False, percentile_lower: int = 5, percentile_upper: int = 95, use_smoothing: bool = False, site_id: Optional[int] = None ) -> List[Tuple[Any, float, float]]: """ 批量归一化Raw Score到Display Score 流程: 1. 提取所有raw_score 2. 计算分位点(可选EWMA平滑) 3. Winsorize截断 4. MinMax映射到0-10 Args: raw_scores: (entity_id, raw_score) 元组列表 use_log: 是否使用对数压缩 percentile_lower: 下分位百分比 percentile_upper: 上分位百分比 use_smoothing: 是否使用EWMA平滑分位点 site_id: 门店ID(平滑时需要) Returns: (entity_id, raw_score, display_score) 元组列表 """ if not raw_scores: return [] # 提取raw_score scores = [s for _, s in raw_scores] # 计算分位点 q_l, q_u = self.calculate_percentiles(scores, percentile_lower, percentile_upper) # EWMA平滑 if use_smoothing and site_id is not None: q_l, q_u = self._apply_ewma_smoothing(site_id, q_l, q_u) # 映射 results = [] for entity_id, raw_score in raw_scores: clipped = self.winsorize(raw_score, q_l, q_u) display = self.normalize_to_display(clipped, q_l, q_u, use_log) results.append((entity_id, raw_score, round(display, 2))) return results # ========================================================================== # 算法参数加载 # ========================================================================== def load_index_parameters( self, index_type: Optional[str] = None, force_reload: bool = False ) -> Dict[str, float]: """ 加载指数算法参数 Args: index_type: 指数类型(默认使用子类定义的INDEX_TYPE) force_reload: 是否强制重新加载 Returns: 参数名到参数值的字典 """ if index_type is None: index_type = self.get_index_type() now = datetime.now(self.tz) # 检查缓存 if ( not force_reload and self._index_params_cache is not None and (now - self._index_params_cache.loaded_at).total_seconds() < self._index_params_ttl ): return self._index_params_cache.params self.logger.debug("加载指数算法参数: %s", index_type) sql = """ SELECT param_name, param_value FROM billiards_dws.cfg_index_parameters WHERE index_type = %s AND effective_from <= CURRENT_DATE AND (effective_to IS NULL OR effective_to >= CURRENT_DATE) ORDER BY effective_from DESC """ rows = self.db.query(sql, (index_type,)) params = {} seen = set() for row in (rows or []): row_dict = dict(row) name = row_dict['param_name'] if name not in seen: params[name] = float(row_dict['param_value']) seen.add(name) self._index_params_cache = IndexParameters( params=params, loaded_at=now ) return params def get_param(self, name: str, default: float = 0.0) -> float: """ 获取单个参数值 Args: name: 参数名 default: 默认值 Returns: 参数值 """ params = self.load_index_parameters() return params.get(name, default) # ========================================================================== # 分位点历史管理(EWMA平滑) # ========================================================================== def get_last_percentile_history( self, site_id: int, index_type: Optional[str] = None ) -> Optional[PercentileHistory]: """ 获取最近一次分位点历史 Args: site_id: 门店ID index_type: 指数类型 Returns: PercentileHistory 或 None """ if index_type is None: index_type = self.get_index_type() sql = """ SELECT percentile_5, percentile_95, percentile_5_smoothed, percentile_95_smoothed, record_count, calc_time FROM billiards_dws.dws_index_percentile_history WHERE site_id = %s AND index_type = %s ORDER BY calc_time DESC LIMIT 1 """ rows = self.db.query(sql, (site_id, index_type)) if not rows: return None row = dict(rows[0]) return PercentileHistory( percentile_5=float(row['percentile_5'] or 0), percentile_95=float(row['percentile_95'] or 0), percentile_5_smoothed=float(row['percentile_5_smoothed'] or 0), percentile_95_smoothed=float(row['percentile_95_smoothed'] or 0), record_count=int(row['record_count'] or 0), calc_time=row['calc_time'] ) def save_percentile_history( self, site_id: int, percentile_5: float, percentile_95: float, percentile_5_smoothed: float, percentile_95_smoothed: float, record_count: int, min_raw: float, max_raw: float, avg_raw: float, index_type: Optional[str] = None ) -> None: """ 保存分位点历史 Args: site_id: 门店ID percentile_5: 原始5分位 percentile_95: 原始95分位 percentile_5_smoothed: 平滑后5分位 percentile_95_smoothed: 平滑后95分位 record_count: 记录数 min_raw: 最小Raw Score max_raw: 最大Raw Score avg_raw: 平均Raw Score index_type: 指数类型 """ if index_type is None: index_type = self.get_index_type() sql = """ INSERT INTO billiards_dws.dws_index_percentile_history ( site_id, index_type, calc_time, percentile_5, percentile_95, percentile_5_smoothed, percentile_95_smoothed, record_count, min_raw_score, max_raw_score, avg_raw_score ) VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, %s, %s) """ with self.db.conn.cursor() as cur: cur.execute(sql, ( site_id, index_type, percentile_5, percentile_95, percentile_5_smoothed, percentile_95_smoothed, record_count, min_raw, max_raw, avg_raw )) self.db.conn.commit() def _apply_ewma_smoothing( self, site_id: int, current_p5: float, current_p95: float, alpha: Optional[float] = None ) -> Tuple[float, float]: """ 应用EWMA平滑到分位点 公式: Q_t = (1 - α) * Q_{t-1} + α * Q_now Args: site_id: 门店ID current_p5: 当前5分位 current_p95: 当前95分位 alpha: 平滑系数(默认0.2) Returns: (平滑后的P5, 平滑后的P95) """ if alpha is None: alpha = self.get_param('ewma_alpha', self.DEFAULT_EWMA_ALPHA) history = self.get_last_percentile_history(site_id) if history is None: # 首次计算,不平滑 return current_p5, current_p95 smoothed_p5 = (1 - alpha) * history.percentile_5_smoothed + alpha * current_p5 smoothed_p95 = (1 - alpha) * history.percentile_95_smoothed + alpha * current_p95 return smoothed_p5, smoothed_p95 # ========================================================================== # 统计工具方法 # ========================================================================== def calculate_median(self, values: List[float]) -> float: """计算中位数""" if not values: return 0.0 sorted_vals = sorted(values) n = len(sorted_vals) mid = n // 2 if n % 2 == 0: return (sorted_vals[mid - 1] + sorted_vals[mid]) / 2 return sorted_vals[mid] def calculate_mad(self, values: List[float]) -> float: """ 计算MAD(中位绝对偏差) MAD = median(|x - median(x)|) MAD是比标准差更稳健的离散度度量,不受极端值影响 """ if not values: return 0.0 median_val = self.calculate_median(values) deviations = [abs(v - median_val) for v in values] return self.calculate_median(deviations) def safe_log(self, value: float, default: float = 0.0) -> float: """安全的对数运算""" if value <= 0: return default return math.log(value) def safe_ln1p(self, value: float) -> float: """安全的ln(1+x)运算""" if value < -1: return 0.0 return math.log1p(value)