# -*- coding: utf-8 -*- # AI_CHANGELOG # - 2026-02-14 | bugfix: get_performance_tier 新入职封顶兜底 + safe_decimal 异常捕获 # prompt: "继续。完成后检查所有任务是否全面" # 直接原因: 单元测试发现两处已有 bug:(1) max_tier_level 过滤后小时数超出所有档位上限返回 None;(2) safe_decimal 未捕获 InvalidOperation # 变更: get_performance_tier() 增加 best_fallback 兜底;safe_decimal() except 增加 InvalidOperation;导入增加 InvalidOperation # 风险: get_performance_tier 兜底仅在 max_tier_level 非 None 时生效,不影响正常匹配路径 # 验证: pytest tests/unit -x(449 passed) """ DWS层任务基类 功能说明: - 提供从DWD层读取数据的标准方法 - 提供时间分层查询功能(近2天/近1月/近3月/近6月/全量) - 提供配置表读取方法 - 提供幂等更新机制(delete-before-insert) - 提供SCD2维度as-of取值方法 - 提供滚动窗口统计方法 时间口径说明: - 周起始日:周一 - 月/季度起始:第一天0点 - 环比规则:对比上一个等长区间 - 前3个月:含/不含本月(用于财务筛选) - 最近半年:不含本月 更新频率: - 日度表:每日更新 - 实时表:每小时更新 - 月度表:每日更新当月数据 作者:ETL团队 创建日期:2026-02-01 """ from __future__ import annotations import calendar from abc import abstractmethod from dataclasses import dataclass from datetime import date, datetime, timedelta from decimal import Decimal, InvalidOperation from enum import Enum from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar from ..base_task import BaseTask, TaskContext # ============================================================================= # 类型定义 # ============================================================================= T = TypeVar('T') class TimeLayer(Enum): """时间分层枚举(用于数据筛选)""" LAST_2_DAYS = "LAST_2_DAYS" # 近2天 LAST_1_MONTH = "LAST_1_MONTH" # 近1月 LAST_3_MONTHS = "LAST_3_MONTHS" # 近3月 LAST_6_MONTHS = "LAST_6_MONTHS" # 近6月(不含本月) ALL = "ALL" # 全量 class TimeWindow(Enum): """时间窗口类型枚举(用于财务报表)""" THIS_WEEK = "THIS_WEEK" # 本周(周一起始) LAST_WEEK = "LAST_WEEK" # 上周 THIS_MONTH = "THIS_MONTH" # 本月 LAST_MONTH = "LAST_MONTH" # 上月 LAST_3_MONTHS_EXCL_CURRENT = "LAST_3_MONTHS_EXCL_CURRENT" # 前3个月不含本月 LAST_3_MONTHS_INCL_CURRENT = "LAST_3_MONTHS_INCL_CURRENT" # 前3个月含本月 THIS_QUARTER = "THIS_QUARTER" # 本季度 LAST_QUARTER = "LAST_QUARTER" # 上季度 LAST_6_MONTHS = "LAST_6_MONTHS" # 最近半年(不含本月) class CourseType(Enum): """课程类型枚举""" BASE = "BASE" # 基础课/陪打 BONUS = "BONUS" # 附加课/超休 ROOM = "ROOM" # 包厢课 class DiscountType(Enum): """优惠类型枚举""" GROUPBUY = "GROUPBUY" # 团购优惠 VIP = "VIP" # 会员折扣 GIFT_CARD = "GIFT_CARD" # 赠送卡抵扣 MANUAL = "MANUAL" # 手动调整 ROUNDING = "ROUNDING" # 抹零 BIG_CUSTOMER = "BIG_CUSTOMER" # 大客户优惠 OTHER = "OTHER" # 其他优惠 @dataclass class TimeRange: """时间范围数据类""" start: date end: date @dataclass class ConfigCache: """配置缓存数据类""" performance_tiers: List[Dict[str, Any]] # 绩效档位配置 level_prices: List[Dict[str, Any]] # 等级定价配置 bonus_rules: List[Dict[str, Any]] # 奖金规则配置 area_categories: Dict[str, Dict[str, Any]] # 区域分类映射 skill_types: Dict[int, Dict[str, Any]] # 技能类型映射 loaded_at: datetime # 加载时间 # ============================================================================= # DWS任务基类 # ============================================================================= class BaseDwsTask(BaseTask): """ DWS层任务基类 提供DWS层通用功能: 1. DWD数据读取方法 2. 时间分层与窗口计算 3. 配置表缓存与读取 4. SCD2维度as-of取值 5. 幂等更新机制 6. 滚动窗口统计 """ # 类级别的配置缓存 _config_cache: Optional[ConfigCache] = None _config_cache_ttl: int = 300 # 缓存有效期(秒) # DWS Schema名称 DWS_SCHEMA = "billiards_dws" DWD_SCHEMA = "billiards_dwd" # 滚动窗口天数列表 ROLLING_WINDOWS = [7, 10, 15, 30, 60, 90] # ========================================================================== # 抽象方法(子类必须实现) # ========================================================================== @abstractmethod def get_target_table(self) -> str: """ 获取目标表名(不含schema) Returns: 目标表名,如 'dws_assistant_daily_detail' """ raise NotImplementedError("子类需实现 get_target_table 方法") @abstractmethod def get_primary_keys(self) -> List[str]: """ 获取主键字段列表(用于幂等更新) Returns: 主键字段列表,如 ['site_id', 'assistant_id', 'stat_date'] """ raise NotImplementedError("子类需实现 get_primary_keys 方法") # ========================================================================== # 时间计算方法 # ========================================================================== def get_time_layer_range( self, layer: TimeLayer, base_date: Optional[date] = None ) -> TimeRange: """ 获取时间分层的日期范围 Args: layer: 时间分层枚举 base_date: 基准日期,默认为今天 Returns: TimeRange对象,包含起止日期 """ if base_date is None: base_date = date.today() if layer == TimeLayer.LAST_2_DAYS: return TimeRange( start=base_date - timedelta(days=1), end=base_date ) elif layer == TimeLayer.LAST_1_MONTH: return TimeRange( start=base_date - timedelta(days=30), end=base_date ) elif layer == TimeLayer.LAST_3_MONTHS: return TimeRange( start=base_date - timedelta(days=90), end=base_date ) elif layer == TimeLayer.LAST_6_MONTHS: # 不含本月,从上月末往前6个月 month_start = self.get_month_first_day(base_date) end = month_start - timedelta(days=1) start = self.get_month_first_day(self._shift_months(month_start, -6)) return TimeRange(start=start, end=end) else: # ALL return TimeRange( start=date(2000, 1, 1), end=base_date ) def get_time_window_range( self, window: TimeWindow, base_date: Optional[date] = None ) -> TimeRange: """ 获取时间窗口的日期范围(用于财务报表) 时间口径说明: - 周起始日为周一 - 月/季度起始为第一天0点 Args: window: 时间窗口枚举 base_date: 基准日期,默认为今天 Returns: TimeRange对象 """ if base_date is None: base_date = date.today() if window == TimeWindow.THIS_WEEK: # 本周(周一起始) days_since_monday = base_date.weekday() start = base_date - timedelta(days=days_since_monday) return TimeRange(start=start, end=base_date) elif window == TimeWindow.LAST_WEEK: # 上周 days_since_monday = base_date.weekday() this_monday = base_date - timedelta(days=days_since_monday) end = this_monday - timedelta(days=1) # 上周日 start = end - timedelta(days=6) # 上周一 return TimeRange(start=start, end=end) elif window == TimeWindow.THIS_MONTH: # 本月 start = base_date.replace(day=1) return TimeRange(start=start, end=base_date) elif window == TimeWindow.LAST_MONTH: # 上月 month_start = base_date.replace(day=1) end = month_start - timedelta(days=1) start = end.replace(day=1) return TimeRange(start=start, end=end) elif window == TimeWindow.LAST_3_MONTHS_EXCL_CURRENT: # 前3个月(不含本月):从三个月前月初到上月月末 current_month_start = self.get_month_first_day(base_date) end = current_month_start - timedelta(days=1) start = self.get_month_first_day(self._shift_months(current_month_start, -3)) return TimeRange(start=start, end=end) elif window == TimeWindow.LAST_3_MONTHS_INCL_CURRENT: # 前3个月(含本月):从两个月前月初到当前日期 current_month_start = self.get_month_first_day(base_date) start = self.get_month_first_day(self._shift_months(current_month_start, -2)) return TimeRange(start=start, end=base_date) elif window == TimeWindow.THIS_QUARTER: # 本季度 quarter = (base_date.month - 1) // 3 start_month = quarter * 3 + 1 start = base_date.replace(month=start_month, day=1) return TimeRange(start=start, end=base_date) elif window == TimeWindow.LAST_QUARTER: # 上季度 quarter = (base_date.month - 1) // 3 start_month = quarter * 3 + 1 this_quarter_start = base_date.replace(month=start_month, day=1) end = this_quarter_start - timedelta(days=1) prev_quarter = (end.month - 1) // 3 prev_start_month = prev_quarter * 3 + 1 start = end.replace(month=prev_start_month, day=1) return TimeRange(start=start, end=end) elif window == TimeWindow.LAST_6_MONTHS: # 最近半年(不含本月) month_start = self.get_month_first_day(base_date) end = month_start - timedelta(days=1) start = self.get_month_first_day(self._shift_months(month_start, -6)) return TimeRange(start=start, end=end) raise ValueError(f"不支持的时间窗口类型: {window}") def get_comparison_range(self, time_range: TimeRange) -> TimeRange: """ 计算环比区间(上一个等长区间) 环比规则:对比上一个等长区间 Args: time_range: 当前时间范围 Returns: 环比时间范围 """ duration = (time_range.end - time_range.start).days + 1 prev_end = time_range.start - timedelta(days=1) prev_start = prev_end - timedelta(days=duration - 1) return TimeRange(start=prev_start, end=prev_end) def get_month_first_day(self, dt: date) -> date: """获取月第一天""" return dt.replace(day=1) def get_month_last_day(self, dt: date) -> date: """获取月最后一天""" last_day = calendar.monthrange(dt.year, dt.month)[1] return dt.replace(day=last_day) def _shift_months(self, base_date: date, months: int) -> date: """ 按月偏移日期(保持日不越界) """ total_months = base_date.year * 12 + (base_date.month - 1) + months year = total_months // 12 month = total_months % 12 + 1 last_day = calendar.monthrange(year, month)[1] day = min(base_date.day, last_day) return date(year, month, day) def is_new_hire_in_month(self, hire_date: date, stat_month: date) -> bool: """ 判断是否为新入职(月1日0点后入职) 新入职定档规则:月1日0点之后入职的,计算为新入职 Args: hire_date: 入职日期 stat_month: 统计月份(月第一天) Returns: 是否为新入职 """ month_start = self.get_month_first_day(stat_month) return hire_date >= month_start # ========================================================================== # 配置表读取方法 # ========================================================================== def load_config_cache(self, force_reload: bool = False) -> ConfigCache: """ 加载配置表缓存 Args: force_reload: 是否强制重新加载 Returns: ConfigCache对象 """ now = datetime.now(self.tz) # 检查缓存是否有效 if ( not force_reload and self._config_cache is not None and (now - self._config_cache.loaded_at).total_seconds() < self._config_cache_ttl ): return self._config_cache self.logger.debug("重新加载DWS配置表缓存") # 加载绩效档位配置 performance_tiers = self._load_performance_tiers() # 加载等级定价配置 level_prices = self._load_level_prices() # 加载奖金规则配置 bonus_rules = self._load_bonus_rules() # 加载区域分类映射 area_categories = self._load_area_categories() # 加载技能类型映射 skill_types = self._load_skill_types() self._config_cache = ConfigCache( performance_tiers=performance_tiers, level_prices=level_prices, bonus_rules=bonus_rules, area_categories=area_categories, skill_types=skill_types, loaded_at=now ) return self._config_cache def _load_performance_tiers(self) -> List[Dict[str, Any]]: """ 加载绩效档位配置 字段说明(来自DWS数据库处理需求.md): - base_deduction: 专业课抽成(元/小时),球房从基础课每小时扣除的金额 - bonus_deduction_ratio: 打赏课抽成比例,球房从附加课收入中扣除的比例 - vacation_days: 次月可休假天数 - vacation_unlimited: 休假自由标记(最高档为TRUE) """ sql = """ SELECT tier_id, tier_code, tier_name, tier_level, min_hours, max_hours, base_deduction, bonus_deduction_ratio, vacation_days, vacation_unlimited, is_new_hire_tier, effective_from, effective_to FROM billiards_dws.cfg_performance_tier ORDER BY tier_level ASC, effective_from ASC """ rows = self.db.query(sql) return [dict(row) for row in rows] if rows else [] def _load_level_prices(self) -> List[Dict[str, Any]]: """加载等级定价配置""" sql = """ SELECT price_id, level_code, level_name, base_course_price, bonus_course_price, effective_from, effective_to FROM billiards_dws.cfg_assistant_level_price ORDER BY level_code ASC, effective_from DESC """ rows = self.db.query(sql) return [dict(row) for row in rows] if rows else [] def _load_bonus_rules(self) -> List[Dict[str, Any]]: """加载奖金规则配置""" sql = """ SELECT rule_id, rule_type, rule_code, rule_name, threshold_hours, rank_position, bonus_amount, is_cumulative, priority, effective_from, effective_to FROM billiards_dws.cfg_bonus_rules ORDER BY rule_type, priority DESC, effective_from DESC """ rows = self.db.query(sql) return [dict(row) for row in rows] if rows else [] def _load_area_categories(self) -> Dict[str, Dict[str, Any]]: """加载区域分类映射""" sql = """ SELECT source_area_name, category_code, category_name, match_type, match_priority FROM billiards_dws.cfg_area_category WHERE is_active = TRUE ORDER BY match_priority ASC """ rows = self.db.query(sql) if not rows: return {} result = {} for row in rows: row_dict = dict(row) result[row_dict['source_area_name']] = row_dict return result def _load_skill_types(self) -> Dict[int, Dict[str, Any]]: """加载技能类型映射""" sql = """ SELECT skill_id, skill_name, course_type_code, course_type_name FROM billiards_dws.cfg_skill_type WHERE is_active = TRUE """ rows = self.db.query(sql) if not rows: return {} result = {} for row in rows: row_dict = dict(row) result[int(row_dict['skill_id'])] = row_dict return result # ========================================================================== # 配置应用方法 # ========================================================================== def _filter_by_effective_date( self, items: List[Dict[str, Any]], effective_date: Optional[date] ) -> List[Dict[str, Any]]: """ 按生效期过滤配置项 """ ref_date = effective_date or date.today() results: List[Dict[str, Any]] = [] for item in items: eff_from = item.get('effective_from') eff_to = item.get('effective_to') if eff_from and ref_date < eff_from: continue if eff_to and ref_date > eff_to: continue results.append(item) return results def get_performance_tier( self, effective_hours: Decimal, is_new_hire: bool, effective_date: Optional[date] = None, max_tier_level: Optional[int] = None ) -> Optional[Dict[str, Any]]: """ 根据有效业绩小时数匹配绩效档位 Args: effective_hours: 有效业绩小时数 is_new_hire: 是否为新入职 effective_date: 生效日期(用于历史月份) Returns: 匹配的档位配置,如果没有匹配则返回None """ _ = is_new_hire # 保留参数以兼容调用方,新入职封顶逻辑在月度任务中处理 config = self.load_config_cache() tiers = self._filter_by_effective_date(config.performance_tiers, effective_date) if max_tier_level is not None: tiers = [ t for t in tiers if t.get('tier_level') is None or int(t.get('tier_level')) <= max_tier_level ] # CHANGE [2026-02-14] intent: 新入职封顶兜底——小时数超过所有可用档位上限时返回最高可用档位 # assumptions: max_tier_level 仅在新入职场景传入;正常匹配路径不受影响 # prompt: "继续。完成后检查所有任务是否全面" # 按阈值匹配档位 best_fallback = None for tier in tiers: if tier.get('is_new_hire_tier'): continue min_hours = Decimal(str(tier.get('min_hours', 0))) max_hours = tier.get('max_hours') if max_hours is not None: max_hours = Decimal(str(max_hours)) if effective_hours >= min_hours: if max_hours is None or effective_hours < max_hours: return tier # 超出上限时记录为兜底候选(取 tier_level 最高的) if best_fallback is None or int(tier.get('tier_level', 0)) > int(best_fallback.get('tier_level', 0)): best_fallback = tier # 新入职封顶场景:小时数超过所有可用档位上限时,返回最高可用档位 if best_fallback is not None and max_tier_level is not None: return best_fallback return None def get_performance_tier_by_id( self, tier_id: Optional[int], effective_date: Optional[date] = None ) -> Optional[Dict[str, Any]]: """ 通过档位ID获取配置(支持生效期筛选) """ if not tier_id: return None config = self.load_config_cache() tiers = self._filter_by_effective_date(config.performance_tiers, effective_date) for tier in tiers: if tier.get('tier_id') == tier_id: return tier return None def get_level_price( self, level_code: int, effective_date: Optional[date] = None ) -> Optional[Dict[str, Any]]: """ 获取助教等级对应的单价(SCD2口径,按生效日期取值) Args: level_code: 等级代码 effective_date: 生效日期 Returns: 等级定价配置 """ config = self.load_config_cache() prices = self._filter_by_effective_date(config.level_prices, effective_date) for price in prices: if price.get('level_code') == level_code: return price return None def get_course_type(self, skill_id: int) -> CourseType: """ 根据skill_id获取课程类型 Args: skill_id: 技能ID Returns: CourseType枚举 """ config = self.load_config_cache() skill_config = config.skill_types.get(skill_id) if skill_config: code = skill_config.get('course_type_code', 'BASE') if code == 'BONUS': return CourseType.BONUS if code == 'ROOM': return CourseType.ROOM return CourseType.BASE # 默认为基础课 return CourseType.BASE def get_area_category(self, area_name: Optional[str]) -> Dict[str, str]: """ 获取区域分类(支持精确匹配、模糊匹配、兜底) Args: area_name: 原始区域名称 Returns: 包含 category_code 和 category_name 的字典 """ config = self.load_config_cache() if not area_name: # 无区域名称,返回默认 return {'category_code': 'OTHER', 'category_name': '其他区域'} # 1. 精确匹配 if area_name in config.area_categories: cat = config.area_categories[area_name] if cat.get('match_type') == 'EXACT': return { 'category_code': cat['category_code'], 'category_name': cat['category_name'] } # 2. 模糊匹配(按优先级) for key, cat in config.area_categories.items(): if cat.get('match_type') == 'LIKE': pattern = key.replace('%', '') if pattern and pattern in area_name: return { 'category_code': cat['category_code'], 'category_name': cat['category_name'] } # 3. 兜底 if 'DEFAULT' in config.area_categories: cat = config.area_categories['DEFAULT'] return { 'category_code': cat['category_code'], 'category_name': cat['category_name'] } return {'category_code': 'OTHER', 'category_name': '其他区域'} def calculate_sprint_bonus( self, effective_hours: Decimal, effective_date: Optional[date] = None ) -> Decimal: """ 计算冲刺奖金(不累计,取最高档) 冲刺奖金规则: - 按 cfg_bonus_rules 配置(可为历史口径) - 不累计,取最高档 Args: effective_hours: 有效业绩小时数 effective_date: 生效日期 Returns: 冲刺奖金金额 """ config = self.load_config_cache() bonus_rules = self._filter_by_effective_date(config.bonus_rules, effective_date) # 筛选冲刺奖金规则,按优先级降序 sprint_rules = [ r for r in bonus_rules if r.get('rule_type') == 'SPRINT' ] sprint_rules.sort(key=lambda x: x.get('priority', 0), reverse=True) # 取满足条件的最高档 for rule in sprint_rules: threshold = Decimal(str(rule.get('threshold_hours', 0))) if effective_hours >= threshold: return Decimal(str(rule.get('bonus_amount', 0))) return Decimal('0') def calculate_top_rank_bonus( self, rank: int, effective_date: Optional[date] = None ) -> Decimal: """ 计算Top排名奖金 Top3奖金规则: - 第1名: 1000元 - 第2名: 600元 - 第3名: 400元 - 并列都算 Args: rank: 排名(考虑并列后的排名) effective_date: 生效日期 Returns: 排名奖金金额 """ config = self.load_config_cache() bonus_rules = self._filter_by_effective_date(config.bonus_rules, effective_date) if rank > 3: return Decimal('0') for rule in bonus_rules: if rule.get('rule_type') == 'TOP_RANK': if rule.get('rank_position') == rank: return Decimal(str(rule.get('bonus_amount', 0))) return Decimal('0') # ========================================================================== # DWD数据读取方法 # ========================================================================== def iter_dwd_rows( self, table_name: str, columns: List[str], start_date: date, end_date: date, date_col: str = "created_at", where_clause: str = "", order_by: str = "", batch_size: int = 1000 ) -> Iterator[List[Dict[str, Any]]]: """ 分批迭代读取DWD表数据 Args: table_name: DWD表名(不含schema) columns: 需要查询的字段列表 start_date: 开始日期(含) end_date: 结束日期(含) date_col: 日期过滤字段 where_clause: 额外的WHERE条件(不含WHERE关键字) order_by: 排序字段(不含ORDER BY关键字) batch_size: 批次大小 Yields: 每批次的数据行列表 """ offset = 0 cols_str = ", ".join(columns) # 构建WHERE条件 where_parts = [f"DATE({date_col}) >= %s", f"DATE({date_col}) <= %s"] params: List[Any] = [start_date, end_date] if where_clause: where_parts.append(f"({where_clause})") where_str = " AND ".join(where_parts) # 构建ORDER BY order_str = f"ORDER BY {order_by}" if order_by else f"ORDER BY {date_col} ASC" while True: sql = f""" SELECT {cols_str} FROM {self.DWD_SCHEMA}.{table_name} WHERE {where_str} {order_str} LIMIT %s OFFSET %s """ rows = self.db.query(sql, (*params, batch_size, offset)) if not rows: break yield [dict(row) for row in rows] if len(rows) < batch_size: break offset += batch_size def query_dwd( self, sql: str, params: Optional[Tuple] = None ) -> List[Dict[str, Any]]: """ 直接执行DWD查询 Args: sql: SQL语句 params: 参数元组 Returns: 查询结果列表 """ rows = self.db.query(sql, params) return [dict(row) for row in rows] if rows else [] # ========================================================================== # SCD2维度取值方法 # ========================================================================== def get_assistant_level_asof( self, assistant_id: int, asof_date: date ) -> Optional[Dict[str, Any]]: """ 获取助教在指定日期的等级(SCD2 as-of取值) 助教等级是SCD2维度,历史月份不能直接用"当前等级"。 需要按有效期as-of join取数。 Args: assistant_id: 助教ID asof_date: 取值日期 Returns: 助教等级信息,包含level_code和level_name """ sql = """ SELECT assistant_id, nickname, level AS level_code, CASE level WHEN 8 THEN '助教管理' WHEN 10 THEN '初级' WHEN 20 THEN '中级' WHEN 30 THEN '高级' WHEN 40 THEN '星级' ELSE '未知' END AS level_name, scd2_start_time, scd2_end_time FROM billiards_dwd.dim_assistant WHERE assistant_id = %s AND scd2_start_time <= %s AND (scd2_end_time IS NULL OR scd2_end_time > %s) ORDER BY scd2_start_time DESC LIMIT 1 """ rows = self.db.query(sql, (assistant_id, asof_date, asof_date)) return dict(rows[0]) if rows else None def get_member_card_balance_asof( self, member_id: int, asof_date: date ) -> Dict[str, Decimal]: """ 获取会员在指定日期的卡余额(SCD2 as-of取值) Args: member_id: 会员ID asof_date: 取值日期 Returns: 卡余额字典,包含cash_balance和gift_balance """ sql = """ SELECT card_type_id, balance FROM billiards_dwd.dim_member_card_account WHERE tenant_member_id = %s AND scd2_start_time <= %s AND (scd2_end_time IS NULL OR scd2_end_time > %s) AND COALESCE(is_delete, 0) = 0 """ rows = self.db.query(sql, (member_id, asof_date, asof_date)) # 卡类型ID映射 CASH_CARD_TYPE_ID = 2793249295533893 GIFT_CARD_TYPE_IDS = [ 2791990152417157, # 台费卡 2793266846533445, # 活动抵用券 2794699703437125, # 酒水卡 ] cash_balance = Decimal('0') gift_balance = Decimal('0') for row in (rows or []): row_dict = dict(row) card_type_id = row_dict.get('card_type_id') balance = Decimal(str(row_dict.get('balance', 0))) if card_type_id == CASH_CARD_TYPE_ID: cash_balance += balance elif card_type_id in GIFT_CARD_TYPE_IDS: gift_balance += balance return { 'cash_balance': cash_balance, 'gift_balance': gift_balance, 'total_balance': cash_balance + gift_balance } # ========================================================================== # 幂等更新方法 # ========================================================================== def delete_existing_data( self, context: TaskContext, date_col: str = "stat_date", extra_conditions: Optional[Dict[str, Any]] = None ) -> int: """ 删除已存在的数据(实现幂等更新) Args: context: 任务上下文 date_col: 日期字段名 extra_conditions: 额外的删除条件 Returns: 删除的行数 """ target_table = self.get_target_table() full_table = f"{self.DWS_SCHEMA}.{target_table}" # 构建WHERE条件 where_parts = [f"site_id = %s"] params: List[Any] = [context.store_id] # 日期范围条件 start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end where_parts.append(f"{date_col} >= %s") params.append(start_date) where_parts.append(f"{date_col} <= %s") params.append(end_date) # 额外条件 if extra_conditions: for col, val in extra_conditions.items(): where_parts.append(f"{col} = %s") params.append(val) where_str = " AND ".join(where_parts) sql = f"DELETE FROM {full_table} WHERE {where_str}" with self.db.conn.cursor() as cur: cur.execute(sql, params) deleted = cur.rowcount self.logger.debug( "%s: 删除已存在数据 %d 行,条件: %s", self.get_task_code(), deleted, where_str ) return deleted def bulk_insert( self, rows: List[Dict[str, Any]], columns: Optional[List[str]] = None ) -> int: """ 批量插入数据 Args: rows: 数据行列表 columns: 字段列表(如果为None则从第一行获取) Returns: 插入的行数 """ if not rows: return 0 target_table = self.get_target_table() full_table = f"{self.DWS_SCHEMA}.{target_table}" if columns is None: columns = list(rows[0].keys()) cols_str = ", ".join(columns) placeholders = ", ".join(["%s"] * len(columns)) sql = f"INSERT INTO {full_table} ({cols_str}) VALUES ({placeholders})" inserted = 0 with self.db.conn.cursor() as cur: for row in rows: values = [row.get(col) for col in columns] cur.execute(sql, values) inserted += cur.rowcount return inserted def upsert( self, rows: List[Dict[str, Any]], columns: Optional[List[str]] = None, update_columns: Optional[List[str]] = None ) -> Tuple[int, int]: """ 批量upsert(插入或更新) Args: rows: 数据行列表 columns: 全部字段列表 update_columns: 需要更新的字段列表 Returns: (inserted, updated) 元组 """ if not rows: return 0, 0 target_table = self.get_target_table() full_table = f"{self.DWS_SCHEMA}.{target_table}" primary_keys = self.get_primary_keys() if columns is None: columns = list(rows[0].keys()) if update_columns is None: update_columns = [c for c in columns if c not in primary_keys and c not in ('created_at',)] cols_str = ", ".join(columns) placeholders = ", ".join(["%s"] * len(columns)) conflict_cols = ", ".join(primary_keys) update_parts = [f"{col} = EXCLUDED.{col}" for col in update_columns] update_parts.append("updated_at = NOW()") update_str = ", ".join(update_parts) sql = f""" INSERT INTO {full_table} ({cols_str}) VALUES ({placeholders}) ON CONFLICT ({conflict_cols}) DO UPDATE SET {update_str} """ inserted = 0 updated = 0 with self.db.conn.cursor() as cur: for row in rows: values = [row.get(col) for col in columns] cur.execute(sql, values) # PostgreSQL的INSERT ON CONFLICT返回1表示有操作 if cur.rowcount > 0: # 无法精确区分insert和update,统计为inserted inserted += 1 return inserted, updated # ========================================================================== # 滚动窗口统计方法 # ========================================================================== def calculate_rolling_stats( self, base_date: date, entity_id: int, entity_type: str, stat_sql: str, windows: Optional[List[int]] = None ) -> Dict[str, Any]: """ 计算滚动窗口统计 Args: base_date: 基准日期 entity_id: 实体ID(如assistant_id或member_id) entity_type: 实体类型(用于SQL参数名) stat_sql: 统计SQL模板,需要包含 {window_days} 和 {entity_id} 占位符 windows: 窗口天数列表,默认为 [7, 10, 15, 30, 60, 90] Returns: 各窗口的统计结果字典 """ if windows is None: windows = self.ROLLING_WINDOWS results = {} for days in windows: start_date = base_date - timedelta(days=days - 1) # 替换SQL中的占位符 sql = stat_sql.format( window_days=days, entity_id=entity_id, start_date=start_date, end_date=base_date ) rows = self.db.query(sql) if rows: for key, value in dict(rows[0]).items(): results[f"{key}_{days}d"] = value return results # ========================================================================== # 排名计算方法 # ========================================================================== def calculate_rank_with_ties( self, values: List[Tuple[int, Decimal]] ) -> List[Tuple[int, int, int]]: """ 计算考虑并列的排名 Top3排名口径:按绩效总小时数,如遇并列则都算, 比如2个第一,则记为2个第一,一个第三。 Args: values: (entity_id, score) 元组列表 Returns: (entity_id, rank, dense_rank) 元组列表 rank: 考虑并列的排名(如2个第一,下一个是3) dense_rank: 密集排名(如2个第一,下一个是2) """ if not values: return [] # 按分数降序排序 sorted_values = sorted(values, key=lambda x: x[1], reverse=True) results = [] prev_score = None prev_rank = 0 count = 0 for entity_id, score in sorted_values: count += 1 if score != prev_score: # 新的分数,rank为当前计数 current_rank = count prev_score = score else: # 相同分数,rank保持不变 current_rank = prev_rank prev_rank = current_rank results.append((entity_id, current_rank, count)) return results # ========================================================================== # 散客过滤 # ========================================================================== def is_guest(self, member_id: Optional[int]) -> bool: """ 判断是否为散客 散客处理:member_id=0 的客户是散客,不进入客户维度统计 Args: member_id: 会员ID Returns: 是否为散客 """ return member_id is None or member_id == 0 # ========================================================================== # 工具方法 # ========================================================================== def safe_decimal(self, value: Any, default: Decimal = Decimal('0')) -> Decimal: """安全转换为Decimal""" if value is None: return default try: return Decimal(str(value)) # CHANGE [2026-02-14] intent: 捕获 InvalidOperation 防止非数值字符串导致异常 # assumptions: 调用方期望 safe_decimal 对任何输入都不抛异常 except (ValueError, TypeError, InvalidOperation): return default def safe_int(self, value: Any, default: int = 0) -> int: """安全转换为int""" if value is None: return default try: return int(value) except (ValueError, TypeError): return default def seconds_to_hours(self, seconds: int) -> Decimal: """秒转换为小时""" return Decimal(str(seconds)) / Decimal('3600') def hours_to_seconds(self, hours: Decimal) -> int: """小时转换为秒""" return int(hours * Decimal('3600'))