Files
feiqiu-ETL/etl_billiards/tests/test_dws_tasks.py
2026-02-04 21:39:01 +08:00

473 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
DWS任务单元测试
测试内容:
- BaseDwsTask基类方法
- 时间计算方法
- 配置应用方法
- 排名计算方法
"""
import pytest
from datetime import date, datetime, timedelta
from decimal import Decimal
from unittest.mock import MagicMock, patch
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from etl_billiards.tasks.dws.base_dws_task import (
BaseDwsTask,
TimeLayer,
TimeWindow,
CourseType,
TimeRange,
ConfigCache
)
from etl_billiards.tasks.dws.finance_daily_task import FinanceDailyTask
from etl_billiards.tasks.dws.assistant_monthly_task import AssistantMonthlyTask
class TestTimeLayerRange:
"""测试时间分层范围计算"""
def test_last_2_days(self):
"""测试近2天"""
base_date = date(2026, 2, 1)
# 创建一个模拟的BaseDwsTask实例
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_2_DAYS, base_date)
assert result.start == date(2026, 1, 31)
assert result.end == date(2026, 2, 1)
def test_last_1_month(self):
"""测试近1月"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_1_MONTH, base_date)
assert result.start == date(2026, 1, 2)
assert result.end == date(2026, 2, 1)
def test_last_3_months(self):
"""测试近3月"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_layer_range(TimeLayer.LAST_3_MONTHS, base_date)
assert result.start == date(2025, 11, 3)
assert result.end == date(2026, 2, 1)
class TestTimeWindowRange:
"""测试时间窗口范围计算"""
def test_this_week_monday_start(self):
"""测试本周(周一起始)"""
# 2026-02-01 是周日
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_WEEK, base_date)
# 本周一是 2026-01-26
assert result.start == date(2026, 1, 26)
assert result.end == date(2026, 2, 1)
def test_last_week(self):
"""测试上周"""
base_date = date(2026, 2, 1)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_WEEK, base_date)
# 上周一是 2026-01-19上周日是 2026-01-25
assert result.start == date(2026, 1, 19)
assert result.end == date(2026, 1, 25)
def test_this_month(self):
"""测试本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_MONTH, base_date)
assert result.start == date(2026, 2, 1)
assert result.end == date(2026, 2, 15)
def test_last_month(self):
"""测试上月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_MONTH, base_date)
assert result.start == date(2026, 1, 1)
assert result.end == date(2026, 1, 31)
def test_last_3_months_excl_current(self):
"""测试前3个月不含本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_EXCL_CURRENT, base_date)
assert result.start == date(2025, 11, 1)
assert result.end == date(2026, 1, 31)
def test_last_3_months_incl_current(self):
"""测试前3个月含本月"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_3_MONTHS_INCL_CURRENT, base_date)
assert result.start == date(2025, 12, 1)
assert result.end == date(2026, 2, 15)
def test_this_quarter(self):
"""测试本季度"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.THIS_QUARTER, base_date)
assert result.start == date(2026, 1, 1)
assert result.end == date(2026, 2, 15)
def test_last_6_months(self):
"""测试最近半年(不含本月)"""
base_date = date(2026, 2, 15)
task = create_mock_task()
result = task.get_time_window_range(TimeWindow.LAST_6_MONTHS, base_date)
# 不含本月从上月末往前6个月
assert result.end == date(2026, 1, 31)
assert result.start == date(2025, 8, 1)
class TestComparisonRange:
"""测试环比区间计算"""
def test_comparison_7_days(self):
"""测试7天环比"""
task = create_mock_task()
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 2, 7))
result = task.get_comparison_range(current)
# 上一个7天1月25日-1月31日
assert result.start == date(2026, 1, 25)
assert result.end == date(2026, 1, 31)
def test_comparison_30_days(self):
"""测试30天环比"""
task = create_mock_task()
current = TimeRange(start=date(2026, 2, 1), end=date(2026, 3, 2))
result = task.get_comparison_range(current)
# 上一个30天区间
assert (result.end - result.start).days == (current.end - current.start).days
class TestFinanceDailyRecord:
"""测试财务日度记录计算"""
def test_groupbuy_and_cashflow(self):
"""测试团购优惠与现金流口径"""
task = create_finance_daily_task()
stat_date = date(2026, 2, 1)
settle = {
'gross_amount': Decimal('1000'),
'table_fee_amount': Decimal('1000'),
'goods_amount': Decimal('0'),
'assistant_pd_amount': Decimal('0'),
'assistant_cx_amount': Decimal('0'),
'cash_pay_amount': Decimal('300'),
'card_pay_amount': Decimal('0'),
'balance_pay_amount': Decimal('0'),
'gift_card_pay_amount': Decimal('0'),
'coupon_amount': Decimal('200'),
'pl_coupon_sale_amount': Decimal('0'),
'adjust_amount': Decimal('50'),
'member_discount_amount': Decimal('10'),
'rounding_amount': Decimal('0'),
'order_count': 1,
'member_order_count': 1,
'guest_order_count': 0,
}
groupbuy = {'groupbuy_pay_total': Decimal('80')}
recharge = {'recharge_cash': Decimal('20')}
expense = {'expense_amount': Decimal('40')}
platform = {
'settlement_amount': Decimal('60'),
'commission_amount': Decimal('5'),
'service_fee': Decimal('5'),
}
big_customer = {'big_customer_amount': Decimal('20')}
record = task._build_daily_record(
stat_date, settle, groupbuy, recharge, expense, platform, big_customer, 1
)
assert record['discount_groupbuy'] == Decimal('120')
assert record['discount_other'] == Decimal('30')
assert record['platform_settlement_amount'] == Decimal('60')
assert record['platform_fee_amount'] == Decimal('10')
assert record['cash_inflow_total'] == Decimal('380')
assert record['cash_outflow_total'] == Decimal('50')
assert record['cash_balance_change'] == Decimal('330')
class TestNewHireTier:
"""测试新入职定档规则"""
def test_new_hire_tier_hours(self):
"""测试日均*30折算"""
task = create_assistant_monthly_task()
effective_hours = Decimal('15')
work_days = 5
result = task._calc_new_hire_tier_hours(effective_hours, work_days)
assert result == Decimal('90')
def test_max_tier_level_cap(self):
"""测试新入职定档上限"""
task = create_mock_task()
now = datetime.now()
task._config_cache = ConfigCache(
performance_tiers=[
{'tier_id': 1, 'tier_level': 1, 'min_hours': 0, 'max_hours': 100, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 2, 'tier_level': 2, 'min_hours': 100, 'max_hours': 200, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 3, 'tier_level': 3, 'min_hours': 200, 'max_hours': 300, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
{'tier_id': 4, 'tier_level': 4, 'min_hours': 300, 'max_hours': None, 'is_new_hire_tier': False, 'effective_from': date(2020, 1, 1), 'effective_to': date(2099, 1, 1)},
],
level_prices=[],
bonus_rules=[],
area_categories={},
skill_types={},
loaded_at=now
)
tier = task.get_performance_tier(
Decimal('350'),
is_new_hire=True,
effective_date=date(2026, 2, 1),
max_tier_level=3
)
assert tier['tier_level'] == 3
class TestNewHireCheck:
"""测试新入职判断"""
def test_new_hire_in_month(self):
"""测试月内入职为新入职"""
task = create_mock_task()
hire_date = date(2026, 2, 5)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == True
def test_not_new_hire(self):
"""测试月前入职不是新入职"""
task = create_mock_task()
hire_date = date(2026, 1, 15)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == False
def test_hire_on_first_day(self):
"""测试月1日入职为新入职"""
task = create_mock_task()
hire_date = date(2026, 2, 1)
stat_month = date(2026, 2, 1)
assert task.is_new_hire_in_month(hire_date, stat_month) == True
class TestRankWithTies:
"""测试考虑并列的排名计算"""
def test_no_ties(self):
"""测试无并列情况"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('90')),
(3, Decimal('80')),
]
result = task.calculate_rank_with_ties(values)
assert result[0] == (1, 1, 1) # 第1名
assert result[1] == (2, 2, 2) # 第2名
assert result[2] == (3, 3, 3) # 第3名
def test_with_ties(self):
"""测试有并列情况"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('100')), # 并列第1
(3, Decimal('80')),
]
result = task.calculate_rank_with_ties(values)
# 两个第1下一个是第3
assert result[0][1] == 1 # 第1名
assert result[1][1] == 1 # 并列第1名
assert result[2][1] == 3 # 第3名跳过2
def test_all_ties(self):
"""测试全部并列"""
task = create_mock_task()
values = [
(1, Decimal('100')),
(2, Decimal('100')),
(3, Decimal('100')),
]
result = task.calculate_rank_with_ties(values)
# 全部第1
assert all(r[1] == 1 for r in result)
class TestGuestCheck:
"""测试散客判断"""
def test_guest_zero(self):
"""测试member_id=0为散客"""
task = create_mock_task()
assert task.is_guest(0) == True
def test_guest_none(self):
"""测试member_id=None为散客"""
task = create_mock_task()
assert task.is_guest(None) == True
def test_not_guest(self):
"""测试正常会员不是散客"""
task = create_mock_task()
assert task.is_guest(12345) == False
class TestUtilityMethods:
"""测试工具方法"""
def test_safe_decimal(self):
"""测试安全Decimal转换"""
task = create_mock_task()
assert task.safe_decimal(100) == Decimal('100')
assert task.safe_decimal('123.45') == Decimal('123.45')
assert task.safe_decimal(None) == Decimal('0')
assert task.safe_decimal('invalid') == Decimal('0')
def test_safe_int(self):
"""测试安全int转换"""
task = create_mock_task()
assert task.safe_int(100) == 100
assert task.safe_int('123') == 123
assert task.safe_int(None) == 0
assert task.safe_int('invalid') == 0
def test_seconds_to_hours(self):
"""测试秒转小时"""
task = create_mock_task()
assert task.seconds_to_hours(3600) == Decimal('1')
assert task.seconds_to_hours(5400) == Decimal('1.5')
assert task.seconds_to_hours(0) == Decimal('0')
def test_hours_to_seconds(self):
"""测试小时转秒"""
task = create_mock_task()
assert task.hours_to_seconds(Decimal('1')) == 3600
assert task.hours_to_seconds(Decimal('1.5')) == 5400
class TestCourseType:
"""测试课程类型"""
def test_base_course(self):
"""测试基础课"""
assert CourseType.BASE.value == 'BASE'
def test_bonus_course(self):
"""测试附加课"""
assert CourseType.BONUS.value == 'BONUS'
# =============================================================================
# 辅助函数
# =============================================================================
def create_mock_task():
"""
创建一个模拟的BaseDwsTask实例用于测试
"""
# 创建一个具体的子类用于测试
class TestDwsTask(BaseDwsTask):
def get_task_code(self):
return "TEST_DWS_TASK"
def get_target_table(self):
return "test_table"
def get_primary_keys(self):
return ["id"]
def extract(self, context):
return {}
def load(self, transformed, context):
return {}
# 创建模拟的依赖
mock_config = MagicMock()
mock_config.get.return_value = None
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
task = TestDwsTask(mock_config, mock_db, mock_api, mock_logger)
return task
def create_finance_daily_task():
"""创建 FinanceDailyTask 实例用于测试"""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: 1 if key == "app.tenant_id" else default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
return FinanceDailyTask(mock_config, mock_db, mock_api, mock_logger)
def create_assistant_monthly_task():
"""创建 AssistantMonthlyTask 实例用于测试"""
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: default
mock_db = MagicMock()
mock_api = MagicMock()
mock_logger = MagicMock()
return AssistantMonthlyTask(mock_config, mock_db, mock_api, mock_logger)
if __name__ == "__main__":
pytest.main([__file__, "-v"])