1223 lines
40 KiB
Python
1223 lines
40 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
DWS层任务基类
|
||
|
||
功能说明:
|
||
- 提供从DWD层读取数据的标准方法
|
||
- 提供时间分层查询功能(近2天/近1月/近3月/近6月/全量)
|
||
- 提供配置表读取方法
|
||
- 提供幂等更新机制(delete-before-insert)
|
||
- 提供SCD2维度as-of取值方法
|
||
- 提供滚动窗口统计方法
|
||
|
||
时间口径说明:
|
||
- 周起始日:周一
|
||
- 月/季度起始:第一天0点
|
||
- 环比规则:对比上一个等长区间
|
||
- 前3个月:含/不含本月(用于财务筛选)
|
||
- 最近半年:不含本月
|
||
|
||
更新频率:
|
||
- 日度表:每日更新
|
||
- 实时表:每小时更新
|
||
- 月度表:每日更新当月数据
|
||
|
||
作者:ETL团队
|
||
创建日期:2026-02-01
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import calendar
|
||
from abc import abstractmethod
|
||
from dataclasses import dataclass
|
||
from datetime import date, datetime, timedelta
|
||
from decimal import Decimal
|
||
from enum import Enum
|
||
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar
|
||
|
||
from ..base_task import BaseTask, TaskContext
|
||
|
||
# =============================================================================
|
||
# 类型定义
|
||
# =============================================================================
|
||
|
||
T = TypeVar('T')
|
||
|
||
|
||
class TimeLayer(Enum):
|
||
"""时间分层枚举(用于数据筛选)"""
|
||
LAST_2_DAYS = "LAST_2_DAYS" # 近2天
|
||
LAST_1_MONTH = "LAST_1_MONTH" # 近1月
|
||
LAST_3_MONTHS = "LAST_3_MONTHS" # 近3月
|
||
LAST_6_MONTHS = "LAST_6_MONTHS" # 近6月(不含本月)
|
||
ALL = "ALL" # 全量
|
||
|
||
|
||
class TimeWindow(Enum):
|
||
"""时间窗口类型枚举(用于财务报表)"""
|
||
THIS_WEEK = "THIS_WEEK" # 本周(周一起始)
|
||
LAST_WEEK = "LAST_WEEK" # 上周
|
||
THIS_MONTH = "THIS_MONTH" # 本月
|
||
LAST_MONTH = "LAST_MONTH" # 上月
|
||
LAST_3_MONTHS_EXCL_CURRENT = "LAST_3_MONTHS_EXCL_CURRENT" # 前3个月不含本月
|
||
LAST_3_MONTHS_INCL_CURRENT = "LAST_3_MONTHS_INCL_CURRENT" # 前3个月含本月
|
||
THIS_QUARTER = "THIS_QUARTER" # 本季度
|
||
LAST_QUARTER = "LAST_QUARTER" # 上季度
|
||
LAST_6_MONTHS = "LAST_6_MONTHS" # 最近半年(不含本月)
|
||
|
||
|
||
class CourseType(Enum):
|
||
"""课程类型枚举"""
|
||
BASE = "BASE" # 基础课/陪打
|
||
BONUS = "BONUS" # 附加课/超休
|
||
ROOM = "ROOM" # 包厢课
|
||
|
||
|
||
class DiscountType(Enum):
|
||
"""优惠类型枚举"""
|
||
GROUPBUY = "GROUPBUY" # 团购优惠
|
||
VIP = "VIP" # 会员折扣
|
||
GIFT_CARD = "GIFT_CARD" # 赠送卡抵扣
|
||
MANUAL = "MANUAL" # 手动调整
|
||
ROUNDING = "ROUNDING" # 抹零
|
||
BIG_CUSTOMER = "BIG_CUSTOMER" # 大客户优惠
|
||
OTHER = "OTHER" # 其他优惠
|
||
|
||
|
||
@dataclass
|
||
class TimeRange:
|
||
"""时间范围数据类"""
|
||
start: date
|
||
end: date
|
||
|
||
|
||
@dataclass
|
||
class ConfigCache:
|
||
"""配置缓存数据类"""
|
||
performance_tiers: List[Dict[str, Any]] # 绩效档位配置
|
||
level_prices: List[Dict[str, Any]] # 等级定价配置
|
||
bonus_rules: List[Dict[str, Any]] # 奖金规则配置
|
||
area_categories: Dict[str, Dict[str, Any]] # 区域分类映射
|
||
skill_types: Dict[int, Dict[str, Any]] # 技能类型映射
|
||
loaded_at: datetime # 加载时间
|
||
|
||
|
||
# =============================================================================
|
||
# DWS任务基类
|
||
# =============================================================================
|
||
|
||
class BaseDwsTask(BaseTask):
|
||
"""
|
||
DWS层任务基类
|
||
|
||
提供DWS层通用功能:
|
||
1. DWD数据读取方法
|
||
2. 时间分层与窗口计算
|
||
3. 配置表缓存与读取
|
||
4. SCD2维度as-of取值
|
||
5. 幂等更新机制
|
||
6. 滚动窗口统计
|
||
"""
|
||
|
||
# 类级别的配置缓存
|
||
_config_cache: Optional[ConfigCache] = None
|
||
_config_cache_ttl: int = 300 # 缓存有效期(秒)
|
||
|
||
# DWS Schema名称
|
||
DWS_SCHEMA = "billiards_dws"
|
||
DWD_SCHEMA = "billiards_dwd"
|
||
|
||
# 滚动窗口天数列表
|
||
ROLLING_WINDOWS = [7, 10, 15, 30, 60, 90]
|
||
|
||
# ==========================================================================
|
||
# 抽象方法(子类必须实现)
|
||
# ==========================================================================
|
||
|
||
@abstractmethod
|
||
def get_target_table(self) -> str:
|
||
"""
|
||
获取目标表名(不含schema)
|
||
|
||
Returns:
|
||
目标表名,如 'dws_assistant_daily_detail'
|
||
"""
|
||
raise NotImplementedError("子类需实现 get_target_table 方法")
|
||
|
||
@abstractmethod
|
||
def get_primary_keys(self) -> List[str]:
|
||
"""
|
||
获取主键字段列表(用于幂等更新)
|
||
|
||
Returns:
|
||
主键字段列表,如 ['site_id', 'assistant_id', 'stat_date']
|
||
"""
|
||
raise NotImplementedError("子类需实现 get_primary_keys 方法")
|
||
|
||
# ==========================================================================
|
||
# 时间计算方法
|
||
# ==========================================================================
|
||
|
||
def get_time_layer_range(
|
||
self,
|
||
layer: TimeLayer,
|
||
base_date: Optional[date] = None
|
||
) -> TimeRange:
|
||
"""
|
||
获取时间分层的日期范围
|
||
|
||
Args:
|
||
layer: 时间分层枚举
|
||
base_date: 基准日期,默认为今天
|
||
|
||
Returns:
|
||
TimeRange对象,包含起止日期
|
||
"""
|
||
if base_date is None:
|
||
base_date = date.today()
|
||
|
||
if layer == TimeLayer.LAST_2_DAYS:
|
||
return TimeRange(
|
||
start=base_date - timedelta(days=1),
|
||
end=base_date
|
||
)
|
||
elif layer == TimeLayer.LAST_1_MONTH:
|
||
return TimeRange(
|
||
start=base_date - timedelta(days=30),
|
||
end=base_date
|
||
)
|
||
elif layer == TimeLayer.LAST_3_MONTHS:
|
||
return TimeRange(
|
||
start=base_date - timedelta(days=90),
|
||
end=base_date
|
||
)
|
||
elif layer == TimeLayer.LAST_6_MONTHS:
|
||
# 不含本月,从上月末往前6个月
|
||
month_start = self.get_month_first_day(base_date)
|
||
end = month_start - timedelta(days=1)
|
||
start = self.get_month_first_day(self._shift_months(month_start, -6))
|
||
return TimeRange(start=start, end=end)
|
||
else: # ALL
|
||
return TimeRange(
|
||
start=date(2000, 1, 1),
|
||
end=base_date
|
||
)
|
||
|
||
def get_time_window_range(
|
||
self,
|
||
window: TimeWindow,
|
||
base_date: Optional[date] = None
|
||
) -> TimeRange:
|
||
"""
|
||
获取时间窗口的日期范围(用于财务报表)
|
||
|
||
时间口径说明:
|
||
- 周起始日为周一
|
||
- 月/季度起始为第一天0点
|
||
|
||
Args:
|
||
window: 时间窗口枚举
|
||
base_date: 基准日期,默认为今天
|
||
|
||
Returns:
|
||
TimeRange对象
|
||
"""
|
||
if base_date is None:
|
||
base_date = date.today()
|
||
|
||
if window == TimeWindow.THIS_WEEK:
|
||
# 本周(周一起始)
|
||
days_since_monday = base_date.weekday()
|
||
start = base_date - timedelta(days=days_since_monday)
|
||
return TimeRange(start=start, end=base_date)
|
||
|
||
elif window == TimeWindow.LAST_WEEK:
|
||
# 上周
|
||
days_since_monday = base_date.weekday()
|
||
this_monday = base_date - timedelta(days=days_since_monday)
|
||
end = this_monday - timedelta(days=1) # 上周日
|
||
start = end - timedelta(days=6) # 上周一
|
||
return TimeRange(start=start, end=end)
|
||
|
||
elif window == TimeWindow.THIS_MONTH:
|
||
# 本月
|
||
start = base_date.replace(day=1)
|
||
return TimeRange(start=start, end=base_date)
|
||
|
||
elif window == TimeWindow.LAST_MONTH:
|
||
# 上月
|
||
month_start = base_date.replace(day=1)
|
||
end = month_start - timedelta(days=1)
|
||
start = end.replace(day=1)
|
||
return TimeRange(start=start, end=end)
|
||
|
||
elif window == TimeWindow.LAST_3_MONTHS_EXCL_CURRENT:
|
||
# 前3个月(不含本月):从三个月前月初到上月月末
|
||
current_month_start = self.get_month_first_day(base_date)
|
||
end = current_month_start - timedelta(days=1)
|
||
start = self.get_month_first_day(self._shift_months(current_month_start, -3))
|
||
return TimeRange(start=start, end=end)
|
||
|
||
elif window == TimeWindow.LAST_3_MONTHS_INCL_CURRENT:
|
||
# 前3个月(含本月):从两个月前月初到当前日期
|
||
current_month_start = self.get_month_first_day(base_date)
|
||
start = self.get_month_first_day(self._shift_months(current_month_start, -2))
|
||
return TimeRange(start=start, end=base_date)
|
||
|
||
elif window == TimeWindow.THIS_QUARTER:
|
||
# 本季度
|
||
quarter = (base_date.month - 1) // 3
|
||
start_month = quarter * 3 + 1
|
||
start = base_date.replace(month=start_month, day=1)
|
||
return TimeRange(start=start, end=base_date)
|
||
|
||
elif window == TimeWindow.LAST_QUARTER:
|
||
# 上季度
|
||
quarter = (base_date.month - 1) // 3
|
||
start_month = quarter * 3 + 1
|
||
this_quarter_start = base_date.replace(month=start_month, day=1)
|
||
end = this_quarter_start - timedelta(days=1)
|
||
prev_quarter = (end.month - 1) // 3
|
||
prev_start_month = prev_quarter * 3 + 1
|
||
start = end.replace(month=prev_start_month, day=1)
|
||
return TimeRange(start=start, end=end)
|
||
|
||
elif window == TimeWindow.LAST_6_MONTHS:
|
||
# 最近半年(不含本月)
|
||
month_start = self.get_month_first_day(base_date)
|
||
end = month_start - timedelta(days=1)
|
||
start = self.get_month_first_day(self._shift_months(month_start, -6))
|
||
return TimeRange(start=start, end=end)
|
||
|
||
raise ValueError(f"不支持的时间窗口类型: {window}")
|
||
|
||
def get_comparison_range(self, time_range: TimeRange) -> TimeRange:
|
||
"""
|
||
计算环比区间(上一个等长区间)
|
||
|
||
环比规则:对比上一个等长区间
|
||
|
||
Args:
|
||
time_range: 当前时间范围
|
||
|
||
Returns:
|
||
环比时间范围
|
||
"""
|
||
duration = (time_range.end - time_range.start).days + 1
|
||
prev_end = time_range.start - timedelta(days=1)
|
||
prev_start = prev_end - timedelta(days=duration - 1)
|
||
return TimeRange(start=prev_start, end=prev_end)
|
||
|
||
def get_month_first_day(self, dt: date) -> date:
|
||
"""获取月第一天"""
|
||
return dt.replace(day=1)
|
||
|
||
def get_month_last_day(self, dt: date) -> date:
|
||
"""获取月最后一天"""
|
||
last_day = calendar.monthrange(dt.year, dt.month)[1]
|
||
return dt.replace(day=last_day)
|
||
|
||
def _shift_months(self, base_date: date, months: int) -> date:
|
||
"""
|
||
按月偏移日期(保持日不越界)
|
||
"""
|
||
total_months = base_date.year * 12 + (base_date.month - 1) + months
|
||
year = total_months // 12
|
||
month = total_months % 12 + 1
|
||
last_day = calendar.monthrange(year, month)[1]
|
||
day = min(base_date.day, last_day)
|
||
return date(year, month, day)
|
||
|
||
def is_new_hire_in_month(self, hire_date: date, stat_month: date) -> bool:
|
||
"""
|
||
判断是否为新入职(月1日0点后入职)
|
||
|
||
新入职定档规则:月1日0点之后入职的,计算为新入职
|
||
|
||
Args:
|
||
hire_date: 入职日期
|
||
stat_month: 统计月份(月第一天)
|
||
|
||
Returns:
|
||
是否为新入职
|
||
"""
|
||
month_start = self.get_month_first_day(stat_month)
|
||
return hire_date >= month_start
|
||
|
||
# ==========================================================================
|
||
# 配置表读取方法
|
||
# ==========================================================================
|
||
|
||
def load_config_cache(self, force_reload: bool = False) -> ConfigCache:
|
||
"""
|
||
加载配置表缓存
|
||
|
||
Args:
|
||
force_reload: 是否强制重新加载
|
||
|
||
Returns:
|
||
ConfigCache对象
|
||
"""
|
||
now = datetime.now(self.tz)
|
||
|
||
# 检查缓存是否有效
|
||
if (
|
||
not force_reload
|
||
and self._config_cache is not None
|
||
and (now - self._config_cache.loaded_at).total_seconds() < self._config_cache_ttl
|
||
):
|
||
return self._config_cache
|
||
|
||
self.logger.debug("重新加载DWS配置表缓存")
|
||
|
||
# 加载绩效档位配置
|
||
performance_tiers = self._load_performance_tiers()
|
||
|
||
# 加载等级定价配置
|
||
level_prices = self._load_level_prices()
|
||
|
||
# 加载奖金规则配置
|
||
bonus_rules = self._load_bonus_rules()
|
||
|
||
# 加载区域分类映射
|
||
area_categories = self._load_area_categories()
|
||
|
||
# 加载技能类型映射
|
||
skill_types = self._load_skill_types()
|
||
|
||
self._config_cache = ConfigCache(
|
||
performance_tiers=performance_tiers,
|
||
level_prices=level_prices,
|
||
bonus_rules=bonus_rules,
|
||
area_categories=area_categories,
|
||
skill_types=skill_types,
|
||
loaded_at=now
|
||
)
|
||
|
||
return self._config_cache
|
||
|
||
def _load_performance_tiers(self) -> List[Dict[str, Any]]:
|
||
"""
|
||
加载绩效档位配置
|
||
|
||
字段说明(来自DWS数据库处理需求.md):
|
||
- base_deduction: 专业课抽成(元/小时),球房从基础课每小时扣除的金额
|
||
- bonus_deduction_ratio: 打赏课抽成比例,球房从附加课收入中扣除的比例
|
||
- vacation_days: 次月可休假天数
|
||
- vacation_unlimited: 休假自由标记(最高档为TRUE)
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
tier_id, tier_code, tier_name, tier_level,
|
||
min_hours, max_hours,
|
||
base_deduction, bonus_deduction_ratio,
|
||
vacation_days, vacation_unlimited,
|
||
is_new_hire_tier, effective_from, effective_to
|
||
FROM billiards_dws.cfg_performance_tier
|
||
ORDER BY tier_level ASC, effective_from ASC
|
||
"""
|
||
rows = self.db.query(sql)
|
||
return [dict(row) for row in rows] if rows else []
|
||
|
||
def _load_level_prices(self) -> List[Dict[str, Any]]:
|
||
"""加载等级定价配置"""
|
||
sql = """
|
||
SELECT
|
||
price_id, level_code, level_name,
|
||
base_course_price, bonus_course_price,
|
||
effective_from, effective_to
|
||
FROM billiards_dws.cfg_assistant_level_price
|
||
ORDER BY level_code ASC, effective_from DESC
|
||
"""
|
||
rows = self.db.query(sql)
|
||
return [dict(row) for row in rows] if rows else []
|
||
|
||
def _load_bonus_rules(self) -> List[Dict[str, Any]]:
|
||
"""加载奖金规则配置"""
|
||
sql = """
|
||
SELECT
|
||
rule_id, rule_type, rule_code, rule_name,
|
||
threshold_hours, rank_position, bonus_amount,
|
||
is_cumulative, priority,
|
||
effective_from, effective_to
|
||
FROM billiards_dws.cfg_bonus_rules
|
||
ORDER BY rule_type, priority DESC, effective_from DESC
|
||
"""
|
||
rows = self.db.query(sql)
|
||
return [dict(row) for row in rows] if rows else []
|
||
|
||
def _load_area_categories(self) -> Dict[str, Dict[str, Any]]:
|
||
"""加载区域分类映射"""
|
||
sql = """
|
||
SELECT
|
||
source_area_name, category_code, category_name,
|
||
match_type, match_priority
|
||
FROM billiards_dws.cfg_area_category
|
||
WHERE is_active = TRUE
|
||
ORDER BY match_priority ASC
|
||
"""
|
||
rows = self.db.query(sql)
|
||
if not rows:
|
||
return {}
|
||
|
||
result = {}
|
||
for row in rows:
|
||
row_dict = dict(row)
|
||
result[row_dict['source_area_name']] = row_dict
|
||
return result
|
||
|
||
def _load_skill_types(self) -> Dict[int, Dict[str, Any]]:
|
||
"""加载技能类型映射"""
|
||
sql = """
|
||
SELECT
|
||
skill_id, skill_name,
|
||
course_type_code, course_type_name
|
||
FROM billiards_dws.cfg_skill_type
|
||
WHERE is_active = TRUE
|
||
"""
|
||
rows = self.db.query(sql)
|
||
if not rows:
|
||
return {}
|
||
|
||
result = {}
|
||
for row in rows:
|
||
row_dict = dict(row)
|
||
result[int(row_dict['skill_id'])] = row_dict
|
||
return result
|
||
|
||
# ==========================================================================
|
||
# 配置应用方法
|
||
# ==========================================================================
|
||
|
||
def _filter_by_effective_date(
|
||
self,
|
||
items: List[Dict[str, Any]],
|
||
effective_date: Optional[date]
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
按生效期过滤配置项
|
||
"""
|
||
ref_date = effective_date or date.today()
|
||
results: List[Dict[str, Any]] = []
|
||
for item in items:
|
||
eff_from = item.get('effective_from')
|
||
eff_to = item.get('effective_to')
|
||
if eff_from and ref_date < eff_from:
|
||
continue
|
||
if eff_to and ref_date > eff_to:
|
||
continue
|
||
results.append(item)
|
||
return results
|
||
|
||
def get_performance_tier(
|
||
self,
|
||
effective_hours: Decimal,
|
||
is_new_hire: bool,
|
||
effective_date: Optional[date] = None,
|
||
max_tier_level: Optional[int] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据有效业绩小时数匹配绩效档位
|
||
|
||
Args:
|
||
effective_hours: 有效业绩小时数
|
||
is_new_hire: 是否为新入职
|
||
effective_date: 生效日期(用于历史月份)
|
||
|
||
Returns:
|
||
匹配的档位配置,如果没有匹配则返回None
|
||
"""
|
||
_ = is_new_hire # 保留参数以兼容调用方,新入职封顶逻辑在月度任务中处理
|
||
config = self.load_config_cache()
|
||
tiers = self._filter_by_effective_date(config.performance_tiers, effective_date)
|
||
|
||
if max_tier_level is not None:
|
||
tiers = [
|
||
t for t in tiers
|
||
if t.get('tier_level') is None or int(t.get('tier_level')) <= max_tier_level
|
||
]
|
||
|
||
# 按阈值匹配档位
|
||
for tier in tiers:
|
||
if tier.get('is_new_hire_tier'):
|
||
continue
|
||
min_hours = Decimal(str(tier.get('min_hours', 0)))
|
||
max_hours = tier.get('max_hours')
|
||
if max_hours is not None:
|
||
max_hours = Decimal(str(max_hours))
|
||
|
||
if effective_hours >= min_hours:
|
||
if max_hours is None or effective_hours < max_hours:
|
||
return tier
|
||
|
||
return None
|
||
|
||
def get_performance_tier_by_id(
|
||
self,
|
||
tier_id: Optional[int],
|
||
effective_date: Optional[date] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过档位ID获取配置(支持生效期筛选)
|
||
"""
|
||
if not tier_id:
|
||
return None
|
||
|
||
config = self.load_config_cache()
|
||
tiers = self._filter_by_effective_date(config.performance_tiers, effective_date)
|
||
for tier in tiers:
|
||
if tier.get('tier_id') == tier_id:
|
||
return tier
|
||
return None
|
||
|
||
def get_level_price(
|
||
self,
|
||
level_code: int,
|
||
effective_date: Optional[date] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取助教等级对应的单价(SCD2口径,按生效日期取值)
|
||
|
||
Args:
|
||
level_code: 等级代码
|
||
effective_date: 生效日期
|
||
|
||
Returns:
|
||
等级定价配置
|
||
"""
|
||
config = self.load_config_cache()
|
||
prices = self._filter_by_effective_date(config.level_prices, effective_date)
|
||
|
||
for price in prices:
|
||
if price.get('level_code') == level_code:
|
||
return price
|
||
|
||
return None
|
||
|
||
def get_course_type(self, skill_id: int) -> CourseType:
|
||
"""
|
||
根据skill_id获取课程类型
|
||
|
||
Args:
|
||
skill_id: 技能ID
|
||
|
||
Returns:
|
||
CourseType枚举
|
||
"""
|
||
config = self.load_config_cache()
|
||
skill_config = config.skill_types.get(skill_id)
|
||
|
||
if skill_config:
|
||
code = skill_config.get('course_type_code', 'BASE')
|
||
if code == 'BONUS':
|
||
return CourseType.BONUS
|
||
if code == 'ROOM':
|
||
return CourseType.ROOM
|
||
return CourseType.BASE
|
||
|
||
# 默认为基础课
|
||
return CourseType.BASE
|
||
|
||
def get_area_category(self, area_name: Optional[str]) -> Dict[str, str]:
|
||
"""
|
||
获取区域分类(支持精确匹配、模糊匹配、兜底)
|
||
|
||
Args:
|
||
area_name: 原始区域名称
|
||
|
||
Returns:
|
||
包含 category_code 和 category_name 的字典
|
||
"""
|
||
config = self.load_config_cache()
|
||
|
||
if not area_name:
|
||
# 无区域名称,返回默认
|
||
return {'category_code': 'OTHER', 'category_name': '其他区域'}
|
||
|
||
# 1. 精确匹配
|
||
if area_name in config.area_categories:
|
||
cat = config.area_categories[area_name]
|
||
if cat.get('match_type') == 'EXACT':
|
||
return {
|
||
'category_code': cat['category_code'],
|
||
'category_name': cat['category_name']
|
||
}
|
||
|
||
# 2. 模糊匹配(按优先级)
|
||
for key, cat in config.area_categories.items():
|
||
if cat.get('match_type') == 'LIKE':
|
||
pattern = key.replace('%', '')
|
||
if pattern and pattern in area_name:
|
||
return {
|
||
'category_code': cat['category_code'],
|
||
'category_name': cat['category_name']
|
||
}
|
||
|
||
# 3. 兜底
|
||
if 'DEFAULT' in config.area_categories:
|
||
cat = config.area_categories['DEFAULT']
|
||
return {
|
||
'category_code': cat['category_code'],
|
||
'category_name': cat['category_name']
|
||
}
|
||
|
||
return {'category_code': 'OTHER', 'category_name': '其他区域'}
|
||
|
||
def calculate_sprint_bonus(
|
||
self,
|
||
effective_hours: Decimal,
|
||
effective_date: Optional[date] = None
|
||
) -> Decimal:
|
||
"""
|
||
计算冲刺奖金(不累计,取最高档)
|
||
|
||
冲刺奖金规则:
|
||
- 按 cfg_bonus_rules 配置(可为历史口径)
|
||
- 不累计,取最高档
|
||
|
||
Args:
|
||
effective_hours: 有效业绩小时数
|
||
effective_date: 生效日期
|
||
|
||
Returns:
|
||
冲刺奖金金额
|
||
"""
|
||
config = self.load_config_cache()
|
||
bonus_rules = self._filter_by_effective_date(config.bonus_rules, effective_date)
|
||
|
||
# 筛选冲刺奖金规则,按优先级降序
|
||
sprint_rules = [
|
||
r for r in bonus_rules
|
||
if r.get('rule_type') == 'SPRINT'
|
||
]
|
||
sprint_rules.sort(key=lambda x: x.get('priority', 0), reverse=True)
|
||
|
||
# 取满足条件的最高档
|
||
for rule in sprint_rules:
|
||
threshold = Decimal(str(rule.get('threshold_hours', 0)))
|
||
if effective_hours >= threshold:
|
||
return Decimal(str(rule.get('bonus_amount', 0)))
|
||
|
||
return Decimal('0')
|
||
|
||
def calculate_top_rank_bonus(
|
||
self,
|
||
rank: int,
|
||
effective_date: Optional[date] = None
|
||
) -> Decimal:
|
||
"""
|
||
计算Top排名奖金
|
||
|
||
Top3奖金规则:
|
||
- 第1名: 1000元
|
||
- 第2名: 600元
|
||
- 第3名: 400元
|
||
- 并列都算
|
||
|
||
Args:
|
||
rank: 排名(考虑并列后的排名)
|
||
effective_date: 生效日期
|
||
|
||
Returns:
|
||
排名奖金金额
|
||
"""
|
||
config = self.load_config_cache()
|
||
bonus_rules = self._filter_by_effective_date(config.bonus_rules, effective_date)
|
||
|
||
if rank > 3:
|
||
return Decimal('0')
|
||
|
||
for rule in bonus_rules:
|
||
if rule.get('rule_type') == 'TOP_RANK':
|
||
if rule.get('rank_position') == rank:
|
||
return Decimal(str(rule.get('bonus_amount', 0)))
|
||
|
||
return Decimal('0')
|
||
|
||
# ==========================================================================
|
||
# DWD数据读取方法
|
||
# ==========================================================================
|
||
|
||
def iter_dwd_rows(
|
||
self,
|
||
table_name: str,
|
||
columns: List[str],
|
||
start_date: date,
|
||
end_date: date,
|
||
date_col: str = "created_at",
|
||
where_clause: str = "",
|
||
order_by: str = "",
|
||
batch_size: int = 1000
|
||
) -> Iterator[List[Dict[str, Any]]]:
|
||
"""
|
||
分批迭代读取DWD表数据
|
||
|
||
Args:
|
||
table_name: DWD表名(不含schema)
|
||
columns: 需要查询的字段列表
|
||
start_date: 开始日期(含)
|
||
end_date: 结束日期(含)
|
||
date_col: 日期过滤字段
|
||
where_clause: 额外的WHERE条件(不含WHERE关键字)
|
||
order_by: 排序字段(不含ORDER BY关键字)
|
||
batch_size: 批次大小
|
||
|
||
Yields:
|
||
每批次的数据行列表
|
||
"""
|
||
offset = 0
|
||
cols_str = ", ".join(columns)
|
||
|
||
# 构建WHERE条件
|
||
where_parts = [f"DATE({date_col}) >= %s", f"DATE({date_col}) <= %s"]
|
||
params: List[Any] = [start_date, end_date]
|
||
|
||
if where_clause:
|
||
where_parts.append(f"({where_clause})")
|
||
|
||
where_str = " AND ".join(where_parts)
|
||
|
||
# 构建ORDER BY
|
||
order_str = f"ORDER BY {order_by}" if order_by else f"ORDER BY {date_col} ASC"
|
||
|
||
while True:
|
||
sql = f"""
|
||
SELECT {cols_str}
|
||
FROM {self.DWD_SCHEMA}.{table_name}
|
||
WHERE {where_str}
|
||
{order_str}
|
||
LIMIT %s OFFSET %s
|
||
"""
|
||
|
||
rows = self.db.query(sql, (*params, batch_size, offset))
|
||
|
||
if not rows:
|
||
break
|
||
|
||
yield [dict(row) for row in rows]
|
||
|
||
if len(rows) < batch_size:
|
||
break
|
||
|
||
offset += batch_size
|
||
|
||
def query_dwd(
|
||
self,
|
||
sql: str,
|
||
params: Optional[Tuple] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
直接执行DWD查询
|
||
|
||
Args:
|
||
sql: SQL语句
|
||
params: 参数元组
|
||
|
||
Returns:
|
||
查询结果列表
|
||
"""
|
||
rows = self.db.query(sql, params)
|
||
return [dict(row) for row in rows] if rows else []
|
||
|
||
# ==========================================================================
|
||
# SCD2维度取值方法
|
||
# ==========================================================================
|
||
|
||
def get_assistant_level_asof(
|
||
self,
|
||
assistant_id: int,
|
||
asof_date: date
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取助教在指定日期的等级(SCD2 as-of取值)
|
||
|
||
助教等级是SCD2维度,历史月份不能直接用"当前等级"。
|
||
需要按有效期as-of join取数。
|
||
|
||
Args:
|
||
assistant_id: 助教ID
|
||
asof_date: 取值日期
|
||
|
||
Returns:
|
||
助教等级信息,包含level_code和level_name
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
assistant_id,
|
||
nickname,
|
||
level AS level_code,
|
||
CASE level
|
||
WHEN 8 THEN '助教管理'
|
||
WHEN 10 THEN '初级'
|
||
WHEN 20 THEN '中级'
|
||
WHEN 30 THEN '高级'
|
||
WHEN 40 THEN '星级'
|
||
ELSE '未知'
|
||
END AS level_name,
|
||
scd2_start_time,
|
||
scd2_end_time
|
||
FROM billiards_dwd.dim_assistant
|
||
WHERE assistant_id = %s
|
||
AND scd2_start_time <= %s
|
||
AND (scd2_end_time IS NULL OR scd2_end_time > %s)
|
||
ORDER BY scd2_start_time DESC
|
||
LIMIT 1
|
||
"""
|
||
rows = self.db.query(sql, (assistant_id, asof_date, asof_date))
|
||
return dict(rows[0]) if rows else None
|
||
|
||
def get_member_card_balance_asof(
|
||
self,
|
||
member_id: int,
|
||
asof_date: date
|
||
) -> Dict[str, Decimal]:
|
||
"""
|
||
获取会员在指定日期的卡余额(SCD2 as-of取值)
|
||
|
||
Args:
|
||
member_id: 会员ID
|
||
asof_date: 取值日期
|
||
|
||
Returns:
|
||
卡余额字典,包含cash_balance和gift_balance
|
||
"""
|
||
sql = """
|
||
SELECT
|
||
card_type_id,
|
||
balance
|
||
FROM billiards_dwd.dim_member_card_account
|
||
WHERE tenant_member_id = %s
|
||
AND scd2_start_time <= %s
|
||
AND (scd2_end_time IS NULL OR scd2_end_time > %s)
|
||
AND COALESCE(is_delete, 0) = 0
|
||
"""
|
||
rows = self.db.query(sql, (member_id, asof_date, asof_date))
|
||
|
||
# 卡类型ID映射
|
||
CASH_CARD_TYPE_ID = 2793249295533893
|
||
GIFT_CARD_TYPE_IDS = [
|
||
2791990152417157, # 台费卡
|
||
2793266846533445, # 活动抵用券
|
||
2794699703437125, # 酒水卡
|
||
]
|
||
|
||
cash_balance = Decimal('0')
|
||
gift_balance = Decimal('0')
|
||
|
||
for row in (rows or []):
|
||
row_dict = dict(row)
|
||
card_type_id = row_dict.get('card_type_id')
|
||
balance = Decimal(str(row_dict.get('balance', 0)))
|
||
|
||
if card_type_id == CASH_CARD_TYPE_ID:
|
||
cash_balance += balance
|
||
elif card_type_id in GIFT_CARD_TYPE_IDS:
|
||
gift_balance += balance
|
||
|
||
return {
|
||
'cash_balance': cash_balance,
|
||
'gift_balance': gift_balance,
|
||
'total_balance': cash_balance + gift_balance
|
||
}
|
||
|
||
# ==========================================================================
|
||
# 幂等更新方法
|
||
# ==========================================================================
|
||
|
||
def delete_existing_data(
|
||
self,
|
||
context: TaskContext,
|
||
date_col: str = "stat_date",
|
||
extra_conditions: Optional[Dict[str, Any]] = None
|
||
) -> int:
|
||
"""
|
||
删除已存在的数据(实现幂等更新)
|
||
|
||
Args:
|
||
context: 任务上下文
|
||
date_col: 日期字段名
|
||
extra_conditions: 额外的删除条件
|
||
|
||
Returns:
|
||
删除的行数
|
||
"""
|
||
target_table = self.get_target_table()
|
||
full_table = f"{self.DWS_SCHEMA}.{target_table}"
|
||
|
||
# 构建WHERE条件
|
||
where_parts = [f"site_id = %s"]
|
||
params: List[Any] = [context.store_id]
|
||
|
||
# 日期范围条件
|
||
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
|
||
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
|
||
|
||
where_parts.append(f"{date_col} >= %s")
|
||
params.append(start_date)
|
||
where_parts.append(f"{date_col} <= %s")
|
||
params.append(end_date)
|
||
|
||
# 额外条件
|
||
if extra_conditions:
|
||
for col, val in extra_conditions.items():
|
||
where_parts.append(f"{col} = %s")
|
||
params.append(val)
|
||
|
||
where_str = " AND ".join(where_parts)
|
||
|
||
sql = f"DELETE FROM {full_table} WHERE {where_str}"
|
||
|
||
with self.db.conn.cursor() as cur:
|
||
cur.execute(sql, params)
|
||
deleted = cur.rowcount
|
||
|
||
self.logger.debug(
|
||
"%s: 删除已存在数据 %d 行,条件: %s",
|
||
self.get_task_code(), deleted, where_str
|
||
)
|
||
|
||
return deleted
|
||
|
||
def bulk_insert(
|
||
self,
|
||
rows: List[Dict[str, Any]],
|
||
columns: Optional[List[str]] = None
|
||
) -> int:
|
||
"""
|
||
批量插入数据
|
||
|
||
Args:
|
||
rows: 数据行列表
|
||
columns: 字段列表(如果为None则从第一行获取)
|
||
|
||
Returns:
|
||
插入的行数
|
||
"""
|
||
if not rows:
|
||
return 0
|
||
|
||
target_table = self.get_target_table()
|
||
full_table = f"{self.DWS_SCHEMA}.{target_table}"
|
||
|
||
if columns is None:
|
||
columns = list(rows[0].keys())
|
||
|
||
cols_str = ", ".join(columns)
|
||
placeholders = ", ".join(["%s"] * len(columns))
|
||
|
||
sql = f"INSERT INTO {full_table} ({cols_str}) VALUES ({placeholders})"
|
||
|
||
inserted = 0
|
||
with self.db.conn.cursor() as cur:
|
||
for row in rows:
|
||
values = [row.get(col) for col in columns]
|
||
cur.execute(sql, values)
|
||
inserted += cur.rowcount
|
||
|
||
return inserted
|
||
|
||
def upsert(
|
||
self,
|
||
rows: List[Dict[str, Any]],
|
||
columns: Optional[List[str]] = None,
|
||
update_columns: Optional[List[str]] = None
|
||
) -> Tuple[int, int]:
|
||
"""
|
||
批量upsert(插入或更新)
|
||
|
||
Args:
|
||
rows: 数据行列表
|
||
columns: 全部字段列表
|
||
update_columns: 需要更新的字段列表
|
||
|
||
Returns:
|
||
(inserted, updated) 元组
|
||
"""
|
||
if not rows:
|
||
return 0, 0
|
||
|
||
target_table = self.get_target_table()
|
||
full_table = f"{self.DWS_SCHEMA}.{target_table}"
|
||
primary_keys = self.get_primary_keys()
|
||
|
||
if columns is None:
|
||
columns = list(rows[0].keys())
|
||
|
||
if update_columns is None:
|
||
update_columns = [c for c in columns if c not in primary_keys and c not in ('created_at',)]
|
||
|
||
cols_str = ", ".join(columns)
|
||
placeholders = ", ".join(["%s"] * len(columns))
|
||
conflict_cols = ", ".join(primary_keys)
|
||
|
||
update_parts = [f"{col} = EXCLUDED.{col}" for col in update_columns]
|
||
update_parts.append("updated_at = NOW()")
|
||
update_str = ", ".join(update_parts)
|
||
|
||
sql = f"""
|
||
INSERT INTO {full_table} ({cols_str})
|
||
VALUES ({placeholders})
|
||
ON CONFLICT ({conflict_cols})
|
||
DO UPDATE SET {update_str}
|
||
"""
|
||
|
||
inserted = 0
|
||
updated = 0
|
||
|
||
with self.db.conn.cursor() as cur:
|
||
for row in rows:
|
||
values = [row.get(col) for col in columns]
|
||
cur.execute(sql, values)
|
||
# PostgreSQL的INSERT ON CONFLICT返回1表示有操作
|
||
if cur.rowcount > 0:
|
||
# 无法精确区分insert和update,统计为inserted
|
||
inserted += 1
|
||
|
||
return inserted, updated
|
||
|
||
# ==========================================================================
|
||
# 滚动窗口统计方法
|
||
# ==========================================================================
|
||
|
||
def calculate_rolling_stats(
|
||
self,
|
||
base_date: date,
|
||
entity_id: int,
|
||
entity_type: str,
|
||
stat_sql: str,
|
||
windows: Optional[List[int]] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
计算滚动窗口统计
|
||
|
||
Args:
|
||
base_date: 基准日期
|
||
entity_id: 实体ID(如assistant_id或member_id)
|
||
entity_type: 实体类型(用于SQL参数名)
|
||
stat_sql: 统计SQL模板,需要包含 {window_days} 和 {entity_id} 占位符
|
||
windows: 窗口天数列表,默认为 [7, 10, 15, 30, 60, 90]
|
||
|
||
Returns:
|
||
各窗口的统计结果字典
|
||
"""
|
||
if windows is None:
|
||
windows = self.ROLLING_WINDOWS
|
||
|
||
results = {}
|
||
|
||
for days in windows:
|
||
start_date = base_date - timedelta(days=days - 1)
|
||
|
||
# 替换SQL中的占位符
|
||
sql = stat_sql.format(
|
||
window_days=days,
|
||
entity_id=entity_id,
|
||
start_date=start_date,
|
||
end_date=base_date
|
||
)
|
||
|
||
rows = self.db.query(sql)
|
||
if rows:
|
||
for key, value in dict(rows[0]).items():
|
||
results[f"{key}_{days}d"] = value
|
||
|
||
return results
|
||
|
||
# ==========================================================================
|
||
# 排名计算方法
|
||
# ==========================================================================
|
||
|
||
def calculate_rank_with_ties(
|
||
self,
|
||
values: List[Tuple[int, Decimal]]
|
||
) -> List[Tuple[int, int, int]]:
|
||
"""
|
||
计算考虑并列的排名
|
||
|
||
Top3排名口径:按绩效总小时数,如遇并列则都算,
|
||
比如2个第一,则记为2个第一,一个第三。
|
||
|
||
Args:
|
||
values: (entity_id, score) 元组列表
|
||
|
||
Returns:
|
||
(entity_id, rank, dense_rank) 元组列表
|
||
rank: 考虑并列的排名(如2个第一,下一个是3)
|
||
dense_rank: 密集排名(如2个第一,下一个是2)
|
||
"""
|
||
if not values:
|
||
return []
|
||
|
||
# 按分数降序排序
|
||
sorted_values = sorted(values, key=lambda x: x[1], reverse=True)
|
||
|
||
results = []
|
||
prev_score = None
|
||
prev_rank = 0
|
||
count = 0
|
||
|
||
for entity_id, score in sorted_values:
|
||
count += 1
|
||
|
||
if score != prev_score:
|
||
# 新的分数,rank为当前计数
|
||
current_rank = count
|
||
prev_score = score
|
||
else:
|
||
# 相同分数,rank保持不变
|
||
current_rank = prev_rank
|
||
|
||
prev_rank = current_rank
|
||
results.append((entity_id, current_rank, count))
|
||
|
||
return results
|
||
|
||
# ==========================================================================
|
||
# 散客过滤
|
||
# ==========================================================================
|
||
|
||
def is_guest(self, member_id: Optional[int]) -> bool:
|
||
"""
|
||
判断是否为散客
|
||
|
||
散客处理:member_id=0 的客户是散客,不进入客户维度统计
|
||
|
||
Args:
|
||
member_id: 会员ID
|
||
|
||
Returns:
|
||
是否为散客
|
||
"""
|
||
return member_id is None or member_id == 0
|
||
|
||
# ==========================================================================
|
||
# 工具方法
|
||
# ==========================================================================
|
||
|
||
def safe_decimal(self, value: Any, default: Decimal = Decimal('0')) -> Decimal:
|
||
"""安全转换为Decimal"""
|
||
if value is None:
|
||
return default
|
||
try:
|
||
return Decimal(str(value))
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def safe_int(self, value: Any, default: int = 0) -> int:
|
||
"""安全转换为int"""
|
||
if value is None:
|
||
return default
|
||
try:
|
||
return int(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
def seconds_to_hours(self, seconds: int) -> Decimal:
|
||
"""秒转换为小时"""
|
||
return Decimal(str(seconds)) / Decimal('3600')
|
||
|
||
def hours_to_seconds(self, hours: Decimal) -> int:
|
||
"""小时转换为秒"""
|
||
return int(hours * Decimal('3600'))
|