init: 项目初始提交 - NeoZQYY Monorepo 完整代码

This commit is contained in:
Neo
2026-02-15 14:58:14 +08:00
commit ded6dfb9d8
769 changed files with 182616 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 移除 RecallIndexTask/IntimacyIndexTask 导出,仅保留 WBI/NCI/ML/Relation
"""
指数算法任务模块
包含:
- WinbackIndexTask: 老客挽回指数 (WBI)
- NewconvIndexTask: 新客转化指数 (NCI)
- MlManualImportTask: ML 人工台账导入任务
- RelationIndexTask: 关系指数计算任务RS/OS/MS/ML
"""
from .winback_index_task import WinbackIndexTask
from .newconv_index_task import NewconvIndexTask
from .ml_manual_import_task import MlManualImportTask
from .relation_index_task import RelationIndexTask
__all__ = [
'WinbackIndexTask',
'NewconvIndexTask',
'MlManualImportTask',
'RelationIndexTask',
]

View File

@@ -0,0 +1,572 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 更新 docstring移除 RECALL/INTIMACY 引用反映当前指数体系WBI/NCI/RS/OS/MS/ML
"""
指数算法任务基类
功能说明:
- 提供半衰期时间衰减函数
- 提供分位数计算和分位截断
- 提供0-10映射方法
- 提供算法参数加载
- 提供分位点历史记录用于EWMA平滑
算法原理:
1. 时间衰减函数半衰期模型decay(d; h) = exp(-ln(2) * d / h)
当 d=h 时权重衰减到 0.5,越近权重越大
2. 0-10映射流程
Raw Score → Winsorize(P5, P95) → [可选Log/asinh压缩] → MinMax(0, 10)
作者ETL团队
创建日期2026-02-03
"""
from __future__ import annotations
import math
from abc import abstractmethod
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from ..base_dws_task import BaseDwsTask, TaskContext
# =============================================================================
# 数据类定义
# =============================================================================
@dataclass
class IndexParameters:
"""指数算法参数数据类"""
params: Dict[str, float]
loaded_at: datetime
@dataclass
class PercentileHistory:
"""分位点历史记录"""
percentile_5: float
percentile_95: float
percentile_5_smoothed: float
percentile_95_smoothed: float
record_count: int
calc_time: datetime
# =============================================================================
# 指数任务基类
# =============================================================================
class BaseIndexTask(BaseDwsTask):
"""
指数算法任务基类
提供指数计算通用功能:
1. 半衰期时间衰减函数
2. 分位数计算与截断
3. 0-10归一化映射
4. 算法参数加载
5. 分位点历史管理EWMA平滑
"""
# 子类需要定义的指数类型
INDEX_TYPE: str = ""
# 参数缓存TTL
_index_params_ttl: int = 300
def __init__(self, config, db_connection, api_client, logger):
super().__init__(config, db_connection, api_client, logger)
# 参数缓存:按 index_type 隔离,避免单任务多指数串参
self._index_params_cache_by_type: Dict[str, IndexParameters] = {}
# 默认参数
DEFAULT_LOOKBACK_DAYS = 60
DEFAULT_PERCENTILE_LOWER = 5
DEFAULT_PERCENTILE_UPPER = 95
DEFAULT_EWMA_ALPHA = 0.2
# ==========================================================================
# 抽象方法(子类需实现)
# ==========================================================================
@abstractmethod
def get_index_type(self) -> str:
"""获取指数类型(如 WBI/NCI/RS/OS/MS/ML"""
raise NotImplementedError
# ==========================================================================
# 时间衰减函数
# ==========================================================================
def decay(self, days: float, halflife: float) -> float:
"""
半衰期衰减函数
公式: decay(d; h) = exp(-ln(2) * d / h)
解释:当 d=h 时权重衰减到 0.5;越近权重越大,符合"近期更重要"的直觉
Args:
days: 事件距今天数 (d >= 0)
halflife: 半衰期 (h > 0),单位:天
Returns:
衰减后的权重,范围 (0, 1]
Examples:
>>> decay(0, 7) # 今天,权重=1.0
1.0
>>> decay(7, 7) # 7天前半衰期=7权重=0.5
0.5
>>> decay(14, 7) # 14天前权重=0.25
0.25
"""
if halflife <= 0:
raise ValueError("半衰期必须大于0")
if days < 0:
days = 0
return math.exp(-math.log(2) * days / halflife)
# ==========================================================================
# 分位数计算
# ==========================================================================
def calculate_percentiles(
self,
scores: List[float],
lower: int = 5,
upper: int = 95
) -> Tuple[float, float]:
"""
计算分位点
Args:
scores: 分数列表
lower: 下分位点百分比默认5
upper: 上分位点百分比默认95
Returns:
(下分位值, 上分位值) 元组
"""
if not scores:
return 0.0, 0.0
sorted_scores = sorted(scores)
n = len(sorted_scores)
# 计算分位点索引
lower_idx = max(0, int(n * lower / 100) - 1)
upper_idx = min(n - 1, int(n * upper / 100))
return sorted_scores[lower_idx], sorted_scores[upper_idx]
def winsorize(self, value: float, lower: float, upper: float) -> float:
"""
分位截断Winsorize
将值限制在 [lower, upper] 范围内
Args:
value: 原始值
lower: 下限P5分位
upper: 上限P95分位
Returns:
截断后的值
"""
return min(max(value, lower), upper)
# ==========================================================================
# 0-10映射
# ==========================================================================
def normalize_to_display(
self,
value: float,
min_val: float,
max_val: float,
use_log: bool = False,
compression: Optional[str] = None,
epsilon: float = 1e-6
) -> float:
"""
归一化到0-10分
映射流程:
1. [可选] 压缩y = ln(1 + x) / asinh(x)
2. MinMax映射score = 10 * (y - min) / (max - min)
Args:
value: 原始值已Winsorize
min_val: 最小值通常为P5
max_val: 最大值通常为P95
use_log: 是否使用log1p压缩兼容历史参数
compression: 压缩方式none/log1p/asinh优先级高于use_log
epsilon: 防除零小量
Returns:
0-10范围的分数
"""
compression_mode = self._resolve_compression(compression, use_log)
if compression_mode == "log1p":
value = math.log1p(value)
min_val = math.log1p(min_val)
max_val = math.log1p(max_val)
elif compression_mode == "asinh":
value = math.asinh(value)
min_val = math.asinh(min_val)
max_val = math.asinh(max_val)
# 防止分母为0
range_val = max_val - min_val
if range_val < epsilon:
return 5.0 # 几乎全员相同时返回中间值
score = 10.0 * (value - min_val) / range_val
# 确保在0-10范围内
return max(0.0, min(10.0, score))
def batch_normalize_to_display(
self,
raw_scores: List[Tuple[Any, float]], # [(entity_id, raw_score), ...]
use_log: bool = False,
compression: Optional[str] = None,
percentile_lower: int = 5,
percentile_upper: int = 95,
use_smoothing: bool = False,
site_id: Optional[int] = None,
index_type: Optional[str] = None,
) -> List[Tuple[Any, float, float]]:
"""
批量归一化Raw Score到Display Score
流程:
1. 提取所有raw_score
2. 计算分位点可选EWMA平滑
3. Winsorize截断
4. MinMax映射到0-10
Args:
raw_scores: (entity_id, raw_score) 元组列表
use_log: 是否使用log1p压缩兼容历史参数
compression: 压缩方式none/log1p/asinh优先级高于use_log
percentile_lower: 下分位百分比
percentile_upper: 上分位百分比
use_smoothing: 是否使用EWMA平滑分位点
site_id: 门店ID平滑时需要
index_type: 指数类型(平滑时用于分位历史隔离)
Returns:
(entity_id, raw_score, display_score) 元组列表
"""
if not raw_scores:
return []
# 提取raw_score
scores = [s for _, s in raw_scores]
# 计算分位点
q_l, q_u = self.calculate_percentiles(scores, percentile_lower, percentile_upper)
# EWMA平滑
if use_smoothing and site_id is not None:
q_l, q_u = self._apply_ewma_smoothing(
site_id=site_id,
current_p5=q_l,
current_p95=q_u,
index_type=index_type,
)
# 映射
results = []
compression_mode = self._resolve_compression(compression, use_log)
for entity_id, raw_score in raw_scores:
clipped = self.winsorize(raw_score, q_l, q_u)
display = self.normalize_to_display(
clipped,
q_l,
q_u,
compression=compression_mode,
)
results.append((entity_id, raw_score, round(display, 2)))
return results
# ==========================================================================
# 算法参数加载
# ==========================================================================
def load_index_parameters(
self,
index_type: Optional[str] = None,
force_reload: bool = False
) -> Dict[str, float]:
"""
加载指数算法参数
Args:
index_type: 指数类型默认使用子类定义的INDEX_TYPE
force_reload: 是否强制重新加载
Returns:
参数名到参数值的字典
"""
if index_type is None:
index_type = self.get_index_type()
now = datetime.now(self.tz)
cache_key = str(index_type).upper()
cache_item = self._index_params_cache_by_type.get(cache_key)
# 检查缓存
if (
not force_reload
and cache_item is not None
and (now - cache_item.loaded_at).total_seconds() < self._index_params_ttl
):
return cache_item.params
self.logger.debug("加载指数算法参数: %s", index_type)
sql = """
SELECT param_name, param_value
FROM billiards_dws.cfg_index_parameters
WHERE index_type = %s
AND effective_from <= CURRENT_DATE
AND (effective_to IS NULL OR effective_to >= CURRENT_DATE)
ORDER BY effective_from DESC
"""
rows = self.db.query(sql, (index_type,))
params = {}
seen = set()
for row in (rows or []):
row_dict = dict(row)
name = row_dict['param_name']
if name not in seen:
params[name] = float(row_dict['param_value'])
seen.add(name)
self._index_params_cache_by_type[cache_key] = IndexParameters(
params=params,
loaded_at=now
)
return params
def get_param(
self,
name: str,
default: float = 0.0,
index_type: Optional[str] = None,
) -> float:
"""
获取单个参数值
Args:
name: 参数名
default: 默认值
Returns:
参数值
"""
params = self.load_index_parameters(index_type=index_type)
return params.get(name, default)
# ==========================================================================
# 分位点历史管理EWMA平滑
# ==========================================================================
def get_last_percentile_history(
self,
site_id: int,
index_type: Optional[str] = None
) -> Optional[PercentileHistory]:
"""
获取最近一次分位点历史
Args:
site_id: 门店ID
index_type: 指数类型
Returns:
PercentileHistory 或 None
"""
if index_type is None:
index_type = self.get_index_type()
sql = """
SELECT
percentile_5, percentile_95,
percentile_5_smoothed, percentile_95_smoothed,
record_count, calc_time
FROM billiards_dws.dws_index_percentile_history
WHERE site_id = %s AND index_type = %s
ORDER BY calc_time DESC
LIMIT 1
"""
rows = self.db.query(sql, (site_id, index_type))
if not rows:
return None
row = dict(rows[0])
return PercentileHistory(
percentile_5=float(row['percentile_5'] or 0),
percentile_95=float(row['percentile_95'] or 0),
percentile_5_smoothed=float(row['percentile_5_smoothed'] or 0),
percentile_95_smoothed=float(row['percentile_95_smoothed'] or 0),
record_count=int(row['record_count'] or 0),
calc_time=row['calc_time']
)
def save_percentile_history(
self,
site_id: int,
percentile_5: float,
percentile_95: float,
percentile_5_smoothed: float,
percentile_95_smoothed: float,
record_count: int,
min_raw: float,
max_raw: float,
avg_raw: float,
index_type: Optional[str] = None
) -> None:
"""
保存分位点历史
Args:
site_id: 门店ID
percentile_5: 原始5分位
percentile_95: 原始95分位
percentile_5_smoothed: 平滑后5分位
percentile_95_smoothed: 平滑后95分位
record_count: 记录数
min_raw: 最小Raw Score
max_raw: 最大Raw Score
avg_raw: 平均Raw Score
index_type: 指数类型
"""
if index_type is None:
index_type = self.get_index_type()
sql = """
INSERT INTO billiards_dws.dws_index_percentile_history (
site_id, index_type, calc_time,
percentile_5, percentile_95,
percentile_5_smoothed, percentile_95_smoothed,
record_count, min_raw_score, max_raw_score, avg_raw_score
) VALUES (%s, %s, NOW(), %s, %s, %s, %s, %s, %s, %s, %s)
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (
site_id, index_type,
percentile_5, percentile_95,
percentile_5_smoothed, percentile_95_smoothed,
record_count, min_raw, max_raw, avg_raw
))
self.db.conn.commit()
def _apply_ewma_smoothing(
self,
site_id: int,
current_p5: float,
current_p95: float,
alpha: Optional[float] = None,
index_type: Optional[str] = None,
) -> Tuple[float, float]:
"""
应用EWMA平滑到分位点
公式: Q_t = (1 - α) * Q_{t-1} + α * Q_now
Args:
site_id: 门店ID
current_p5: 当前5分位
current_p95: 当前95分位
alpha: 平滑系数默认0.2
index_type: 指数类型(用于参数和历史隔离)
Returns:
(平滑后的P5, 平滑后的P95)
"""
if index_type is None:
index_type = self.get_index_type()
if alpha is None:
alpha = self.get_param(
'ewma_alpha',
self.DEFAULT_EWMA_ALPHA,
index_type=index_type,
)
history = self.get_last_percentile_history(site_id, index_type=index_type)
if history is None:
# 首次计算,不平滑
return current_p5, current_p95
smoothed_p5 = (1 - alpha) * history.percentile_5_smoothed + alpha * current_p5
smoothed_p95 = (1 - alpha) * history.percentile_95_smoothed + alpha * current_p95
return smoothed_p5, smoothed_p95
# ==========================================================================
# 统计工具方法
# ==========================================================================
def calculate_median(self, values: List[float]) -> float:
"""计算中位数"""
if not values:
return 0.0
sorted_vals = sorted(values)
n = len(sorted_vals)
mid = n // 2
if n % 2 == 0:
return (sorted_vals[mid - 1] + sorted_vals[mid]) / 2
return sorted_vals[mid]
def calculate_mad(self, values: List[float]) -> float:
"""
计算MAD中位绝对偏差
MAD = median(|x - median(x)|)
MAD是比标准差更稳健的离散度度量不受极端值影响
"""
if not values:
return 0.0
median_val = self.calculate_median(values)
deviations = [abs(v - median_val) for v in values]
return self.calculate_median(deviations)
def safe_log(self, value: float, default: float = 0.0) -> float:
"""安全的对数运算"""
if value <= 0:
return default
return math.log(value)
def safe_ln1p(self, value: float) -> float:
"""安全的ln(1+x)运算"""
if value < -1:
return 0.0
return math.log1p(value)
def _resolve_compression(self, compression: Optional[str], use_log: bool) -> str:
"""规范化压缩方式"""
if compression is None:
return "log1p" if use_log else "none"
compression_key = str(compression).strip().lower()
if compression_key in ("none", "log1p", "asinh"):
return compression_key
if hasattr(self, "logger"):
self.logger.warning("未知压缩方式: %s,已降级为 none", compression)
return "none"

View File

@@ -0,0 +1,461 @@
# -*- coding: utf-8 -*-
"""
会员层召回/转化指数共享逻辑
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import TaskContext
@dataclass
class MemberActivityData:
"""Shared member activity features for WBI/NCI."""
member_id: int
site_id: int
tenant_id: int
member_create_time: Optional[datetime] = None
first_visit_time: Optional[datetime] = None
last_visit_time: Optional[datetime] = None
last_recharge_time: Optional[datetime] = None
t_v: float = 60.0
t_r: float = 60.0
t_a: float = 60.0
days_since_first_visit: Optional[int] = None
days_since_last_visit: Optional[int] = None
days_since_last_recharge: Optional[int] = None
visits_14d: int = 0
visits_60d: int = 0
visits_total: int = 0
spend_30d: float = 0.0
spend_180d: float = 0.0
sv_balance: float = 0.0
recharge_60d_amt: float = 0.0
interval_count: int = 0
intervals: List[float] = field(default_factory=list)
interval_ages_days: List[int] = field(default_factory=list)
recharge_unconsumed: int = 0
class MemberIndexBaseTask(BaseIndexTask):
"""Shared extraction and feature building for WBI/NCI."""
DEFAULT_VISIT_LOOKBACK_DAYS = 180
DEFAULT_RECENCY_LOOKBACK_DAYS = 60
CASH_CARD_TYPE_ID = 2793249295533893
def _get_site_id(self, context: Optional[TaskContext]) -> int:
"""获取门店ID"""
if context and hasattr(context, 'store_id') and context.store_id:
return context.store_id
site_id = self.config.get('app.default_site_id') or self.config.get('app.store_id')
if site_id is not None:
return int(site_id)
sql = "SELECT DISTINCT site_id FROM billiards_dwd.dwd_settlement_head WHERE site_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
value = dict(rows[0]).get('site_id')
if value is not None:
return int(value)
self.logger.warning("无法确定门店ID使用 0 继续执行")
return 0
def _get_tenant_id(self) -> int:
"""获取租户ID"""
tenant_id = self.config.get('app.tenant_id')
if tenant_id is not None:
return int(tenant_id)
sql = "SELECT DISTINCT tenant_id FROM billiards_dwd.dwd_settlement_head WHERE tenant_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
value = dict(rows[0]).get('tenant_id')
if value is not None:
return int(value)
self.logger.warning("无法确定租户ID使用 0 继续执行")
return 0
def _load_params(self) -> Dict[str, float]:
"""Load index parameters with defaults and runtime overrides."""
params = self.load_index_parameters()
result = dict(self.DEFAULT_PARAMS)
result.update(params)
# GUI/环境变量可通过 run.index_lookback_days 覆盖 recency 窗口
override_days = self.config.get('run.index_lookback_days')
if override_days is not None:
try:
override_days_int = int(override_days)
if override_days_int < 7 or override_days_int > 180:
self.logger.warning(
"%s: run.index_lookback_days=%s 超出建议范围[7,180],已自动截断",
self.get_task_code(),
override_days,
)
override_days_int = max(7, min(180, override_days_int))
result['lookback_days_recency'] = float(override_days_int)
self.logger.info(
"%s: 使用回溯天数覆盖 lookback_days_recency=%d",
self.get_task_code(),
override_days_int,
)
except (TypeError, ValueError):
self.logger.warning(
"%s: run.index_lookback_days=%s is invalid; ignore override and use parameter table value",
self.get_task_code(),
override_days,
)
return result
def _build_visit_condition_sql(self) -> str:
"""Build visit-scope condition SQL."""
return """
(
s.settle_type = 1
OR (
s.settle_type = 3
AND EXISTS (
SELECT 1
FROM billiards_dwd.dwd_assistant_service_log asl
JOIN billiards_dws.cfg_skill_type st
ON asl.skill_id = st.skill_id
AND st.course_type_code = 'BONUS'
AND st.is_active = TRUE
WHERE asl.order_settle_id = s.order_settle_id
AND asl.site_id = s.site_id
AND asl.tenant_member_id = s.member_id
AND asl.is_delete = 0
)
)
)
"""
def _extract_visit_day_rows(
self,
site_id: int,
start_date: date,
end_date: date,
) -> List[Dict[str, Any]]:
"""提取到店记录(按天去重)"""
condition_sql = self._build_visit_condition_sql()
sql = f"""
WITH visit_source AS (
SELECT
COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) AS canonical_member_id,
s.pay_time,
s.pay_amount
FROM billiards_dwd.dwd_settlement_head s
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON s.member_card_account_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = s.site_id
AND COALESCE(mca.is_delete, 0) = 0
WHERE s.site_id = %s
AND s.pay_time >= %s
AND s.pay_time < %s + INTERVAL '1 day'
AND {condition_sql}
)
SELECT
canonical_member_id AS member_id,
DATE(pay_time) AS visit_date,
MAX(pay_time) AS last_visit_time,
SUM(COALESCE(pay_amount, 0)) AS day_pay_amount
FROM visit_source
WHERE canonical_member_id > 0
GROUP BY canonical_member_id, DATE(pay_time)
ORDER BY canonical_member_id, visit_date
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
return [dict(row) for row in (rows or [])]
def _extract_recharge_rows(
self,
site_id: int,
start_date: date,
end_date: date,
) -> Dict[int, Dict[str, Any]]:
"""提取充值记录近60天"""
sql = """
WITH recharge_source AS (
SELECT
COALESCE(NULLIF(r.member_id, 0), mca.tenant_member_id) AS canonical_member_id,
r.pay_time,
r.pay_amount
FROM billiards_dwd.dwd_recharge_order r
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON r.tenant_member_card_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = r.site_id
AND COALESCE(mca.is_delete, 0) = 0
WHERE r.site_id = %s
AND r.settle_type = 5
AND r.pay_time >= %s
AND r.pay_time < %s + INTERVAL '1 day'
)
SELECT
canonical_member_id AS member_id,
MAX(pay_time) AS last_recharge_time,
SUM(COALESCE(pay_amount, 0)) AS recharge_60d_amt
FROM recharge_source
WHERE canonical_member_id > 0
GROUP BY canonical_member_id
"""
rows = self.db.query(sql, (site_id, start_date, end_date))
result: Dict[int, Dict[str, Any]] = {}
for row in (rows or []):
row_dict = dict(row)
result[int(row_dict['member_id'])] = row_dict
return result
def _extract_member_create_times(self, member_ids: List[int]) -> Dict[int, datetime]:
"""提取会员建档时间"""
if not member_ids:
return {}
member_ids_str = ','.join(str(m) for m in member_ids)
sql = f"""
SELECT
member_id,
create_time
FROM billiards_dwd.dim_member
WHERE member_id IN ({member_ids_str})
AND scd2_is_current = 1
"""
rows = self.db.query(sql)
result = {}
for row in (rows or []):
row_dict = dict(row)
member_id = int(row_dict['member_id'])
create_time = row_dict.get('create_time')
if create_time:
result[member_id] = create_time
return result
def _extract_first_visit_times(self, site_id: int, member_ids: List[int]) -> Dict[int, datetime]:
"""提取首次到店时间(全量)"""
if not member_ids:
return {}
member_ids_str = ','.join(str(m) for m in member_ids)
condition_sql = self._build_visit_condition_sql()
sql = f"""
WITH visit_source AS (
SELECT
COALESCE(NULLIF(s.member_id, 0), mca.tenant_member_id) AS canonical_member_id,
s.pay_time
FROM billiards_dwd.dwd_settlement_head s
LEFT JOIN billiards_dwd.dim_member_card_account mca
ON s.member_card_account_id = mca.member_card_id
AND mca.scd2_is_current = 1
AND mca.register_site_id = s.site_id
AND COALESCE(mca.is_delete, 0) = 0
WHERE s.site_id = %s
AND {condition_sql}
)
SELECT
canonical_member_id AS member_id,
MIN(pay_time) AS first_visit_time
FROM visit_source
WHERE canonical_member_id IN ({member_ids_str})
GROUP BY canonical_member_id
"""
rows = self.db.query(sql, (site_id,))
result = {}
for row in (rows or []):
row_dict = dict(row)
member_id = int(row_dict['member_id'])
first_visit_time = row_dict.get('first_visit_time')
if first_visit_time:
result[member_id] = first_visit_time
return result
def _extract_sv_balances(self, site_id: int, tenant_id: int, member_ids: List[int]) -> Dict[int, Decimal]:
"""Fetch member stored-value card balances."""
if not member_ids:
return {}
member_ids_str = ','.join(str(m) for m in member_ids)
sql = f"""
SELECT
tenant_member_id AS member_id,
SUM(CASE WHEN card_type_id = %s THEN balance ELSE 0 END) AS sv_balance
FROM billiards_dwd.dim_member_card_account
WHERE tenant_id = %s
AND register_site_id = %s
AND scd2_is_current = 1
AND COALESCE(is_delete, 0) = 0
AND tenant_member_id IN ({member_ids_str})
GROUP BY tenant_member_id
"""
rows = self.db.query(sql, (self.CASH_CARD_TYPE_ID, tenant_id, site_id))
result: Dict[int, Decimal] = {}
for row in (rows or []):
row_dict = dict(row)
member_id = int(row_dict['member_id'])
result[member_id] = row_dict.get('sv_balance') or Decimal('0')
return result
def _build_member_activity(
self,
site_id: int,
tenant_id: int,
params: Dict[str, float],
) -> Dict[int, MemberActivityData]:
"""构建会员活动特征"""
now = datetime.now(self.tz)
base_date = now.date()
visit_lookback_days = int(params.get('visit_lookback_days', self.DEFAULT_VISIT_LOOKBACK_DAYS))
recency_days = int(params.get('lookback_days_recency', self.DEFAULT_RECENCY_LOOKBACK_DAYS))
visit_start_date = base_date - timedelta(days=visit_lookback_days)
visit_rows = self._extract_visit_day_rows(site_id, visit_start_date, base_date)
member_day_rows: Dict[int, List[Dict[str, Any]]] = {}
for row in (visit_rows or []):
member_id = int(row['member_id'])
member_day_rows.setdefault(member_id, []).append(row)
recharge_start_date = base_date - timedelta(days=recency_days)
recharge_rows = self._extract_recharge_rows(site_id, recharge_start_date, base_date)
member_ids = set(member_day_rows.keys()) | set(recharge_rows.keys())
if not member_ids:
return {}
member_id_list = list(member_ids)
member_create_times = self._extract_member_create_times(member_id_list)
first_visit_times = self._extract_first_visit_times(site_id, member_id_list)
sv_balances = self._extract_sv_balances(site_id, tenant_id, member_id_list)
results: Dict[int, MemberActivityData] = {}
for member_id in member_ids:
data = MemberActivityData(
member_id=member_id,
site_id=site_id,
tenant_id=tenant_id,
)
day_rows = member_day_rows.get(member_id, [])
if day_rows:
day_rows_sorted = sorted(day_rows, key=lambda x: x['visit_date'])
data.visits_total = len(day_rows_sorted)
last_visit_time = max(r.get('last_visit_time') for r in day_rows_sorted)
data.last_visit_time = last_visit_time
# 近14/60天到店次数
days_14_ago = base_date - timedelta(days=14)
days_60_ago = base_date - timedelta(days=60)
for r in day_rows_sorted:
visit_date = r.get('visit_date')
if visit_date is None:
continue
if visit_date >= days_14_ago:
data.visits_14d += 1
if visit_date >= days_60_ago:
data.visits_60d += 1
# 消费金额
days_30_ago = base_date - timedelta(days=30)
for r in day_rows_sorted:
visit_date = r.get('visit_date')
day_pay = float(r.get('day_pay_amount') or 0)
data.spend_180d += day_pay
if visit_date and visit_date >= days_30_ago:
data.spend_30d += day_pay
# 计算到店间隔(按天)
visit_dates = [r.get('visit_date') for r in day_rows_sorted if r.get('visit_date')]
intervals: List[float] = []
interval_ages_days: List[int] = []
for i in range(1, len(visit_dates)):
interval = (visit_dates[i] - visit_dates[i - 1]).days
intervals.append(float(min(recency_days, interval)))
interval_ages_days.append(max(0, (base_date - visit_dates[i]).days))
data.intervals = intervals
data.interval_ages_days = interval_ages_days
data.interval_count = len(intervals)
recharge_info = recharge_rows.get(member_id)
if recharge_info:
data.last_recharge_time = recharge_info.get('last_recharge_time')
data.recharge_60d_amt = float(recharge_info.get('recharge_60d_amt') or 0)
data.member_create_time = member_create_times.get(member_id)
data.first_visit_time = first_visit_times.get(member_id)
sv_balance = sv_balances.get(member_id)
if sv_balance is not None:
data.sv_balance = float(sv_balance)
# 时间差计算
if data.first_visit_time:
data.days_since_first_visit = (base_date - data.first_visit_time.date()).days
if data.last_visit_time:
data.days_since_last_visit = (base_date - data.last_visit_time.date()).days
if data.last_recharge_time:
data.days_since_last_recharge = (base_date - data.last_recharge_time.date()).days
# tV/tR/tA
data.t_v = float(min(recency_days, data.days_since_last_visit)) if data.days_since_last_visit is not None else float(recency_days)
data.t_r = float(min(recency_days, data.days_since_last_recharge)) if data.days_since_last_recharge is not None else float(recency_days)
data.t_a = float(min(data.t_v, data.t_r))
# 充值是否未回访
if data.last_recharge_time and (data.last_visit_time is None or data.last_recharge_time > data.last_visit_time):
data.recharge_unconsumed = 1
results[member_id] = data
return results
def classify_segment(
self,
data: MemberActivityData,
params: Dict[str, float],
) -> Tuple[str, str, bool]:
"""Classify member into NEW/OLD/STOP buckets."""
recency_days = int(params.get('lookback_days_recency', self.DEFAULT_RECENCY_LOOKBACK_DAYS))
enable_stop_exception = int(params.get('enable_stop_high_balance_exception', 0)) == 1
high_balance_threshold = float(params.get('high_balance_threshold', 1000))
if data.t_a >= recency_days:
if enable_stop_exception and data.sv_balance >= high_balance_threshold:
return "STOP", "STOP_HIGH_BALANCE", True
return "STOP", "STOP", False
new_visit_threshold = int(params.get('new_visit_threshold', 2))
new_days_threshold = int(params.get('new_days_threshold', 30))
recharge_recent_days = int(params.get('recharge_recent_days', 14))
new_recharge_max_visits = int(params.get('new_recharge_max_visits', 10))
is_new_by_visits = data.visits_total <= new_visit_threshold
is_new_by_first_visit = data.days_since_first_visit is not None and data.days_since_first_visit <= new_days_threshold
is_new_by_recharge = (
data.recharge_unconsumed == 1
and data.days_since_last_recharge is not None
and data.days_since_last_recharge <= recharge_recent_days
and data.visits_total <= new_recharge_max_visits
)
if is_new_by_visits or is_new_by_first_visit or is_new_by_recharge:
return "NEW", "NEW", True
return "OLD", "OLD", True

View File

@@ -0,0 +1,623 @@
# -*- coding: utf-8 -*-
"""
ML 人工台账导入任务。
设计目标:
1. 人工台账作为 ML 唯一真源;
2. 同一订单支持多助教归因,默认均分;
3. 覆盖策略:
- 近 30 天:按 site_id + biz_date 日覆盖;
- 超过 30 天按固定纪元2026-01-01切 30 天批次覆盖。
"""
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import TaskContext
@dataclass(frozen=True)
class ImportScope:
"""导入覆盖范围定义。"""
site_id: int
scope_type: str # DAY / P30
start_date: date
end_date: date
@property
def scope_key(self) -> str:
if self.scope_type == "DAY":
return f"DAY:{self.site_id}:{self.start_date.isoformat()}"
return (
f"P30:{self.site_id}:{self.start_date.isoformat()}:{self.end_date.isoformat()}"
)
class MlManualImportTask(BaseIndexTask):
"""导入并拆分 ML 人工台账(订单宽表 + 助教分摊窄表)。"""
INDEX_TYPE = "ML"
EPOCH_ANCHOR = date(2026, 1, 1)
HISTORICAL_BUCKET_DAYS = 30
ASSISTANT_SLOT_COUNT = 5
# Excel 模板字段(按列顺序)
TEMPLATE_COLUMNS = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"currency",
"assistant_id_1",
"assistant_name_1",
"assistant_id_2",
"assistant_name_2",
"assistant_id_3",
"assistant_name_3",
"assistant_id_4",
"assistant_name_4",
"assistant_id_5",
"assistant_name_5",
"remark",
]
def get_task_code(self) -> str:
return "DWS_ML_MANUAL_IMPORT"
def get_target_table(self) -> str:
return "dws_ml_manual_order_source"
def get_primary_keys(self) -> List[str]:
return ["site_id", "external_id", "import_scope_key", "row_no"]
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""
执行导入。
说明:该任务按“文件”运行,不依赖时间窗口。调度器会以工具任务方式直接触发。
"""
file_path = self._resolve_file_path()
if not file_path:
raise ValueError(
"未找到 ML 台账文件,请通过环境变量 ML_MANUAL_LEDGER_FILE 或配置 run.ml_manual_ledger_file 指定"
)
rows = self._read_excel_rows(file_path)
if not rows:
self.logger.warning("台账文件为空:%s", file_path)
return {
"status": "SUCCESS",
"counts": {
"source_rows": 0,
"alloc_rows": 0,
"deleted_source_rows": 0,
"deleted_alloc_rows": 0,
"scopes": 0,
},
}
now = datetime.now(self.tz)
today = now.date()
import_batch_no = self._build_import_batch_no(now)
import_file_name = Path(file_path).name
import_user = self._resolve_import_user()
source_rows: List[Dict[str, Any]] = []
alloc_rows: List[Dict[str, Any]] = []
scope_set: Dict[Tuple[int, str, date, date], ImportScope] = {}
for idx, raw in enumerate(rows, start=2):
normalized = self._normalize_row(raw, row_no=idx, file_path=file_path)
row_scope = self.resolve_scope(
site_id=normalized["site_id"],
biz_date=normalized["biz_date"],
today=today,
)
scope_set[(row_scope.site_id, row_scope.scope_type, row_scope.start_date, row_scope.end_date)] = row_scope
source_row = self._build_source_row(
normalized=normalized,
scope=row_scope,
import_batch_no=import_batch_no,
import_file_name=import_file_name,
import_user=import_user,
import_time=now,
)
source_rows.append(source_row)
alloc_rows.extend(
self._build_alloc_rows(
normalized=normalized,
scope=row_scope,
import_batch_no=import_batch_no,
import_file_name=import_file_name,
import_user=import_user,
import_time=now,
)
)
scopes = list(scope_set.values())
deleted_source_rows, deleted_alloc_rows = self._delete_by_scopes(scopes)
inserted_source = self._insert_source_rows(source_rows)
upserted_alloc = self._upsert_alloc_rows(alloc_rows)
self.db.conn.commit()
self.logger.info(
"ML 人工台账导入完成: file=%s source=%d alloc=%d scopes=%d",
file_path,
inserted_source,
upserted_alloc,
len(scopes),
)
return {
"status": "SUCCESS",
"counts": {
"source_rows": inserted_source,
"alloc_rows": upserted_alloc,
"deleted_source_rows": deleted_source_rows,
"deleted_alloc_rows": deleted_alloc_rows,
"scopes": len(scopes),
},
}
def _resolve_file_path(self) -> Optional[str]:
"""解析台账文件路径。"""
raw_path = (
self.config.get("run.ml_manual_ledger_file")
or self.config.get("run.ml_manual_file")
or os.getenv("ML_MANUAL_LEDGER_FILE")
)
if not raw_path:
return None
candidate = Path(str(raw_path)).expanduser()
if not candidate.is_absolute():
candidate = Path.cwd() / candidate
if not candidate.exists():
raise FileNotFoundError(f"台账文件不存在: {candidate}")
return str(candidate)
def _read_excel_rows(self, file_path: str) -> List[Dict[str, Any]]:
"""读取 Excel 为行字典列表。"""
try:
from openpyxl import load_workbook
except Exception as exc: # noqa: BLE001
raise RuntimeError(
"缺少 openpyxl 依赖,无法读取 Excel请先安装 openpyxl"
) from exc
wb = load_workbook(file_path, data_only=True)
ws = wb.active
header_row = next(ws.iter_rows(min_row=1, max_row=1, values_only=True), None)
if not header_row:
return []
headers = [str(col).strip() if col is not None else "" for col in header_row]
if not headers:
return []
rows: List[Dict[str, Any]] = []
for values in ws.iter_rows(min_row=2, values_only=True):
if values is None:
continue
row_dict = {headers[i]: values[i] for i in range(min(len(headers), len(values)))}
if self._is_empty_row(row_dict):
continue
rows.append(row_dict)
return rows
@staticmethod
def _is_empty_row(row: Dict[str, Any]) -> bool:
for value in row.values():
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
return False
return True
def _normalize_row(
self,
raw: Dict[str, Any],
row_no: int,
file_path: str,
) -> Dict[str, Any]:
"""规范化单行字段。"""
site_id = self._to_int(raw.get("site_id"), fallback=self.config.get("app.store_id"))
biz_date = self._to_date(raw.get("biz_date"))
pay_time = self._to_datetime(raw.get("pay_time"), fallback_date=biz_date)
external_id = str(raw.get("external_id") or "").strip()
if not external_id:
raise ValueError(f"台账行 {row_no} 缺少 external_id订单ID: {file_path}")
member_id = self._to_int(raw.get("member_id"), fallback=0)
order_amount = self._to_decimal(raw.get("order_amount"))
currency = str(raw.get("currency") or "CNY").strip().upper() or "CNY"
remark = str(raw.get("remark") or "").strip()
assistants: List[Tuple[int, str]] = []
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
aid = self._to_int(raw.get(f"assistant_id_{idx}"), fallback=None)
name = str(raw.get(f"assistant_name_{idx}") or "").strip()
if aid is None:
continue
assistants.append((aid, name))
return {
"site_id": site_id,
"biz_date": biz_date,
"external_id": external_id,
"member_id": member_id,
"pay_time": pay_time,
"order_amount": order_amount,
"currency": currency,
"assistants": assistants,
"remark": remark,
"row_no": row_no,
}
def _build_source_row(
self,
*,
normalized: Dict[str, Any],
scope: ImportScope,
import_batch_no: str,
import_file_name: str,
import_user: str,
import_time: datetime,
) -> Dict[str, Any]:
"""构造宽表入库行。"""
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
row = {
"site_id": normalized["site_id"],
"biz_date": normalized["biz_date"],
"external_id": normalized["external_id"],
"member_id": normalized["member_id"],
"pay_time": normalized["pay_time"],
"order_amount": normalized["order_amount"],
"currency": normalized["currency"],
"import_batch_no": import_batch_no,
"import_file_name": import_file_name,
"import_scope_key": scope.scope_key,
"import_time": import_time,
"import_user": import_user,
"row_no": normalized["row_no"],
"remark": normalized["remark"],
}
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
aid, aname = (assistants[idx - 1] if idx - 1 < len(assistants) else (None, None))
row[f"assistant_id_{idx}"] = aid
row[f"assistant_name_{idx}"] = aname
return row
def _build_alloc_rows(
self,
*,
normalized: Dict[str, Any],
scope: ImportScope,
import_batch_no: str,
import_file_name: str,
import_user: str,
import_time: datetime,
) -> List[Dict[str, Any]]:
"""构造窄表分摊行。"""
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
if not assistants:
return []
n = Decimal(str(len(assistants)))
share_ratio = Decimal("1") / n
rows: List[Dict[str, Any]] = []
for assistant_id, assistant_name in assistants:
allocated_amount = normalized["order_amount"] * share_ratio
rows.append(
{
"site_id": normalized["site_id"],
"biz_date": normalized["biz_date"],
"external_id": normalized["external_id"],
"member_id": normalized["member_id"],
"pay_time": normalized["pay_time"],
"order_amount": normalized["order_amount"],
"assistant_id": assistant_id,
"assistant_name": assistant_name,
"share_ratio": share_ratio,
"allocated_amount": allocated_amount,
"currency": normalized["currency"],
"import_scope_key": scope.scope_key,
"import_batch_no": import_batch_no,
"import_file_name": import_file_name,
"import_time": import_time,
"import_user": import_user,
}
)
return rows
@classmethod
def resolve_scope(cls, site_id: int, biz_date: date, today: date) -> ImportScope:
"""按规则解析覆盖范围。"""
day_diff = (today - biz_date).days
if day_diff <= cls.HISTORICAL_BUCKET_DAYS:
return ImportScope(
site_id=site_id,
scope_type="DAY",
start_date=biz_date,
end_date=biz_date,
)
bucket_start, bucket_end = cls.resolve_p30_bucket(biz_date)
return ImportScope(
site_id=site_id,
scope_type="P30",
start_date=bucket_start,
end_date=bucket_end,
)
@classmethod
def resolve_p30_bucket(cls, biz_date: date) -> Tuple[date, date]:
"""固定纪元 30 天分桶。"""
delta_days = (biz_date - cls.EPOCH_ANCHOR).days
bucket_index = delta_days // cls.HISTORICAL_BUCKET_DAYS
bucket_start = cls.EPOCH_ANCHOR + timedelta(days=bucket_index * cls.HISTORICAL_BUCKET_DAYS)
bucket_end = bucket_start + timedelta(days=cls.HISTORICAL_BUCKET_DAYS - 1)
return bucket_start, bucket_end
def _delete_by_scopes(self, scopes: Iterable[ImportScope]) -> Tuple[int, int]:
"""按 scope 先删后写,保证整批覆盖。"""
deleted_source = 0
deleted_alloc = 0
with self.db.conn.cursor() as cur:
for scope in scopes:
if scope.scope_type == "DAY":
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_source
WHERE site_id = %s AND biz_date = %s
""",
(scope.site_id, scope.start_date),
)
deleted_source += max(cur.rowcount, 0)
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s AND biz_date = %s
""",
(scope.site_id, scope.start_date),
)
deleted_alloc += max(cur.rowcount, 0)
else:
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_source
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
""",
(scope.site_id, scope.start_date, scope.end_date),
)
deleted_source += max(cur.rowcount, 0)
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
""",
(scope.site_id, scope.start_date, scope.end_date),
)
deleted_alloc += max(cur.rowcount, 0)
return deleted_source, deleted_alloc
def _insert_source_rows(self, rows: List[Dict[str, Any]]) -> int:
if not rows:
return 0
columns = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"currency",
"assistant_id_1",
"assistant_name_1",
"assistant_id_2",
"assistant_name_2",
"assistant_id_3",
"assistant_name_3",
"assistant_id_4",
"assistant_name_4",
"assistant_id_5",
"assistant_name_5",
"import_batch_no",
"import_file_name",
"import_scope_key",
"import_time",
"import_user",
"row_no",
"remark",
"created_at",
"updated_at",
]
sql = f"""
INSERT INTO billiards_dws.dws_ml_manual_order_source ({", ".join(columns)})
VALUES ({", ".join(["%s"] * len(columns))})
"""
inserted = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [
row.get("site_id"),
row.get("biz_date"),
row.get("external_id"),
row.get("member_id"),
row.get("pay_time"),
row.get("order_amount"),
row.get("currency"),
row.get("assistant_id_1"),
row.get("assistant_name_1"),
row.get("assistant_id_2"),
row.get("assistant_name_2"),
row.get("assistant_id_3"),
row.get("assistant_name_3"),
row.get("assistant_id_4"),
row.get("assistant_name_4"),
row.get("assistant_id_5"),
row.get("assistant_name_5"),
row.get("import_batch_no"),
row.get("import_file_name"),
row.get("import_scope_key"),
row.get("import_time"),
row.get("import_user"),
row.get("row_no"),
row.get("remark"),
row.get("import_time"),
row.get("import_time"),
]
cur.execute(sql, values)
inserted += max(cur.rowcount, 0)
return inserted
def _upsert_alloc_rows(self, rows: List[Dict[str, Any]]) -> int:
if not rows:
return 0
columns = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"assistant_id",
"assistant_name",
"share_ratio",
"allocated_amount",
"currency",
"import_scope_key",
"import_batch_no",
"import_file_name",
"import_time",
"import_user",
"created_at",
"updated_at",
]
sql = f"""
INSERT INTO billiards_dws.dws_ml_manual_order_alloc ({", ".join(columns)})
VALUES ({", ".join(["%s"] * len(columns))})
ON CONFLICT (site_id, external_id, assistant_id)
DO UPDATE SET
biz_date = EXCLUDED.biz_date,
member_id = EXCLUDED.member_id,
pay_time = EXCLUDED.pay_time,
order_amount = EXCLUDED.order_amount,
assistant_name = EXCLUDED.assistant_name,
share_ratio = EXCLUDED.share_ratio,
allocated_amount = EXCLUDED.allocated_amount,
currency = EXCLUDED.currency,
import_scope_key = EXCLUDED.import_scope_key,
import_batch_no = EXCLUDED.import_batch_no,
import_file_name = EXCLUDED.import_file_name,
import_time = EXCLUDED.import_time,
import_user = EXCLUDED.import_user,
updated_at = NOW()
"""
affected = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [
row.get("site_id"),
row.get("biz_date"),
row.get("external_id"),
row.get("member_id"),
row.get("pay_time"),
row.get("order_amount"),
row.get("assistant_id"),
row.get("assistant_name"),
row.get("share_ratio"),
row.get("allocated_amount"),
row.get("currency"),
row.get("import_scope_key"),
row.get("import_batch_no"),
row.get("import_file_name"),
row.get("import_time"),
row.get("import_user"),
row.get("import_time"),
row.get("import_time"),
]
cur.execute(sql, values)
affected += max(cur.rowcount, 0)
return affected
@staticmethod
def _to_int(value: Any, fallback: Optional[int] = None) -> Optional[int]:
if value is None:
return fallback
if isinstance(value, str) and not value.strip():
return fallback
try:
return int(value)
except Exception: # noqa: BLE001
return fallback
@staticmethod
def _to_decimal(value: Any) -> Decimal:
if value is None or value == "":
return Decimal("0")
return Decimal(str(value))
@staticmethod
def _to_date(value: Any) -> date:
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
if isinstance(value, str):
text = value.strip()
if not text:
raise ValueError("biz_date 不能为空")
if len(text) >= 10:
return datetime.fromisoformat(text[:10]).date()
return datetime.fromisoformat(text).date()
raise ValueError(f"无法解析 biz_date: {value}")
@staticmethod
def _to_datetime(value: Any, fallback_date: date) -> datetime:
if isinstance(value, datetime):
return value
if isinstance(value, date):
return datetime.combine(value, datetime.min.time())
if isinstance(value, str):
text = value.strip()
if text:
text = text.replace("/", "-")
try:
return datetime.fromisoformat(text)
except Exception: # noqa: BLE001
if len(text) >= 19:
return datetime.strptime(text[:19], "%Y-%m-%d %H:%M:%S")
return datetime.fromisoformat(text[:10])
return datetime.combine(fallback_date, datetime.min.time())
@staticmethod
def _build_import_batch_no(now: datetime) -> str:
return f"MLM_{now.strftime('%Y%m%d%H%M%S')}_{str(uuid.uuid4())[:8]}"
@staticmethod
def _resolve_import_user() -> str:
return (
os.getenv("ETL_OPERATOR")
or os.getenv("USERNAME")
or os.getenv("USER")
or "system"
)
__all__ = ["MlManualImportTask", "ImportScope"]

