# -*- coding: utf-8 -*- """ 关系指数任务(RS/OS/MS/ML)。 设计说明: 1. 单任务一次产出 RS / OS / MS / ML,写入统一关系表; 2. RS/MS 复用服务日志 + 会话合并口径; 3. ML 以人工台账窄表为唯一真源,last-touch 仅保留备用路径(默认关闭); 4. RS/MS/ML 的 display 映射按 index_type 隔离分位历史。 """ from __future__ import annotations import math from dataclasses import dataclass, field from datetime import datetime, timedelta from decimal import Decimal from typing import Any, Dict, List, Optional, Tuple from .base_index_task import BaseIndexTask from ..base_dws_task import CourseType, TaskContext @dataclass class ServiceSession: """合并后的服务会话。""" session_start: datetime session_end: datetime total_duration_minutes: int course_weight: float is_incentive: bool @dataclass class RelationPairMetrics: """单个 member-assistant 关系对的计算指标。""" site_id: int tenant_id: int member_id: int assistant_id: int sessions: List[ServiceSession] = field(default_factory=list) days_since_last_session: Optional[int] = None session_count: int = 0 total_duration_minutes: int = 0 basic_session_count: int = 0 incentive_session_count: int = 0 rs_f: float = 0.0 rs_d: float = 0.0 rs_r: float = 0.0 rs_raw: float = 0.0 rs_display: float = 0.0 ms_f_short: float = 0.0 ms_f_long: float = 0.0 ms_raw: float = 0.0 ms_display: float = 0.0 ml_raw: float = 0.0 ml_display: float = 0.0 ml_order_count: int = 0 ml_allocated_amount: float = 0.0 os_share: float = 0.0 os_label: str = "POOL" os_rank: Optional[int] = None class RelationIndexTask(BaseIndexTask): """关系指数任务:单任务产出 RS / OS / MS / ML。""" INDEX_TYPE = "RS" DEFAULT_PARAMS_RS: Dict[str, float] = { "lookback_days": 60, "session_merge_hours": 4, "incentive_weight": 1.5, "halflife_session": 14.0, "halflife_last": 10.0, "weight_f": 1.0, "weight_d": 0.7, "gate_alpha": 0.6, "percentile_lower": 5.0, "percentile_upper": 95.0, "compression_mode": 1.0, "use_smoothing": 1.0, "ewma_alpha": 0.2, } DEFAULT_PARAMS_OS: Dict[str, float] = { "min_rs_raw_for_ownership": 0.05, "min_total_rs_raw": 0.10, "ownership_main_threshold": 0.60, "ownership_comanage_threshold": 0.35, "ownership_gap_threshold": 0.15, "eps": 1e-6, } DEFAULT_PARAMS_MS: Dict[str, float] = { "lookback_days": 60, "session_merge_hours": 4, "incentive_weight": 1.5, "halflife_short": 7.0, "halflife_long": 30.0, "eps": 1e-6, "percentile_lower": 5.0, "percentile_upper": 95.0, "compression_mode": 1.0, "use_smoothing": 1.0, "ewma_alpha": 0.2, } DEFAULT_PARAMS_ML: Dict[str, float] = { "lookback_days": 60, "source_mode": 0.0, # 0=manual_only, 1=last_touch_fallback "recharge_attribute_hours": 1.0, "amount_base": 500.0, "halflife_recharge": 21.0, "percentile_lower": 5.0, "percentile_upper": 95.0, "compression_mode": 1.0, "use_smoothing": 1.0, "ewma_alpha": 0.2, } def get_task_code(self) -> str: return "DWS_RELATION_INDEX" def get_target_table(self) -> str: return "dws_member_assistant_relation_index" def get_primary_keys(self) -> List[str]: return ["site_id", "member_id", "assistant_id"] def get_index_type(self) -> str: # 多指数任务保留一个默认 index_type,调用处应显式传 RS/MS/ML return self.INDEX_TYPE def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]: self.logger.info("开始计算关系指数(RS/OS/MS/ML)") site_id = self._get_site_id(context) tenant_id = self._get_tenant_id() now = datetime.now(self.tz) params_rs = self._load_params("RS", self.DEFAULT_PARAMS_RS) params_os = self._load_params("OS", self.DEFAULT_PARAMS_OS) params_ms = self._load_params("MS", self.DEFAULT_PARAMS_MS) params_ml = self._load_params("ML", self.DEFAULT_PARAMS_ML) service_lookback_days = max( int(params_rs.get("lookback_days", 60)), int(params_ms.get("lookback_days", 60)), ) service_start = now - timedelta(days=service_lookback_days) merge_hours = max( int(params_rs.get("session_merge_hours", 4)), int(params_ms.get("session_merge_hours", 4)), ) raw_services = self._extract_service_records(site_id, service_start, now) pair_map = self._group_and_merge_sessions( raw_services=raw_services, merge_hours=merge_hours, incentive_weight=max( float(params_rs.get("incentive_weight", 1.5)), float(params_ms.get("incentive_weight", 1.5)), ), now=now, site_id=site_id, tenant_id=tenant_id, ) self.logger.info("服务关系对数量: %d", len(pair_map)) self._calculate_rs(pair_map, params_rs, now) self._calculate_ms(pair_map, params_ms, now) self._calculate_ml(pair_map, params_ml, site_id, now) self._calculate_os(pair_map, params_os) self._apply_display_scores(pair_map, params_rs, params_ms, params_ml, site_id) inserted = self._save_relation_rows(site_id, list(pair_map.values())) self.logger.info("关系指数计算完成,写入 %d 条记录", inserted) return { "status": "SUCCESS", "records_inserted": inserted, "pair_count": len(pair_map), } def _load_params(self, index_type: str, defaults: Dict[str, float]) -> Dict[str, float]: params = dict(defaults) params.update(self.load_index_parameters(index_type=index_type)) return params def _extract_service_records( self, site_id: int, start_datetime: datetime, end_datetime: datetime, ) -> List[Dict[str, Any]]: """提取服务记录。""" sql = """ SELECT s.tenant_member_id AS member_id, d.assistant_id AS assistant_id, s.start_use_time AS start_time, s.last_use_time AS end_time, COALESCE(s.income_seconds, 0) / 60 AS duration_minutes, s.skill_id FROM billiards_dwd.dwd_assistant_service_log s JOIN billiards_dwd.dim_assistant d ON s.user_id = d.user_id AND d.scd2_is_current = 1 AND COALESCE(d.is_delete, 0) = 0 WHERE s.site_id = %s AND s.tenant_member_id > 0 AND s.user_id > 0 AND s.is_delete = 0 AND s.last_use_time >= %s AND s.last_use_time < %s ORDER BY s.tenant_member_id, d.assistant_id, s.start_use_time """ rows = self.db.query(sql, (site_id, start_datetime, end_datetime)) return [dict(row) for row in (rows or [])] def _group_and_merge_sessions( self, *, raw_services: List[Dict[str, Any]], merge_hours: int, incentive_weight: float, now: datetime, site_id: int, tenant_id: int, ) -> Dict[Tuple[int, int], RelationPairMetrics]: """按 (member_id, assistant_id) 分组并合并会话。""" result: Dict[Tuple[int, int], RelationPairMetrics] = {} if not raw_services: return result merge_threshold = timedelta(hours=max(0, merge_hours)) grouped: Dict[Tuple[int, int], List[Dict[str, Any]]] = {} for row in raw_services: member_id = int(row["member_id"]) assistant_id = int(row["assistant_id"]) grouped.setdefault((member_id, assistant_id), []).append(row) for (member_id, assistant_id), records in grouped.items(): metrics = RelationPairMetrics( site_id=site_id, tenant_id=tenant_id, member_id=member_id, assistant_id=assistant_id, ) sorted_records = sorted(records, key=lambda r: r["start_time"]) current: Optional[ServiceSession] = None for svc in sorted_records: start_time = svc["start_time"] end_time = svc["end_time"] duration = int(svc.get("duration_minutes") or 0) skill_id = int(svc.get("skill_id") or 0) course_type = self.get_course_type(skill_id) is_incentive = course_type == CourseType.BONUS weight = incentive_weight if is_incentive else 1.0 if current is None: current = ServiceSession( session_start=start_time, session_end=end_time, total_duration_minutes=duration, course_weight=weight, is_incentive=is_incentive, ) continue if start_time - current.session_end <= merge_threshold: current.session_end = max(current.session_end, end_time) current.total_duration_minutes += duration current.course_weight = max(current.course_weight, weight) current.is_incentive = current.is_incentive or is_incentive else: metrics.sessions.append(current) current = ServiceSession( session_start=start_time, session_end=end_time, total_duration_minutes=duration, course_weight=weight, is_incentive=is_incentive, ) if current is not None: metrics.sessions.append(current) metrics.session_count = len(metrics.sessions) metrics.total_duration_minutes = sum(s.total_duration_minutes for s in metrics.sessions) metrics.basic_session_count = sum(1 for s in metrics.sessions if not s.is_incentive) metrics.incentive_session_count = sum(1 for s in metrics.sessions if s.is_incentive) if metrics.sessions: last_session = max(metrics.sessions, key=lambda s: s.session_end) metrics.days_since_last_session = (now - last_session.session_end).days result[(member_id, assistant_id)] = metrics return result def _calculate_rs( self, pair_map: Dict[Tuple[int, int], RelationPairMetrics], params: Dict[str, float], now: datetime, ) -> None: lookback_days = int(params.get("lookback_days", 60)) halflife_session = float(params.get("halflife_session", 14.0)) halflife_last = float(params.get("halflife_last", 10.0)) weight_f = float(params.get("weight_f", 1.0)) weight_d = float(params.get("weight_d", 0.7)) gate_alpha = max(0.0, float(params.get("gate_alpha", 0.6))) for metrics in pair_map.values(): f_score = 0.0 d_score = 0.0 for session in metrics.sessions: days_ago = min( lookback_days, max(0.0, (now - session.session_end).total_seconds() / 86400.0), ) decay_factor = self.decay(days_ago, halflife_session) f_score += session.course_weight * decay_factor d_score += ( math.sqrt(max(session.total_duration_minutes, 0) / 60.0) * session.course_weight * decay_factor ) if metrics.days_since_last_session is None: r_score = 0.0 else: r_score = self.decay(min(lookback_days, metrics.days_since_last_session), halflife_last) base = weight_f * f_score + weight_d * d_score gate = math.pow(r_score, gate_alpha) if r_score > 0 else 0.0 metrics.rs_f = f_score metrics.rs_d = d_score metrics.rs_r = r_score metrics.rs_raw = max(0.0, base * gate) def _calculate_ms( self, pair_map: Dict[Tuple[int, int], RelationPairMetrics], params: Dict[str, float], now: datetime, ) -> None: lookback_days = int(params.get("lookback_days", 60)) halflife_short = float(params.get("halflife_short", 7.0)) halflife_long = float(params.get("halflife_long", 30.0)) eps = float(params.get("eps", 1e-6)) for metrics in pair_map.values(): f_short = 0.0 f_long = 0.0 for session in metrics.sessions: days_ago = min( lookback_days, max(0.0, (now - session.session_end).total_seconds() / 86400.0), ) f_short += session.course_weight * self.decay(days_ago, halflife_short) f_long += session.course_weight * self.decay(days_ago, halflife_long) ratio = (f_short + eps) / (f_long + eps) metrics.ms_f_short = f_short metrics.ms_f_long = f_long metrics.ms_raw = max(0.0, self.safe_log(ratio, 0.0)) def _calculate_ml( self, pair_map: Dict[Tuple[int, int], RelationPairMetrics], params: Dict[str, float], site_id: int, now: datetime, ) -> None: lookback_days = int(params.get("lookback_days", 60)) source_mode = int(params.get("source_mode", 0)) amount_base = float(params.get("amount_base", 500.0)) halflife_recharge = float(params.get("halflife_recharge", 21.0)) start_time = now - timedelta(days=lookback_days) manual_rows = self._extract_manual_alloc(site_id, start_time, now) for row in manual_rows: member_id = int(row["member_id"]) assistant_id = int(row["assistant_id"]) key = (member_id, assistant_id) if key not in pair_map: pair_map[key] = RelationPairMetrics( site_id=site_id, tenant_id=pair_map[next(iter(pair_map))].tenant_id if pair_map else self._get_tenant_id(), member_id=member_id, assistant_id=assistant_id, ) metrics = pair_map[key] amount = float(row.get("allocated_amount") or 0.0) pay_time = row.get("pay_time") if amount <= 0 or pay_time is None: continue days_ago = min(lookback_days, max(0.0, (now - pay_time).total_seconds() / 86400.0)) metrics.ml_raw += math.log1p(amount / max(amount_base, 1e-6)) * self.decay( days_ago, halflife_recharge, ) metrics.ml_order_count += 1 metrics.ml_allocated_amount += amount # 备用路径:仅在明确打开且人工台账为空时使用 last-touch。 if source_mode == 1 and not manual_rows: self.logger.warning("ML source_mode=1 且人工台账为空,启用 last-touch 备用归因") self._apply_last_touch_ml(pair_map, params, site_id, now) def _extract_manual_alloc( self, site_id: int, start_time: datetime, end_time: datetime, ) -> List[Dict[str, Any]]: sql = """ SELECT member_id, assistant_id, pay_time, allocated_amount FROM billiards_dws.dws_ml_manual_order_alloc WHERE site_id = %s AND pay_time >= %s AND pay_time < %s """ rows = self.db.query(sql, (site_id, start_time, end_time)) return [dict(row) for row in (rows or [])] def _apply_last_touch_ml( self, pair_map: Dict[Tuple[int, int], RelationPairMetrics], params: Dict[str, float], site_id: int, now: datetime, ) -> None: lookback_days = int(params.get("lookback_days", 60)) attribution_hours = int(params.get("recharge_attribute_hours", 1)) amount_base = float(params.get("amount_base", 500.0)) halflife_recharge = float(params.get("halflife_recharge", 21.0)) start_time = now - timedelta(days=lookback_days) end_time = now # 为 last-touch 建立 member -> sessions 索引 member_sessions: Dict[int, List[Tuple[datetime, int]]] = {} for metrics in pair_map.values(): for session in metrics.sessions: member_sessions.setdefault(metrics.member_id, []).append( (session.session_end, metrics.assistant_id) ) for sessions in member_sessions.values(): sessions.sort(key=lambda item: item[0]) sql = """ SELECT member_id, pay_time, pay_amount FROM billiards_dwd.dwd_recharge_order WHERE site_id = %s AND settle_type = 5 AND COALESCE(is_delete, 0) = 0 AND member_id > 0 AND pay_time >= %s AND pay_time < %s """ rows = self.db.query(sql, (site_id, start_time, end_time)) for row in (rows or []): row_dict = dict(row) member_id = int(row_dict.get("member_id") or 0) pay_time = row_dict.get("pay_time") pay_amount = float(row_dict.get("pay_amount") or 0.0) if member_id <= 0 or pay_time is None or pay_amount <= 0: continue candidates = member_sessions.get(member_id, []) selected_assistant: Optional[int] = None selected_end: Optional[datetime] = None for end_time_candidate, assistant_id in candidates: if end_time_candidate > pay_time: continue if pay_time - end_time_candidate > timedelta(hours=attribution_hours): continue if selected_end is None or end_time_candidate > selected_end: selected_end = end_time_candidate selected_assistant = assistant_id if selected_assistant is None: continue key = (member_id, selected_assistant) if key not in pair_map: pair_map[key] = RelationPairMetrics( site_id=site_id, tenant_id=pair_map[next(iter(pair_map))].tenant_id if pair_map else self._get_tenant_id(), member_id=member_id, assistant_id=selected_assistant, ) metrics = pair_map[key] days_ago = min(lookback_days, max(0.0, (now - pay_time).total_seconds() / 86400.0)) metrics.ml_raw += math.log1p(pay_amount / max(amount_base, 1e-6)) * self.decay( days_ago, halflife_recharge, ) metrics.ml_order_count += 1 metrics.ml_allocated_amount += pay_amount def _calculate_os( self, pair_map: Dict[Tuple[int, int], RelationPairMetrics], params: Dict[str, float], ) -> None: min_rs = float(params.get("min_rs_raw_for_ownership", 0.05)) min_total = float(params.get("min_total_rs_raw", 0.10)) main_threshold = float(params.get("ownership_main_threshold", 0.60)) comanage_threshold = float(params.get("ownership_comanage_threshold", 0.35)) gap_threshold = float(params.get("ownership_gap_threshold", 0.15)) member_groups: Dict[int, List[RelationPairMetrics]] = {} for metrics in pair_map.values(): member_groups.setdefault(metrics.member_id, []).append(metrics) for _, rows in member_groups.items(): eligible = [row for row in rows if row.rs_raw >= min_rs] sum_rs = sum(row.rs_raw for row in eligible) if sum_rs < min_total: for row in rows: row.os_share = 0.0 row.os_label = "UNASSIGNED" row.os_rank = None continue for row in rows: if row.rs_raw >= min_rs: row.os_share = row.rs_raw / sum_rs else: row.os_share = 0.0 sorted_eligible = sorted( eligible, key=lambda item: ( -item.os_share, -item.rs_raw, item.days_since_last_session if item.days_since_last_session is not None else 10**9, item.assistant_id, ), ) for idx, row in enumerate(sorted_eligible, start=1): row.os_rank = idx top1 = sorted_eligible[0] top2_share = sorted_eligible[1].os_share if len(sorted_eligible) > 1 else 0.0 gap = top1.os_share - top2_share has_main = top1.os_share >= main_threshold and gap >= gap_threshold if has_main: for row in rows: if row is top1: row.os_label = "MAIN" elif row.os_share >= comanage_threshold: row.os_label = "COMANAGE" else: row.os_label = "POOL" else: for row in rows: if row.os_share >= comanage_threshold and row.rs_raw >= min_rs: row.os_label = "COMANAGE" else: row.os_label = "POOL" # 非 eligible 不赋 rank for row in rows: if row.rs_raw < min_rs: row.os_rank = None def _apply_display_scores( self, pair_map: Dict[Tuple[int, int], RelationPairMetrics], params_rs: Dict[str, float], params_ms: Dict[str, float], params_ml: Dict[str, float], site_id: int, ) -> None: pair_items = list(pair_map.items()) rs_map = self._normalize_and_record( raw_pairs=[(key, item.rs_raw) for key, item in pair_items], params=params_rs, index_type="RS", site_id=site_id, ) ms_map = self._normalize_and_record( raw_pairs=[(key, item.ms_raw) for key, item in pair_items], params=params_ms, index_type="MS", site_id=site_id, ) ml_map = self._normalize_and_record( raw_pairs=[(key, item.ml_raw) for key, item in pair_items], params=params_ml, index_type="ML", site_id=site_id, ) for key, item in pair_items: item.rs_display = rs_map.get(key, 0.0) item.ms_display = ms_map.get(key, 0.0) item.ml_display = ml_map.get(key, 0.0) def _normalize_and_record( self, *, raw_pairs: List[Tuple[Any, float]], params: Dict[str, float], index_type: str, site_id: int, ) -> Dict[Any, float]: if not raw_pairs: return {} if all(abs(score) <= 1e-9 for _, score in raw_pairs): return {entity: 0.0 for entity, _ in raw_pairs} percentile_lower = int(params.get("percentile_lower", 5)) percentile_upper = int(params.get("percentile_upper", 95)) use_smoothing = int(params.get("use_smoothing", 1)) == 1 compression = self._map_compression(params) normalized = self.batch_normalize_to_display( raw_scores=raw_pairs, compression=compression, percentile_lower=percentile_lower, percentile_upper=percentile_upper, use_smoothing=use_smoothing, site_id=site_id, index_type=index_type, ) display_map = {entity: display for entity, _, display in normalized} raw_values = [float(score) for _, score in raw_pairs] q_l, q_u = self.calculate_percentiles(raw_values, percentile_lower, percentile_upper) if use_smoothing: smoothed_l, smoothed_u = self._apply_ewma_smoothing( site_id=site_id, current_p5=q_l, current_p95=q_u, index_type=index_type, ) else: smoothed_l, smoothed_u = q_l, q_u self.save_percentile_history( site_id=site_id, percentile_5=q_l, percentile_95=q_u, percentile_5_smoothed=smoothed_l, percentile_95_smoothed=smoothed_u, record_count=len(raw_values), min_raw=min(raw_values), max_raw=max(raw_values), avg_raw=sum(raw_values) / len(raw_values), index_type=index_type, ) return display_map @staticmethod def _map_compression(params: Dict[str, float]) -> str: mode = int(params.get("compression_mode", 0)) if mode == 1: return "log1p" if mode == 2: return "asinh" return "none" def _save_relation_rows(self, site_id: int, rows: List[RelationPairMetrics]) -> int: with self.db.conn.cursor() as cur: cur.execute( "DELETE FROM billiards_dws.dws_member_assistant_relation_index WHERE site_id = %s", (site_id,), ) if not rows: self.db.conn.commit() return 0 insert_sql = """ INSERT INTO billiards_dws.dws_member_assistant_relation_index ( site_id, tenant_id, member_id, assistant_id, session_count, total_duration_minutes, basic_session_count, incentive_session_count, days_since_last_session, rs_f, rs_d, rs_r, rs_raw, rs_display, os_share, os_label, os_rank, ms_f_short, ms_f_long, ms_raw, ms_display, ml_order_count, ml_allocated_amount, ml_raw, ml_display, calc_time, created_at, updated_at ) VALUES ( %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW(), NOW() ) """ inserted = 0 for row in rows: cur.execute( insert_sql, ( row.site_id, row.tenant_id, row.member_id, row.assistant_id, row.session_count, row.total_duration_minutes, row.basic_session_count, row.incentive_session_count, row.days_since_last_session, row.rs_f, row.rs_d, row.rs_r, row.rs_raw, row.rs_display, row.os_share, row.os_label, row.os_rank, row.ms_f_short, row.ms_f_long, row.ms_raw, row.ms_display, row.ml_order_count, row.ml_allocated_amount, row.ml_raw, row.ml_display, ), ) inserted += max(cur.rowcount, 0) self.db.conn.commit() return inserted def _get_site_id(self, context: Optional[TaskContext]) -> int: if context and getattr(context, "store_id", None): return int(context.store_id) site_id = self.config.get("app.default_site_id") or self.config.get("app.store_id") if site_id is not None: return int(site_id) sql = "SELECT DISTINCT site_id FROM billiards_dwd.dwd_assistant_service_log WHERE site_id IS NOT NULL LIMIT 1" rows = self.db.query(sql) if rows: return int(dict(rows[0]).get("site_id") or 0) self.logger.warning("无法确定门店ID,使用 0 继续执行") return 0 def _get_tenant_id(self) -> int: tenant_id = self.config.get("app.tenant_id") if tenant_id is not None: return int(tenant_id) sql = "SELECT DISTINCT tenant_id FROM billiards_dwd.dwd_assistant_service_log WHERE tenant_id IS NOT NULL LIMIT 1" rows = self.db.query(sql) if rows: return int(dict(rows[0]).get("tenant_id") or 0) self.logger.warning("无法确定租户ID,使用 0 继续执行") return 0 __all__ = ["RelationIndexTask", "RelationPairMetrics", "ServiceSession"]