Files
ZQYY.FQ-ETL/tasks/dws/base_dws_task.py

1223 lines
40 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
DWS层任务基类
功能说明:
- 提供从DWD层读取数据的标准方法
- 提供时间分层查询功能近2天/近1月/近3月/近6月/全量)
- 提供配置表读取方法
- 提供幂等更新机制delete-before-insert
- 提供SCD2维度as-of取值方法
- 提供滚动窗口统计方法
时间口径说明:
- 周起始日:周一
- 月/季度起始第一天0点
- 环比规则:对比上一个等长区间
- 前3个月含/不含本月(用于财务筛选)
- 最近半年:不含本月
更新频率:
- 日度表:每日更新
- 实时表:每小时更新
- 月度表:每日更新当月数据
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
import calendar
from abc import abstractmethod
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar
from ..base_task import BaseTask, TaskContext
# =============================================================================
# 类型定义
# =============================================================================
T = TypeVar('T')
class TimeLayer(Enum):
"""时间分层枚举(用于数据筛选)"""
LAST_2_DAYS = "LAST_2_DAYS" # 近2天
LAST_1_MONTH = "LAST_1_MONTH" # 近1月
LAST_3_MONTHS = "LAST_3_MONTHS" # 近3月
LAST_6_MONTHS = "LAST_6_MONTHS" # 近6月不含本月
ALL = "ALL" # 全量
class TimeWindow(Enum):
"""时间窗口类型枚举(用于财务报表)"""
THIS_WEEK = "THIS_WEEK" # 本周(周一起始)
LAST_WEEK = "LAST_WEEK" # 上周
THIS_MONTH = "THIS_MONTH" # 本月
LAST_MONTH = "LAST_MONTH" # 上月
LAST_3_MONTHS_EXCL_CURRENT = "LAST_3_MONTHS_EXCL_CURRENT" # 前3个月不含本月
LAST_3_MONTHS_INCL_CURRENT = "LAST_3_MONTHS_INCL_CURRENT" # 前3个月含本月
THIS_QUARTER = "THIS_QUARTER" # 本季度
LAST_QUARTER = "LAST_QUARTER" # 上季度
LAST_6_MONTHS = "LAST_6_MONTHS" # 最近半年(不含本月)
class CourseType(Enum):
"""课程类型枚举"""
BASE = "BASE" # 基础课/陪打
BONUS = "BONUS" # 附加课/超休
ROOM = "ROOM" # 包厢课
class DiscountType(Enum):
"""优惠类型枚举"""
GROUPBUY = "GROUPBUY" # 团购优惠
VIP = "VIP" # 会员折扣
GIFT_CARD = "GIFT_CARD" # 赠送卡抵扣
MANUAL = "MANUAL" # 手动调整
ROUNDING = "ROUNDING" # 抹零
BIG_CUSTOMER = "BIG_CUSTOMER" # 大客户优惠
OTHER = "OTHER" # 其他优惠
@dataclass
class TimeRange:
"""时间范围数据类"""
start: date
end: date
@dataclass
class ConfigCache:
"""配置缓存数据类"""
performance_tiers: List[Dict[str, Any]] # 绩效档位配置
level_prices: List[Dict[str, Any]] # 等级定价配置
bonus_rules: List[Dict[str, Any]] # 奖金规则配置
area_categories: Dict[str, Dict[str, Any]] # 区域分类映射
skill_types: Dict[int, Dict[str, Any]] # 技能类型映射
loaded_at: datetime # 加载时间
# =============================================================================
# DWS任务基类
# =============================================================================
class BaseDwsTask(BaseTask):
"""
DWS层任务基类
提供DWS层通用功能
1. DWD数据读取方法
2. 时间分层与窗口计算
3. 配置表缓存与读取
4. SCD2维度as-of取值
5. 幂等更新机制
6. 滚动窗口统计
"""
# 类级别的配置缓存
_config_cache: Optional[ConfigCache] = None
_config_cache_ttl: int = 300 # 缓存有效期(秒)
# DWS Schema名称
DWS_SCHEMA = "billiards_dws"
DWD_SCHEMA = "billiards_dwd"
# 滚动窗口天数列表
ROLLING_WINDOWS = [7, 10, 15, 30, 60, 90]
# ==========================================================================
# 抽象方法(子类必须实现)
# ==========================================================================
@abstractmethod
def get_target_table(self) -> str:
"""
获取目标表名不含schema
Returns:
目标表名,如 'dws_assistant_daily_detail'
"""
raise NotImplementedError("子类需实现 get_target_table 方法")
@abstractmethod
def get_primary_keys(self) -> List[str]:
"""
获取主键字段列表(用于幂等更新)
Returns:
主键字段列表,如 ['site_id', 'assistant_id', 'stat_date']
"""
raise NotImplementedError("子类需实现 get_primary_keys 方法")
# ==========================================================================
# 时间计算方法
# ==========================================================================
def get_time_layer_range(
self,
layer: TimeLayer,
base_date: Optional[date] = None
) -> TimeRange:
"""
获取时间分层的日期范围
Args:
layer: 时间分层枚举
base_date: 基准日期,默认为今天
Returns:
TimeRange对象包含起止日期
"""
if base_date is None:
base_date = date.today()
if layer == TimeLayer.LAST_2_DAYS:
return TimeRange(
start=base_date - timedelta(days=1),
end=base_date
)
elif layer == TimeLayer.LAST_1_MONTH:
return TimeRange(
start=base_date - timedelta(days=30),
end=base_date
)
elif layer == TimeLayer.LAST_3_MONTHS:
return TimeRange(
start=base_date - timedelta(days=90),
end=base_date
)
elif layer == TimeLayer.LAST_6_MONTHS:
# 不含本月从上月末往前6个月
month_start = self.get_month_first_day(base_date)
end = month_start - timedelta(days=1)
start = self.get_month_first_day(self._shift_months(month_start, -6))
return TimeRange(start=start, end=end)
else: # ALL
return TimeRange(
start=date(2000, 1, 1),
end=base_date
)
def get_time_window_range(
self,
window: TimeWindow,
base_date: Optional[date] = None
) -> TimeRange:
"""
获取时间窗口的日期范围(用于财务报表)
时间口径说明:
- 周起始日为周一
- 月/季度起始为第一天0点
Args:
window: 时间窗口枚举
base_date: 基准日期,默认为今天
Returns:
TimeRange对象
"""
if base_date is None:
base_date = date.today()
if window == TimeWindow.THIS_WEEK:
# 本周(周一起始)
days_since_monday = base_date.weekday()
start = base_date - timedelta(days=days_since_monday)
return TimeRange(start=start, end=base_date)
elif window == TimeWindow.LAST_WEEK:
# 上周
days_since_monday = base_date.weekday()
this_monday = base_date - timedelta(days=days_since_monday)
end = this_monday - timedelta(days=1) # 上周日
start = end - timedelta(days=6) # 上周一
return TimeRange(start=start, end=end)
elif window == TimeWindow.THIS_MONTH:
# 本月
start = base_date.replace(day=1)
return TimeRange(start=start, end=base_date)
elif window == TimeWindow.LAST_MONTH:
# 上月
month_start = base_date.replace(day=1)
end = month_start - timedelta(days=1)
start = end.replace(day=1)
return TimeRange(start=start, end=end)
elif window == TimeWindow.LAST_3_MONTHS_EXCL_CURRENT:
# 前3个月不含本月从三个月前月初到上月月末
current_month_start = self.get_month_first_day(base_date)
end = current_month_start - timedelta(days=1)
start = self.get_month_first_day(self._shift_months(current_month_start, -3))
return TimeRange(start=start, end=end)
elif window == TimeWindow.LAST_3_MONTHS_INCL_CURRENT:
# 前3个月含本月从两个月前月初到当前日期
current_month_start = self.get_month_first_day(base_date)
start = self.get_month_first_day(self._shift_months(current_month_start, -2))
return TimeRange(start=start, end=base_date)
elif window == TimeWindow.THIS_QUARTER:
# 本季度
quarter = (base_date.month - 1) // 3
start_month = quarter * 3 + 1
start = base_date.replace(month=start_month, day=1)
return TimeRange(start=start, end=base_date)
elif window == TimeWindow.LAST_QUARTER:
# 上季度
quarter = (base_date.month - 1) // 3
start_month = quarter * 3 + 1
this_quarter_start = base_date.replace(month=start_month, day=1)
end = this_quarter_start - timedelta(days=1)
prev_quarter = (end.month - 1) // 3
prev_start_month = prev_quarter * 3 + 1
start = end.replace(month=prev_start_month, day=1)
return TimeRange(start=start, end=end)
elif window == TimeWindow.LAST_6_MONTHS:
# 最近半年(不含本月)
month_start = self.get_month_first_day(base_date)
end = month_start - timedelta(days=1)
start = self.get_month_first_day(self._shift_months(month_start, -6))
return TimeRange(start=start, end=end)
raise ValueError(f"不支持的时间窗口类型: {window}")
def get_comparison_range(self, time_range: TimeRange) -> TimeRange:
"""
计算环比区间(上一个等长区间)
环比规则:对比上一个等长区间
Args:
time_range: 当前时间范围
Returns:
环比时间范围
"""
duration = (time_range.end - time_range.start).days + 1
prev_end = time_range.start - timedelta(days=1)
prev_start = prev_end - timedelta(days=duration - 1)
return TimeRange(start=prev_start, end=prev_end)
def get_month_first_day(self, dt: date) -> date:
"""获取月第一天"""
return dt.replace(day=1)
def get_month_last_day(self, dt: date) -> date:
"""获取月最后一天"""
last_day = calendar.monthrange(dt.year, dt.month)[1]
return dt.replace(day=last_day)
def _shift_months(self, base_date: date, months: int) -> date:
"""
按月偏移日期(保持日不越界)
"""
total_months = base_date.year * 12 + (base_date.month - 1) + months
year = total_months // 12
month = total_months % 12 + 1
last_day = calendar.monthrange(year, month)[1]
day = min(base_date.day, last_day)
return date(year, month, day)
def is_new_hire_in_month(self, hire_date: date, stat_month: date) -> bool:
"""
判断是否为新入职月1日0点后入职
新入职定档规则月1日0点之后入职的计算为新入职
Args:
hire_date: 入职日期
stat_month: 统计月份(月第一天)
Returns:
是否为新入职
"""
month_start = self.get_month_first_day(stat_month)
return hire_date >= month_start
# ==========================================================================
# 配置表读取方法
# ==========================================================================
def load_config_cache(self, force_reload: bool = False) -> ConfigCache:
"""
加载配置表缓存
Args:
force_reload: 是否强制重新加载
Returns:
ConfigCache对象
"""
now = datetime.now(self.tz)
# 检查缓存是否有效
if (
not force_reload
and self._config_cache is not None
and (now - self._config_cache.loaded_at).total_seconds() < self._config_cache_ttl
):
return self._config_cache
self.logger.debug("重新加载DWS配置表缓存")
# 加载绩效档位配置
performance_tiers = self._load_performance_tiers()
# 加载等级定价配置
level_prices = self._load_level_prices()
# 加载奖金规则配置
bonus_rules = self._load_bonus_rules()
# 加载区域分类映射
area_categories = self._load_area_categories()
# 加载技能类型映射
skill_types = self._load_skill_types()
self._config_cache = ConfigCache(
performance_tiers=performance_tiers,
level_prices=level_prices,
bonus_rules=bonus_rules,
area_categories=area_categories,
skill_types=skill_types,
loaded_at=now
)
return self._config_cache
def _load_performance_tiers(self) -> List[Dict[str, Any]]:
"""
加载绩效档位配置
字段说明来自DWS数据库处理需求.md
- base_deduction: 专业课抽成(元/小时),球房从基础课每小时扣除的金额
- bonus_deduction_ratio: 打赏课抽成比例,球房从附加课收入中扣除的比例
- vacation_days: 次月可休假天数
- vacation_unlimited: 休假自由标记最高档为TRUE
"""
sql = """
SELECT
tier_id, tier_code, tier_name, tier_level,
min_hours, max_hours,
base_deduction, bonus_deduction_ratio,
vacation_days, vacation_unlimited,
is_new_hire_tier, effective_from, effective_to
FROM billiards_dws.cfg_performance_tier
ORDER BY tier_level ASC, effective_from ASC
"""
rows = self.db.query(sql)
return [dict(row) for row in rows] if rows else []
def _load_level_prices(self) -> List[Dict[str, Any]]:
"""加载等级定价配置"""
sql = """
SELECT
price_id, level_code, level_name,
base_course_price, bonus_course_price,
effective_from, effective_to
FROM billiards_dws.cfg_assistant_level_price
ORDER BY level_code ASC, effective_from DESC
"""
rows = self.db.query(sql)
return [dict(row) for row in rows] if rows else []
def _load_bonus_rules(self) -> List[Dict[str, Any]]:
"""加载奖金规则配置"""
sql = """
SELECT
rule_id, rule_type, rule_code, rule_name,
threshold_hours, rank_position, bonus_amount,
is_cumulative, priority,
effective_from, effective_to
FROM billiards_dws.cfg_bonus_rules
ORDER BY rule_type, priority DESC, effective_from DESC
"""
rows = self.db.query(sql)
return [dict(row) for row in rows] if rows else []
def _load_area_categories(self) -> Dict[str, Dict[str, Any]]:
"""加载区域分类映射"""
sql = """
SELECT
source_area_name, category_code, category_name,
match_type, match_priority
FROM billiards_dws.cfg_area_category
WHERE is_active = TRUE
ORDER BY match_priority ASC
"""
rows = self.db.query(sql)
if not rows:
return {}
result = {}
for row in rows:
row_dict = dict(row)
result[row_dict['source_area_name']] = row_dict
return result
def _load_skill_types(self) -> Dict[int, Dict[str, Any]]:
"""加载技能类型映射"""
sql = """
SELECT
skill_id, skill_name,
course_type_code, course_type_name
FROM billiards_dws.cfg_skill_type
WHERE is_active = TRUE
"""
rows = self.db.query(sql)
if not rows:
return {}
result = {}
for row in rows:
row_dict = dict(row)
result[int(row_dict['skill_id'])] = row_dict
return result
# ==========================================================================
# 配置应用方法
# ==========================================================================
def _filter_by_effective_date(
self,
items: List[Dict[str, Any]],
effective_date: Optional[date]
) -> List[Dict[str, Any]]:
"""
按生效期过滤配置项
"""
ref_date = effective_date or date.today()
results: List[Dict[str, Any]] = []
for item in items:
eff_from = item.get('effective_from')
eff_to = item.get('effective_to')
if eff_from and ref_date < eff_from:
continue
if eff_to and ref_date > eff_to:
continue
results.append(item)
return results
def get_performance_tier(
self,
effective_hours: Decimal,
is_new_hire: bool,
effective_date: Optional[date] = None,
max_tier_level: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
根据有效业绩小时数匹配绩效档位
Args:
effective_hours: 有效业绩小时数
is_new_hire: 是否为新入职
effective_date: 生效日期(用于历史月份)
Returns:
匹配的档位配置如果没有匹配则返回None
"""
_ = is_new_hire # 保留参数以兼容调用方,新入职封顶逻辑在月度任务中处理
config = self.load_config_cache()
tiers = self._filter_by_effective_date(config.performance_tiers, effective_date)
if max_tier_level is not None:
tiers = [
t for t in tiers
if t.get('tier_level') is None or int(t.get('tier_level')) <= max_tier_level
]
# 按阈值匹配档位
for tier in tiers:
if tier.get('is_new_hire_tier'):
continue
min_hours = Decimal(str(tier.get('min_hours', 0)))
max_hours = tier.get('max_hours')
if max_hours is not None:
max_hours = Decimal(str(max_hours))
if effective_hours >= min_hours:
if max_hours is None or effective_hours < max_hours:
return tier
return None
def get_performance_tier_by_id(
self,
tier_id: Optional[int],
effective_date: Optional[date] = None
) -> Optional[Dict[str, Any]]:
"""
通过档位ID获取配置支持生效期筛选
"""
if not tier_id:
return None
config = self.load_config_cache()
tiers = self._filter_by_effective_date(config.performance_tiers, effective_date)
for tier in tiers:
if tier.get('tier_id') == tier_id:
return tier
return None
def get_level_price(
self,
level_code: int,
effective_date: Optional[date] = None
) -> Optional[Dict[str, Any]]:
"""
获取助教等级对应的单价SCD2口径按生效日期取值
Args:
level_code: 等级代码
effective_date: 生效日期
Returns:
等级定价配置
"""
config = self.load_config_cache()
prices = self._filter_by_effective_date(config.level_prices, effective_date)
for price in prices:
if price.get('level_code') == level_code:
return price
return None
def get_course_type(self, skill_id: int) -> CourseType:
"""
根据skill_id获取课程类型
Args:
skill_id: 技能ID
Returns:
CourseType枚举
"""
config = self.load_config_cache()
skill_config = config.skill_types.get(skill_id)
if skill_config:
code = skill_config.get('course_type_code', 'BASE')
if code == 'BONUS':
return CourseType.BONUS
if code == 'ROOM':
return CourseType.ROOM
return CourseType.BASE
# 默认为基础课
return CourseType.BASE
def get_area_category(self, area_name: Optional[str]) -> Dict[str, str]:
"""
获取区域分类(支持精确匹配、模糊匹配、兜底)
Args:
area_name: 原始区域名称
Returns:
包含 category_code 和 category_name 的字典
"""
config = self.load_config_cache()
if not area_name:
# 无区域名称,返回默认
return {'category_code': 'OTHER', 'category_name': '其他区域'}
# 1. 精确匹配
if area_name in config.area_categories:
cat = config.area_categories[area_name]
if cat.get('match_type') == 'EXACT':
return {
'category_code': cat['category_code'],
'category_name': cat['category_name']
}
# 2. 模糊匹配(按优先级)
for key, cat in config.area_categories.items():
if cat.get('match_type') == 'LIKE':
pattern = key.replace('%', '')
if pattern and pattern in area_name:
return {
'category_code': cat['category_code'],
'category_name': cat['category_name']
}
# 3. 兜底
if 'DEFAULT' in config.area_categories:
cat = config.area_categories['DEFAULT']
return {
'category_code': cat['category_code'],
'category_name': cat['category_name']
}
return {'category_code': 'OTHER', 'category_name': '其他区域'}
def calculate_sprint_bonus(
self,
effective_hours: Decimal,
effective_date: Optional[date] = None
) -> Decimal:
"""
计算冲刺奖金(不累计,取最高档)
冲刺奖金规则:
- 按 cfg_bonus_rules 配置(可为历史口径)
- 不累计,取最高档
Args:
effective_hours: 有效业绩小时数
effective_date: 生效日期
Returns:
冲刺奖金金额
"""
config = self.load_config_cache()
bonus_rules = self._filter_by_effective_date(config.bonus_rules, effective_date)
# 筛选冲刺奖金规则,按优先级降序
sprint_rules = [
r for r in bonus_rules
if r.get('rule_type') == 'SPRINT'
]
sprint_rules.sort(key=lambda x: x.get('priority', 0), reverse=True)
# 取满足条件的最高档
for rule in sprint_rules:
threshold = Decimal(str(rule.get('threshold_hours', 0)))
if effective_hours >= threshold:
return Decimal(str(rule.get('bonus_amount', 0)))
return Decimal('0')
def calculate_top_rank_bonus(
self,
rank: int,
effective_date: Optional[date] = None
) -> Decimal:
"""
计算Top排名奖金
Top3奖金规则
- 第1名: 1000元
- 第2名: 600元
- 第3名: 400元
- 并列都算
Args:
rank: 排名(考虑并列后的排名)
effective_date: 生效日期
Returns:
排名奖金金额
"""
config = self.load_config_cache()
bonus_rules = self._filter_by_effective_date(config.bonus_rules, effective_date)
if rank > 3:
return Decimal('0')
for rule in bonus_rules:
if rule.get('rule_type') == 'TOP_RANK':
if rule.get('rank_position') == rank:
return Decimal(str(rule.get('bonus_amount', 0)))
return Decimal('0')
# ==========================================================================
# DWD数据读取方法
# ==========================================================================
def iter_dwd_rows(
self,
table_name: str,
columns: List[str],
start_date: date,
end_date: date,
date_col: str = "created_at",
where_clause: str = "",
order_by: str = "",
batch_size: int = 1000
) -> Iterator[List[Dict[str, Any]]]:
"""
分批迭代读取DWD表数据
Args:
table_name: DWD表名不含schema
columns: 需要查询的字段列表
start_date: 开始日期(含)
end_date: 结束日期(含)
date_col: 日期过滤字段
where_clause: 额外的WHERE条件不含WHERE关键字
order_by: 排序字段不含ORDER BY关键字
batch_size: 批次大小
Yields:
每批次的数据行列表
"""
offset = 0
cols_str = ", ".join(columns)
# 构建WHERE条件
where_parts = [f"DATE({date_col}) >= %s", f"DATE({date_col}) <= %s"]
params: List[Any] = [start_date, end_date]
if where_clause:
where_parts.append(f"({where_clause})")
where_str = " AND ".join(where_parts)
# 构建ORDER BY
order_str = f"ORDER BY {order_by}" if order_by else f"ORDER BY {date_col} ASC"
while True:
sql = f"""
SELECT {cols_str}
FROM {self.DWD_SCHEMA}.{table_name}
WHERE {where_str}
{order_str}
LIMIT %s OFFSET %s
"""
rows = self.db.query(sql, (*params, batch_size, offset))
if not rows:
break
yield [dict(row) for row in rows]
if len(rows) < batch_size:
break
offset += batch_size
def query_dwd(
self,
sql: str,
params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
"""
直接执行DWD查询
Args:
sql: SQL语句
params: 参数元组
Returns:
查询结果列表
"""
rows = self.db.query(sql, params)
return [dict(row) for row in rows] if rows else []
# ==========================================================================
# SCD2维度取值方法
# ==========================================================================
def get_assistant_level_asof(
self,
assistant_id: int,
asof_date: date
) -> Optional[Dict[str, Any]]:
"""
获取助教在指定日期的等级SCD2 as-of取值
助教等级是SCD2维度历史月份不能直接用"当前等级"
需要按有效期as-of join取数。
Args:
assistant_id: 助教ID
asof_date: 取值日期
Returns:
助教等级信息包含level_code和level_name
"""
sql = """
SELECT
assistant_id,
nickname,
level AS level_code,
CASE level
WHEN 8 THEN '助教管理'
WHEN 10 THEN '初级'
WHEN 20 THEN '中级'
WHEN 30 THEN '高级'
WHEN 40 THEN '星级'
ELSE '未知'
END AS level_name,
scd2_start_time,
scd2_end_time
FROM billiards_dwd.dim_assistant
WHERE assistant_id = %s
AND scd2_start_time <= %s
AND (scd2_end_time IS NULL OR scd2_end_time > %s)
ORDER BY scd2_start_time DESC
LIMIT 1
"""
rows = self.db.query(sql, (assistant_id, asof_date, asof_date))
return dict(rows[0]) if rows else None
def get_member_card_balance_asof(
self,
member_id: int,
asof_date: date
) -> Dict[str, Decimal]:
"""
获取会员在指定日期的卡余额SCD2 as-of取值
Args:
member_id: 会员ID
asof_date: 取值日期
Returns:
卡余额字典包含cash_balance和gift_balance
"""
sql = """
SELECT
card_type_id,
balance
FROM billiards_dwd.dim_member_card_account
WHERE tenant_member_id = %s
AND scd2_start_time <= %s
AND (scd2_end_time IS NULL OR scd2_end_time > %s)
AND COALESCE(is_delete, 0) = 0
"""
rows = self.db.query(sql, (member_id, asof_date, asof_date))
# 卡类型ID映射
CASH_CARD_TYPE_ID = 2793249295533893
GIFT_CARD_TYPE_IDS = [
2791990152417157, # 台费卡
2793266846533445, # 活动抵用券
2794699703437125, # 酒水卡
]
cash_balance = Decimal('0')
gift_balance = Decimal('0')
for row in (rows or []):
row_dict = dict(row)
card_type_id = row_dict.get('card_type_id')
balance = Decimal(str(row_dict.get('balance', 0)))
if card_type_id == CASH_CARD_TYPE_ID:
cash_balance += balance
elif card_type_id in GIFT_CARD_TYPE_IDS:
gift_balance += balance
return {
'cash_balance': cash_balance,
'gift_balance': gift_balance,
'total_balance': cash_balance + gift_balance
}
# ==========================================================================
# 幂等更新方法
# ==========================================================================
def delete_existing_data(
self,
context: TaskContext,
date_col: str = "stat_date",
extra_conditions: Optional[Dict[str, Any]] = None
) -> int:
"""
删除已存在的数据(实现幂等更新)
Args:
context: 任务上下文
date_col: 日期字段名
extra_conditions: 额外的删除条件
Returns:
删除的行数
"""
target_table = self.get_target_table()
full_table = f"{self.DWS_SCHEMA}.{target_table}"
# 构建WHERE条件
where_parts = [f"site_id = %s"]
params: List[Any] = [context.store_id]
# 日期范围条件
start_date = context.window_start.date() if hasattr(context.window_start, 'date') else context.window_start
end_date = context.window_end.date() if hasattr(context.window_end, 'date') else context.window_end
where_parts.append(f"{date_col} >= %s")
params.append(start_date)
where_parts.append(f"{date_col} <= %s")
params.append(end_date)
# 额外条件
if extra_conditions:
for col, val in extra_conditions.items():
where_parts.append(f"{col} = %s")
params.append(val)
where_str = " AND ".join(where_parts)
sql = f"DELETE FROM {full_table} WHERE {where_str}"
with self.db.conn.cursor() as cur:
cur.execute(sql, params)
deleted = cur.rowcount
self.logger.debug(
"%s: 删除已存在数据 %d 行,条件: %s",
self.get_task_code(), deleted, where_str
)
return deleted
def bulk_insert(
self,
rows: List[Dict[str, Any]],
columns: Optional[List[str]] = None
) -> int:
"""
批量插入数据
Args:
rows: 数据行列表
columns: 字段列表如果为None则从第一行获取
Returns:
插入的行数
"""
if not rows:
return 0
target_table = self.get_target_table()
full_table = f"{self.DWS_SCHEMA}.{target_table}"
if columns is None:
columns = list(rows[0].keys())
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {full_table} ({cols_str}) VALUES ({placeholders})"
inserted = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [row.get(col) for col in columns]
cur.execute(sql, values)
inserted += cur.rowcount
return inserted
def upsert(
self,
rows: List[Dict[str, Any]],
columns: Optional[List[str]] = None,
update_columns: Optional[List[str]] = None
) -> Tuple[int, int]:
"""
批量upsert插入或更新
Args:
rows: 数据行列表
columns: 全部字段列表
update_columns: 需要更新的字段列表
Returns:
(inserted, updated) 元组
"""
if not rows:
return 0, 0
target_table = self.get_target_table()
full_table = f"{self.DWS_SCHEMA}.{target_table}"
primary_keys = self.get_primary_keys()
if columns is None:
columns = list(rows[0].keys())
if update_columns is None:
update_columns = [c for c in columns if c not in primary_keys and c not in ('created_at',)]
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
conflict_cols = ", ".join(primary_keys)
update_parts = [f"{col} = EXCLUDED.{col}" for col in update_columns]
update_parts.append("updated_at = NOW()")
update_str = ", ".join(update_parts)
sql = f"""
INSERT INTO {full_table} ({cols_str})
VALUES ({placeholders})
ON CONFLICT ({conflict_cols})
DO UPDATE SET {update_str}
"""
inserted = 0
updated = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [row.get(col) for col in columns]
cur.execute(sql, values)
# PostgreSQL的INSERT ON CONFLICT返回1表示有操作
if cur.rowcount > 0:
# 无法精确区分insert和update统计为inserted
inserted += 1
return inserted, updated
# ==========================================================================
# 滚动窗口统计方法
# ==========================================================================
def calculate_rolling_stats(
self,
base_date: date,
entity_id: int,
entity_type: str,
stat_sql: str,
windows: Optional[List[int]] = None
) -> Dict[str, Any]:
"""
计算滚动窗口统计
Args:
base_date: 基准日期
entity_id: 实体ID如assistant_id或member_id
entity_type: 实体类型用于SQL参数名
stat_sql: 统计SQL模板需要包含 {window_days}{entity_id} 占位符
windows: 窗口天数列表,默认为 [7, 10, 15, 30, 60, 90]
Returns:
各窗口的统计结果字典
"""
if windows is None:
windows = self.ROLLING_WINDOWS
results = {}
for days in windows:
start_date = base_date - timedelta(days=days - 1)
# 替换SQL中的占位符
sql = stat_sql.format(
window_days=days,
entity_id=entity_id,
start_date=start_date,
end_date=base_date
)
rows = self.db.query(sql)
if rows:
for key, value in dict(rows[0]).items():
results[f"{key}_{days}d"] = value
return results
# ==========================================================================
# 排名计算方法
# ==========================================================================
def calculate_rank_with_ties(
self,
values: List[Tuple[int, Decimal]]
) -> List[Tuple[int, int, int]]:
"""
计算考虑并列的排名
Top3排名口径按绩效总小时数如遇并列则都算
比如2个第一则记为2个第一一个第三。
Args:
values: (entity_id, score) 元组列表
Returns:
(entity_id, rank, dense_rank) 元组列表
rank: 考虑并列的排名如2个第一下一个是3
dense_rank: 密集排名如2个第一下一个是2
"""
if not values:
return []
# 按分数降序排序
sorted_values = sorted(values, key=lambda x: x[1], reverse=True)
results = []
prev_score = None
prev_rank = 0
count = 0
for entity_id, score in sorted_values:
count += 1
if score != prev_score:
# 新的分数rank为当前计数
current_rank = count
prev_score = score
else:
# 相同分数rank保持不变
current_rank = prev_rank
prev_rank = current_rank
results.append((entity_id, current_rank, count))
return results
# ==========================================================================
# 散客过滤
# ==========================================================================
def is_guest(self, member_id: Optional[int]) -> bool:
"""
判断是否为散客
散客处理member_id=0 的客户是散客,不进入客户维度统计
Args:
member_id: 会员ID
Returns:
是否为散客
"""
return member_id is None or member_id == 0
# ==========================================================================
# 工具方法
# ==========================================================================
def safe_decimal(self, value: Any, default: Decimal = Decimal('0')) -> Decimal:
"""安全转换为Decimal"""
if value is None:
return default
try:
return Decimal(str(value))
except (ValueError, TypeError):
return default
def safe_int(self, value: Any, default: int = 0) -> int:
"""安全转换为int"""
if value is None:
return default
try:
return int(value)
except (ValueError, TypeError):
return default
def seconds_to_hours(self, seconds: int) -> Decimal:
"""秒转换为小时"""
return Decimal(str(seconds)) / Decimal('3600')
def hours_to_seconds(self, hours: Decimal) -> int:
"""小时转换为秒"""
return int(hours * Decimal('3600'))