View File

@@ -0,0 +1,381 @@
# -*- coding: utf-8 -*-
"""
新客转化指数NCI计算任务。"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from .member_index_base import MemberActivityData, MemberIndexBaseTask
from ..base_dws_task import TaskContext
@dataclass
class MemberNewconvData:
activity: MemberActivityData
status: str
segment: str
need_new: float = 0.0
salvage_new: float = 0.0
recharge_new: float = 0.0
value_new: float = 0.0
welcome_new: float = 0.0
raw_score_welcome: Optional[float] = None
raw_score_convert: Optional[float] = None
raw_score: Optional[float] = None
display_score_welcome: Optional[float] = None
display_score_convert: Optional[float] = None
display_score: Optional[float] = None
class NewconvIndexTask(MemberIndexBaseTask):
"""新客转化指数NCI计算任务。"""
INDEX_TYPE = "NCI"
DEFAULT_PARAMS = {
# 通用参数
'lookback_days_recency': 60,
'visit_lookback_days': 180,
'percentile_lower': 5,
'percentile_upper': 95,
'compression_mode': 0,
'use_smoothing': 1,
'ewma_alpha': 0.2,
# 分流参数
'new_visit_threshold': 2,
'new_days_threshold': 30,
'recharge_recent_days': 14,
'new_recharge_max_visits': 10,
# NCI参数
'no_touch_days_new': 3,
't2_target_days': 7,
'salvage_start': 30,
'salvage_end': 60,
'welcome_window_days': 3,
'active_new_visit_threshold_14d': 2,
'active_new_recency_days': 7,
'active_new_penalty': 0.2,
'h_recharge': 7,
'amount_base_M0': 300,
'balance_base_B0': 500,
'value_w_spend': 1.0,
'value_w_bal': 0.8,
'w_welcome': 1.0,
'w_need': 1.6,
'w_re': 0.8,
'w_value': 1.0,
# STOP高余额例外默认关闭
'enable_stop_high_balance_exception': 0,
'high_balance_threshold': 1000,
}
def get_task_code(self) -> str:
return "DWS_NEWCONV_INDEX"
def get_target_table(self) -> str:
return "dws_member_newconv_index"
def get_primary_keys(self) -> List[str]:
return ['site_id', 'member_id']
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""执行 NCI 计算"""
self.logger.info("开始计算新客转化指数(NCI)")
site_id = self._get_site_id(context)
tenant_id = self._get_tenant_id()
params = self._load_params()
activity_map = self._build_member_activity(site_id, tenant_id, params)
if not activity_map:
self.logger.warning("No member activity data available; skip calculation")
return {'status': 'skipped', 'reason': 'no_data'}
newconv_list: List[MemberNewconvData] = []
for activity in activity_map.values():
segment, status, in_scope = self.classify_segment(activity, params)
if not in_scope:
continue
if segment != "NEW":
continue
data = MemberNewconvData(activity=activity, status=status, segment=segment)
self._calculate_nci_scores(data, params)
newconv_list.append(data)
if not newconv_list:
self.logger.warning("No new-member rows to calculate")
return {'status': 'skipped', 'reason': 'no_new_members'}
# 归一化 Display Score
raw_scores = [
(d.activity.member_id, d.raw_score)
for d in newconv_list
if d.raw_score is not None
]
if raw_scores:
use_smoothing = int(params.get('use_smoothing', 1)) == 1
total_score_map = self._normalize_score_pairs(
raw_scores,
params=params,
site_id=site_id,
use_smoothing=use_smoothing,
)
for data in newconv_list:
if data.activity.member_id in total_score_map:
data.display_score = total_score_map[data.activity.member_id]
raw_scores_welcome = [
(d.activity.member_id, d.raw_score_welcome)
for d in newconv_list
if d.raw_score_welcome is not None
]
welcome_score_map = self._normalize_score_pairs(
raw_scores_welcome,
params=params,
site_id=site_id,
use_smoothing=False,
)
for data in newconv_list:
if data.activity.member_id in welcome_score_map:
data.display_score_welcome = welcome_score_map[data.activity.member_id]
raw_scores_convert = [
(d.activity.member_id, d.raw_score_convert)
for d in newconv_list
if d.raw_score_convert is not None
]
convert_score_map = self._normalize_score_pairs(
raw_scores_convert,
params=params,
site_id=site_id,
use_smoothing=False,
)
for data in newconv_list:
if data.activity.member_id in convert_score_map:
data.display_score_convert = convert_score_map[data.activity.member_id]
# 保存分位点历史
all_raw = [float(score) for _, score in raw_scores]
q_l, q_u = self.calculate_percentiles(
all_raw,
int(params['percentile_lower']),
int(params['percentile_upper'])
)
if use_smoothing:
smoothed_l, smoothed_u = self._apply_ewma_smoothing(site_id, q_l, q_u)
else:
smoothed_l, smoothed_u = q_l, q_u
self.save_percentile_history(
site_id=site_id,
percentile_5=q_l,
percentile_95=q_u,
percentile_5_smoothed=smoothed_l,
percentile_95_smoothed=smoothed_u,
record_count=len(all_raw),
min_raw=min(all_raw),
max_raw=max(all_raw),
avg_raw=sum(all_raw) / len(all_raw)
)
inserted = self._save_newconv_data(newconv_list)
self.logger.info("NCI calculation finished, inserted %d rows", inserted)
return {
'status': 'success',
'member_count': len(newconv_list),
'records_inserted': inserted
}
def _calculate_nci_scores(self, data: MemberNewconvData, params: Dict[str, float]) -> None:
"""计算 NCI 分项与 Raw Score"""
activity = data.activity
# 1) 紧迫度
no_touch_days = float(params['no_touch_days_new'])
t2_target_days = float(params['t2_target_days'])
t2_max_days = t2_target_days * 2.0
if t2_max_days <= no_touch_days:
data.need_new = 0.0
else:
data.need_new = self._clip(
(activity.t_v - no_touch_days) / (t2_max_days - no_touch_days),
0.0, 1.0
)
# 2) Salvage30-60天线性衰减
salvage_start = float(params['salvage_start'])
salvage_end = float(params['salvage_end'])
if salvage_end <= salvage_start:
data.salvage_new = 0.0
elif activity.t_a <= salvage_start:
data.salvage_new = 1.0
elif activity.t_a >= salvage_end:
data.salvage_new = 0.0
else:
data.salvage_new = (salvage_end - activity.t_a) / (salvage_end - salvage_start)
# 3) 充值未回访压力
if activity.recharge_unconsumed == 1:
data.recharge_new = self.decay(activity.t_r, params['h_recharge'])
else:
data.recharge_new = 0.0
# 4) 价值分
m0 = float(params['amount_base_M0'])
b0 = float(params['balance_base_B0'])
spend_score = math.log1p(activity.spend_180d / m0) if m0 > 0 else 0.0
bal_score = math.log1p(activity.sv_balance / b0) if b0 > 0 else 0.0
data.value_new = float(params['value_w_spend']) * spend_score + float(params['value_w_bal']) * bal_score
# 5) 欢迎建联分:优先首访后立即触达
welcome_window_days = float(params.get('welcome_window_days', 3))
data.welcome_new = 0.0
if welcome_window_days > 0 and activity.visits_total <= 1 and activity.t_v <= welcome_window_days:
data.welcome_new = self._clip(1.0 - (activity.t_v / welcome_window_days), 0.0, 1.0)
# 6) 抑制高活跃新客在转化召回排名中的权重
active_visit_threshold = int(params.get('active_new_visit_threshold_14d', 2))
active_recency_days = float(params.get('active_new_recency_days', 7))
active_penalty = float(params.get('active_new_penalty', 0.2))
if activity.visits_14d >= active_visit_threshold and activity.t_v <= active_recency_days:
active_multiplier = self._clip(active_penalty, 0.0, 1.0)
else:
active_multiplier = 1.0
# 7) 价值/充值分主要在进入免打扰窗口后生效
if no_touch_days > 0:
touch_multiplier = self._clip(activity.t_v / no_touch_days, 0.0, 1.0)
else:
touch_multiplier = 1.0
data.raw_score_welcome = float(params.get('w_welcome', 1.0)) * data.welcome_new
data.raw_score_convert = active_multiplier * (
float(params['w_need']) * (data.need_new * data.salvage_new)
+ float(params['w_re']) * data.recharge_new * touch_multiplier
+ float(params['w_value']) * data.value_new * touch_multiplier
)
data.raw_score_welcome = max(0.0, data.raw_score_welcome)
data.raw_score_convert = max(0.0, data.raw_score_convert)
data.raw_score = data.raw_score_welcome + data.raw_score_convert
if data.raw_score < 0:
data.raw_score = 0.0
def _save_newconv_data(self, data_list: List[MemberNewconvData]) -> int:
"""保存 NCI 数据"""
if not data_list:
return 0
site_id = data_list[0].activity.site_id
# 按门店全量刷新,避免因分群变化导致过期数据残留。
delete_sql = """
DELETE FROM billiards_dws.dws_member_newconv_index
WHERE site_id = %s
"""
with self.db.conn.cursor() as cur:
cur.execute(delete_sql, (site_id,))
insert_sql = """
INSERT INTO billiards_dws.dws_member_newconv_index (
site_id, tenant_id, member_id,
status, segment,
member_create_time, first_visit_time, last_visit_time, last_recharge_time,
t_v, t_r, t_a,
visits_14d, visits_60d, visits_total,
spend_30d, spend_180d, sv_balance, recharge_60d_amt,
interval_count,
need_new, salvage_new, recharge_new, value_new,
welcome_new,
raw_score_welcome, raw_score_convert, raw_score,
display_score_welcome, display_score_convert, display_score,
last_wechat_touch_time,
calc_time, created_at, updated_at
) VALUES (
%s, %s, %s,
%s, %s,
%s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s,
%s, %s, %s, %s,
%s,
%s, %s, %s,
%s, %s, %s,
%s,
NOW(), NOW(), NOW()
)
"""
inserted = 0
with self.db.conn.cursor() as cur:
for data in data_list:
activity = data.activity
cur.execute(insert_sql, (
activity.site_id, activity.tenant_id, activity.member_id,
data.status, data.segment,
activity.member_create_time, activity.first_visit_time, activity.last_visit_time, activity.last_recharge_time,
activity.t_v, activity.t_r, activity.t_a,
activity.visits_14d, activity.visits_60d, activity.visits_total,
activity.spend_30d, activity.spend_180d, activity.sv_balance, activity.recharge_60d_amt,
activity.interval_count,
data.need_new, data.salvage_new, data.recharge_new, data.value_new,
data.welcome_new,
data.raw_score_welcome, data.raw_score_convert, data.raw_score,
data.display_score_welcome, data.display_score_convert, data.display_score,
None,
))
inserted += cur.rowcount
self.db.conn.commit()
return inserted
def _clip(self, value: float, low: float, high: float) -> float:
return max(low, min(high, value))
def _map_compression(self, params: Dict[str, float]) -> str:
mode = int(params.get('compression_mode', 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
def _normalize_score_pairs(
self,
raw_scores: List[tuple[int, Optional[float]]],
params: Dict[str, float],
site_id: int,
use_smoothing: bool,
) -> Dict[int, float]:
valid_scores = [(member_id, float(score)) for member_id, score in raw_scores if score is not None]
if not valid_scores:
return {}
# 全为0时直接返回避免 MinMax 归一化退化
if all(abs(score) <= 1e-9 for _, score in valid_scores):
return {member_id: 0.0 for member_id, _ in valid_scores}
compression = self._map_compression(params)
normalized = self.batch_normalize_to_display(
valid_scores,
compression=compression,
percentile_lower=int(params['percentile_lower']),
percentile_upper=int(params['percentile_upper']),
use_smoothing=use_smoothing,
site_id=site_id
)
return {member_id: display for member_id, _, display in normalized}
__all__ = ['NewconvIndexTask']

View File

@@ -0,0 +1,695 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 删除 _apply_last_touch_ml 方法及 source_mode/recharge_attribute_hours 参数;
# 更新 docstring 移除 last-touch 备用路径描述;
# Prompt: "ML 只用人工台账,删除所有 last-touch 备用路径"
"""
关系指数任务RS/OS/MS/ML
设计说明:
1. 单任务一次产出 RS / OS / MS / ML写入统一关系表
2. RS/MS 复用服务日志 + 会话合并口径;
3. ML 以人工台账窄表为唯一真源;
4. RS/MS/ML 的 display 映射按 index_type 隔离分位历史。
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Optional, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import CourseType, TaskContext
@dataclass
class ServiceSession:
"""合并后的服务会话。"""
session_start: datetime
session_end: datetime
total_duration_minutes: int
course_weight: float
is_incentive: bool
@dataclass
class RelationPairMetrics:
"""单个 member-assistant 关系对的计算指标。"""
site_id: int
tenant_id: int
member_id: int
assistant_id: int
sessions: List[ServiceSession] = field(default_factory=list)
days_since_last_session: Optional[int] = None
session_count: int = 0
total_duration_minutes: int = 0
basic_session_count: int = 0
incentive_session_count: int = 0
rs_f: float = 0.0
rs_d: float = 0.0
rs_r: float = 0.0
rs_raw: float = 0.0
rs_display: float = 0.0
ms_f_short: float = 0.0
ms_f_long: float = 0.0
ms_raw: float = 0.0
ms_display: float = 0.0
ml_raw: float = 0.0
ml_display: float = 0.0
ml_order_count: int = 0
ml_allocated_amount: float = 0.0
os_share: float = 0.0
os_label: str = "POOL"
os_rank: Optional[int] = None
class RelationIndexTask(BaseIndexTask):
"""关系指数任务:单任务产出 RS / OS / MS / ML。"""
INDEX_TYPE = "RS"
DEFAULT_PARAMS_RS: Dict[str, float] = {
"lookback_days": 60,
"session_merge_hours": 4,
"incentive_weight": 1.5,
"halflife_session": 14.0,
"halflife_last": 10.0,
"weight_f": 1.0,
"weight_d": 0.7,
"gate_alpha": 0.6,
"percentile_lower": 5.0,
"percentile_upper": 95.0,
"compression_mode": 1.0,
"use_smoothing": 1.0,
"ewma_alpha": 0.2,
}
DEFAULT_PARAMS_OS: Dict[str, float] = {
"min_rs_raw_for_ownership": 0.05,
"min_total_rs_raw": 0.10,
"ownership_main_threshold": 0.60,
"ownership_comanage_threshold": 0.35,
"ownership_gap_threshold": 0.15,
"eps": 1e-6,
}
DEFAULT_PARAMS_MS: Dict[str, float] = {
"lookback_days": 60,
"session_merge_hours": 4,
"incentive_weight": 1.5,
"halflife_short": 7.0,
"halflife_long": 30.0,
"eps": 1e-6,
"percentile_lower": 5.0,
"percentile_upper": 95.0,
"compression_mode": 1.0,
"use_smoothing": 1.0,
"ewma_alpha": 0.2,
}
# CHANGE 2026-02-13 | intent: ML 仅使用人工台账,移除 source_mode / recharge_attribute_hours
DEFAULT_PARAMS_ML: Dict[str, float] = {
"lookback_days": 60,
"amount_base": 500.0,
"halflife_recharge": 21.0,
"percentile_lower": 5.0,
"percentile_upper": 95.0,
"compression_mode": 1.0,
"use_smoothing": 1.0,
"ewma_alpha": 0.2,
}
def get_task_code(self) -> str:
return "DWS_RELATION_INDEX"
def get_target_table(self) -> str:
return "dws_member_assistant_relation_index"
def get_primary_keys(self) -> List[str]:
return ["site_id", "member_id", "assistant_id"]
def get_index_type(self) -> str:
# 多指数任务保留一个默认 index_type调用处应显式传 RS/MS/ML
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
self.logger.info("开始计算关系指数RS/OS/MS/ML")
site_id = self._get_site_id(context)
tenant_id = self._get_tenant_id()
now = datetime.now(self.tz)
params_rs = self._load_params("RS", self.DEFAULT_PARAMS_RS)
params_os = self._load_params("OS", self.DEFAULT_PARAMS_OS)
params_ms = self._load_params("MS", self.DEFAULT_PARAMS_MS)
params_ml = self._load_params("ML", self.DEFAULT_PARAMS_ML)
service_lookback_days = max(
int(params_rs.get("lookback_days", 60)),
int(params_ms.get("lookback_days", 60)),
)
service_start = now - timedelta(days=service_lookback_days)
merge_hours = max(
int(params_rs.get("session_merge_hours", 4)),
int(params_ms.get("session_merge_hours", 4)),
)
raw_services = self._extract_service_records(site_id, service_start, now)
pair_map = self._group_and_merge_sessions(
raw_services=raw_services,
merge_hours=merge_hours,
incentive_weight=max(
float(params_rs.get("incentive_weight", 1.5)),
float(params_ms.get("incentive_weight", 1.5)),
),
now=now,
site_id=site_id,
tenant_id=tenant_id,
)
self.logger.info("服务关系对数量: %d", len(pair_map))
self._calculate_rs(pair_map, params_rs, now)
self._calculate_ms(pair_map, params_ms, now)
self._calculate_ml(pair_map, params_ml, site_id, now)
self._calculate_os(pair_map, params_os)
self._apply_display_scores(pair_map, params_rs, params_ms, params_ml, site_id)
inserted = self._save_relation_rows(site_id, list(pair_map.values()))
self.logger.info("关系指数计算完成,写入 %d 条记录", inserted)
return {
"status": "SUCCESS",
"records_inserted": inserted,
"pair_count": len(pair_map),
}
def _load_params(self, index_type: str, defaults: Dict[str, float]) -> Dict[str, float]:
params = dict(defaults)
params.update(self.load_index_parameters(index_type=index_type))
return params
def _extract_service_records(
self,
site_id: int,
start_datetime: datetime,
end_datetime: datetime,
) -> List[Dict[str, Any]]:
"""提取服务记录。"""
sql = """
SELECT
s.tenant_member_id AS member_id,
d.assistant_id AS assistant_id,
s.start_use_time AS start_time,
s.last_use_time AS end_time,
COALESCE(s.income_seconds, 0) / 60 AS duration_minutes,
s.skill_id
FROM billiards_dwd.dwd_assistant_service_log s
JOIN billiards_dwd.dim_assistant d
ON s.user_id = d.user_id
AND d.scd2_is_current = 1
AND COALESCE(d.is_delete, 0) = 0
WHERE s.site_id = %s
AND s.tenant_member_id > 0
AND s.user_id > 0
AND s.is_delete = 0
AND s.last_use_time >= %s
AND s.last_use_time < %s
ORDER BY s.tenant_member_id, d.assistant_id, s.start_use_time
"""
rows = self.db.query(sql, (site_id, start_datetime, end_datetime))
return [dict(row) for row in (rows or [])]
def _group_and_merge_sessions(
self,
*,
raw_services: List[Dict[str, Any]],
merge_hours: int,
incentive_weight: float,
now: datetime,
site_id: int,
tenant_id: int,
) -> Dict[Tuple[int, int], RelationPairMetrics]:
"""按 (member_id, assistant_id) 分组并合并会话。"""
result: Dict[Tuple[int, int], RelationPairMetrics] = {}
if not raw_services:
return result
merge_threshold = timedelta(hours=max(0, merge_hours))
grouped: Dict[Tuple[int, int], List[Dict[str, Any]]] = {}
for row in raw_services:
member_id = int(row["member_id"])
assistant_id = int(row["assistant_id"])
grouped.setdefault((member_id, assistant_id), []).append(row)
for (member_id, assistant_id), records in grouped.items():
metrics = RelationPairMetrics(
site_id=site_id,
tenant_id=tenant_id,
member_id=member_id,
assistant_id=assistant_id,
)
sorted_records = sorted(records, key=lambda r: r["start_time"])
current: Optional[ServiceSession] = None
for svc in sorted_records:
start_time = svc["start_time"]
end_time = svc["end_time"]
duration = int(svc.get("duration_minutes") or 0)
skill_id = int(svc.get("skill_id") or 0)
course_type = self.get_course_type(skill_id)
is_incentive = course_type == CourseType.BONUS
weight = incentive_weight if is_incentive else 1.0
if current is None:
current = ServiceSession(
session_start=start_time,
session_end=end_time,
total_duration_minutes=duration,
course_weight=weight,
is_incentive=is_incentive,
)
continue
if start_time - current.session_end <= merge_threshold:
current.session_end = max(current.session_end, end_time)
current.total_duration_minutes += duration
current.course_weight = max(current.course_weight, weight)
current.is_incentive = current.is_incentive or is_incentive
else:
metrics.sessions.append(current)
current = ServiceSession(
session_start=start_time,
session_end=end_time,
total_duration_minutes=duration,
course_weight=weight,
is_incentive=is_incentive,
)
if current is not None:
metrics.sessions.append(current)
metrics.session_count = len(metrics.sessions)
metrics.total_duration_minutes = sum(s.total_duration_minutes for s in metrics.sessions)
metrics.basic_session_count = sum(1 for s in metrics.sessions if not s.is_incentive)
metrics.incentive_session_count = sum(1 for s in metrics.sessions if s.is_incentive)
if metrics.sessions:
last_session = max(metrics.sessions, key=lambda s: s.session_end)
metrics.days_since_last_session = (now - last_session.session_end).days
result[(member_id, assistant_id)] = metrics
return result
def _calculate_rs(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
now: datetime,
) -> None:
lookback_days = int(params.get("lookback_days", 60))
halflife_session = float(params.get("halflife_session", 14.0))
halflife_last = float(params.get("halflife_last", 10.0))
weight_f = float(params.get("weight_f", 1.0))
weight_d = float(params.get("weight_d", 0.7))
gate_alpha = max(0.0, float(params.get("gate_alpha", 0.6)))
for metrics in pair_map.values():
f_score = 0.0
d_score = 0.0
for session in metrics.sessions:
days_ago = min(
lookback_days,
max(0.0, (now - session.session_end).total_seconds() / 86400.0),
)
decay_factor = self.decay(days_ago, halflife_session)
f_score += session.course_weight * decay_factor
d_score += (
math.sqrt(max(session.total_duration_minutes, 0) / 60.0)
* session.course_weight
* decay_factor
)
if metrics.days_since_last_session is None:
r_score = 0.0
else:
r_score = self.decay(min(lookback_days, metrics.days_since_last_session), halflife_last)
base = weight_f * f_score + weight_d * d_score
gate = math.pow(r_score, gate_alpha) if r_score > 0 else 0.0
metrics.rs_f = f_score
metrics.rs_d = d_score
metrics.rs_r = r_score
metrics.rs_raw = max(0.0, base * gate)
def _calculate_ms(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
now: datetime,
) -> None:
lookback_days = int(params.get("lookback_days", 60))
halflife_short = float(params.get("halflife_short", 7.0))
halflife_long = float(params.get("halflife_long", 30.0))
eps = float(params.get("eps", 1e-6))
for metrics in pair_map.values():
f_short = 0.0
f_long = 0.0
for session in metrics.sessions:
days_ago = min(
lookback_days,
max(0.0, (now - session.session_end).total_seconds() / 86400.0),
)
f_short += session.course_weight * self.decay(days_ago, halflife_short)
f_long += session.course_weight * self.decay(days_ago, halflife_long)
ratio = (f_short + eps) / (f_long + eps)
metrics.ms_f_short = f_short
metrics.ms_f_long = f_long
metrics.ms_raw = max(0.0, self.safe_log(ratio, 0.0))
def _calculate_ml(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
site_id: int,
now: datetime,
) -> None:
lookback_days = int(params.get("lookback_days", 60))
amount_base = float(params.get("amount_base", 500.0))
halflife_recharge = float(params.get("halflife_recharge", 21.0))
start_time = now - timedelta(days=lookback_days)
# CHANGE 2026-02-13 | intent: ML 仅使用人工台账,移除 last-touch 备用路径
manual_rows = self._extract_manual_alloc(site_id, start_time, now)
for row in manual_rows:
member_id = int(row["member_id"])
assistant_id = int(row["assistant_id"])
key = (member_id, assistant_id)
if key not in pair_map:
pair_map[key] = RelationPairMetrics(
site_id=site_id,
tenant_id=pair_map[next(iter(pair_map))].tenant_id if pair_map else self._get_tenant_id(),
member_id=member_id,
assistant_id=assistant_id,
)
metrics = pair_map[key]
amount = float(row.get("allocated_amount") or 0.0)
pay_time = row.get("pay_time")
if amount <= 0 or pay_time is None:
continue
days_ago = min(lookback_days, max(0.0, (now - pay_time).total_seconds() / 86400.0))
metrics.ml_raw += math.log1p(amount / max(amount_base, 1e-6)) * self.decay(
days_ago,
halflife_recharge,
)
metrics.ml_order_count += 1
metrics.ml_allocated_amount += amount
def _extract_manual_alloc(
self,
site_id: int,
start_time: datetime,
end_time: datetime,
) -> List[Dict[str, Any]]:
sql = """
SELECT
member_id,
assistant_id,
pay_time,
allocated_amount
FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s
AND pay_time >= %s
AND pay_time < %s
"""
rows = self.db.query(sql, (site_id, start_time, end_time))
return [dict(row) for row in (rows or [])]
def _calculate_os(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params: Dict[str, float],
) -> None:
min_rs = float(params.get("min_rs_raw_for_ownership", 0.05))
min_total = float(params.get("min_total_rs_raw", 0.10))
main_threshold = float(params.get("ownership_main_threshold", 0.60))
comanage_threshold = float(params.get("ownership_comanage_threshold", 0.35))
gap_threshold = float(params.get("ownership_gap_threshold", 0.15))
member_groups: Dict[int, List[RelationPairMetrics]] = {}
for metrics in pair_map.values():
member_groups.setdefault(metrics.member_id, []).append(metrics)
for _, rows in member_groups.items():
eligible = [row for row in rows if row.rs_raw >= min_rs]
sum_rs = sum(row.rs_raw for row in eligible)
if sum_rs < min_total:
for row in rows:
row.os_share = 0.0
row.os_label = "UNASSIGNED"
row.os_rank = None
continue
for row in rows:
if row.rs_raw >= min_rs:
row.os_share = row.rs_raw / sum_rs
else:
row.os_share = 0.0
sorted_eligible = sorted(
eligible,
key=lambda item: (
-item.os_share,
-item.rs_raw,
item.days_since_last_session if item.days_since_last_session is not None else 10**9,
item.assistant_id,
),
)
for idx, row in enumerate(sorted_eligible, start=1):
row.os_rank = idx
top1 = sorted_eligible[0]
top2_share = sorted_eligible[1].os_share if len(sorted_eligible) > 1 else 0.0
gap = top1.os_share - top2_share
has_main = top1.os_share >= main_threshold and gap >= gap_threshold
if has_main:
for row in rows:
if row is top1:
row.os_label = "MAIN"
elif row.os_share >= comanage_threshold:
row.os_label = "COMANAGE"
else:
row.os_label = "POOL"
else:
for row in rows:
if row.os_share >= comanage_threshold and row.rs_raw >= min_rs:
row.os_label = "COMANAGE"
else:
row.os_label = "POOL"
# 非 eligible 不赋 rank
for row in rows:
if row.rs_raw < min_rs:
row.os_rank = None
def _apply_display_scores(
self,
pair_map: Dict[Tuple[int, int], RelationPairMetrics],
params_rs: Dict[str, float],
params_ms: Dict[str, float],
params_ml: Dict[str, float],
site_id: int,
) -> None:
pair_items = list(pair_map.items())
rs_map = self._normalize_and_record(
raw_pairs=[(key, item.rs_raw) for key, item in pair_items],
params=params_rs,
index_type="RS",
site_id=site_id,
)
ms_map = self._normalize_and_record(
raw_pairs=[(key, item.ms_raw) for key, item in pair_items],
params=params_ms,
index_type="MS",
site_id=site_id,
)
ml_map = self._normalize_and_record(
raw_pairs=[(key, item.ml_raw) for key, item in pair_items],
params=params_ml,
index_type="ML",
site_id=site_id,
)
for key, item in pair_items:
item.rs_display = rs_map.get(key, 0.0)
item.ms_display = ms_map.get(key, 0.0)
item.ml_display = ml_map.get(key, 0.0)
def _normalize_and_record(
self,
*,
raw_pairs: List[Tuple[Any, float]],
params: Dict[str, float],
index_type: str,
site_id: int,
) -> Dict[Any, float]:
if not raw_pairs:
return {}
if all(abs(score) <= 1e-9 for _, score in raw_pairs):
return {entity: 0.0 for entity, _ in raw_pairs}
percentile_lower = int(params.get("percentile_lower", 5))
percentile_upper = int(params.get("percentile_upper", 95))
use_smoothing = int(params.get("use_smoothing", 1)) == 1
compression = self._map_compression(params)
normalized = self.batch_normalize_to_display(
raw_scores=raw_pairs,
compression=compression,
percentile_lower=percentile_lower,
percentile_upper=percentile_upper,
use_smoothing=use_smoothing,
site_id=site_id,
index_type=index_type,
)
display_map = {entity: display for entity, _, display in normalized}
raw_values = [float(score) for _, score in raw_pairs]
q_l, q_u = self.calculate_percentiles(raw_values, percentile_lower, percentile_upper)
if use_smoothing:
smoothed_l, smoothed_u = self._apply_ewma_smoothing(
site_id=site_id,
current_p5=q_l,
current_p95=q_u,
index_type=index_type,
)
else:
smoothed_l, smoothed_u = q_l, q_u
self.save_percentile_history(
site_id=site_id,
percentile_5=q_l,
percentile_95=q_u,
percentile_5_smoothed=smoothed_l,
percentile_95_smoothed=smoothed_u,
record_count=len(raw_values),
min_raw=min(raw_values),
max_raw=max(raw_values),
avg_raw=sum(raw_values) / len(raw_values),
index_type=index_type,
)
return display_map
@staticmethod
def _map_compression(params: Dict[str, float]) -> str:
mode = int(params.get("compression_mode", 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
def _save_relation_rows(self, site_id: int, rows: List[RelationPairMetrics]) -> int:
with self.db.conn.cursor() as cur:
cur.execute(
"DELETE FROM billiards_dws.dws_member_assistant_relation_index WHERE site_id = %s",
(site_id,),
)
if not rows:
self.db.conn.commit()
return 0
insert_sql = """
INSERT INTO billiards_dws.dws_member_assistant_relation_index (
site_id, tenant_id, member_id, assistant_id,
session_count, total_duration_minutes, basic_session_count, incentive_session_count,
days_since_last_session,
rs_f, rs_d, rs_r, rs_raw, rs_display,
os_share, os_label, os_rank,
ms_f_short, ms_f_long, ms_raw, ms_display,
ml_order_count, ml_allocated_amount, ml_raw, ml_display,
calc_time, created_at, updated_at
) VALUES (
%s, %s, %s, %s,
%s, %s, %s, %s,
%s,
%s, %s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s, %s, %s, %s,
NOW(), NOW(), NOW()
)
"""
inserted = 0
for row in rows:
cur.execute(
insert_sql,
(
row.site_id,
row.tenant_id,
row.member_id,
row.assistant_id,
row.session_count,
row.total_duration_minutes,
row.basic_session_count,
row.incentive_session_count,
row.days_since_last_session,
row.rs_f,
row.rs_d,
row.rs_r,
row.rs_raw,
row.rs_display,
row.os_share,
row.os_label,
row.os_rank,
row.ms_f_short,
row.ms_f_long,
row.ms_raw,
row.ms_display,
row.ml_order_count,
row.ml_allocated_amount,
row.ml_raw,
row.ml_display,
),
)
inserted += max(cur.rowcount, 0)
self.db.conn.commit()
return inserted
def _get_site_id(self, context: Optional[TaskContext]) -> int:
if context and getattr(context, "store_id", None):
return int(context.store_id)
site_id = self.config.get("app.default_site_id") or self.config.get("app.store_id")
if site_id is not None:
return int(site_id)
sql = "SELECT DISTINCT site_id FROM billiards_dwd.dwd_assistant_service_log WHERE site_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
return int(dict(rows[0]).get("site_id") or 0)
self.logger.warning("无法确定门店ID使用 0 继续执行")
return 0
def _get_tenant_id(self) -> int:
tenant_id = self.config.get("app.tenant_id")
if tenant_id is not None:
return int(tenant_id)
sql = "SELECT DISTINCT tenant_id FROM billiards_dwd.dwd_assistant_service_log WHERE tenant_id IS NOT NULL LIMIT 1"
rows = self.db.query(sql)
if rows:
return int(dict(rows[0]).get("tenant_id") or 0)
self.logger.warning("无法确定租户ID使用 0 继续执行")
return 0
__all__ = ["RelationIndexTask", "RelationPairMetrics", "ServiceSession"]

View File

@@ -0,0 +1,408 @@
# -*- coding: utf-8 -*-
# AI_CHANGELOG [2026-02-13] 修复 STOP_HIGH_BALANCE 会员不参与评分的逻辑缺陷;
# Prompt: "STOP_HIGH_BALANCE 应该参与 WBI 评分"
"""
老客挽回指数WBI计算任务。"""
from __future__ import annotations
import math
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Any, Dict, List, Optional, Tuple
from .member_index_base import MemberActivityData, MemberIndexBaseTask
from ..base_dws_task import TaskContext
@dataclass
class MemberWinbackData:
activity: MemberActivityData
status: str
segment: str
overdue_old: float = 0.0
overdue_cdf_p: float = 0.0
drop_old: float = 0.0
recharge_old: float = 0.0
value_old: float = 0.0
ideal_interval_days: Optional[float] = None
ideal_next_visit_date: Optional[date] = None
raw_score: Optional[float] = None
display_score: Optional[float] = None
class WinbackIndexTask(MemberIndexBaseTask):
"""老客挽回指数WBI计算任务。"""
INDEX_TYPE = "WBI"
DEFAULT_PARAMS = {
# 通用参数
'lookback_days_recency': 60,
'visit_lookback_days': 180,
'percentile_lower': 5,
'percentile_upper': 95,
'compression_mode': 0,
'use_smoothing': 1,
'ewma_alpha': 0.2,
# 分流参数
'new_visit_threshold': 2,
'new_days_threshold': 30,
'recharge_recent_days': 14,
'new_recharge_max_visits': 10,
'recency_hard_floor_days': 14,
'recency_gate_days': 14,
'recency_gate_slope_days': 3,
# WBI参数
'overdue_alpha': 2.0,
'overdue_weight_halflife_days': 30,
'overdue_weight_blend_min_samples': 8,
'h_recharge': 7,
'amount_base_M0': 300,
'balance_base_B0': 500,
'value_w_spend': 1.0,
'value_w_bal': 1.0,
'w_over': 2.0,
'w_drop': 1.0,
'w_re': 0.4,
'w_value': 1.2,
# STOP高余额例外默认关闭
'enable_stop_high_balance_exception': 0,
'high_balance_threshold': 1000,
}
def get_task_code(self) -> str:
return "DWS_WINBACK_INDEX"
def get_target_table(self) -> str:
return "dws_member_winback_index"
def get_primary_keys(self) -> List[str]:
return ['site_id', 'member_id']
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""执行 WBI 计算"""
self.logger.info("开始计算老客挽回指数 (WBI)")
site_id = self._get_site_id(context)
tenant_id = self._get_tenant_id()
params = self._load_params()
activity_map = self._build_member_activity(site_id, tenant_id, params)
if not activity_map:
self.logger.warning("No member activity data available; skip calculation")
return {'status': 'skipped', 'reason': 'no_data'}
winback_list: List[MemberWinbackData] = []
for activity in activity_map.values():
segment, status, in_scope = self.classify_segment(activity, params)
if not in_scope:
continue
if segment != "OLD" and status != "STOP_HIGH_BALANCE":
continue
data = MemberWinbackData(activity=activity, status=status, segment=segment)
# CHANGE 2026-02-13 | intent: STOP_HIGH_BALANCE 也参与评分
# 原先只对 segment=="OLD" 评分STOP_HIGH_BALANCE (segment=="STOP")
# 进入范围却不评分raw_score=NULL 无运营价值。
# 这些会员具备完整特征数据,应同等评分。
if segment == "OLD" or status == "STOP_HIGH_BALANCE":
self._calculate_wbi_scores(data, params)
winback_list.append(data)
if not winback_list:
self.logger.warning("No old-member rows to calculate")
return {'status': 'skipped', 'reason': 'no_old_members'}
# 归一化 Display Score
raw_scores = [
(d.activity.member_id, d.raw_score)
for d in winback_list
if d.raw_score is not None
]
if raw_scores:
compression = self._map_compression(params)
use_smoothing = int(params.get('use_smoothing', 1)) == 1
normalized = self.batch_normalize_to_display(
raw_scores,
compression=compression,
percentile_lower=int(params['percentile_lower']),
percentile_upper=int(params['percentile_upper']),
use_smoothing=use_smoothing,
site_id=site_id
)
score_map = {member_id: display for member_id, _, display in normalized}
for data in winback_list:
if data.activity.member_id in score_map:
data.display_score = score_map[data.activity.member_id]
# 保存分位点历史
all_raw = [float(score) for _, score in raw_scores]
q_l, q_u = self.calculate_percentiles(
all_raw,
int(params['percentile_lower']),
int(params['percentile_upper'])
)
if use_smoothing:
smoothed_l, smoothed_u = self._apply_ewma_smoothing(site_id, q_l, q_u)
else:
smoothed_l, smoothed_u = q_l, q_u
self.save_percentile_history(
site_id=site_id,
percentile_5=q_l,
percentile_95=q_u,
percentile_5_smoothed=smoothed_l,
percentile_95_smoothed=smoothed_u,
record_count=len(all_raw),
min_raw=min(all_raw),
max_raw=max(all_raw),
avg_raw=sum(all_raw) / len(all_raw)
)
inserted = self._save_winback_data(winback_list)
self.logger.info("WBI calculation finished, inserted %d rows", inserted)
return {
'status': 'success',
'member_count': len(winback_list),
'records_inserted': inserted
}
def _weighted_cdf(
self,
samples: List[Tuple[float, int]],
t_v: float,
halflife_days: float,
blend_min_samples: int,
) -> float:
if not samples:
return 0.5
if halflife_days <= 0:
p_equal = sum(1.0 for interval, _ in samples if interval <= t_v) / len(samples)
return self._clip(p_equal, 0.0, 1.0)
ln2 = math.log(2.0)
weighted_hit = 0.0
weight_sum = 0.0
equal_hit = 0.0
for interval, age_days in samples:
weight = math.exp(-ln2 * float(age_days) / halflife_days)
indicator = 1.0 if interval <= t_v else 0.0
weighted_hit += weight * indicator
weight_sum += weight
equal_hit += indicator
p_weighted = 0.5 if weight_sum <= 0 else (weighted_hit / weight_sum)
p_equal = equal_hit / len(samples)
lam = min(1.0, float(len(samples)) / float(max(1, blend_min_samples)))
p_final = lam * p_weighted + (1.0 - lam) * p_equal
return self._clip(p_final, 0.0, 1.0)
def _weighted_quantile(
self,
samples: List[Tuple[float, int]],
quantile: float,
halflife_days: float,
blend_min_samples: int,
) -> Optional[float]:
if not samples:
return None
q = self._clip(quantile, 0.0, 1.0)
equal_weight = 1.0 / float(len(samples))
if halflife_days <= 0:
weighted = [(interval, equal_weight) for interval, _ in samples]
else:
ln2 = math.log(2.0)
raw_weighted: List[Tuple[float, float]] = []
total = 0.0
for interval, age_days in samples:
w = math.exp(-ln2 * float(age_days) / halflife_days)
raw_weighted.append((interval, w))
total += w
if total <= 0:
weighted = [(interval, equal_weight) for interval, _ in samples]
else:
weighted = [(interval, w / total) for interval, w in raw_weighted]
# 对小样本混合加权分布与等权分布。
lam = min(1.0, float(len(samples)) / float(max(1, blend_min_samples)))
blended: List[Tuple[float, float]] = []
for (interval_w, w), (interval_e, _) in zip(weighted, samples):
_ = interval_e # keep tuple alignment explicit
blended_weight = lam * w + (1.0 - lam) * equal_weight
blended.append((interval_w, blended_weight))
blended.sort(key=lambda item: item[0])
cumulative = 0.0
for interval, weight in blended:
cumulative += weight
if cumulative >= q:
return float(interval)
return float(blended[-1][0])
def _calculate_wbi_scores(self, data: MemberWinbackData, params: Dict[str, float]) -> None:
"""计算 WBI 分项与 Raw Score"""
activity = data.activity
# 1) 超期紧急性基于近期加权经验CDF
overdue_alpha = float(params['overdue_alpha'])
half_life_days = float(params.get('overdue_weight_halflife_days', 30))
blend_min_samples = int(params.get('overdue_weight_blend_min_samples', 8))
if activity.interval_count <= 0:
p = 0.5
ideal_interval = None
else:
if len(activity.interval_ages_days) == activity.interval_count:
samples = list(zip(activity.intervals, activity.interval_ages_days))
else:
samples = [(interval, 0) for interval in activity.intervals]
p = self._weighted_cdf(
samples=samples,
t_v=activity.t_v,
halflife_days=half_life_days,
blend_min_samples=blend_min_samples,
)
ideal_interval = self._weighted_quantile(
samples=samples,
quantile=0.5,
halflife_days=half_life_days,
blend_min_samples=blend_min_samples,
)
data.overdue_cdf_p = p
data.overdue_old = math.pow(p, overdue_alpha)
data.ideal_interval_days = ideal_interval
if ideal_interval is not None and activity.last_visit_time is not None:
ideal_days = max(0, int(round(ideal_interval)))
data.ideal_next_visit_date = activity.last_visit_time.date() + timedelta(days=ideal_days)
else:
data.ideal_next_visit_date = None
# 2) 降频分
expected14 = activity.visits_60d * 14.0 / 60.0
data.drop_old = self._clip((expected14 - activity.visits_14d) / (expected14 + 1), 0.0, 1.0)
# 3) 充值未回访压力
if activity.recharge_unconsumed == 1:
data.recharge_old = self.decay(activity.t_r, params['h_recharge'])
else:
data.recharge_old = 0.0
# 4) 价值分
m0 = float(params['amount_base_M0'])
b0 = float(params['balance_base_B0'])
spend_score = math.log1p(activity.spend_180d / m0) if m0 > 0 else 0.0
bal_score = math.log1p(activity.sv_balance / b0) if b0 > 0 else 0.0
data.value_old = float(params['value_w_spend']) * spend_score + float(params['value_w_bal']) * bal_score
data.raw_score = (
float(params['w_over']) * data.overdue_old
+ float(params['w_drop']) * data.drop_old
+ float(params['w_re']) * data.recharge_old
+ float(params['w_value']) * data.value_old
)
hard_floor_days = float(params.get('recency_hard_floor_days', 0))
gate_days = float(params.get('recency_gate_days', 14))
slope_days = float(params.get('recency_gate_slope_days', 3))
if hard_floor_days > 0 and activity.t_v < hard_floor_days:
suppression = 0.0
elif slope_days <= 0:
suppression = 1.0 if activity.t_v >= gate_days else 0.0
else:
x = (activity.t_v - gate_days) / slope_days
x = self._clip(x, -60.0, 60.0)
suppression = 1.0 / (1.0 + math.exp(-x))
data.raw_score *= suppression
# 限制在 0 以上
if data.raw_score < 0:
data.raw_score = 0.0
def _save_winback_data(self, data_list: List[MemberWinbackData]) -> int:
"""保存 WBI 数据"""
if not data_list:
return 0
site_id = data_list[0].activity.site_id
# 按门店全量刷新,避免因分群变化导致过期数据残留。
delete_sql = """
DELETE FROM billiards_dws.dws_member_winback_index
WHERE site_id = %s
"""
with self.db.conn.cursor() as cur:
cur.execute(delete_sql, (site_id,))
insert_sql = """
INSERT INTO billiards_dws.dws_member_winback_index (
site_id, tenant_id, member_id,
status, segment,
member_create_time, first_visit_time, last_visit_time, last_recharge_time,
t_v, t_r, t_a,
visits_14d, visits_60d, visits_total,
spend_30d, spend_180d, sv_balance, recharge_60d_amt,
interval_count,
overdue_old, overdue_cdf_p, drop_old, recharge_old, value_old,
ideal_interval_days, ideal_next_visit_date,
raw_score, display_score,
last_wechat_touch_time,
calc_time, created_at, updated_at
) VALUES (
%s, %s, %s,
%s, %s,
%s, %s, %s, %s,
%s, %s, %s,
%s, %s, %s,
%s, %s, %s, %s,
%s,
%s, %s, %s, %s, %s,
%s, %s,
%s, %s,
%s,
NOW(), NOW(), NOW()
)
"""
inserted = 0
with self.db.conn.cursor() as cur:
for data in data_list:
activity = data.activity
cur.execute(insert_sql, (
activity.site_id, activity.tenant_id, activity.member_id,
data.status, data.segment,
activity.member_create_time, activity.first_visit_time, activity.last_visit_time, activity.last_recharge_time,
activity.t_v, activity.t_r, activity.t_a,
activity.visits_14d, activity.visits_60d, activity.visits_total,
activity.spend_30d, activity.spend_180d, activity.sv_balance, activity.recharge_60d_amt,
activity.interval_count,
data.overdue_old, data.overdue_cdf_p, data.drop_old, data.recharge_old, data.value_old,
data.ideal_interval_days, data.ideal_next_visit_date,
data.raw_score, data.display_score,
None,
))
inserted += cur.rowcount
self.db.conn.commit()
return inserted
def _clip(self, value: float, low: float, high: float) -> float:
return max(low, min(high, value))
def _map_compression(self, params: Dict[str, float]) -> str:
mode = int(params.get('compression_mode', 0))
if mode == 1:
return "log1p"
if mode == 2:
return "asinh"
return "none"
__all__ = ['WinbackIndexTask']