572 lines
18 KiB
Python
572 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
指数算法任务基类
|
||
|
||
功能说明:
|
||
- 提供半衰期时间衰减函数
|
||
- 提供分位数计算和分位截断
|
||
- 提供0-10映射方法
|
||
- 提供算法参数加载
|
||
- 提供分位点历史记录(用于EWMA平滑)
|
||
|
||
算法原理:
|
||
1. 时间衰减函数(半衰期模型):decay(d; h) = exp(-ln(2) * d / h)
|
||
当 d=h 时权重衰减到 0.5,越近权重越大
|
||
|
||
2. 0-10映射流程:
|
||
Raw Score → Winsorize(P5, P95) → [可选Log/asinh压缩] → MinMax(0, 10)
|
||
|
||
作者:ETL团队
|
||
创建日期:2026-02-03
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import math
|
||
from abc import abstractmethod
|
||
from dataclasses import dataclass
|
||
from datetime import date, datetime
|
||
from decimal import Decimal
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
from ..base_dws_task import BaseDwsTask, TaskContext
|
||
|
||
|
||
# =============================================================================
|
||
# 数据类定义
|
||
# =============================================================================
|
||
|
||
@dataclass
|
||
class IndexParameters:
|
||
"""指数算法参数数据类"""
|
||
params: Dict[str, float]
|
||
loaded_at: datetime
|
||
|
||
|
||
@dataclass
|
||
class PercentileHistory:
|
||
"""分位点历史记录"""
|
||
percentile_5: float
|
||
percentile_95: float
|
||
percentile_5_smoothed: float
|
||
percentile_95_smoothed: float
|
||
record_count: int
|
||
calc_time: datetime
|
||
|
||
|
||
# =============================================================================
|
||
# 指数任务基类
|
||
# =============================================================================
|
||
|
||
class BaseIndexTask(BaseDwsTask):
|
||
"""
|
||
指数算法任务基类
|
||
|
||
提供指数计算通用功能:
|
||
1. 半衰期时间衰减函数
|
||
2. 分位数计算与截断
|
||
3. 0-10归一化映射
|
||
4. 算法参数加载
|
||
5. 分位点历史管理(EWMA平滑)
|
||
"""
|
||
|
||
# 子类需要定义的指数类型
|
||
INDEX_TYPE: str = ""
|
||
|
||
# 参数缓存TTL(秒)
|
||
_index_params_ttl: int = 300
|
||
|
||
def __init__(self, config, db_connection, api_client, logger):
|
||
super().__init__(config, db_connection, api_client, logger)
|
||
# 参数缓存:按 index_type 隔离,避免单任务多指数串参
|
||
self._index_params_cache_by_type: Dict[str, IndexParameters] = {}
|
||
|
||
# 默认参数
|
||
DEFAULT_LOOKBACK_DAYS = 60
|
||
DEFAULT_PERCENTILE_LOWER = 5
|
||
DEFAULT_PERCENTILE_UPPER = 95
|
||
DEFAULT_EWMA_ALPHA = 0.2
|
||
|
||
# ==========================================================================
|
||
# 抽象方法(子类需实现)
|
||
# ==========================================================================
|
||
|
||
@abstractmethod
|
||
def get_index_type(self) -> str:
|
||
"""获取指数类型(RECALL/INTIMACY)"""
|
||
raise NotImplementedError
|
||
|
||
# ==========================================================================
|
||
# 时间衰减函数
|
||
# ==========================================================================
|
||
|
||
def decay(self, days: float, halflife: float) -> float:
|
||
"""
|
||
半衰期衰减函数
|
||
|
||
公式: decay(d; h) = exp(-ln(2) * d / h)
|
||
|
||
解释:当 d=h 时权重衰减到 0.5;越近权重越大,符合"近期更重要"的直觉
|
||
|
||
Args:
|
||
days: 事件距今天数 (d >= 0)
|
||
halflife: 半衰期 (h > 0),单位:天
|
||
|
||
Returns:
|
||
衰减后的权重,范围 (0, 1]
|
||
|
||
Examples:
|
||
>>> decay(0, 7) # 今天,权重=1.0
|
||
1.0
|
||
>>> decay(7, 7) # 7天前,半衰期=7,权重=0.5
|
||
0.5
|
||
>>> decay(14, 7) # 14天前,权重=0.25
|
||
0.25
|
||
"""
|
||
if halflife <= 0:
|
||
raise ValueError("半衰期必须大于0")
|
||
if days < 0:
|
||
days = 0
|
||
return math.exp(-math.log(2) * days / halflife)
|
||
|
||
# ==========================================================================
|
||
# 分位数计算
|
||
# ==========================================================================
|
||
|
||
def calculate_percentiles(
|
||
self,
|
||
scores: List[float],
|
||
lower: int = 5,
|
||
upper: int = 95
|
||
) -> Tuple[float, float]:
|
||
"""
|
||
计算分位点
|
||
|
||
Args:
|
||
scores: 分数列表
|
||
lower: 下分位点百分比(默认5)
|
||
upper: 上分位点百分比(默认95)
|
||
|
||
Returns:
|
||
(下分位值, 上分位值) 元组
|
||
"""
|
||
if not scores:
|
||
return 0.0, 0.0
|
||
|
||
sorted_scores = sorted(scores)
|
||
n = len(sorted_scores)
|
||
|
||
# 计算分位点索引
|
||
lower_idx = max(0, int(n * lower / 100) - 1)
|
||
upper_idx = min(n - 1, int(n * upper / 100))
|
||
|
||
return sorted_scores[lower_idx], sorted_scores[upper_idx]
|
||
|
||
def winsorize(self, value: float, lower: float, upper: float) -> float:
|
||
"""
|
||
分位截断(Winsorize)
|
||
|
||
将值限制在 [lower, upper] 范围内
|
||
|
||
Args:
|
||
value: 原始值
|
||
lower: 下限(P5分位)
|
||
upper: 上限(P95分位)
|
||
|
||
Returns:
|
||
截断后的值
|
||
"""
|
||
return min(max(value, lower), upper)
|
||
|
||
# ==========================================================================
|
||
# 0-10映射
|
||
# ==========================================================================
|
||
|
||
def normalize_to_display(
|
||
self,
|
||
value: float,
|
||
min_val: float,
|
||
max_val: float,
|
||
use_log: bool = False,
|
||
compression: Optional[str] = None,
|
||
epsilon: float = 1e-6
|
||
) -> float:
|
||
"""
|
||
归一化到0-10分
|
||
|
||
映射流程:
|
||
1. [可选] 压缩:y = ln(1 + x) / asinh(x)
|
||
2. MinMax映射:score = 10 * (y - min) / (max - min)
|
||
|
||
Args:
|
||
value: 原始值(已Winsorize)
|
||
min_val: 最小值(通常为P5)
|
||
max_val: 最大值(通常为P95)
|
||
use_log: 是否使用log1p压缩(兼容历史参数)
|
||
compression: 压缩方式(none/log1p/asinh),优先级高于use_log
|
||
epsilon: 防除零小量
|
||
|
||
Returns:
|
||
0-10范围的分数
|
||
"""
|
||
compression_mode = self._resolve_compression(compression, use_log)
|
||
if compression_mode == "log1p":
|
||
value = math.log1p(value)
|
||
min_val = math.log1p(min_val)
|
||
max_val = math.log1p(max_val)
|
||
elif compression_mode == "asinh":
|
||
value = math.asinh(value)
|
||
min_val = math.asinh(min_val)
|
||
max_val = math.asinh(max_val)
|
||
|
||
# 防止分母为0
|
||
range_val = max_val - min_val
|
||
if range_val < epsilon:
|
||
return 5.0 # 几乎全员相同时返回中间值
|
||
|
||
score = 10.0 * (value - min_val) / range_val
|
||
|
||
# 确保在0-10范围内
|
||
return max(0.0, min(10.0, score))
|
||
|
||
def batch_normalize_to_display(
|
||
self,
|
||
raw_scores: List[Tuple[Any, float]], # [(entity_id, raw_score), ...]
|
||
use_log: bool = False,
|
||
compression: Optional[str] = None,
|
||
percentile_lower: int = 5,
|
||
percentile_upper: int = 95,
|
||
use_smoothing: bool = False,
|
||
site_id: Optional[int] = None,
|
||
index_type: Optional[str] = None,
|
||
) -> List[Tuple[Any, float, float]]:
|
||
"""
|
||
批量归一化Raw Score到Display Score
|
||
|
||
流程:
|
||
1. 提取所有raw_score
|
||
2. 计算分位点(可选EWMA平滑)
|
||
3. Winsorize截断
|
||
4. MinMax映射到0-10
|
||
|
||
Args:
|
||
raw_scores: (entity_id, raw_score) 元组列表
|
||
use_log: 是否使用log1p压缩(兼容历史参数)
|
||
compression: 压缩方式(none/log1p/asinh),优先级高于use_log
|
||
percentile_lower: 下分位百分比
|
||
percentile_upper: 上分位百分比
|
||
use_smoothing: 是否使用EWMA平滑分位点
|
||
site_id: 门店ID(平滑时需要)
|
||
index_type: 指数类型(平滑时用于分位历史隔离)
|
||
|
||
Returns:
|
||
(entity_id, raw_score, display_score) 元组列表
|
||
"""
|
||
if not raw_scores:
|
||
return []
|
||
|
||
# 提取raw_score
|
||
scores = [s for _, s in raw_scores]
|
||
|
||
# 计算分位点
|
||
q_l, q_u = self.calculate_percentiles(scores, percentile_lower, percentile_upper)
|
||
|
||
# EWMA平滑
|
||
if use_smoothing and site_id is not None:
|
||
q_l, q_u = self._apply_ewma_smoothing(
|
||
site_id=site_id,
|
||
current_p5=q_l,
|
||
current_p95=q_u,
|
||
index_type=index_type,
|
||
)
|
||
|
||
# 映射
|
||
results = []
|
||
compression_mode = self._resolve_compression(compression, use_log)
|
||
for entity_id, raw_score in raw_scores:
|
||
clipped = self.winsorize(raw_score, q_l, q_u)
|
||
display = self.normalize_to_display(
|
||
clipped,
|
||
q_l,
|
||
q_u,
|
||
compression=compression_mode,
|
||
)
|
||
results.append((entity_id, raw_score, round(display, 2)))
|
||
|
||
return results
|
||
|
||
# ==========================================================================
|
||
# 算法参数加载
|
||
# ==========================================================================
|
||
|
||
def load_index_parameters(
|
||
self,
|
||
index_type: Optional[str] = None,
|
||
force_reload: bool = False
|
||
) -> Dict[str, float]:
|
||
"""
|
||
加载指数算法参数
|
||
|
||
Args:
|
||
index_type: 指数类型(默认使用子类定义的INDEX_TYPE)
|
||
force_reload: 是否强制重新加载
|
||
|
||
Returns:
|
||
参数名到参数值的字典
|
||
"""
|
||
if index_type is None:
|
||
index_type = self.get_index_type()
|
||
|
||
now = datetime.now(self.tz)
|
||
cache_key = str(index_type).upper()
|
||
cache_item = self._index_params_cache_by_type.get(cache_key)
|
||
|
||
# 检查缓存
|
||
if (
|
||
not force_reload
|
||
and cache_item is not None
|
||
and (now - cache_item.loaded_at).total_seconds() < self._index_params_ttl
|
||
):
|
||
return cache_item.params
|
||
|
||
self.logger.debug("加载指数算法参数: %s", index_type)
|
||
|
||
sql = """
|
||
SELECT param_name, param_value
|
||
FROM billiards_dws.cfg_index_parameters
|
||
WHERE index_type = %s
|
||
AND effective_from <= CURRENT_DATE
|
||
AND (effective_to IS NULL OR effective_to >= CURRENT_DATE)
|
||
ORDER BY effective_from DESC
|
||
"""
|
||
|
||
rows = self.db.query(sql, (index_type,))
|
||
|
||
params = {}
|
||
seen = set()
|
||
for row in (rows or []):
|
||
row_dict = dict(row)
|
||
name = row_dict['param_name']
|
||
if name not in seen:
|
||
params[name] = float(row_dict['param_value'])
|
||
seen.add(name)
|
||
|
||
self._index_params_cache_by_type[cache_key] = IndexParameters(
|
||
params=params,
|
||
loaded_at=now
|
||
)
|
||
|
||
return params
|
||
|
||
def get_param(
|
||
self,
|
||
name: str,
|
||
default: float = 0.0,
|
||
index_type: Optional[str] = None,
|
||
) -> float:
|
||
"""
|
||
获取单个参数值
|
||
|
||
Args:
|
||
name: 参数名
|
||
default: 默认值
|
||
|
||
Returns:
|
||
参数值
|
||
"""
|
||
params = self.load_index_parameters(index_type=index_type)
|
||
return params.get(name, default)
|
||
|
||
# ==========================================================================
|
||
# 分位点历史管理(EWMA平滑)
|
||
# ==========================================================================
|
||
|
||
def get_last_percentile_history(
|
||
self,
|
||
site_id: int,
|
||
index_type: Optional[str] = None
|
||
) -> Optional[PercentileHistory]:
|
||
"""
|
||
获取最近一次分位点历史
|
||
|
||
Args:
|
||
site_id: 门店ID
|
||
index_type: 指数类型
|
||
|
||
Returns:
|
||
PercentileHistory 或 None
|
||
"""
|
||
if index_type is None:
|
||
index_type = self.get_index_type()
|
||
|
||
sql = """
|
||
SELECT
|
||
percentile_5, percentile_95,
|
||
percentile_5_smoothed, percentile_95_smoothed,
|
||
record_count, calc_time
|
||
FROM billiards_dws.dws_index_percentile_history
|
||
WHERE site_id = %s AND index_type = %s
|
||
ORDER BY calc_time DESC
|
||
LIMIT 1
|
||
"""
|
||
|
||
rows = self.db.query(sql, (site_id, index_type))
|
||
|
||
if not rows:
|
||
return None
|
||
|
||
row = dict(rows[0])
|
||
return PercentileHistory(
|
||
percentile_5=float(row['percentile_5'] or 0),
|
||
percentile_95=float(row['percentile_95'] or 0),
|
||
percentile_5_smoothed=float(row['percentile_5_smoothed'] or 0),
|
||
percentile_95_smoothed=float(row['percentile_95_smoothed'] or 0),
|
||
record_count=int(row['record_count'] or 0),
|
||
calc_time=row['calc_time']
|
||
)
|
||
|
||
def save_percentile_history(
|
||
self,
|
||
site_id: int,
|
||
percentile_5: float,
|
||
percentile_95: float,
|
||
percentile_5_smoothed: float,
|
||
percentile_95_smoothed: float,
|
||
record_count: int,
|
||
min_raw: float,
|
||
max_raw: float,
|
||
avg_raw: float,
|
||
index_type: Optional[str] = None
|
||
) -> None:
|
||
"""
|
||
保存分位点历史
|
||
|
||
Args:
|
||
site_id: 门店ID
|
||
percentile_5: 原始5分位
|
||
percentile_95: 原始95分位
|
||
percentile_5_smoothed: 平滑后5分位
|
||
percentile_95_smoothed: 平滑后95分位
|
||
record_count: 记录数
|
||
min_raw: 最小Raw Score
|
||
max_raw: 最大Raw Score
|
||
avg_raw: 平均Raw Score
|
||
index_type: 指数类型
|
||
"""
|
||
if index_type is None:
|
||
index_type = self.get_index_type()
|
||
|
||
sql = """
|
||
INSERT INTO billiards_dws.dws_index_percentile_history (
|
||
site_id, index_type, calc_time,
|
||
percentile_5, percentile_95,
|
||
percentile_5_smoothed, percentile_95_smoothed,
|
||
record_count, min_raw_score, max_raw_score, avg_raw_score
|
||
) VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql, (
|
||
site_id, index_type,
|
||
percentile_5, percentile_95,
|
||
percentile_5_smoothed, percentile_95_smoothed,
|
||
record_count, min_raw, max_raw, avg_raw
|
||
))
|
||
self.db.conn.commit()
|
||
|
||
def _apply_ewma_smoothing(
|
||
self,
|
||
site_id: int,
|
||
current_p5: float,
|
||
current_p95: float,
|
||
alpha: Optional[float] = None,
|
||
index_type: Optional[str] = None,
|
||
) -> Tuple[float, float]:
|
||
"""
|
||
应用EWMA平滑到分位点
|
||
|
||
公式: Q_t = (1 - α) * Q_{t-1} + α * Q_now
|
||
|
||
Args:
|
||
site_id: 门店ID
|
||
current_p5: 当前5分位
|
||
current_p95: 当前95分位
|
||
alpha: 平滑系数(默认0.2)
|
||
index_type: 指数类型(用于参数和历史隔离)
|
||
|
||
Returns:
|
||
(平滑后的P5, 平滑后的P95)
|
||
"""
|
||
if index_type is None:
|
||
index_type = self.get_index_type()
|
||
|
||
if alpha is None:
|
||
alpha = self.get_param(
|
||
'ewma_alpha',
|
||
self.DEFAULT_EWMA_ALPHA,
|
||
index_type=index_type,
|
||
)
|
||
|
||
history = self.get_last_percentile_history(site_id, index_type=index_type)
|
||
|
||
if history is None:
|
||
# 首次计算,不平滑
|
||
return current_p5, current_p95
|
||
|
||
smoothed_p5 = (1 - alpha) * history.percentile_5_smoothed + alpha * current_p5
|
||
smoothed_p95 = (1 - alpha) * history.percentile_95_smoothed + alpha * current_p95
|
||
|
||
return smoothed_p5, smoothed_p95
|
||
|
||
# ==========================================================================
|
||
# 统计工具方法
|
||
# ==========================================================================
|
||
|
||
def calculate_median(self, values: List[float]) -> float:
|
||
"""计算中位数"""
|
||
if not values:
|
||
return 0.0
|
||
sorted_vals = sorted(values)
|
||
n = len(sorted_vals)
|
||
mid = n // 2
|
||
if n % 2 == 0:
|
||
return (sorted_vals[mid - 1] + sorted_vals[mid]) / 2
|
||
return sorted_vals[mid]
|
||
|
||
def calculate_mad(self, values: List[float]) -> float:
|
||
"""
|
||
计算MAD(中位绝对偏差)
|
||
|
||
MAD = median(|x - median(x)|)
|
||
|
||
MAD是比标准差更稳健的离散度度量,不受极端值影响
|
||
"""
|
||
if not values:
|
||
return 0.0
|
||
median_val = self.calculate_median(values)
|
||
deviations = [abs(v - median_val) for v in values]
|
||
return self.calculate_median(deviations)
|
||
|
||
def safe_log(self, value: float, default: float = 0.0) -> float:
|
||
"""安全的对数运算"""
|
||
if value <= 0:
|
||
return default
|
||
return math.log(value)
|
||
|
||
def safe_ln1p(self, value: float) -> float:
|
||
"""安全的ln(1+x)运算"""
|
||
if value < -1:
|
||
return 0.0
|
||
return math.log1p(value)
|
||
|
||
def _resolve_compression(self, compression: Optional[str], use_log: bool) -> str:
|
||
"""规范化压缩方式"""
|
||
if compression is None:
|
||
return "log1p" if use_log else "none"
|
||
compression_key = str(compression).strip().lower()
|
||
if compression_key in ("none", "log1p", "asinh"):
|
||
return compression_key
|
||
if hasattr(self, "logger"):
|
||
self.logger.warning("未知压缩方式: %s,已降级为 none", compression)
|
||
return "none"
|