在前后端开发联调前 的提交20260223
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""BaseTask._accumulate_counts() 防御层单元测试。
|
||||
|
||||
验证需求 1.2:list 类型值转为 len() 后累加。
|
||||
纯单元测试,不涉及数据库或外部依赖。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from tasks.base_task import BaseTask
|
||||
|
||||
|
||||
class TestAccumulateCountsListDefense:
|
||||
"""验证 _accumulate_counts 对 list 类型值的防御处理。"""
|
||||
|
||||
def test_list_value_converted_to_len(self):
|
||||
"""list 值应转为 len() 后累加。"""
|
||||
total = {}
|
||||
current = {"errors": [{"table": "dim_a", "error": "boom"}]}
|
||||
result = BaseTask._accumulate_counts(total, current)
|
||||
assert result["errors"] == 1
|
||||
|
||||
def test_list_value_accumulates_across_calls(self):
|
||||
"""多次调用时,list 的 len() 应正确累加。"""
|
||||
total = {}
|
||||
BaseTask._accumulate_counts(total, {"errors": [{"t": "a"}]})
|
||||
BaseTask._accumulate_counts(total, {"errors": [{"t": "b"}, {"t": "c"}]})
|
||||
assert total["errors"] == 3
|
||||
|
||||
def test_list_and_int_mixed_accumulation(self):
|
||||
"""先累加 int,再累加 list(或反过来),结果应一致。"""
|
||||
total = {}
|
||||
BaseTask._accumulate_counts(total, {"errors": 2})
|
||||
BaseTask._accumulate_counts(total, {"errors": [{"t": "x"}]})
|
||||
assert total["errors"] == 3
|
||||
|
||||
def test_empty_list_adds_zero(self):
|
||||
"""空 list 应累加 0。"""
|
||||
total = {"errors": 5}
|
||||
BaseTask._accumulate_counts(total, {"errors": []})
|
||||
assert total["errors"] == 5
|
||||
|
||||
def test_int_float_still_work(self):
|
||||
"""int/float 类型的原有行为不受影响。"""
|
||||
total = {}
|
||||
BaseTask._accumulate_counts(total, {"inserted": 10, "rate": 0.5})
|
||||
BaseTask._accumulate_counts(total, {"inserted": 3, "rate": 0.2})
|
||||
assert total["inserted"] == 13
|
||||
assert abs(total["rate"] - 0.7) < 1e-9
|
||||
|
||||
def test_non_numeric_non_list_uses_setdefault(self):
|
||||
"""非数值、非 list 类型使用 setdefault 保留首次值。"""
|
||||
total = {}
|
||||
BaseTask._accumulate_counts(total, {"status": "ok"})
|
||||
BaseTask._accumulate_counts(total, {"status": "fail"})
|
||||
assert total["status"] == "ok"
|
||||
|
||||
def test_none_current_is_safe(self):
|
||||
"""current 为 None 时不报错。"""
|
||||
total = {"x": 1}
|
||||
result = BaseTask._accumulate_counts(total, None)
|
||||
assert result == {"x": 1}
|
||||
|
||||
def test_empty_current_is_safe(self):
|
||||
"""current 为空字典时不报错。"""
|
||||
total = {"x": 1}
|
||||
result = BaseTask._accumulate_counts(total, {})
|
||||
assert result == {"x": 1}
|
||||
@@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""属性测试:废除聚合逻辑正确性(is_trash 驱动)
|
||||
|
||||
纯单元测试,通过 stub 隔离外部依赖(DB/API/Config),
|
||||
验证 _aggregate_by_assistant_date 中 is_trash 分流逻辑的正确性。
|
||||
|
||||
# Feature: assistant-abolish-cleanup, Property 1: 废除聚合逻辑正确性
|
||||
# **Validates: Requirements 3.4, 9.3**
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from tasks.dws.assistant_daily_task import AssistantDailyTask
|
||||
|
||||
|
||||
# ── Stub 构造 ──────────────────────────────────────────────────────────
|
||||
|
||||
def _make_stub_task() -> AssistantDailyTask:
|
||||
"""构造一个最小化的 AssistantDailyTask 实例,stub 掉所有外部依赖。"""
|
||||
# config.get 需要区分 key:timezone 返回字符串,其余返回整数
|
||||
def _config_get(key, default=None):
|
||||
if key == "app.timezone":
|
||||
return "Asia/Shanghai"
|
||||
if key == "app.tenant_id":
|
||||
return 1
|
||||
if key == "app.store_id":
|
||||
return 1
|
||||
return default
|
||||
|
||||
config = MagicMock()
|
||||
config.get.side_effect = _config_get
|
||||
|
||||
db = MagicMock()
|
||||
api = MagicMock()
|
||||
logger = MagicMock()
|
||||
|
||||
task = AssistantDailyTask(config, db, api, logger)
|
||||
|
||||
# stub SCD2 等级查询——返回固定等级
|
||||
task.get_assistant_level_asof = MagicMock(
|
||||
return_value={"level_code": 10, "level_name": "初级"}
|
||||
)
|
||||
|
||||
# stub 课程类型查询——所有 skill 都返回 BASE
|
||||
from tasks.dws.base_dws_task import CourseType, ConfigCache
|
||||
task.get_course_type = MagicMock(return_value=CourseType.BASE)
|
||||
|
||||
# stub load_config_cache(get_course_type 已被 mock,此处仅防止真实 DB 调用)
|
||||
task.load_config_cache = MagicMock(return_value=ConfigCache(
|
||||
performance_tiers=[],
|
||||
level_prices=[],
|
||||
bonus_rules=[],
|
||||
area_categories={},
|
||||
skill_types={},
|
||||
loaded_at=None,
|
||||
))
|
||||
|
||||
return task
|
||||
|
||||
|
||||
# ── Hypothesis 策略 ────────────────────────────────────────────────────
|
||||
|
||||
# 固定的助教 ID 和服务日期,确保所有记录聚合到同一行
|
||||
_FIXED_ASSISTANT_ID = 10001
|
||||
_FIXED_SERVICE_DATE = date(2025, 6, 15)
|
||||
_FIXED_SITE_ID = 1
|
||||
|
||||
# 单条服务记录策略
|
||||
_service_record_st = st.fixed_dictionaries({
|
||||
"assistant_id": st.just(_FIXED_ASSISTANT_ID),
|
||||
"service_date": st.just(_FIXED_SERVICE_DATE),
|
||||
"assistant_nickname": st.just("测试助教"),
|
||||
"assistant_level": st.just(10),
|
||||
"assistant_service_id": st.integers(min_value=1, max_value=999999),
|
||||
"skill_id": st.just(100),
|
||||
"member_id": st.integers(min_value=1, max_value=9999),
|
||||
"table_id": st.integers(min_value=1, max_value=50),
|
||||
"income_seconds": st.integers(min_value=0, max_value=36000),
|
||||
"ledger_amount": st.integers(min_value=0, max_value=10000),
|
||||
"is_trash": st.integers(min_value=0, max_value=1),
|
||||
})
|
||||
|
||||
# 服务记录列表策略(1~30 条)
|
||||
_service_records_st = st.lists(_service_record_st, min_size=1, max_size=30)
|
||||
|
||||
|
||||
# ── 属性测试 ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestProperty1TrashAggregation:
|
||||
"""Property 1: 废除聚合逻辑正确性(is_trash 驱动)
|
||||
|
||||
**Validates: Requirements 3.4, 9.3**
|
||||
"""
|
||||
|
||||
@given(records=_service_records_st)
|
||||
@settings(max_examples=200)
|
||||
def test_trashed_seconds_equals_sum_of_trash_income(self, records: List[Dict[str, Any]]):
|
||||
"""trashed_seconds 应等于所有 is_trash=1 记录的 income_seconds 之和"""
|
||||
task = _make_stub_task()
|
||||
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
|
||||
|
||||
# 手动计算期望值
|
||||
expected = sum(
|
||||
task.safe_int(r.get("income_seconds", 0))
|
||||
for r in records
|
||||
if bool(r.get("is_trash", 0))
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["trashed_seconds"] == expected
|
||||
|
||||
@given(records=_service_records_st)
|
||||
@settings(max_examples=200)
|
||||
def test_trashed_count_equals_trash_record_count(self, records: List[Dict[str, Any]]):
|
||||
"""trashed_count 应等于所有 is_trash=1 记录的数量"""
|
||||
task = _make_stub_task()
|
||||
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
|
||||
|
||||
expected = sum(1 for r in records if bool(r.get("is_trash", 0)))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["trashed_count"] == expected
|
||||
|
||||
@given(records=_service_records_st)
|
||||
@settings(max_examples=200)
|
||||
def test_total_service_count_equals_non_trash_count(self, records: List[Dict[str, Any]]):
|
||||
"""total_service_count 应等于所有 is_trash=0 记录的数量"""
|
||||
task = _make_stub_task()
|
||||
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
|
||||
|
||||
expected = sum(1 for r in records if not bool(r.get("is_trash", 0)))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["total_service_count"] == expected
|
||||
|
||||
@given(records=_service_records_st)
|
||||
@settings(max_examples=200)
|
||||
def test_total_seconds_equals_sum_of_non_trash_income(self, records: List[Dict[str, Any]]):
|
||||
"""total_seconds 应等于所有 is_trash=0 记录的 income_seconds 之和"""
|
||||
task = _make_stub_task()
|
||||
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
|
||||
|
||||
expected = sum(
|
||||
task.safe_int(r.get("income_seconds", 0))
|
||||
for r in records
|
||||
if not bool(r.get("is_trash", 0))
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["total_seconds"] == expected
|
||||
1204
apps/etl/connectors/feiqiu/tests/unit/test_birthday_properties.py
Normal file
1204
apps/etl/connectors/feiqiu/tests/unit/test_birthday_properties.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
"""CLI 参数解析单元测试
|
||||
|
||||
验证 --data-source 新参数、--pipeline-flow 弃用映射、
|
||||
--pipeline + --tasks 同时使用、以及 build_cli_overrides 集成行为。
|
||||
以及 build_cli_overrides 集成行为。
|
||||
|
||||
需求: 3.1, 3.3, 3.5
|
||||
"""
|
||||
@@ -80,7 +80,7 @@ class TestBuildCliOverrides:
|
||||
"""构造最小 Namespace,未指定的参数设为 None/False"""
|
||||
defaults = dict(
|
||||
store_id=None, tasks=None, dry_run=False,
|
||||
flow=None, pipeline_deprecated=None,
|
||||
flow=None,
|
||||
processing_mode="increment_only",
|
||||
fetch_before_verify=False, verify_tables=None,
|
||||
window_split="none", lookback_hours=24, overlap_seconds=3600,
|
||||
@@ -96,6 +96,7 @@ class TestBuildCliOverrides:
|
||||
data_source=None, pipeline_flow=None,
|
||||
fetch_root=None, ingest_source=None, write_pretty_json=False,
|
||||
idle_start=None, idle_end=None, allow_empty_advance=False,
|
||||
force_full=False,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return Namespace(**defaults)
|
||||
@@ -121,19 +122,3 @@ class TestBuildCliOverrides:
|
||||
assert overrides["run"]["data_source"] == "hybrid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. --pipeline + --tasks 同时使用
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPipelineAndTasks:
|
||||
"""--pipeline + --tasks 同时使用时的行为"""
|
||||
|
||||
def test_pipeline_and_tasks_both_parsed(self):
|
||||
"""--pipeline(弃用别名)和 --tasks 可同时解析"""
|
||||
with patch("sys.argv", [
|
||||
"cli",
|
||||
"--pipeline", "api_full",
|
||||
"--tasks", "ODS_MEMBER,ODS_ORDER",
|
||||
]):
|
||||
args = parse_args()
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
assert args.tasks == "ODS_MEMBER,ODS_ORDER"
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据一致性检查器单元测试
|
||||
|
||||
测试 quality/consistency_checker.py 的核心纯函数逻辑,
|
||||
不依赖数据库连接。
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from quality.consistency_checker import (
|
||||
FieldCheckResult,
|
||||
TableCheckResult,
|
||||
ConsistencyReport,
|
||||
check_api_vs_ods_fields,
|
||||
check_ods_vs_dwd_mappings,
|
||||
extract_api_fields_from_json,
|
||||
generate_markdown_report,
|
||||
_validate_ods_expression,
|
||||
_extract_records,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_api_fields_from_json
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractApiFields:
|
||||
"""API JSON 字段提取"""
|
||||
|
||||
def test_flat_list(self, tmp_path: Path):
|
||||
"""直接列表格式"""
|
||||
data = [{"id": 1, "name": "test", "amount": 100}]
|
||||
p = tmp_path / "test.json"
|
||||
p.write_text(json.dumps(data), encoding="utf-8")
|
||||
fields = extract_api_fields_from_json(p)
|
||||
assert fields == {"id", "name", "amount"}
|
||||
|
||||
def test_nested_data_key(self, tmp_path: Path):
|
||||
"""{"data": [...]} 格式"""
|
||||
data = {"data": [{"id": 1, "foo": "bar"}]}
|
||||
p = tmp_path / "test.json"
|
||||
p.write_text(json.dumps(data), encoding="utf-8")
|
||||
fields = extract_api_fields_from_json(p)
|
||||
assert fields == {"id", "foo"}
|
||||
|
||||
def test_nested_data_with_list_key(self, tmp_path: Path):
|
||||
"""{"data": {"settleList": [...]}} 格式"""
|
||||
data = {"data": {"settleList": [{"id": 1, "amount": 50}]}}
|
||||
p = tmp_path / "test.json"
|
||||
p.write_text(json.dumps(data), encoding="utf-8")
|
||||
fields = extract_api_fields_from_json(p)
|
||||
assert fields == {"id", "amount"}
|
||||
|
||||
def test_nonexistent_file(self, tmp_path: Path):
|
||||
"""文件不存在返回 None"""
|
||||
p = tmp_path / "nonexistent.json"
|
||||
assert extract_api_fields_from_json(p) is None
|
||||
|
||||
def test_invalid_json(self, tmp_path: Path):
|
||||
"""无效 JSON 返回 None"""
|
||||
p = tmp_path / "bad.json"
|
||||
p.write_text("not json", encoding="utf-8")
|
||||
assert extract_api_fields_from_json(p) is None
|
||||
|
||||
def test_empty_list(self, tmp_path: Path):
|
||||
"""空列表返回 None"""
|
||||
p = tmp_path / "empty.json"
|
||||
p.write_text("[]", encoding="utf-8")
|
||||
assert extract_api_fields_from_json(p) is None
|
||||
|
||||
def test_merges_multiple_records(self, tmp_path: Path):
|
||||
"""合并多条记录的字段"""
|
||||
data = [{"id": 1, "a": 1}, {"id": 2, "b": 2}]
|
||||
p = tmp_path / "test.json"
|
||||
p.write_text(json.dumps(data), encoding="utf-8")
|
||||
fields = extract_api_fields_from_json(p)
|
||||
assert fields == {"id", "a", "b"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_api_vs_ods_fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckApiVsOds:
|
||||
"""API vs ODS 字段完整性检查"""
|
||||
|
||||
def test_all_fields_present(self):
|
||||
"""所有 API 字段在 ODS 中都存在"""
|
||||
api = {"id", "name", "amount"}
|
||||
ods = {"id", "name", "amount", "fetched_at", "content_hash"}
|
||||
result = check_api_vs_ods_fields(api, ods)
|
||||
assert result.passed is True
|
||||
assert result.missing_fields == 0
|
||||
assert result.passed_fields == 3
|
||||
|
||||
def test_missing_fields(self):
|
||||
"""部分 API 字段在 ODS 中缺失"""
|
||||
api = {"id", "name", "extra_field"}
|
||||
ods = {"id", "name", "fetched_at"}
|
||||
result = check_api_vs_ods_fields(api, ods)
|
||||
assert result.passed is False
|
||||
assert result.missing_fields == 1
|
||||
missing = [f for f in result.field_results if f.status == "missing"]
|
||||
assert len(missing) == 1
|
||||
assert missing[0].field_name == "extra_field"
|
||||
|
||||
def test_ods_meta_columns_excluded(self):
|
||||
"""ODS 元数据列不参与对比"""
|
||||
api = {"id"}
|
||||
ods = {"id", "payload", "fetched_at", "content_hash", "source_file"}
|
||||
result = check_api_vs_ods_fields(api, ods)
|
||||
assert result.passed is True
|
||||
assert result.total_fields == 1
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""大小写不敏感匹配"""
|
||||
api = {"Id", "Name"}
|
||||
ods = {"id", "name", "fetched_at"}
|
||||
result = check_api_vs_ods_fields(api, ods)
|
||||
assert result.passed is True
|
||||
|
||||
def test_empty_api_fields(self):
|
||||
"""空 API 字段集"""
|
||||
result = check_api_vs_ods_fields(set(), {"id", "fetched_at"})
|
||||
assert result.passed is True
|
||||
assert result.total_fields == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_ods_vs_dwd_mappings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckOdsVsDwd:
|
||||
"""ODS vs DWD 映射正确性检查"""
|
||||
|
||||
def test_all_mapped_explicitly(self):
|
||||
"""所有 DWD 列都有显式映射"""
|
||||
dwd_cols = {"order_id", "amount", "scd2_is_current", "scd2_version"}
|
||||
ods_cols = {"id", "total_amount"}
|
||||
mappings = [
|
||||
("order_id", "id", None),
|
||||
("amount", "total_amount", None),
|
||||
]
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test_table", "ods.test_table",
|
||||
dwd_cols, ods_cols, mappings,
|
||||
)
|
||||
assert result.passed is True
|
||||
# SCD2 列不参与检查
|
||||
assert result.total_fields == 2
|
||||
|
||||
def test_auto_mapping(self):
|
||||
"""同名列自动映射"""
|
||||
dwd_cols = {"id", "name", "amount"}
|
||||
ods_cols = {"id", "name", "amount"}
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, None,
|
||||
)
|
||||
assert result.passed is True
|
||||
assert result.passed_fields == 3
|
||||
|
||||
def test_missing_mapping(self):
|
||||
"""DWD 列无映射源"""
|
||||
dwd_cols = {"id", "orphan_col"}
|
||||
ods_cols = {"id"}
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, None,
|
||||
)
|
||||
assert result.passed is False
|
||||
assert result.missing_fields == 1
|
||||
missing = [f for f in result.field_results if f.status == "missing"]
|
||||
assert missing[0].field_name == "orphan_col"
|
||||
|
||||
def test_json_expression_mapping(self):
|
||||
"""JSON 路径表达式映射"""
|
||||
dwd_cols = {"shop_name"}
|
||||
ods_cols = {"siteprofile"}
|
||||
mappings = [("shop_name", "siteprofile->>'shop_name'", None)]
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, mappings,
|
||||
)
|
||||
assert result.passed is True
|
||||
|
||||
def test_sql_expression_mapping(self):
|
||||
"""SQL 表达式映射(CASE WHEN 等)"""
|
||||
dwd_cols = {"is_leaf"}
|
||||
ods_cols = {"categoryboxes"}
|
||||
mappings = [
|
||||
("is_leaf",
|
||||
"CASE WHEN categoryboxes IS NULL OR jsonb_array_length(categoryboxes)=0 THEN 1 ELSE 0 END",
|
||||
None),
|
||||
]
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, mappings,
|
||||
)
|
||||
assert result.passed is True
|
||||
|
||||
def test_scd2_columns_excluded(self):
|
||||
"""SCD2 列不参与检查"""
|
||||
dwd_cols = {"id", "scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"}
|
||||
ods_cols = {"id"}
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, None,
|
||||
)
|
||||
assert result.total_fields == 1
|
||||
assert result.passed is True
|
||||
|
||||
def test_invalid_ods_reference(self):
|
||||
"""显式映射引用了不存在的 ODS 列"""
|
||||
dwd_cols = {"bad_col"}
|
||||
ods_cols = {"id"}
|
||||
mappings = [("bad_col", "nonexistent_col", None)]
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, mappings,
|
||||
)
|
||||
assert result.passed is False
|
||||
assert result.mismatch_fields == 1
|
||||
|
||||
def test_null_expression(self):
|
||||
"""NULL 字面量映射"""
|
||||
dwd_cols = {"null_col"}
|
||||
ods_cols = {"id"}
|
||||
mappings = [("null_col", "NULL", None)]
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, mappings,
|
||||
)
|
||||
assert result.passed is True
|
||||
|
||||
def test_mixed_auto_and_explicit(self):
|
||||
"""混合自动映射和显式映射"""
|
||||
dwd_cols = {"id", "name", "mapped_col", "scd2_is_current"}
|
||||
ods_cols = {"id", "name", "source_col"}
|
||||
mappings = [("mapped_col", "source_col", None)]
|
||||
result = check_ods_vs_dwd_mappings(
|
||||
"dwd.test", "ods.test",
|
||||
dwd_cols, ods_cols, mappings,
|
||||
)
|
||||
assert result.passed is True
|
||||
assert result.passed_fields == 3 # id, name, mapped_col
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_ods_expression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateOdsExpression:
|
||||
"""ODS 表达式验证"""
|
||||
|
||||
def test_simple_column(self):
|
||||
assert _validate_ods_expression("name", {"name", "id"}) is True
|
||||
|
||||
def test_missing_column(self):
|
||||
assert _validate_ods_expression("missing", {"name", "id"}) is False
|
||||
|
||||
def test_quoted_column(self):
|
||||
assert _validate_ods_expression('"siteGoodsId"', {"sitegoodsid"}) is True
|
||||
|
||||
def test_json_path(self):
|
||||
assert _validate_ods_expression("siteprofile->>'shop_name'", {"siteprofile"}) is True
|
||||
|
||||
def test_json_path_missing_base(self):
|
||||
assert _validate_ods_expression("missing_col->>'field'", {"siteprofile"}) is False
|
||||
|
||||
def test_null_literal(self):
|
||||
assert _validate_ods_expression("NULL", set()) is True
|
||||
|
||||
def test_case_expression(self):
|
||||
assert _validate_ods_expression("CASE WHEN x=1 THEN 'a' ELSE 'b' END", set()) is True
|
||||
|
||||
def test_coalesce_expression(self):
|
||||
assert _validate_ods_expression("COALESCE(a, b)", set()) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_records
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractRecords:
|
||||
"""API 响应记录提取"""
|
||||
|
||||
def test_direct_list(self):
|
||||
assert _extract_records([{"a": 1}]) == [{"a": 1}]
|
||||
|
||||
def test_data_list(self):
|
||||
assert _extract_records({"data": [{"a": 1}]}) == [{"a": 1}]
|
||||
|
||||
def test_nested_list_key(self):
|
||||
result = _extract_records({"data": {"items": [{"a": 1}]}})
|
||||
assert result == [{"a": 1}]
|
||||
|
||||
def test_empty_data(self):
|
||||
assert _extract_records({}) == []
|
||||
|
||||
def test_none_input(self):
|
||||
assert _extract_records(None) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_markdown_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGenerateMarkdownReport:
|
||||
"""Markdown 报告生成"""
|
||||
|
||||
def test_empty_report(self):
|
||||
report = ConsistencyReport(generated_at="2026-01-01T00:00:00")
|
||||
md = generate_markdown_report(report)
|
||||
assert "数据一致性黑盒测试报告" in md
|
||||
assert "全部通过" in md
|
||||
|
||||
def test_failed_report(self):
|
||||
report = ConsistencyReport(
|
||||
generated_at="2026-01-01T00:00:00",
|
||||
ods_vs_dwd_results=[
|
||||
TableCheckResult(
|
||||
table_name="dwd.test",
|
||||
check_type="ods_vs_dwd",
|
||||
passed=False,
|
||||
total_fields=3,
|
||||
passed_fields=2,
|
||||
missing_fields=1,
|
||||
field_results=[
|
||||
FieldCheckResult("col_a", "pass", "自动映射"),
|
||||
FieldCheckResult("col_b", "pass", "显式映射"),
|
||||
FieldCheckResult("col_c", "missing", "无映射源"),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
md = generate_markdown_report(report)
|
||||
assert "存在异常" in md
|
||||
assert "col_c" in md
|
||||
assert "missing" in md
|
||||
|
||||
def test_report_contains_summary(self):
|
||||
report = ConsistencyReport(
|
||||
generated_at="2026-01-01T00:00:00",
|
||||
api_vs_ods_results=[
|
||||
TableCheckResult("ods.t1", "api_vs_ods", passed=True, total_fields=5, passed_fields=5),
|
||||
TableCheckResult("ods.t2", "api_vs_ods", passed=False, total_fields=3, passed_fields=2, missing_fields=1),
|
||||
],
|
||||
)
|
||||
md = generate_markdown_report(report)
|
||||
assert "1/2 张表通过" in md
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""配置层属性测试 — 验证 AppConfig 的深度合并、store_id 验证、DSN 组装、点号路径 get。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
Feature: etl-flow-debug
|
||||
使用 hypothesis 对 AppConfig 的 4 个核心正确性属性进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DWD/DWS 层属性测试 — 验证 DwdLoadTask 和 BaseTask 的核心正确性属性。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
Feature: etl-flow-debug
|
||||
使用 hypothesis 对 DWD 列映射完整性、only_tables 过滤和 DWS 分段累加进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -58,7 +58,7 @@ def _get_ods_table_names() -> set[str]:
|
||||
|
||||
# ===========================================================================
|
||||
# Property 6: DWD FACT_MAPPINGS 列映射完整性
|
||||
# Feature: etl-pipeline-debug, Property 6: DWD FACT_MAPPINGS 列映射完整性
|
||||
# Feature: etl-flow-debug, Property 6: DWD FACT_MAPPINGS 列映射完整性
|
||||
# Validates: Requirements 2.4
|
||||
#
|
||||
# 验证策略:静态检查 FACT_MAPPINGS 中每个映射条目,当 ods_expr 是简单列名时,
|
||||
@@ -122,7 +122,7 @@ def test_property6_fact_mappings_column_integrity(idx):
|
||||
|
||||
# ===========================================================================
|
||||
# Property 7: DWD only_tables 过滤
|
||||
# Feature: etl-pipeline-debug, Property 7: DWD only_tables 过滤
|
||||
# Feature: etl-flow-debug, Property 7: DWD only_tables 过滤
|
||||
# Validates: Requirements 2.6
|
||||
#
|
||||
# 验证策略:模拟 DwdLoadTask.load() 中的 only_tables 过滤逻辑,
|
||||
@@ -209,7 +209,7 @@ def test_property7_dwd_only_tables_filter(only_tables_cfg):
|
||||
|
||||
# ===========================================================================
|
||||
# Property 8: DWS 分段累加一致性
|
||||
# Feature: etl-pipeline-debug, Property 8: DWS 分段累加一致性
|
||||
# Feature: etl-flow-debug, Property 8: DWS 分段累加一致性
|
||||
# Validates: Requirements 3.3
|
||||
#
|
||||
# 验证策略:直接测试 BaseTask._accumulate_counts 静态方法,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ODS 层属性测试 — 验证 BaseOdsTask 的核心正确性属性。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
Feature: etl-flow-debug
|
||||
使用 hypothesis 对 ODS 任务的关键行为进行属性测试。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
@@ -166,7 +166,7 @@ _COLUMNS_WITH_IS_DELETE = [
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: ODS 任务提取记录数一致性
|
||||
# Feature: etl-pipeline-debug, Property 1: ODS 任务提取记录数一致性
|
||||
# Feature: etl-flow-debug, Property 1: ODS 任务提取记录数一致性
|
||||
# Validates: Requirements 1.1, 1.2
|
||||
# ===========================================================================
|
||||
|
||||
@@ -207,7 +207,7 @@ def test_property1_ods_record_count_consistency(tmp_path, records):
|
||||
|
||||
# ===========================================================================
|
||||
# Property 2: ODS 冲突处理策略正确性
|
||||
# Feature: etl-pipeline-debug, Property 2: ODS 冲突处理策略正确性
|
||||
# Feature: etl-flow-debug, Property 2: ODS 冲突处理策略正确性
|
||||
# Validates: Requirements 1.3
|
||||
# ===========================================================================
|
||||
|
||||
@@ -256,7 +256,7 @@ def test_property2_ods_conflict_mode_sql(tmp_path, conflict_mode):
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: ODS 跳过缺失主键记录
|
||||
# Feature: etl-pipeline-debug, Property 3: ODS 跳过缺失主键记录
|
||||
# Feature: etl-flow-debug, Property 3: ODS 跳过缺失主键记录
|
||||
# Validates: Requirements 1.4
|
||||
# ===========================================================================
|
||||
|
||||
@@ -304,7 +304,7 @@ def test_property3_ods_skip_missing_pk(tmp_path, valid_records, missing_pk_recor
|
||||
|
||||
# ===========================================================================
|
||||
# Property 4: ODS content_hash 去重
|
||||
# Feature: etl-pipeline-debug, Property 4: ODS content_hash 去重
|
||||
# Feature: etl-flow-debug, Property 4: ODS content_hash 去重
|
||||
# Validates: Requirements 1.5
|
||||
# ===========================================================================
|
||||
|
||||
@@ -337,7 +337,7 @@ def test_property4_ods_content_hash_deterministic(tmp_path, record):
|
||||
|
||||
# ===========================================================================
|
||||
# Property 5: ODS 快照删除标记(INSERT 语义)
|
||||
# Feature: etl-pipeline-debug, Property 5: ODS 快照删除标记
|
||||
# Feature: etl-flow-debug, Property 5: ODS 快照删除标记
|
||||
# Validates: Requirements 1.7, 7.1, 7.4
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""编排层属性测试 — 验证 FlowRunner、TaskExecutor、CLI 的核心正确性属性。
|
||||
|
||||
Feature: etl-pipeline-debug
|
||||
Feature: etl-flow-debug
|
||||
使用 hypothesis 对 Flow 层解析、无效 Flow 拒绝、工具类任务跳过游标、
|
||||
CLI data_source 解析进行属性测试。
|
||||
"""
|
||||
@@ -58,7 +58,7 @@ def test_property9_pipeline_runner_flow_layer_resolution(flow_name: str):
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
|
||||
|
||||
# 直接验证 FLOW_LAYERS 字典查找——这是 FlowRunner.run() 内部
|
||||
# 解析层列表的唯一路径:`layers = self.FLOW_LAYERS[pipeline]`
|
||||
# 解析层列表的唯一路径:`layers = self.FLOW_LAYERS[flow]`
|
||||
assert flow_name in FlowRunner.FLOW_LAYERS
|
||||
assert FlowRunner.FLOW_LAYERS[flow_name] == expected_layers
|
||||
|
||||
@@ -97,7 +97,7 @@ def test_property10_pipeline_runner_rejects_invalid_flow(invalid_flow: str):
|
||||
避免构造完整的 FlowRunner 实例(需要真实 DB/API 连接)。
|
||||
"""
|
||||
# FlowRunner.run() 的第一行就是:
|
||||
# if pipeline not in self.FLOW_LAYERS: raise ValueError(...)
|
||||
# if flow not in self.FLOW_LAYERS: raise ValueError(...)
|
||||
# 我们直接验证这个守卫条件
|
||||
assert invalid_flow not in FlowRunner.FLOW_LAYERS
|
||||
|
||||
@@ -116,7 +116,7 @@ def test_property10_pipeline_runner_rejects_invalid_flow(invalid_flow: str):
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="无效的 Flow 名称"):
|
||||
runner.run(pipeline=invalid_flow)
|
||||
runner.run(flow=invalid_flow)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""任务 3.3: 验证 DwdLoadTask 自动列映射包含 birthday。
|
||||
|
||||
验证点:
|
||||
1. _get_columns() 从 information_schema 读取 DWD 表列名时,birthday 被自动包含
|
||||
2. _is_row_changed() 的 SCD2 变化检测自动包含 birthday(因为 birthday 不在 SCD_COLS 中)
|
||||
3. _merge_dim_scd2() 构建 SELECT 表达式时,birthday 作为同名列被自动映射
|
||||
|
||||
需求: 4.2, 4.3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
|
||||
|
||||
def _make_bare_task() -> DwdLoadTask:
|
||||
"""创建一个最小化的 DwdLoadTask 实例,仅设置 _is_row_changed 所需的属性。"""
|
||||
task = DwdLoadTask.__new__(DwdLoadTask)
|
||||
task.tz = ZoneInfo("Asia/Shanghai")
|
||||
return task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dim_member 表列定义(迁移 C1 后的状态,含 birthday)
|
||||
# ---------------------------------------------------------------------------
|
||||
_DIM_MEMBER_COLUMNS = [
|
||||
{"column_name": "member_id"},
|
||||
{"column_name": "scd2_version"},
|
||||
{"column_name": "nickname"},
|
||||
{"column_name": "mobile"},
|
||||
{"column_name": "birthday"},
|
||||
{"column_name": "register_site_id"},
|
||||
{"column_name": "scd2_start_time"},
|
||||
{"column_name": "scd2_end_time"},
|
||||
{"column_name": "scd2_is_current"},
|
||||
]
|
||||
|
||||
_ODS_MEMBER_COLUMNS = [
|
||||
{"column_name": "id"},
|
||||
{"column_name": "nickname"},
|
||||
{"column_name": "mobile"},
|
||||
{"column_name": "birthday"},
|
||||
{"column_name": "fetched_at"},
|
||||
{"column_name": "payload"},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDwdBirthdayColumnMapping:
|
||||
"""验证 DwdLoadTask 的自动列映射和 SCD2 变化检测包含 birthday。"""
|
||||
|
||||
def test_get_columns_includes_birthday(self):
|
||||
"""_get_columns() 从 information_schema 读取列名时,birthday 应在返回列表中。
|
||||
|
||||
验证机制:_get_columns() 执行 SELECT column_name FROM information_schema.columns
|
||||
并返回所有列名的小写列表。当 dim_member 表包含 birthday 列时,
|
||||
返回值中应包含 'birthday'。
|
||||
"""
|
||||
# 直接验证 _get_columns 的逻辑:它返回 information_schema 查询结果的 column_name
|
||||
# 模拟 information_schema 返回包含 birthday 的列定义
|
||||
columns = [row["column_name"].lower() for row in _DIM_MEMBER_COLUMNS]
|
||||
|
||||
assert "birthday" in columns, (
|
||||
"dim_member 列列表应包含 birthday,"
|
||||
f"实际列: {columns}"
|
||||
)
|
||||
# 同时验证 SCD2 元数据列也在列表中(它们会被 _is_row_changed 跳过)
|
||||
for scd_col in ("scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"):
|
||||
assert scd_col in columns, f"dim_member 应包含 SCD2 列 {scd_col}"
|
||||
|
||||
def test_birthday_not_in_scd_cols(self):
|
||||
"""birthday 不在 SCD_COLS 集合中,因此 SCD2 变化检测会自动包含它。
|
||||
|
||||
_is_row_changed() 遍历 dwd_cols,跳过 SCD_COLS 中的列,
|
||||
对其余列逐一比较。birthday 不在 SCD_COLS 中,所以会参与比较。
|
||||
"""
|
||||
assert "birthday" not in DwdLoadTask.SCD_COLS, (
|
||||
"birthday 不应在 SCD_COLS 中,否则 SCD2 变化检测会跳过它。"
|
||||
f"当前 SCD_COLS: {DwdLoadTask.SCD_COLS}"
|
||||
)
|
||||
|
||||
def test_is_row_changed_detects_birthday_change(self):
|
||||
"""当 birthday 值变化时,_is_row_changed() 应返回 True。
|
||||
|
||||
验证需求 4.3: SCD2 将 birthday 作为变化检测字段之一。
|
||||
"""
|
||||
task = _make_bare_task()
|
||||
|
||||
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
|
||||
|
||||
current = {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800138000",
|
||||
"birthday": date(1990, 5, 15),
|
||||
"register_site_id": 1001,
|
||||
"scd2_version": 1,
|
||||
"scd2_start_time": datetime(2025, 1, 1),
|
||||
"scd2_end_time": datetime(9999, 12, 31),
|
||||
"scd2_is_current": 1,
|
||||
}
|
||||
# birthday 变化:1990-05-15 → 1991-06-20
|
||||
incoming = {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800138000",
|
||||
"birthday": date(1991, 6, 20),
|
||||
"register_site_id": 1001,
|
||||
}
|
||||
|
||||
changed = task._is_row_changed(current, incoming, dwd_cols)
|
||||
assert changed is True, (
|
||||
"birthday 值变化时,_is_row_changed 应返回 True"
|
||||
)
|
||||
|
||||
def test_is_row_changed_no_change_when_birthday_same(self):
|
||||
"""当 birthday 值不变时(其他字段也不变),_is_row_changed() 应返回 False。"""
|
||||
task = _make_bare_task()
|
||||
|
||||
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
|
||||
|
||||
current = {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800138000",
|
||||
"birthday": date(1990, 5, 15),
|
||||
"register_site_id": 1001,
|
||||
"scd2_version": 1,
|
||||
"scd2_start_time": datetime(2025, 1, 1),
|
||||
"scd2_end_time": datetime(9999, 12, 31),
|
||||
"scd2_is_current": 1,
|
||||
}
|
||||
incoming = {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800138000",
|
||||
"birthday": date(1990, 5, 15),
|
||||
"register_site_id": 1001,
|
||||
}
|
||||
|
||||
changed = task._is_row_changed(current, incoming, dwd_cols)
|
||||
assert changed is False, (
|
||||
"所有字段(含 birthday)不变时,_is_row_changed 应返回 False"
|
||||
)
|
||||
|
||||
def test_is_row_changed_birthday_null_to_value(self):
|
||||
"""birthday 从 NULL 变为有值时,应检测到变化。"""
|
||||
task = _make_bare_task()
|
||||
|
||||
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
|
||||
|
||||
current = {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800138000",
|
||||
"birthday": None,
|
||||
"register_site_id": 1001,
|
||||
"scd2_version": 1,
|
||||
"scd2_start_time": datetime(2025, 1, 1),
|
||||
"scd2_end_time": datetime(9999, 12, 31),
|
||||
"scd2_is_current": 1,
|
||||
}
|
||||
incoming = {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800138000",
|
||||
"birthday": date(1990, 5, 15),
|
||||
"register_site_id": 1001,
|
||||
}
|
||||
|
||||
changed = task._is_row_changed(current, incoming, dwd_cols)
|
||||
assert changed is True, (
|
||||
"birthday 从 None 变为有值时,_is_row_changed 应返回 True"
|
||||
)
|
||||
|
||||
def test_birthday_in_scd2_select_expressions(self):
|
||||
"""验证 _merge_dim_scd2 构建 SELECT 表达式时,birthday 作为同名列被自动映射。
|
||||
|
||||
机制:_merge_dim_scd2 遍历 dwd_cols,对于不在 SCD_COLS 中且不在 mapping 中
|
||||
但在 ods_set 中的列,会生成 '"birthday" AS "birthday"' 表达式。
|
||||
"""
|
||||
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
|
||||
ods_cols = [row["column_name"] for row in _ODS_MEMBER_COLUMNS]
|
||||
ods_set = {c.lower() for c in ods_cols}
|
||||
|
||||
# 模拟 _merge_dim_scd2 中构建 SELECT 表达式的逻辑
|
||||
scd_cols = DwdLoadTask.SCD_COLS
|
||||
mapping = {} # dim_member 没有显式 FACT_MAPPINGS
|
||||
|
||||
selected_cols = []
|
||||
for col in dwd_cols:
|
||||
lc = col.lower()
|
||||
if lc in scd_cols:
|
||||
continue
|
||||
if lc in mapping:
|
||||
selected_cols.append(lc)
|
||||
elif lc in ods_set:
|
||||
selected_cols.append(lc)
|
||||
|
||||
assert "birthday" in selected_cols, (
|
||||
"birthday 应出现在 SCD2 SELECT 表达式的列列表中,"
|
||||
f"实际选中列: {selected_cols}"
|
||||
)
|
||||
|
||||
def test_scd2_change_detection_columns_include_birthday(self):
|
||||
"""验证 SCD2 变化检测遍历的列集合包含 birthday。
|
||||
|
||||
_is_row_changed 遍历 dwd_cols,跳过 SCD_COLS。
|
||||
最终参与比较的列应包含 birthday。
|
||||
"""
|
||||
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
|
||||
scd_cols = DwdLoadTask.SCD_COLS
|
||||
|
||||
# 计算实际参与变化检测的列
|
||||
change_detection_cols = [
|
||||
col for col in dwd_cols if col.lower() not in scd_cols
|
||||
]
|
||||
|
||||
assert "birthday" in change_detection_cols, (
|
||||
"SCD2 变化检测列应包含 birthday,"
|
||||
f"实际检测列: {change_detection_cols}"
|
||||
)
|
||||
# 验证 SCD2 元数据列被正确排除
|
||||
for scd_col in scd_cols:
|
||||
assert scd_col not in change_detection_cols, (
|
||||
f"SCD2 元数据列 {scd_col} 不应参与变化检测"
|
||||
)
|
||||
@@ -0,0 +1,144 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DwdLoadTask.load() 返回值格式验证 — 单元测试。
|
||||
|
||||
验证需求 1.1:load() 返回 errors: int, error_details: list[dict]。
|
||||
使用 FakeDB/FakeAPI,不涉及真实数据库连接。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
|
||||
|
||||
def _make_task() -> DwdLoadTask:
|
||||
"""构造最小可用的 DwdLoadTask 实例。"""
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: default)
|
||||
|
||||
# 构造 FakeConnection,cursor() 接受任意 kwargs
|
||||
fake_conn = MagicMock()
|
||||
fake_cursor = MagicMock()
|
||||
fake_cursor.__enter__ = MagicMock(return_value=fake_cursor)
|
||||
fake_cursor.__exit__ = MagicMock(return_value=False)
|
||||
fake_conn.cursor.return_value = fake_cursor
|
||||
|
||||
db = MagicMock()
|
||||
db.conn = fake_conn
|
||||
api = MagicMock()
|
||||
logger = logging.getLogger("test_dwd_return_format")
|
||||
return DwdLoadTask(config, db, api, logger)
|
||||
|
||||
|
||||
class TestLoadReturnFormat:
|
||||
"""验证 load() 返回值中 errors 为 int、error_details 为 list[dict]。"""
|
||||
|
||||
def test_no_errors_returns_zero_and_empty_list(self):
|
||||
"""无错误时:errors=0, error_details=[]。"""
|
||||
task = _make_task()
|
||||
# 让 TABLE_MAP 为空,load() 直接返回空结果
|
||||
with patch.object(DwdLoadTask, "TABLE_MAP", {}):
|
||||
ctx = SimpleNamespace(
|
||||
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
|
||||
)
|
||||
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
|
||||
|
||||
assert isinstance(result["errors"], int), (
|
||||
f"errors 应为 int,实际为 {type(result['errors'])}"
|
||||
)
|
||||
assert result["errors"] == 0
|
||||
assert isinstance(result["error_details"], list)
|
||||
assert result["error_details"] == []
|
||||
assert "tables" in result
|
||||
|
||||
def test_with_errors_returns_count_and_details(self):
|
||||
"""有错误时:errors=len(error_details), error_details 包含错误字典。"""
|
||||
task = _make_task()
|
||||
|
||||
# 模拟一张表装载失败:_get_columns 抛异常
|
||||
fake_cursor = task.db.conn.cursor.return_value.__enter__.return_value
|
||||
fake_cursor.fetchall.return_value = []
|
||||
|
||||
def fake_get_columns(cur, table):
|
||||
if "dim_test" in table:
|
||||
raise RuntimeError("模拟装载失败")
|
||||
return []
|
||||
|
||||
task._get_columns = fake_get_columns
|
||||
|
||||
with patch.object(
|
||||
DwdLoadTask, "TABLE_MAP", {"dwd.dim_test": "ods.test_source"}
|
||||
):
|
||||
ctx = SimpleNamespace(
|
||||
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
|
||||
)
|
||||
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
|
||||
|
||||
assert isinstance(result["errors"], int)
|
||||
assert result["errors"] == 1
|
||||
assert isinstance(result["error_details"], list)
|
||||
assert len(result["error_details"]) == 1
|
||||
assert result["error_details"][0]["table"] == "dwd.dim_test"
|
||||
assert "模拟装载失败" in result["error_details"][0]["error"]
|
||||
|
||||
def test_errors_equals_len_of_error_details(self):
|
||||
"""errors 值始终等于 error_details 列表长度。"""
|
||||
task = _make_task()
|
||||
|
||||
# 模拟两张表都失败
|
||||
def fake_get_columns(cur, table):
|
||||
raise RuntimeError(f"失败: {table}")
|
||||
|
||||
task._get_columns = fake_get_columns
|
||||
|
||||
with patch.object(
|
||||
DwdLoadTask,
|
||||
"TABLE_MAP",
|
||||
{"dwd.dim_a": "ods.a", "dwd.dim_b": "ods.b"},
|
||||
):
|
||||
ctx = SimpleNamespace(
|
||||
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
|
||||
)
|
||||
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
|
||||
|
||||
assert result["errors"] == len(result["error_details"])
|
||||
assert result["errors"] == 2
|
||||
|
||||
def test_successful_tables_in_summary(self):
|
||||
"""成功装载的表出现在 tables 列表中,不影响 errors 计数。"""
|
||||
task = _make_task()
|
||||
|
||||
# dim_ok 返回空列 → 跳过(不报错也不计入 summary)
|
||||
# dim_fail 抛异常 → 计入 errors
|
||||
call_count = 0
|
||||
|
||||
def fake_get_columns(cur, table):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if "fail" in table:
|
||||
raise RuntimeError("boom")
|
||||
return [] # 空列 → 跳过
|
||||
|
||||
task._get_columns = fake_get_columns
|
||||
|
||||
with patch.object(
|
||||
DwdLoadTask,
|
||||
"TABLE_MAP",
|
||||
{"dwd.dim_ok": "ods.ok", "dwd.dim_fail": "ods.fail"},
|
||||
):
|
||||
ctx = SimpleNamespace(
|
||||
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
|
||||
)
|
||||
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
|
||||
|
||||
assert result["errors"] == 1
|
||||
assert len(result["error_details"]) == 1
|
||||
# dim_ok 因空列被跳过,不在 summary 中
|
||||
assert result["tables"] == []
|
||||
@@ -0,0 +1,210 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
DWS 任务 birthday 字段恢复测试
|
||||
|
||||
验证需求 4.4:DWS 任务从 dim_member.birthday 读取生日字段并写入 DWS 目标表。
|
||||
纯单元测试,使用 FakeDB/Mock,不涉及真实数据库连接。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from tests.unit.task_test_utils import FakeDBOperations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: 创建 MemberVisitTask 实例
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _create_member_visit_task():
|
||||
from tasks.dws.member_visit_task import MemberVisitTask
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: (
|
||||
1 if key == "app.tenant_id" else default
|
||||
)
|
||||
db = FakeDBOperations()
|
||||
return MemberVisitTask(mock_config, db, MagicMock(), MagicMock()), db
|
||||
|
||||
|
||||
def _create_member_consumption_task():
|
||||
from tasks.dws.member_consumption_task import MemberConsumptionTask
|
||||
mock_config = MagicMock()
|
||||
mock_config.get.side_effect = lambda key, default=None: (
|
||||
1 if key == "app.tenant_id" else default
|
||||
)
|
||||
db = FakeDBOperations()
|
||||
return MemberConsumptionTask(mock_config, db, MagicMock(), MagicMock()), db
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# MemberVisitTask._extract_member_info — birthday 字段
|
||||
# ===========================================================================
|
||||
|
||||
class TestMemberVisitBirthday:
|
||||
"""验证 MemberVisitTask 的 birthday 提取与写入"""
|
||||
|
||||
def test_extract_member_info_sql_includes_birthday(self):
|
||||
"""_extract_member_info SQL 应包含 birthday 字段"""
|
||||
task, db = _create_member_visit_task()
|
||||
db.query_results.append([
|
||||
{"member_id": 100, "nickname": "张三", "mobile": "13800001111", "birthday": date(1990, 5, 15)},
|
||||
])
|
||||
result = task._extract_member_info(site_id=1)
|
||||
|
||||
# 验证 SQL 包含 birthday
|
||||
sql_executed = db.executes[0]["sql"]
|
||||
assert "birthday" in sql_executed
|
||||
|
||||
# 验证返回值包含 birthday
|
||||
assert result[100]["birthday"] == date(1990, 5, 15)
|
||||
|
||||
def test_extract_member_info_birthday_none(self):
|
||||
"""birthday 为 None 时应正常返回"""
|
||||
task, db = _create_member_visit_task()
|
||||
db.query_results.append([
|
||||
{"member_id": 200, "nickname": "李四", "mobile": "13900002222", "birthday": None},
|
||||
])
|
||||
result = task._extract_member_info(site_id=1)
|
||||
assert result[200]["birthday"] is None
|
||||
|
||||
def test_transform_writes_member_birthday(self):
|
||||
"""transform() 应将 birthday 写入 member_birthday 字段"""
|
||||
task, db = _create_member_visit_task()
|
||||
|
||||
extracted = {
|
||||
"settlements": [
|
||||
{
|
||||
"member_id": 100,
|
||||
"order_settle_id": 1001,
|
||||
"visit_date": date(2026, 2, 20),
|
||||
"create_time": "2026-02-20 14:00:00",
|
||||
"table_id": 10,
|
||||
"table_charge_money": Decimal("50.00"),
|
||||
"goods_money": Decimal("30.00"),
|
||||
"assistant_pd_money": Decimal("20.00"),
|
||||
"assistant_cx_money": Decimal("10.00"),
|
||||
"consume_money": Decimal("110.00"),
|
||||
"pay_amount": Decimal("100.00"),
|
||||
"balance_amount": Decimal("0"),
|
||||
"gift_card_amount": Decimal("0"),
|
||||
"coupon_amount": Decimal("10.00"),
|
||||
"discount_money": Decimal("0"),
|
||||
"free_money": Decimal("0"),
|
||||
},
|
||||
],
|
||||
"assistant_services": [],
|
||||
"member_info": {
|
||||
100: {
|
||||
"member_id": 100,
|
||||
"nickname": "张三",
|
||||
"mobile": "13800001111",
|
||||
"birthday": date(1990, 5, 15),
|
||||
},
|
||||
},
|
||||
"table_info": {
|
||||
10: {"table_id": 10, "table_name": "1号台", "area_name": "大厅"},
|
||||
},
|
||||
"table_fee_durations": [],
|
||||
"site_id": 1,
|
||||
}
|
||||
|
||||
from tasks.dws.base_dws_task import TaskContext
|
||||
from datetime import datetime
|
||||
ctx = TaskContext(
|
||||
store_id=1,
|
||||
window_start=datetime(2026, 2, 20),
|
||||
window_end=datetime(2026, 2, 20, 23, 59, 59),
|
||||
window_minutes=1440,
|
||||
)
|
||||
|
||||
results = task.transform(extracted, ctx)
|
||||
assert len(results) == 1
|
||||
assert results[0]["member_birthday"] == date(1990, 5, 15)
|
||||
|
||||
def test_transform_member_birthday_none_when_no_birthday(self):
|
||||
"""会员无 birthday 时 member_birthday 应为 None"""
|
||||
task, db = _create_member_visit_task()
|
||||
|
||||
extracted = {
|
||||
"settlements": [
|
||||
{
|
||||
"member_id": 300,
|
||||
"order_settle_id": 3001,
|
||||
"visit_date": date(2026, 2, 20),
|
||||
"create_time": "2026-02-20 15:00:00",
|
||||
"table_id": 10,
|
||||
"table_charge_money": Decimal("0"),
|
||||
"goods_money": Decimal("0"),
|
||||
"assistant_pd_money": Decimal("0"),
|
||||
"assistant_cx_money": Decimal("0"),
|
||||
"consume_money": Decimal("0"),
|
||||
"pay_amount": Decimal("0"),
|
||||
"balance_amount": Decimal("0"),
|
||||
"gift_card_amount": Decimal("0"),
|
||||
"coupon_amount": Decimal("0"),
|
||||
"discount_money": Decimal("0"),
|
||||
"free_money": Decimal("0"),
|
||||
},
|
||||
],
|
||||
"assistant_services": [],
|
||||
"member_info": {
|
||||
300: {"member_id": 300, "nickname": "王五", "mobile": None, "birthday": None},
|
||||
},
|
||||
"table_info": {10: {"table_id": 10, "table_name": "1号台", "area_name": None}},
|
||||
"table_fee_durations": [],
|
||||
"site_id": 1,
|
||||
}
|
||||
|
||||
from tasks.dws.base_dws_task import TaskContext
|
||||
from datetime import datetime
|
||||
ctx = TaskContext(
|
||||
store_id=1,
|
||||
window_start=datetime(2026, 2, 20),
|
||||
window_end=datetime(2026, 2, 20, 23, 59, 59),
|
||||
window_minutes=1440,
|
||||
)
|
||||
results = task.transform(extracted, ctx)
|
||||
assert results[0]["member_birthday"] is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# MemberConsumptionTask._extract_member_info — birthday 字段
|
||||
# ===========================================================================
|
||||
|
||||
class TestMemberConsumptionBirthday:
|
||||
"""验证 MemberConsumptionTask 的 birthday 提取"""
|
||||
|
||||
def test_extract_member_info_sql_includes_birthday(self):
|
||||
"""_extract_member_info SQL 应包含 birthday 字段"""
|
||||
task, db = _create_member_consumption_task()
|
||||
db.query_results.append([
|
||||
{
|
||||
"member_id": 100, "nickname": "张三", "mobile": "13800001111",
|
||||
"member_card_grade_name": "金卡", "register_date": date(2025, 1, 1),
|
||||
"recharge_money_sum": Decimal("500.00"), "birthday": date(1995, 8, 20),
|
||||
},
|
||||
])
|
||||
result = task._extract_member_info(site_id=1)
|
||||
|
||||
sql_executed = db.executes[0]["sql"]
|
||||
assert "birthday" in sql_executed
|
||||
assert result[100]["birthday"] == date(1995, 8, 20)
|
||||
|
||||
def test_extract_member_info_birthday_none(self):
|
||||
"""birthday 为 None 时应正常返回"""
|
||||
task, db = _create_member_consumption_task()
|
||||
db.query_results.append([
|
||||
{
|
||||
"member_id": 200, "nickname": "李四", "mobile": None,
|
||||
"member_card_grade_name": None, "register_date": None,
|
||||
"recharge_money_sum": Decimal("0"), "birthday": None,
|
||||
},
|
||||
])
|
||||
result = task._extract_member_info(site_id=1)
|
||||
assert result[200]["birthday"] is None
|
||||
@@ -124,14 +124,14 @@ class TestFlowModeE2E:
|
||||
)
|
||||
|
||||
result = runner.run(
|
||||
pipeline="api_ods",
|
||||
flow="api_ods",
|
||||
processing_mode="increment_only",
|
||||
data_source="hybrid",
|
||||
)
|
||||
|
||||
# 结构验证
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["pipeline"] == "api_ods"
|
||||
assert result["flow"] == "api_ods"
|
||||
assert result["layers"] == ["ODS"]
|
||||
assert isinstance(result["results"], list)
|
||||
# TaskExecutor 被调用
|
||||
@@ -159,7 +159,7 @@ class TestFlowModeE2E:
|
||||
# 校验框架可能未安装,mock 掉 _run_verification
|
||||
with patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}):
|
||||
result = runner.run(
|
||||
pipeline="api_ods",
|
||||
flow="api_ods",
|
||||
processing_mode="verify_only",
|
||||
data_source="hybrid",
|
||||
)
|
||||
@@ -218,5 +218,5 @@ class TestSchedulerThinWrapper:
|
||||
scheduler.task_executor.run_tasks.assert_called_once()
|
||||
|
||||
# run_flow_with_verification 委托
|
||||
scheduler.run_flow_with_verification(pipeline="api_ods")
|
||||
scheduler.run_flow_with_verification(flow="api_ods")
|
||||
scheduler.flow_runner.run.assert_called_once()
|
||||
|
||||
@@ -80,31 +80,27 @@ class TestLayersArgParsing:
|
||||
args = parse_args()
|
||||
assert args.layers is None
|
||||
|
||||
def test_pipeline_still_works(self):
|
||||
"""--pipeline 参数保留可用(需求 6.3),存入 pipeline_deprecated(弃用别名)"""
|
||||
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. --layers 与 --pipeline 互斥(需求 6.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestLayersPipelineMutualExclusion:
|
||||
"""--layers 和 --flow/--pipeline 互斥校验
|
||||
class TestLayersFlowMutualExclusion:
|
||||
"""--layers 和 --flow 互斥校验
|
||||
|
||||
互斥校验在 main() 中实现(非 argparse 层),
|
||||
此处验证两个参数可以同时被解析(互斥由 main 层处理)。
|
||||
"""
|
||||
|
||||
def test_both_args_can_be_parsed(self):
|
||||
"""argparse 层允许同时传入 --layers 和 --pipeline(弃用别名),互斥由 main() 校验"""
|
||||
"""argparse 层允许同时传入 --layers 和 --flow,互斥由 main() 校验"""
|
||||
with patch("sys.argv", [
|
||||
"cli", "--layers", "ODS,DWD", "--pipeline", "api_full",
|
||||
"cli", "--layers", "ODS,DWD", "--flow", "api_full",
|
||||
]):
|
||||
args = parse_args()
|
||||
assert args.layers == "ODS,DWD"
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
assert args.flow == "api_full"
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -208,7 +204,7 @@ class TestParseLayersProperties:
|
||||
# 5. --flow / --pipeline 弃用别名测试(需求 9.3, 9.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestFlowParameter:
|
||||
"""--flow 作为主参数、--pipeline 作为弃用别名"""
|
||||
"""--flow 参数测试(--pipeline 已移除)"""
|
||||
|
||||
def test_flow_parsed(self):
|
||||
"""--flow 作为主参数可正常解析"""
|
||||
@@ -222,52 +218,6 @@ class TestFlowParameter:
|
||||
args = parse_args()
|
||||
assert args.flow is None
|
||||
|
||||
def test_pipeline_deprecated_parsed(self):
|
||||
"""--pipeline 仍可解析,存入 pipeline_deprecated"""
|
||||
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
assert args.flow is None # --flow 未指定
|
||||
|
||||
def test_pipeline_emits_deprecation_warning(self):
|
||||
"""使用 --pipeline 时应发出 DeprecationWarning(需求 9.4)
|
||||
|
||||
直接模拟 main() 中的弃用逻辑,避免进入数据库连接。
|
||||
"""
|
||||
import warnings
|
||||
|
||||
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
|
||||
# 模拟 main() 中的弃用处理逻辑
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
if args.pipeline_deprecated:
|
||||
warnings.warn(
|
||||
"--pipeline 参数已弃用,请使用 --flow",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not args.flow:
|
||||
args.flow = args.pipeline_deprecated
|
||||
|
||||
dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
|
||||
assert len(dep_warnings) == 1
|
||||
assert "--pipeline 参数已弃用" in str(dep_warnings[0].message)
|
||||
# 验证值已合并到 args.flow
|
||||
assert args.flow == "api_full"
|
||||
|
||||
def test_flow_and_pipeline_mutually_exclusive(self):
|
||||
"""--flow 和 --pipeline 不能同时指定(需求 9.3)
|
||||
|
||||
argparse 层允许同时传入,互斥由 main() 中的逻辑处理。
|
||||
"""
|
||||
with patch("sys.argv", ["cli", "--flow", "api_full", "--pipeline", "api_ods"]):
|
||||
args = parse_args()
|
||||
# 两者同时存在时,main() 应 sys.exit(2)
|
||||
assert args.flow == "api_full"
|
||||
assert args.pipeline_deprecated == "api_ods"
|
||||
|
||||
def test_layers_and_flow_mutually_exclusive(self):
|
||||
"""--layers 和 --flow 互斥(argparse 层可同时解析,main() 校验)"""
|
||||
with patch("sys.argv", ["cli", "--layers", "ODS,DWD", "--flow", "api_full"]):
|
||||
@@ -275,21 +225,3 @@ class TestFlowParameter:
|
||||
assert args.layers == "ODS,DWD"
|
||||
assert args.flow == "api_full"
|
||||
|
||||
def test_layers_and_pipeline_deprecated_mutually_exclusive(self):
|
||||
"""--layers 和 --pipeline(弃用别名)也互斥
|
||||
|
||||
--pipeline 先合并到 args.flow,然后 --layers vs --flow 互斥生效。
|
||||
"""
|
||||
with patch("sys.argv", ["cli", "--layers", "ODS,DWD", "--pipeline", "api_full"]):
|
||||
args = parse_args()
|
||||
assert args.layers == "ODS,DWD"
|
||||
assert args.pipeline_deprecated == "api_full"
|
||||
|
||||
def test_pipeline_value_merges_to_flow(self):
|
||||
"""--pipeline 的值在弃用处理后应合并到 args.flow"""
|
||||
with patch("sys.argv", ["cli", "--pipeline", "dwd_dws"]):
|
||||
args = parse_args()
|
||||
# 模拟 main() 中的合并逻辑
|
||||
if args.pipeline_deprecated and not args.flow:
|
||||
args.flow = args.pipeline_deprecated
|
||||
assert args.flow == "dwd_dws"
|
||||
|
||||
@@ -0,0 +1,903 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""助教月度聚合属性测试 — 使用 hypothesis 验证档位分段聚合的正确性。
|
||||
|
||||
纯单元测试,模拟 _extract_daily_aggregates() 的 SQL GROUP BY 逻辑,不涉及真实数据库连接。
|
||||
|
||||
# Feature: etl-aggregation-fix, Property 3: 档位分段聚合正确性
|
||||
# **Validates: Requirements 2.1**
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# ── 业绩指标字段列表(与 _extract_daily_aggregates SQL 中 SUM 聚合的字段一致)──
|
||||
|
||||
_METRIC_FIELDS = [
|
||||
"total_service_count",
|
||||
"base_service_count",
|
||||
"bonus_service_count",
|
||||
"room_service_count",
|
||||
"total_hours",
|
||||
"base_hours",
|
||||
"bonus_hours",
|
||||
"room_hours",
|
||||
"total_ledger_amount",
|
||||
"base_ledger_amount",
|
||||
"bonus_ledger_amount",
|
||||
"room_ledger_amount",
|
||||
"unique_customers",
|
||||
"unique_tables",
|
||||
"trashed_seconds",
|
||||
"trashed_count",
|
||||
]
|
||||
|
||||
|
||||
# ── 核心逻辑:模拟 _extract_daily_aggregates() 的 SQL GROUP BY 行为 ──
|
||||
#
|
||||
# 实际 SQL:
|
||||
# SELECT assistant_id, assistant_level_code, assistant_level_name,
|
||||
# (ARRAY_AGG(assistant_nickname ORDER BY stat_date DESC))[1] AS assistant_nickname,
|
||||
# DATE_TRUNC('month', stat_date)::DATE AS stat_month,
|
||||
# COUNT(DISTINCT stat_date) AS work_days,
|
||||
# SUM(total_service_count) AS total_service_count,
|
||||
# ...
|
||||
# FROM dws.dws_assistant_daily_detail
|
||||
# WHERE site_id = %s AND (stat_date >= ... AND stat_date < ...)
|
||||
# GROUP BY assistant_id, assistant_level_code, assistant_level_name,
|
||||
# DATE_TRUNC('month', stat_date)
|
||||
|
||||
|
||||
def simulate_extract_daily_aggregates(
|
||||
daily_rows: List[Dict[str, Any]],
|
||||
site_id: int,
|
||||
month: date,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""模拟 _extract_daily_aggregates() 的 GROUP BY 逻辑。
|
||||
|
||||
按 (assistant_id, assistant_level_code, assistant_level_name, stat_month) 分组,
|
||||
对业绩指标做 SUM 聚合,work_days 做 COUNT(DISTINCT stat_date)。
|
||||
"""
|
||||
month_start = month.replace(day=1)
|
||||
next_month = (month_start.replace(day=28) + timedelta(days=4)).replace(day=1)
|
||||
|
||||
# 筛选:site_id 匹配 + 日期在月份范围内
|
||||
filtered = [
|
||||
r for r in daily_rows
|
||||
if r.get("site_id") == site_id
|
||||
and month_start <= r["stat_date"] < next_month
|
||||
]
|
||||
|
||||
# GROUP BY (assistant_id, assistant_level_code, assistant_level_name)
|
||||
groups: Dict[Tuple, List[Dict]] = defaultdict(list)
|
||||
for row in filtered:
|
||||
key = (
|
||||
row["assistant_id"],
|
||||
row.get("assistant_level_code"),
|
||||
row.get("assistant_level_name"),
|
||||
)
|
||||
groups[key].append(row)
|
||||
|
||||
results = []
|
||||
for (aid, level_code, level_name), rows in groups.items():
|
||||
agg: Dict[str, Any] = {
|
||||
"assistant_id": aid,
|
||||
"assistant_level_code": level_code,
|
||||
"assistant_level_name": level_name,
|
||||
"stat_month": month_start,
|
||||
"work_days": len({r["stat_date"] for r in rows}),
|
||||
}
|
||||
# SUM 聚合所有业绩指标
|
||||
for field in _METRIC_FIELDS:
|
||||
agg[field] = sum(Decimal(str(r.get(field, 0))) for r in rows)
|
||||
# nickname 按时间倒序取第一条
|
||||
sorted_rows = sorted(rows, key=lambda r: r["stat_date"], reverse=True)
|
||||
agg["assistant_nickname"] = sorted_rows[0].get("assistant_nickname", "")
|
||||
results.append(agg)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
# 助教 ID
|
||||
_assistant_id_st = st.integers(min_value=1, max_value=50)
|
||||
|
||||
# 档位代码:A/B/C/D/E 代表不同档位
|
||||
_level_code_st = st.sampled_from(["A", "B", "C", "D", "E"])
|
||||
|
||||
# 档位名称映射(与 code 绑定,保证同一 code 对应同一 name)
|
||||
_LEVEL_NAME_MAP = {
|
||||
"A": "初级",
|
||||
"B": "中级",
|
||||
"C": "高级",
|
||||
"D": "资深",
|
||||
"E": "专家",
|
||||
}
|
||||
|
||||
# 昵称
|
||||
_nickname_st = st.text(
|
||||
alphabet=st.sampled_from("张李王赵刘陈杨黄周吴"),
|
||||
min_size=2,
|
||||
max_size=4,
|
||||
)
|
||||
|
||||
# 非负业绩指标值(整数型)
|
||||
_count_metric_st = st.integers(min_value=0, max_value=500)
|
||||
|
||||
# 非负业绩指标值(金额/小时型,用 Decimal 表示)
|
||||
_amount_metric_st = st.decimals(
|
||||
min_value=0,
|
||||
max_value=Decimal("99999.99"),
|
||||
places=2,
|
||||
allow_nan=False,
|
||||
allow_infinity=False,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def _daily_detail_row_st(draw, assistant_id: int, level_code: str, month: date):
|
||||
"""生成一条 dws_assistant_daily_detail 日度明细行。
|
||||
|
||||
assistant_id 和 level_code 由外部指定,确保可控制分组。
|
||||
stat_date 在指定月份内随机选取。
|
||||
"""
|
||||
# 月内随机日期
|
||||
month_start = month.replace(day=1)
|
||||
next_month = (month_start.replace(day=28) + timedelta(days=4)).replace(day=1)
|
||||
max_day = (next_month - timedelta(days=1)).day
|
||||
day = draw(st.integers(min_value=1, max_value=max_day))
|
||||
stat_date = month_start.replace(day=day)
|
||||
|
||||
return {
|
||||
"site_id": 1, # 固定 site_id 简化测试
|
||||
"assistant_id": assistant_id,
|
||||
"assistant_level_code": level_code,
|
||||
"assistant_level_name": _LEVEL_NAME_MAP[level_code],
|
||||
"assistant_nickname": draw(_nickname_st),
|
||||
"stat_date": stat_date,
|
||||
"total_service_count": draw(_count_metric_st),
|
||||
"base_service_count": draw(_count_metric_st),
|
||||
"bonus_service_count": draw(_count_metric_st),
|
||||
"room_service_count": draw(_count_metric_st),
|
||||
"total_hours": float(draw(_amount_metric_st)),
|
||||
"base_hours": float(draw(_amount_metric_st)),
|
||||
"bonus_hours": float(draw(_amount_metric_st)),
|
||||
"room_hours": float(draw(_amount_metric_st)),
|
||||
"total_ledger_amount": float(draw(_amount_metric_st)),
|
||||
"base_ledger_amount": float(draw(_amount_metric_st)),
|
||||
"bonus_ledger_amount": float(draw(_amount_metric_st)),
|
||||
"room_ledger_amount": float(draw(_amount_metric_st)),
|
||||
"unique_customers": draw(_count_metric_st),
|
||||
"unique_tables": draw(_count_metric_st),
|
||||
"trashed_seconds": draw(_count_metric_st),
|
||||
"trashed_count": draw(_count_metric_st),
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def _multi_level_scenario_st(draw):
|
||||
"""生成一个助教在同一月内有多个档位的场景。
|
||||
|
||||
返回:
|
||||
- daily_rows: 日度明细行列表
|
||||
- assistant_id: 助教 ID
|
||||
- month: 统计月份
|
||||
- level_codes: 该助教在该月的不同档位代码集合
|
||||
"""
|
||||
assistant_id = draw(_assistant_id_st)
|
||||
# 固定月份为 2025-01
|
||||
month = date(2025, 1, 1)
|
||||
|
||||
# 随机选取 1~4 个不同档位
|
||||
n_levels = draw(st.integers(min_value=1, max_value=4))
|
||||
level_codes = draw(
|
||||
st.lists(
|
||||
_level_code_st,
|
||||
min_size=n_levels,
|
||||
max_size=n_levels,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
|
||||
# 为每个档位生成 1~5 条日度明细
|
||||
daily_rows: List[Dict[str, Any]] = []
|
||||
for code in level_codes:
|
||||
n_rows = draw(st.integers(min_value=1, max_value=5))
|
||||
for _ in range(n_rows):
|
||||
row = draw(_daily_detail_row_st(assistant_id, code, month))
|
||||
daily_rows.append(row)
|
||||
|
||||
return {
|
||||
"daily_rows": daily_rows,
|
||||
"assistant_id": assistant_id,
|
||||
"month": month,
|
||||
"level_codes": set(level_codes),
|
||||
}
|
||||
|
||||
|
||||
# ── Property 3: 档位分段聚合正确性 ───────────────────────────
|
||||
# Feature: etl-aggregation-fix, Property 3: 档位分段聚合正确性
|
||||
# **Validates: Requirements 2.1**
|
||||
#
|
||||
# 对于任意助教在同一月内存在 N 个不同 assistant_level_code 的日度数据,
|
||||
# _extract_daily_aggregates() 应返回恰好 N 行记录,
|
||||
# 每行的业绩指标之和应等于该助教该月的总业绩。
|
||||
|
||||
|
||||
class TestProperty3LevelSegmentAggregation:
|
||||
"""Property 3: 档位分段聚合正确性。"""
|
||||
|
||||
@given(scenario=_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_row_count_equals_distinct_level_codes(self, scenario):
|
||||
"""聚合结果行数应等于不同 assistant_level_code 的数量。"""
|
||||
daily_rows = scenario["daily_rows"]
|
||||
assistant_id = scenario["assistant_id"]
|
||||
month = scenario["month"]
|
||||
level_codes = scenario["level_codes"]
|
||||
|
||||
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
|
||||
|
||||
# 只看当前助教的结果
|
||||
assistant_results = [r for r in results if r["assistant_id"] == assistant_id]
|
||||
|
||||
assert len(assistant_results) == len(level_codes), (
|
||||
f"助教 {assistant_id} 有 {len(level_codes)} 个不同档位 {level_codes},"
|
||||
f"但聚合结果返回 {len(assistant_results)} 行"
|
||||
)
|
||||
|
||||
@given(scenario=_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_metric_sums_equal_total(self, scenario):
|
||||
"""所有档位行的业绩指标之和应等于该助教该月的总业绩。"""
|
||||
daily_rows = scenario["daily_rows"]
|
||||
assistant_id = scenario["assistant_id"]
|
||||
month = scenario["month"]
|
||||
|
||||
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
|
||||
assistant_results = [r for r in results if r["assistant_id"] == assistant_id]
|
||||
|
||||
# 手动计算该助教该月的总业绩(不分档位)
|
||||
for field in _METRIC_FIELDS:
|
||||
expected_total = sum(
|
||||
Decimal(str(r.get(field, 0)))
|
||||
for r in daily_rows
|
||||
if r["assistant_id"] == assistant_id
|
||||
)
|
||||
actual_total = sum(r[field] for r in assistant_results)
|
||||
|
||||
assert actual_total == expected_total, (
|
||||
f"指标 {field}: 各档位之和={actual_total} != 总业绩={expected_total}"
|
||||
)
|
||||
|
||||
@given(scenario=_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_each_row_has_correct_level_code(self, scenario):
|
||||
"""每行结果的 assistant_level_code 应属于原始数据中出现的档位集合。"""
|
||||
daily_rows = scenario["daily_rows"]
|
||||
assistant_id = scenario["assistant_id"]
|
||||
month = scenario["month"]
|
||||
level_codes = scenario["level_codes"]
|
||||
|
||||
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
|
||||
assistant_results = [r for r in results if r["assistant_id"] == assistant_id]
|
||||
|
||||
result_codes = {r["assistant_level_code"] for r in assistant_results}
|
||||
assert result_codes == level_codes, (
|
||||
f"聚合结果的档位集合 {result_codes} != 原始数据的档位集合 {level_codes}"
|
||||
)
|
||||
|
||||
@given(scenario=_multi_level_scenario_st())
|
||||
@settings(max_examples=100)
|
||||
def test_per_level_metrics_match_source(self, scenario):
|
||||
"""每个档位行的业绩指标应等于该档位对应日度数据的 SUM。"""
|
||||
daily_rows = scenario["daily_rows"]
|
||||
assistant_id = scenario["assistant_id"]
|
||||
month = scenario["month"]
|
||||
|
||||
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
|
||||
assistant_results = {
|
||||
r["assistant_level_code"]: r
|
||||
for r in results
|
||||
if r["assistant_id"] == assistant_id
|
||||
}
|
||||
|
||||
# 按档位分组计算期望值
|
||||
by_level: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in daily_rows:
|
||||
if r["assistant_id"] == assistant_id:
|
||||
by_level[r["assistant_level_code"]].append(r)
|
||||
|
||||
for code, rows in by_level.items():
|
||||
assert code in assistant_results, (
|
||||
f"档位 {code} 在原始数据中存在但聚合结果中缺失"
|
||||
)
|
||||
agg = assistant_results[code]
|
||||
for field in _METRIC_FIELDS:
|
||||
expected = sum(Decimal(str(r.get(field, 0))) for r in rows)
|
||||
assert agg[field] == expected, (
|
||||
f"档位 {code} 指标 {field}: 聚合值={agg[field]} != 期望={expected}"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 4: nickname 按时间倒序取值 ─────────────────────────
|
||||
# Feature: etl-aggregation-fix, Property 4: nickname 按时间倒序取值
|
||||
# **Validates: Requirements 2.3, 2.5, 2.6**
|
||||
#
|
||||
# 对于任意助教在聚合周期内有多条不同 nickname 的记录,
|
||||
# 聚合结果中的 nickname 应等于时间最晚的那条记录的 nickname。
|
||||
# 此属性适用于 AssistantMonthlyTask、AssistantFinanceTask、AssistantCustomerTask。
|
||||
|
||||
|
||||
# ── 模拟函数:FinanceTask._extract_daily_revenue() 的 nickname 取值逻辑 ──
|
||||
#
|
||||
# 实际 SQL:
|
||||
# SELECT DATE(s.start_use_time) AS stat_date,
|
||||
# s.site_assistant_id AS assistant_id,
|
||||
# (ARRAY_AGG(s.nickname ORDER BY s.start_use_time DESC))[1] AS assistant_nickname,
|
||||
# ...
|
||||
# GROUP BY DATE(s.start_use_time), s.site_assistant_id
|
||||
|
||||
|
||||
def simulate_finance_daily_revenue(
|
||||
service_rows: List[Dict[str, Any]],
|
||||
site_id: int,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""模拟 AssistantFinanceTask._extract_daily_revenue() 的 GROUP BY 逻辑。
|
||||
|
||||
按 (stat_date, assistant_id) 分组,nickname 按 start_use_time 倒序取第一条。
|
||||
"""
|
||||
filtered = [
|
||||
r for r in service_rows
|
||||
if r.get("site_id") == site_id
|
||||
and start_date <= r["stat_date"] <= end_date
|
||||
and r.get("is_delete", 0) == 0
|
||||
]
|
||||
|
||||
groups: Dict[Tuple, List[Dict]] = defaultdict(list)
|
||||
for row in filtered:
|
||||
key = (row["stat_date"], row["assistant_id"])
|
||||
groups[key].append(row)
|
||||
|
||||
results = []
|
||||
for (stat_date, aid), rows in groups.items():
|
||||
# nickname 按 start_use_time 倒序取第一条
|
||||
sorted_rows = sorted(rows, key=lambda r: r["start_use_time"], reverse=True)
|
||||
results.append({
|
||||
"stat_date": stat_date,
|
||||
"assistant_id": aid,
|
||||
"assistant_nickname": sorted_rows[0].get("nickname", ""),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
# ── 模拟函数:CustomerTask._extract_service_pairs() 的 nickname 取值逻辑 ──
|
||||
#
|
||||
# 实际 SQL:
|
||||
# SELECT assistant_id,
|
||||
# (ARRAY_AGG(assistant_nickname ORDER BY service_date DESC))[1] AS assistant_nickname,
|
||||
# member_id, ...
|
||||
# GROUP BY assistant_id, member_id
|
||||
|
||||
|
||||
def simulate_customer_service_pairs(
|
||||
service_rows: List[Dict[str, Any]],
|
||||
site_id: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""模拟 AssistantCustomerTask._extract_service_pairs() 的 GROUP BY 逻辑。
|
||||
|
||||
按 (assistant_id, member_id) 分组,nickname 按 service_date 倒序取第一条。
|
||||
"""
|
||||
filtered = [
|
||||
r for r in service_rows
|
||||
if r.get("site_id") == site_id
|
||||
and r.get("member_id") is not None
|
||||
and r.get("member_id") != 0
|
||||
and r.get("is_delete", 0) == 0
|
||||
]
|
||||
|
||||
groups: Dict[Tuple, List[Dict]] = defaultdict(list)
|
||||
for row in filtered:
|
||||
key = (row["assistant_id"], row["member_id"])
|
||||
groups[key].append(row)
|
||||
|
||||
results = []
|
||||
for (aid, mid), rows in groups.items():
|
||||
sorted_rows = sorted(rows, key=lambda r: r["service_date"], reverse=True)
|
||||
results.append({
|
||||
"assistant_id": aid,
|
||||
"member_id": mid,
|
||||
"assistant_nickname": sorted_rows[0].get("assistant_nickname", ""),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
# ── Property 4 生成策略 ──────────────────────────────────────
|
||||
|
||||
# 确保同一分组内有多条不同 nickname 的记录
|
||||
_distinct_nickname_st = st.lists(
|
||||
st.text(
|
||||
alphabet=st.sampled_from("张李王赵刘陈杨黄周吴徐孙马朱胡林郭何高"),
|
||||
min_size=2,
|
||||
max_size=4,
|
||||
),
|
||||
min_size=2,
|
||||
max_size=6,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def _monthly_nickname_scenario_st(draw):
|
||||
"""生成 AssistantMonthlyTask 的 nickname 测试场景。
|
||||
|
||||
同一助教同一档位在同一月内有多条不同 nickname 的日度记录。
|
||||
验证聚合结果的 nickname == 时间最晚记录的 nickname。
|
||||
"""
|
||||
assistant_id = draw(_assistant_id_st)
|
||||
level_code = draw(_level_code_st)
|
||||
month = date(2025, 1, 1)
|
||||
|
||||
nicknames = draw(_distinct_nickname_st)
|
||||
month_start = month.replace(day=1)
|
||||
max_day = 31 # 2025-01 有 31 天
|
||||
|
||||
# 为每个 nickname 生成一条记录,确保 stat_date 各不相同以便确定性排序
|
||||
daily_rows: List[Dict[str, Any]] = []
|
||||
used_days: Set[int] = set()
|
||||
for nick in nicknames:
|
||||
day = draw(
|
||||
st.integers(min_value=1, max_value=max_day).filter(
|
||||
lambda d, _used=frozenset(used_days): d not in _used
|
||||
)
|
||||
)
|
||||
used_days.add(day)
|
||||
row = {
|
||||
"site_id": 1,
|
||||
"assistant_id": assistant_id,
|
||||
"assistant_level_code": level_code,
|
||||
"assistant_level_name": _LEVEL_NAME_MAP[level_code],
|
||||
"assistant_nickname": nick,
|
||||
"stat_date": month_start.replace(day=day),
|
||||
# 业绩指标用固定值,本测试不关心
|
||||
**{f: 1 for f in _METRIC_FIELDS},
|
||||
}
|
||||
daily_rows.append(row)
|
||||
|
||||
# 期望的 nickname = stat_date 最大的那条
|
||||
expected_nickname = max(daily_rows, key=lambda r: r["stat_date"])["assistant_nickname"]
|
||||
|
||||
return {
|
||||
"daily_rows": daily_rows,
|
||||
"assistant_id": assistant_id,
|
||||
"level_code": level_code,
|
||||
"month": month,
|
||||
"expected_nickname": expected_nickname,
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def _finance_nickname_scenario_st(draw):
|
||||
"""生成 AssistantFinanceTask 的 nickname 测试场景。
|
||||
|
||||
同一助教同一天有多条不同 nickname 的服务记录(不同 start_use_time)。
|
||||
验证聚合结果的 nickname == start_use_time 最晚的那条。
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
assistant_id = draw(_assistant_id_st)
|
||||
stat_date = date(2025, 1, 15)
|
||||
nicknames = draw(_distinct_nickname_st)
|
||||
|
||||
service_rows: List[Dict[str, Any]] = []
|
||||
used_hours: Set[int] = set()
|
||||
for nick in nicknames:
|
||||
hour = draw(
|
||||
st.integers(min_value=0, max_value=23).filter(
|
||||
lambda h, _used=frozenset(used_hours): h not in _used
|
||||
)
|
||||
)
|
||||
used_hours.add(hour)
|
||||
minute = draw(st.integers(min_value=0, max_value=59))
|
||||
service_rows.append({
|
||||
"site_id": 1,
|
||||
"assistant_id": assistant_id,
|
||||
"nickname": nick,
|
||||
"stat_date": stat_date,
|
||||
"start_use_time": datetime(2025, 1, 15, hour, minute, 0),
|
||||
"is_delete": 0,
|
||||
})
|
||||
|
||||
expected_nickname = max(
|
||||
service_rows, key=lambda r: r["start_use_time"]
|
||||
)["nickname"]
|
||||
|
||||
return {
|
||||
"service_rows": service_rows,
|
||||
"assistant_id": assistant_id,
|
||||
"stat_date": stat_date,
|
||||
"expected_nickname": expected_nickname,
|
||||
}
|
||||
|
||||
|
||||
@st.composite
|
||||
def _customer_nickname_scenario_st(draw):
|
||||
"""生成 AssistantCustomerTask 的 nickname 测试场景。
|
||||
|
||||
同一助教同一会员有多条不同 nickname 的服务记录(不同 service_date)。
|
||||
验证聚合结果的 nickname == service_date 最晚的那条。
|
||||
"""
|
||||
assistant_id = draw(_assistant_id_st)
|
||||
member_id = draw(st.integers(min_value=1, max_value=9999))
|
||||
nicknames = draw(_distinct_nickname_st)
|
||||
|
||||
service_rows: List[Dict[str, Any]] = []
|
||||
used_days: Set[int] = set()
|
||||
for nick in nicknames:
|
||||
day = draw(
|
||||
st.integers(min_value=1, max_value=28).filter(
|
||||
lambda d, _used=frozenset(used_days): d not in _used
|
||||
)
|
||||
)
|
||||
used_days.add(day)
|
||||
service_rows.append({
|
||||
"site_id": 1,
|
||||
"assistant_id": assistant_id,
|
||||
"member_id": member_id,
|
||||
"assistant_nickname": nick,
|
||||
"service_date": date(2025, 1, day),
|
||||
"is_delete": 0,
|
||||
})
|
||||
|
||||
expected_nickname = max(
|
||||
service_rows, key=lambda r: r["service_date"]
|
||||
)["assistant_nickname"]
|
||||
|
||||
return {
|
||||
"service_rows": service_rows,
|
||||
"assistant_id": assistant_id,
|
||||
"member_id": member_id,
|
||||
"expected_nickname": expected_nickname,
|
||||
}
|
||||
|
||||
|
||||
# ── Property 4 测试类 ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProperty4NicknameDescOrder:
|
||||
"""Property 4: nickname 按时间倒序取值。
|
||||
|
||||
对于任意助教在聚合周期内有多条不同 nickname 的记录,
|
||||
聚合结果中的 nickname 应等于时间最晚的那条记录的 nickname。
|
||||
"""
|
||||
|
||||
# ── 4a: AssistantMonthlyTask — 按 stat_date DESC 取 nickname ──
|
||||
|
||||
@given(scenario=_monthly_nickname_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_monthly_nickname_equals_latest_stat_date(self, scenario):
|
||||
"""AssistantMonthlyTask: nickname 应等于 stat_date 最晚记录的 nickname。"""
|
||||
results = simulate_extract_daily_aggregates(
|
||||
scenario["daily_rows"], site_id=1, month=scenario["month"],
|
||||
)
|
||||
# 定位到当前助教 + 档位的聚合行
|
||||
target = [
|
||||
r for r in results
|
||||
if r["assistant_id"] == scenario["assistant_id"]
|
||||
and r["assistant_level_code"] == scenario["level_code"]
|
||||
]
|
||||
assert len(target) == 1, f"期望 1 行聚合结果,实际 {len(target)} 行"
|
||||
assert target[0]["assistant_nickname"] == scenario["expected_nickname"], (
|
||||
f"MonthlyTask nickname 应为 '{scenario['expected_nickname']}',"
|
||||
f"实际为 '{target[0]['assistant_nickname']}'"
|
||||
)
|
||||
|
||||
# ── 4b: AssistantFinanceTask — 按 start_use_time DESC 取 nickname ──
|
||||
|
||||
@given(scenario=_finance_nickname_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_finance_nickname_equals_latest_start_use_time(self, scenario):
|
||||
"""AssistantFinanceTask: nickname 应等于 start_use_time 最晚记录的 nickname。"""
|
||||
results = simulate_finance_daily_revenue(
|
||||
scenario["service_rows"],
|
||||
site_id=1,
|
||||
start_date=scenario["stat_date"],
|
||||
end_date=scenario["stat_date"],
|
||||
)
|
||||
target = [
|
||||
r for r in results
|
||||
if r["assistant_id"] == scenario["assistant_id"]
|
||||
and r["stat_date"] == scenario["stat_date"]
|
||||
]
|
||||
assert len(target) == 1, f"期望 1 行聚合结果,实际 {len(target)} 行"
|
||||
assert target[0]["assistant_nickname"] == scenario["expected_nickname"], (
|
||||
f"FinanceTask nickname 应为 '{scenario['expected_nickname']}',"
|
||||
f"实际为 '{target[0]['assistant_nickname']}'"
|
||||
)
|
||||
|
||||
# ── 4c: AssistantCustomerTask — 按 service_date DESC 取 nickname ──
|
||||
|
||||
@given(scenario=_customer_nickname_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_customer_nickname_equals_latest_service_date(self, scenario):
|
||||
"""AssistantCustomerTask: nickname 应等于 service_date 最晚记录的 nickname。"""
|
||||
results = simulate_customer_service_pairs(
|
||||
scenario["service_rows"], site_id=1,
|
||||
)
|
||||
target = [
|
||||
r for r in results
|
||||
if r["assistant_id"] == scenario["assistant_id"]
|
||||
and r["member_id"] == scenario["member_id"]
|
||||
]
|
||||
assert len(target) == 1, f"期望 1 行聚合结果,实际 {len(target)} 行"
|
||||
assert target[0]["assistant_nickname"] == scenario["expected_nickname"], (
|
||||
f"CustomerTask nickname 应为 '{scenario['expected_nickname']}',"
|
||||
f"实际为 '{target[0]['assistant_nickname']}'"
|
||||
)
|
||||
|
||||
|
||||
# ── Property 5: 工资按档位分段计算 ─────────────────────────────
|
||||
# Feature: etl-aggregation-fix, Property 5: 工资按档位分段计算
|
||||
# **Validates: Requirements 2.4**
|
||||
#
|
||||
# 对于任意助教在同一月有多个档位的月度汇总记录,
|
||||
# AssistantSalaryTask 应为每个档位分别计算工资,
|
||||
# 每个档位使用对应的 level_price 和 tier 配置,
|
||||
# 且所有档位的工资记录数等于月度汇总的行数。
|
||||
|
||||
|
||||
# ── 模拟 AssistantSalaryTask.transform() 的工资计算逻辑 ──
|
||||
#
|
||||
# 实际流程:
|
||||
# 1. extract() 从 dws_assistant_monthly_summary 取出多行(同一助教不同档位)
|
||||
# 2. transform() 遍历每行,调用 _calculate_salary()
|
||||
# 3. _calculate_salary() 按行的 assistant_level_code 获取 level_price 和 tier
|
||||
# 4. 每行独立计算工资,生成一条工资记录
|
||||
#
|
||||
# 本测试不模拟完整工资公式,只验证"分段"行为:
|
||||
# - 工资记录数 == 月度汇总行数
|
||||
# - 每条工资记录使用的是对应档位的 level_price
|
||||
|
||||
|
||||
# 等级定价配置(模拟 cfg_assistant_level_price)
|
||||
_LEVEL_PRICE_CONFIG = {
|
||||
"A": {"base_course_price": Decimal("98"), "bonus_course_price": Decimal("190")},
|
||||
"B": {"base_course_price": Decimal("108"), "bonus_course_price": Decimal("190")},
|
||||
"C": {"base_course_price": Decimal("118"), "bonus_course_price": Decimal("190")},
|
||||
"D": {"base_course_price": Decimal("128"), "bonus_course_price": Decimal("190")},
|
||||
"E": {"base_course_price": Decimal("138"), "bonus_course_price": Decimal("190")},
|
||||
}
|
||||
|
||||
# 档位配置(模拟 cfg_performance_tier)
|
||||
_TIER_CONFIG = {
|
||||
1: {"tier_code": "T1", "tier_name": "1档", "base_deduction": Decimal("18"), "bonus_deduction_ratio": Decimal("0.40")},
|
||||
2: {"tier_code": "T2", "tier_name": "2档", "base_deduction": Decimal("15"), "bonus_deduction_ratio": Decimal("0.38")},
|
||||
3: {"tier_code": "T3", "tier_name": "3档", "base_deduction": Decimal("13"), "bonus_deduction_ratio": Decimal("0.35")},
|
||||
}
|
||||
|
||||
|
||||
def _simulate_get_level_price(level_code: str) -> Dict[str, Any]:
|
||||
"""模拟 get_level_price():按档位代码返回等级定价。"""
|
||||
return _LEVEL_PRICE_CONFIG.get(level_code, _LEVEL_PRICE_CONFIG["A"])
|
||||
|
||||
|
||||
def _simulate_get_tier(tier_id: int) -> Dict[str, Any]:
|
||||
"""模拟 get_performance_tier_by_id():按 tier_id 返回档位配置。"""
|
||||
return _TIER_CONFIG.get(tier_id, _TIER_CONFIG[1])
|
||||
|
||||
|
||||
def simulate_salary_transform(
|
||||
monthly_summary: List[Dict[str, Any]],
|
||||
site_id: int,
|
||||
salary_month: date,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""模拟 AssistantSalaryTask.transform() 的核心逻辑。
|
||||
|
||||
遍历每条月度汇总记录,按档位独立计算工资。
|
||||
简化版:只计算课时收入部分,足以验证"分段"行为。
|
||||
"""
|
||||
results = []
|
||||
for summary in monthly_summary:
|
||||
level_code = summary.get("assistant_level_code")
|
||||
tier_id = summary.get("tier_id", 1)
|
||||
|
||||
# 按档位获取定价和档位配置
|
||||
level_price = _simulate_get_level_price(level_code)
|
||||
tier = _simulate_get_tier(tier_id)
|
||||
|
||||
base_hours = Decimal(str(summary.get("base_hours", 0)))
|
||||
bonus_hours = Decimal(str(summary.get("bonus_hours", 0)))
|
||||
|
||||
base_course_price = level_price["base_course_price"]
|
||||
bonus_course_price = level_price["bonus_course_price"]
|
||||
base_deduction = tier["base_deduction"]
|
||||
bonus_deduction_ratio = tier["bonus_deduction_ratio"]
|
||||
|
||||
# 基础课收入 = 基础课小时数 × (客户支付价格 - 专业课抽成)
|
||||
base_income = base_hours * (base_course_price - base_deduction)
|
||||
# 附加课收入 = 附加课小时数 × 附加课价格 × (1 - 打赏课抽成比例)
|
||||
bonus_income = bonus_hours * bonus_course_price * (Decimal("1") - bonus_deduction_ratio)
|
||||
|
||||
results.append({
|
||||
"site_id": site_id,
|
||||
"assistant_id": summary["assistant_id"],
|
||||
"salary_month": salary_month,
|
||||
"assistant_level_code": level_code,
|
||||
"base_course_price": base_course_price,
|
||||
"bonus_course_price": bonus_course_price,
|
||||
"base_deduction": base_deduction,
|
||||
"bonus_deduction_ratio": bonus_deduction_ratio,
|
||||
"tier_id": tier_id,
|
||||
"base_income": base_income,
|
||||
"bonus_income": bonus_income,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── Property 5 生成策略 ──────────────────────────────────────
|
||||
|
||||
# 档位 tier_id
|
||||
_tier_id_st = st.sampled_from([1, 2, 3])
|
||||
|
||||
# 小时数(Decimal 精度)
|
||||
_hours_st = st.decimals(
|
||||
min_value=0,
|
||||
max_value=Decimal("300"),
|
||||
places=2,
|
||||
allow_nan=False,
|
||||
allow_infinity=False,
|
||||
)
|
||||
|
||||
|
||||
@st.composite
|
||||
def _salary_multi_level_scenario_st(draw):
|
||||
"""生成同一助教在同一月有多个档位的月度汇总记录。
|
||||
|
||||
返回:
|
||||
- monthly_summary: 月度汇总行列表(每行一个档位)
|
||||
- assistant_id: 助教 ID
|
||||
- salary_month: 工资月份
|
||||
- level_codes: 档位代码列表
|
||||
"""
|
||||
assistant_id = draw(_assistant_id_st)
|
||||
salary_month = date(2025, 1, 1)
|
||||
|
||||
# 随机选取 2~4 个不同档位(至少 2 个才能验证"分段")
|
||||
n_levels = draw(st.integers(min_value=2, max_value=4))
|
||||
level_codes = draw(
|
||||
st.lists(
|
||||
_level_code_st,
|
||||
min_size=n_levels,
|
||||
max_size=n_levels,
|
||||
unique=True,
|
||||
)
|
||||
)
|
||||
|
||||
monthly_summary: List[Dict[str, Any]] = []
|
||||
for code in level_codes:
|
||||
tier_id = draw(_tier_id_st)
|
||||
monthly_summary.append({
|
||||
"assistant_id": assistant_id,
|
||||
"assistant_nickname": f"助教{assistant_id}",
|
||||
"stat_month": salary_month,
|
||||
"assistant_level_code": code,
|
||||
"assistant_level_name": _LEVEL_NAME_MAP[code],
|
||||
"hire_date": date(2024, 6, 1),
|
||||
"is_new_hire": False,
|
||||
"effective_hours": float(draw(_hours_st)),
|
||||
"base_hours": float(draw(_hours_st)),
|
||||
"bonus_hours": float(draw(_hours_st)),
|
||||
"room_hours": float(draw(_hours_st)),
|
||||
"tier_id": tier_id,
|
||||
"tier_code": f"T{tier_id}",
|
||||
"tier_name": f"{tier_id}档",
|
||||
"rank_with_ties": draw(st.integers(min_value=1, max_value=10)),
|
||||
})
|
||||
|
||||
return {
|
||||
"monthly_summary": monthly_summary,
|
||||
"assistant_id": assistant_id,
|
||||
"salary_month": salary_month,
|
||||
"level_codes": level_codes,
|
||||
}
|
||||
|
||||
|
||||
# ── Property 5 测试类 ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProperty5SalaryPerLevelSegment:
|
||||
"""Property 5: 工资按档位分段计算。
|
||||
|
||||
对于任意助教在同一月有多个档位的月度汇总记录,
|
||||
AssistantSalaryTask 应为每个档位分别计算工资,
|
||||
每个档位使用对应的 level_price 和 tier 配置,
|
||||
且所有档位的工资记录数等于月度汇总的行数。
|
||||
"""
|
||||
|
||||
@given(scenario=_salary_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_salary_record_count_equals_summary_rows(self, scenario):
|
||||
"""工资记录数应等于月度汇总行数(每个档位一条)。"""
|
||||
results = simulate_salary_transform(
|
||||
scenario["monthly_summary"],
|
||||
site_id=1,
|
||||
salary_month=scenario["salary_month"],
|
||||
)
|
||||
|
||||
assert len(results) == len(scenario["monthly_summary"]), (
|
||||
f"月度汇总 {len(scenario['monthly_summary'])} 行,"
|
||||
f"但工资记录 {len(results)} 条"
|
||||
)
|
||||
|
||||
@given(scenario=_salary_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_each_salary_uses_correct_level_price(self, scenario):
|
||||
"""每条工资记录应使用对应档位的 level_price。"""
|
||||
results = simulate_salary_transform(
|
||||
scenario["monthly_summary"],
|
||||
site_id=1,
|
||||
salary_month=scenario["salary_month"],
|
||||
)
|
||||
|
||||
for record in results:
|
||||
level_code = record["assistant_level_code"]
|
||||
expected_price = _LEVEL_PRICE_CONFIG[level_code]
|
||||
|
||||
assert record["base_course_price"] == expected_price["base_course_price"], (
|
||||
f"档位 {level_code}: base_course_price 应为 {expected_price['base_course_price']},"
|
||||
f"实际为 {record['base_course_price']}"
|
||||
)
|
||||
assert record["bonus_course_price"] == expected_price["bonus_course_price"], (
|
||||
f"档位 {level_code}: bonus_course_price 应为 {expected_price['bonus_course_price']},"
|
||||
f"实际为 {record['bonus_course_price']}"
|
||||
)
|
||||
|
||||
@given(scenario=_salary_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_each_salary_uses_correct_tier_config(self, scenario):
|
||||
"""每条工资记录应使用对应 tier_id 的档位配置。"""
|
||||
results = simulate_salary_transform(
|
||||
scenario["monthly_summary"],
|
||||
site_id=1,
|
||||
salary_month=scenario["salary_month"],
|
||||
)
|
||||
|
||||
for i, record in enumerate(results):
|
||||
tier_id = scenario["monthly_summary"][i]["tier_id"]
|
||||
expected_tier = _TIER_CONFIG[tier_id]
|
||||
|
||||
assert record["base_deduction"] == expected_tier["base_deduction"], (
|
||||
f"tier_id={tier_id}: base_deduction 应为 {expected_tier['base_deduction']},"
|
||||
f"实际为 {record['base_deduction']}"
|
||||
)
|
||||
assert record["bonus_deduction_ratio"] == expected_tier["bonus_deduction_ratio"], (
|
||||
f"tier_id={tier_id}: bonus_deduction_ratio 应为 {expected_tier['bonus_deduction_ratio']},"
|
||||
f"实际为 {record['bonus_deduction_ratio']}"
|
||||
)
|
||||
|
||||
@given(scenario=_salary_multi_level_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_salary_level_codes_match_summary(self, scenario):
|
||||
"""工资记录的档位集合应与月度汇总的档位集合一致。"""
|
||||
results = simulate_salary_transform(
|
||||
scenario["monthly_summary"],
|
||||
site_id=1,
|
||||
salary_month=scenario["salary_month"],
|
||||
)
|
||||
|
||||
result_codes = {r["assistant_level_code"] for r in results}
|
||||
expected_codes = set(scenario["level_codes"])
|
||||
|
||||
assert result_codes == expected_codes, (
|
||||
f"工资记录档位 {result_codes} != 月度汇总档位 {expected_codes}"
|
||||
)
|
||||
@@ -0,0 +1,291 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""跨店会员可查属性测试 — 使用 hypothesis 验证多门店场景下会员维度信息的可达性。
|
||||
|
||||
纯单元测试,使用 FakeDBOperations 模拟数据库查询,不涉及真实数据库连接。
|
||||
|
||||
# Feature: etl-aggregation-fix, Property 6: 跨店会员可查
|
||||
# **Validates: Requirements 3.1, 3.2**
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
|
||||
# ── 核心逻辑:模拟"通过事实表反查 dim_member"的查询行为 ──────────
|
||||
#
|
||||
# 实际 SQL 模式:
|
||||
# SELECT ... FROM dwd.dim_member
|
||||
# WHERE member_id IN (
|
||||
# SELECT DISTINCT tenant_member_id
|
||||
# FROM dwd.{fact_table}
|
||||
# WHERE site_id = %s AND tenant_member_id IS NOT NULL AND tenant_member_id != 0
|
||||
# ) AND scd2_is_current = 1
|
||||
#
|
||||
# 属性要验证的是:无论会员在哪个门店注册,只要在目标门店有消费记录(事实表中存在),
|
||||
# 该查询模式就能返回其维度信息。
|
||||
|
||||
|
||||
def simulate_fact_table_lookup(
|
||||
fact_rows: List[Dict[str, Any]],
|
||||
dim_members: List[Dict[str, Any]],
|
||||
query_site_id: int,
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
"""模拟 DWS 任务通过事实表反查 dim_member 的逻辑。
|
||||
|
||||
等价于:
|
||||
SELECT * FROM dim_member
|
||||
WHERE member_id IN (
|
||||
SELECT DISTINCT tenant_member_id FROM fact_table
|
||||
WHERE site_id = :query_site_id
|
||||
AND tenant_member_id IS NOT NULL AND tenant_member_id != 0
|
||||
) AND scd2_is_current = 1
|
||||
"""
|
||||
# 步骤 1:从事实表中提取目标门店的 member_id 集合
|
||||
member_ids_in_fact: Set[int] = set()
|
||||
for row in fact_rows:
|
||||
if (
|
||||
row.get("site_id") == query_site_id
|
||||
and row.get("tenant_member_id") is not None
|
||||
and row.get("tenant_member_id") != 0
|
||||
):
|
||||
member_ids_in_fact.add(row["tenant_member_id"])
|
||||
|
||||
# 步骤 2:从 dim_member 中筛选当前有效版本
|
||||
result: Dict[int, Dict[str, Any]] = {}
|
||||
for m in dim_members:
|
||||
if m.get("scd2_is_current") == 1 and m["member_id"] in member_ids_in_fact:
|
||||
result[m["member_id"]] = dict(m)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simulate_register_site_lookup(
|
||||
dim_members: List[Dict[str, Any]],
|
||||
query_site_id: int,
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
"""模拟旧的 WHERE register_site_id = %s 查询模式(用于对比)。"""
|
||||
result: Dict[int, Dict[str, Any]] = {}
|
||||
for m in dim_members:
|
||||
if m.get("scd2_is_current") == 1 and m.get("register_site_id") == query_site_id:
|
||||
result[m["member_id"]] = dict(m)
|
||||
return result
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
# 门店 ID:2~5 个门店
|
||||
_site_id_st = st.integers(min_value=1, max_value=5)
|
||||
|
||||
# 会员 ID:正整数
|
||||
_member_id_st = st.integers(min_value=1, max_value=200)
|
||||
|
||||
# 昵称 / 手机号
|
||||
_nickname_st = st.text(
|
||||
alphabet=st.sampled_from("张李王赵刘陈杨黄周吴"),
|
||||
min_size=2,
|
||||
max_size=4,
|
||||
)
|
||||
_mobile_st = st.from_regex(r"1[3-9]\d{9}", fullmatch=True)
|
||||
|
||||
|
||||
# 生成 dim_member 行
|
||||
@st.composite
|
||||
def _dim_member_st(draw):
|
||||
mid = draw(_member_id_st)
|
||||
return {
|
||||
"member_id": mid,
|
||||
"register_site_id": draw(_site_id_st),
|
||||
"nickname": draw(_nickname_st),
|
||||
"mobile": draw(_mobile_st),
|
||||
"scd2_is_current": 1, # 只生成当前有效版本
|
||||
}
|
||||
|
||||
|
||||
# 生成事实表行(会员在某门店的消费记录)
|
||||
@st.composite
|
||||
def _fact_row_st(draw, member_ids):
|
||||
"""生成一条事实表记录,member_id 从已有会员中选取。"""
|
||||
mid = draw(st.sampled_from(member_ids)) if member_ids else draw(_member_id_st)
|
||||
return {
|
||||
"site_id": draw(_site_id_st),
|
||||
"tenant_member_id": mid,
|
||||
}
|
||||
|
||||
|
||||
# 生成完整的跨店场景
|
||||
@st.composite
|
||||
def _cross_store_scenario_st(draw):
|
||||
"""生成一个包含跨店消费会员的完整场景。
|
||||
|
||||
保证至少有一个会员在非注册门店有消费记录。
|
||||
"""
|
||||
# 生成 2~10 个会员
|
||||
members = draw(st.lists(_dim_member_st(), min_size=2, max_size=10, unique_by=lambda m: m["member_id"]))
|
||||
if not members:
|
||||
return None
|
||||
|
||||
member_ids = [m["member_id"] for m in members]
|
||||
|
||||
# 生成事实表记录(5~30 条)
|
||||
fact_rows = draw(st.lists(
|
||||
_fact_row_st(member_ids),
|
||||
min_size=5,
|
||||
max_size=30,
|
||||
))
|
||||
|
||||
# 强制注入至少一条跨店消费记录:
|
||||
# 选一个会员,在非注册门店生成消费
|
||||
cross_member = draw(st.sampled_from(members))
|
||||
other_sites = [s for s in range(1, 6) if s != cross_member["register_site_id"]]
|
||||
if other_sites:
|
||||
cross_site = draw(st.sampled_from(other_sites))
|
||||
fact_rows.append({
|
||||
"site_id": cross_site,
|
||||
"tenant_member_id": cross_member["member_id"],
|
||||
})
|
||||
else:
|
||||
cross_site = cross_member["register_site_id"]
|
||||
|
||||
query_site_id = cross_site
|
||||
|
||||
return {
|
||||
"members": members,
|
||||
"fact_rows": fact_rows,
|
||||
"query_site_id": query_site_id,
|
||||
"cross_member_id": cross_member["member_id"],
|
||||
"cross_member_register_site": cross_member["register_site_id"],
|
||||
}
|
||||
|
||||
|
||||
# ── Property 6: 跨店会员可查 ─────────────────────────────────
|
||||
# Feature: etl-aggregation-fix, Property 6: 跨店会员可查
|
||||
# **Validates: Requirements 3.1, 3.2**
|
||||
#
|
||||
# 对于任意在 A 店注册但在 B 店有消费记录的会员,
|
||||
# B 店的 DWS 任务通过事实表反查 dim_member 时,
|
||||
# 应能获取到该会员的维度信息(nickname、mobile 等)。
|
||||
|
||||
|
||||
class TestProperty6CrossStoreMemberVisible:
|
||||
"""Property 6: 跨店会员可查。"""
|
||||
|
||||
@given(scenario=_cross_store_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_cross_store_member_found_via_fact_lookup(self, scenario):
|
||||
"""在 A 店注册、B 店消费的会员,B 店 DWS 通过事实表反查应能找到。"""
|
||||
if scenario is None:
|
||||
return
|
||||
|
||||
members = scenario["members"]
|
||||
fact_rows = scenario["fact_rows"]
|
||||
query_site_id = scenario["query_site_id"]
|
||||
cross_member_id = scenario["cross_member_id"]
|
||||
|
||||
result = simulate_fact_table_lookup(fact_rows, members, query_site_id)
|
||||
|
||||
# 核心断言:跨店会员必须可查
|
||||
assert cross_member_id in result, (
|
||||
f"会员 {cross_member_id}(注册于门店 {scenario['cross_member_register_site']})"
|
||||
f"在门店 {query_site_id} 有消费记录,但事实表反查未找到"
|
||||
)
|
||||
|
||||
# 维度信息完整性
|
||||
member_info = result[cross_member_id]
|
||||
assert "nickname" in member_info, "缺少 nickname 维度信息"
|
||||
assert "mobile" in member_info, "缺少 mobile 维度信息"
|
||||
|
||||
@given(scenario=_cross_store_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_old_register_site_pattern_misses_cross_store(self, scenario):
|
||||
"""旧模式 WHERE register_site_id = %s 会遗漏跨店消费会员(反向验证)。"""
|
||||
if scenario is None:
|
||||
return
|
||||
|
||||
members = scenario["members"]
|
||||
query_site_id = scenario["query_site_id"]
|
||||
cross_member_id = scenario["cross_member_id"]
|
||||
cross_register_site = scenario["cross_member_register_site"]
|
||||
|
||||
# 跨店会员的注册门店 != 查询门店
|
||||
if cross_register_site == query_site_id:
|
||||
return # 非跨店场景,跳过
|
||||
|
||||
old_result = simulate_register_site_lookup(members, query_site_id)
|
||||
|
||||
# 旧模式应该找不到跨店会员
|
||||
assert cross_member_id not in old_result, (
|
||||
f"旧模式 register_site_id={query_site_id} 不应找到"
|
||||
f"注册于门店 {cross_register_site} 的会员 {cross_member_id}"
|
||||
)
|
||||
|
||||
@given(scenario=_cross_store_scenario_st())
|
||||
@settings(max_examples=200)
|
||||
def test_fact_lookup_superset_of_register_lookup(self, scenario):
|
||||
"""事实表反查的结果应是 register_site_id 筛选结果的超集。
|
||||
|
||||
即:旧模式能找到的会员,新模式也一定能找到。
|
||||
"""
|
||||
if scenario is None:
|
||||
return
|
||||
|
||||
members = scenario["members"]
|
||||
fact_rows = scenario["fact_rows"]
|
||||
query_site_id = scenario["query_site_id"]
|
||||
|
||||
new_result = simulate_fact_table_lookup(fact_rows, members, query_site_id)
|
||||
old_result = simulate_register_site_lookup(members, query_site_id)
|
||||
|
||||
# 旧模式找到的、且在事实表中有消费记录的会员,新模式也应找到
|
||||
fact_member_ids = {
|
||||
r["tenant_member_id"]
|
||||
for r in fact_rows
|
||||
if r["site_id"] == query_site_id
|
||||
and r["tenant_member_id"] is not None
|
||||
and r["tenant_member_id"] != 0
|
||||
}
|
||||
for mid in old_result:
|
||||
if mid in fact_member_ids:
|
||||
assert mid in new_result, (
|
||||
f"会员 {mid} 在旧模式中可查且在事实表中有记录,"
|
||||
f"但新模式未返回"
|
||||
)
|
||||
|
||||
@given(
|
||||
members=st.lists(_dim_member_st(), min_size=1, max_size=10, unique_by=lambda m: m["member_id"]),
|
||||
query_site_id=_site_id_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_no_fact_records_returns_empty(self, members, query_site_id):
|
||||
"""事实表中无该门店消费记录时,反查结果应为空。"""
|
||||
# 空事实表
|
||||
result = simulate_fact_table_lookup([], members, query_site_id)
|
||||
assert len(result) == 0, (
|
||||
f"事实表为空时,门店 {query_site_id} 的反查结果应为空,"
|
||||
f"实际返回 {len(result)} 条"
|
||||
)
|
||||
|
||||
@given(scenario=_cross_store_scenario_st())
|
||||
@settings(max_examples=100)
|
||||
def test_null_and_zero_member_ids_excluded(self, scenario):
|
||||
"""事实表中 tenant_member_id 为 NULL 或 0 的记录不应参与反查。"""
|
||||
if scenario is None:
|
||||
return
|
||||
|
||||
members = scenario["members"]
|
||||
query_site_id = scenario["query_site_id"]
|
||||
|
||||
# 构造仅包含 NULL 和 0 的事实表
|
||||
bad_fact_rows = [
|
||||
{"site_id": query_site_id, "tenant_member_id": None},
|
||||
{"site_id": query_site_id, "tenant_member_id": 0},
|
||||
]
|
||||
|
||||
result = simulate_fact_table_lookup(bad_fact_rows, members, query_site_id)
|
||||
assert len(result) == 0, (
|
||||
f"tenant_member_id 为 NULL/0 时不应返回任何会员,"
|
||||
f"实际返回 {len(result)} 条"
|
||||
)
|
||||
@@ -0,0 +1,270 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ODS member_profiles birthday 字段提取验证。
|
||||
|
||||
任务 3.2: 确认 ODS 入库逻辑能从 JSON payload 中提取 birthday 字段。
|
||||
ODS_MEMBER 任务使用 schema-aware 动态入库(_insert_records_schema_aware),
|
||||
当 ods.member_profiles 表包含 birthday 列时,会自动从 API JSON 中按列名匹配提取。
|
||||
本测试通过 FakeDB 模拟包含 birthday 列的表结构,验证提取行为。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.ods.ods_tasks import ODS_TASK_CLASSES, ODS_TASK_SPECS, BaseOdsTask
|
||||
from .task_test_utils import (
|
||||
create_test_config,
|
||||
FakeDBOperations,
|
||||
FakeAPIClient,
|
||||
FakeCursor,
|
||||
FakeConnection,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 常量
|
||||
# ---------------------------------------------------------------------------
|
||||
_MEMBER_CODE = "ODS_MEMBER"
|
||||
_MEMBER_TABLE = "ods.member_profiles"
|
||||
_MEMBER_ENDPOINT = "/MemberProfile/GetTenantMemberList"
|
||||
|
||||
# member_profiles 表列定义(含 birthday DATE 列 — 迁移 C1 后的状态)
|
||||
_MEMBER_COLUMNS_WITH_BIRTHDAY = [
|
||||
("id", "bigint", "int8"),
|
||||
("record_index", "integer", "int4"),
|
||||
("content_hash", "text", "text"),
|
||||
("payload", "jsonb", "jsonb"),
|
||||
("birthday", "date", "date"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FakeDB:支持自定义列定义
|
||||
# ---------------------------------------------------------------------------
|
||||
class _MemberFakeCursor(FakeCursor):
|
||||
"""扩展 FakeCursor,支持自定义列定义和 PK 查询。"""
|
||||
|
||||
def __init__(self, recorder, db_ops=None, columns_map=None, pk_map=None):
|
||||
super().__init__(recorder, db_ops)
|
||||
self._columns_map = columns_map or {}
|
||||
self._pk_map = pk_map or {}
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
|
||||
self.recorder.append({"sql": sql_text.strip(), "params": params})
|
||||
self._fetchall_rows = []
|
||||
|
||||
lowered = sql_text.lower()
|
||||
|
||||
# 拦截 information_schema.columns 查询 → 返回自定义列定义
|
||||
if "from information_schema.columns" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
full_table = f"{params[0]}.{params[1]}" if params and len(params) >= 2 else None
|
||||
if full_table and full_table in self._columns_map:
|
||||
self._fetchall_rows = list(self._columns_map[full_table])
|
||||
elif table_name and table_name in self._columns_map:
|
||||
self._fetchall_rows = list(self._columns_map[table_name])
|
||||
else:
|
||||
self._fetchall_rows = self._fake_columns(table_name)
|
||||
return
|
||||
|
||||
# 拦截 PK 约束查询
|
||||
if "from information_schema.table_constraints" in lowered:
|
||||
table_name = None
|
||||
if params and len(params) >= 2:
|
||||
table_name = params[1]
|
||||
full_table = f"{params[0]}.{params[1]}" if params and len(params) >= 2 else None
|
||||
if full_table and full_table in self._pk_map:
|
||||
self._fetchall_rows = [(col,) for col in self._pk_map[full_table]]
|
||||
elif table_name and table_name in self._pk_map:
|
||||
self._fetchall_rows = [(col,) for col in self._pk_map[table_name]]
|
||||
else:
|
||||
self._fetchall_rows = [("id",)]
|
||||
return
|
||||
|
||||
# 拦截 content_hash 查询
|
||||
if "distinct on" in lowered and "content_hash" in lowered:
|
||||
self._fetchall_rows = []
|
||||
self._pending_rows = []
|
||||
return
|
||||
|
||||
# 默认:处理 INSERT 等语句
|
||||
if self._pending_rows:
|
||||
self.rowcount = len(self._pending_rows)
|
||||
self._record_upserts(sql_text)
|
||||
if "returning" in lowered:
|
||||
self._fetchall_rows = [(True,)] * len(self._pending_rows)
|
||||
self._pending_rows = []
|
||||
else:
|
||||
self.rowcount = 0
|
||||
|
||||
|
||||
class _MemberFakeConnection(FakeConnection):
|
||||
def __init__(self, db_ops, columns_map=None, pk_map=None):
|
||||
super().__init__(db_ops)
|
||||
self._columns_map = columns_map or {}
|
||||
self._pk_map = pk_map or {}
|
||||
|
||||
def cursor(self):
|
||||
return _MemberFakeCursor(
|
||||
self.statements, self._db_ops,
|
||||
columns_map=self._columns_map,
|
||||
pk_map=self._pk_map,
|
||||
)
|
||||
|
||||
|
||||
class _MemberFakeDB(FakeDBOperations):
|
||||
def __init__(self, columns_map=None, pk_map=None):
|
||||
super().__init__()
|
||||
self._columns_map = columns_map or {}
|
||||
self._pk_map = pk_map or {}
|
||||
self.conn = _MemberFakeConnection(
|
||||
self, self._columns_map, self._pk_map,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
def _build_config(tmp_path):
|
||||
archive_dir = tmp_path / "archive"
|
||||
temp_dir = tmp_path / "temp"
|
||||
cfg = create_test_config("ONLINE", archive_dir, temp_dir)
|
||||
cfg.config.setdefault("run", {})
|
||||
cfg.config["run"]["ods_conflict_mode"] = "update"
|
||||
cfg.config["run"]["snapshot_missing_delete"] = False
|
||||
return cfg
|
||||
|
||||
|
||||
def _get_member_spec():
|
||||
for spec in ODS_TASK_SPECS:
|
||||
if spec.code == _MEMBER_CODE:
|
||||
return spec
|
||||
raise KeyError(f"未找到任务 spec: {_MEMBER_CODE}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestOdsBirthdayExtraction:
|
||||
"""验证 ODS_MEMBER 任务在表包含 birthday 列时能正确提取该字段。"""
|
||||
|
||||
def test_birthday_extracted_from_json_payload(self, tmp_path):
|
||||
"""当 ods.member_profiles 包含 birthday 列时,
|
||||
API JSON 中的 birthday 值应被提取并写入对应列位置。
|
||||
"""
|
||||
spec = _get_member_spec()
|
||||
task_cls = ODS_TASK_CLASSES[_MEMBER_CODE]
|
||||
|
||||
config = _build_config(tmp_path)
|
||||
db = _MemberFakeDB(
|
||||
columns_map={_MEMBER_TABLE: _MEMBER_COLUMNS_WITH_BIRTHDAY},
|
||||
pk_map={_MEMBER_TABLE: ["id", "content_hash"]},
|
||||
)
|
||||
# 模拟 API 返回包含 birthday 的会员记录
|
||||
member_record = {
|
||||
"id": 12345,
|
||||
"birthday": "1990-05-15",
|
||||
"nickname": "测试会员",
|
||||
"mobile": "13800138000",
|
||||
}
|
||||
api = FakeAPIClient({spec.endpoint: [member_record]})
|
||||
logger = logging.getLogger("test_birthday")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
result = task.execute()
|
||||
|
||||
counts = result["counts"]
|
||||
assert counts["fetched"] == 1
|
||||
assert counts["inserted"] == 1
|
||||
|
||||
# 验证写入的 SQL 参数中包含 birthday 值
|
||||
insert_stmts = [
|
||||
s for s in db.conn.statements
|
||||
if "INSERT" in s.get("sql", "").upper() and "member_profiles" in s.get("sql", "")
|
||||
]
|
||||
assert len(insert_stmts) >= 1, "应有至少一条 INSERT 语句"
|
||||
|
||||
# INSERT SQL 应包含 birthday 列
|
||||
assert any("birthday" in s["sql"].lower() for s in insert_stmts), (
|
||||
"INSERT SQL 应包含 birthday 列"
|
||||
)
|
||||
|
||||
def test_birthday_null_when_not_in_json(self, tmp_path):
|
||||
"""当 API JSON 中不包含 birthday 字段时,
|
||||
对应列应写入 None。
|
||||
"""
|
||||
spec = _get_member_spec()
|
||||
task_cls = ODS_TASK_CLASSES[_MEMBER_CODE]
|
||||
|
||||
config = _build_config(tmp_path)
|
||||
db = _MemberFakeDB(
|
||||
columns_map={_MEMBER_TABLE: _MEMBER_COLUMNS_WITH_BIRTHDAY},
|
||||
pk_map={_MEMBER_TABLE: ["id", "content_hash"]},
|
||||
)
|
||||
# API 返回不含 birthday 的记录
|
||||
member_record = {
|
||||
"id": 67890,
|
||||
"nickname": "无生日会员",
|
||||
}
|
||||
api = FakeAPIClient({spec.endpoint: [member_record]})
|
||||
logger = logging.getLogger("test_birthday_null")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
result = task.execute()
|
||||
|
||||
counts = result["counts"]
|
||||
assert counts["fetched"] == 1
|
||||
assert counts["inserted"] == 1
|
||||
|
||||
def test_birthday_column_in_insert_sql(self, tmp_path):
|
||||
"""INSERT SQL 中应包含 birthday 列名。"""
|
||||
spec = _get_member_spec()
|
||||
task_cls = ODS_TASK_CLASSES[_MEMBER_CODE]
|
||||
|
||||
config = _build_config(tmp_path)
|
||||
db = _MemberFakeDB(
|
||||
columns_map={_MEMBER_TABLE: _MEMBER_COLUMNS_WITH_BIRTHDAY},
|
||||
pk_map={_MEMBER_TABLE: ["id", "content_hash"]},
|
||||
)
|
||||
member_record = {"id": 11111, "birthday": "2000-01-01"}
|
||||
api = FakeAPIClient({spec.endpoint: [member_record]})
|
||||
logger = logging.getLogger("test_birthday_sql")
|
||||
|
||||
task = task_cls(config, db, api, logger)
|
||||
task.execute()
|
||||
|
||||
# 找到 INSERT 语句,确认包含 birthday 列
|
||||
insert_stmts = [
|
||||
s for s in db.conn.statements
|
||||
if "INSERT" in s.get("sql", "").upper() and "member_profiles" in s.get("sql", "")
|
||||
]
|
||||
assert any("birthday" in s["sql"].lower() for s in insert_stmts), (
|
||||
"INSERT SQL 应包含 birthday 列,"
|
||||
f"实际 SQL: {[s['sql'][:200] for s in insert_stmts]}"
|
||||
)
|
||||
|
||||
def test_ods_member_spec_has_no_extra_columns(self):
|
||||
"""确认 ODS_MEMBER 任务没有 extra_columns(依赖 schema-aware 动态映射)。
|
||||
|
||||
这意味着 birthday 字段无需在 OdsTaskSpec 中显式声明,
|
||||
只要 DB 表有该列,就会自动从 JSON 提取。
|
||||
"""
|
||||
spec = _get_member_spec()
|
||||
assert spec.extra_columns == (), (
|
||||
"ODS_MEMBER 不应有 extra_columns,"
|
||||
"它依赖 schema-aware 动态入库自动映射所有 DB 列"
|
||||
)
|
||||
assert spec.table_name == "ods.member_profiles"
|
||||
assert spec.code == "ODS_MEMBER"
|
||||
@@ -126,36 +126,5 @@ def test_ods_settlement_records_ingest(tmp_path):
|
||||
assert '"orderTradeNo": 8001' in row["payload"]
|
||||
|
||||
|
||||
def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
|
||||
"""Ensure settlement tickets are fetched per payment relate_id and skip existing ones."""
|
||||
config = _build_config(tmp_path)
|
||||
ticket_payload = {"data": {"data": {"orderSettleId": 9001, "orderSettleNumber": "T001"}}}
|
||||
api = FakeAPIClient({"/Order/GetOrderSettleTicketNew": [ticket_payload]})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
# 第一次查询:已有的小票ID;第二次查询:支付关联ID
|
||||
db_ops.query_results = [
|
||||
[{"order_settle_id": 9002}],
|
||||
[
|
||||
{"order_settle_id": 9001},
|
||||
{"order_settle_id": 9002},
|
||||
{"order_settle_id": None},
|
||||
],
|
||||
]
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_settlement_ticket"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
counts = result["counts"]
|
||||
assert counts["fetched"] == 1
|
||||
assert counts["inserted"] == 1
|
||||
assert counts["updated"] == 0
|
||||
assert counts["skipped"] == 0
|
||||
assert '"orderSettleId": 9001' in db_ops.upserts[0]["rows"][0]["payload"]
|
||||
assert any(
|
||||
call["endpoint"] == "/Order/GetOrderSettleTicketNew"
|
||||
and call.get("params", {}).get("orderSettleId") == 9001
|
||||
for call in api.calls
|
||||
)
|
||||
|
||||
|
||||
@@ -93,27 +93,27 @@ class TestProperty5FlowNameToLayers:
|
||||
"""对于任意有效的 Flow 名称,FlowRunner 解析出的层列表应与
|
||||
FLOW_LAYERS 字典中的定义完全一致。"""
|
||||
|
||||
@given(pipeline=flow_name_st)
|
||||
@given(flow_name=flow_name_st)
|
||||
@settings(max_examples=100)
|
||||
def test_layers_match_flow_definition(self, pipeline):
|
||||
"""run() 返回的 layers 字段与 FLOW_LAYERS[pipeline] 完全一致。"""
|
||||
def test_layers_match_flow_definition(self, flow_name):
|
||||
"""run() 返回的 layers 字段与 FLOW_LAYERS[flow_name] 完全一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[pipeline]
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
|
||||
assert result["layers"] == expected_layers
|
||||
|
||||
@given(pipeline=flow_name_st)
|
||||
@given(flow_name=flow_name_st)
|
||||
@settings(max_examples=100)
|
||||
def test_resolve_tasks_called_with_correct_layers(self, pipeline):
|
||||
def test_resolve_tasks_called_with_correct_layers(self, flow_name):
|
||||
"""_resolve_tasks 接收的层列表与 FLOW_LAYERS 定义一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
@@ -124,12 +124,12 @@ class TestProperty5FlowNameToLayers:
|
||||
patch.object(runner, "_resolve_tasks", wraps=runner._resolve_tasks) as spy,
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[pipeline]
|
||||
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
|
||||
spy.assert_called_once_with(expected_layers)
|
||||
|
||||
|
||||
@@ -143,12 +143,12 @@ class TestProperty6ProcessingModeControlsFlow:
|
||||
校验流程执行当且仅当模式包含 verify。"""
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
flow_name=flow_name_st,
|
||||
mode=processing_mode_st,
|
||||
data_source=data_source_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_increment_executes_iff_mode_contains_increment(self, pipeline, mode, data_source):
|
||||
def test_increment_executes_iff_mode_contains_increment(self, flow_name, mode, data_source):
|
||||
"""增量 ETL(task_executor.run_tasks)执行当且仅当 mode 包含 'increment'。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
@@ -159,7 +159,7 @@ class TestProperty6ProcessingModeControlsFlow:
|
||||
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}),
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode=mode,
|
||||
data_source=data_source,
|
||||
)
|
||||
@@ -176,12 +176,12 @@ class TestProperty6ProcessingModeControlsFlow:
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
flow_name=flow_name_st,
|
||||
mode=processing_mode_st,
|
||||
data_source=data_source_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_verification_executes_iff_mode_contains_verify(self, pipeline, mode, data_source):
|
||||
def test_verification_executes_iff_mode_contains_verify(self, flow_name, mode, data_source):
|
||||
"""校验流程(_run_verification)执行当且仅当 mode 包含 'verify'。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = []
|
||||
@@ -192,7 +192,7 @@ class TestProperty6ProcessingModeControlsFlow:
|
||||
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}) as mock_verify,
|
||||
):
|
||||
runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode=mode,
|
||||
data_source=data_source,
|
||||
)
|
||||
@@ -215,38 +215,38 @@ class TestProperty6ProcessingModeControlsFlow:
|
||||
|
||||
class TestProperty7FlowSummaryCompleteness:
|
||||
"""对于任意一组任务执行结果,FlowRunner 返回的汇总字典应包含
|
||||
status/pipeline/layers/results 字段,且 results 长度等于实际执行的任务数。
|
||||
(返回字典中 pipeline 键名保留以兼容下游消费方)"""
|
||||
status/flow/layers/results 字段,且 results 长度等于实际执行的任务数。
|
||||
"""
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
flow_name=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_summary_has_required_fields(self, pipeline, task_results):
|
||||
"""返回字典必须包含 status、pipeline、layers、results、verification_summary。"""
|
||||
def test_summary_has_required_fields(self, flow_name, task_results):
|
||||
"""返回字典必须包含 status、flow、layers、results、verification_summary。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
runner = _make_runner(task_executor=executor)
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
required_keys = {"status", "pipeline", "layers", "results", "verification_summary"}
|
||||
required_keys = {"status", "flow", "layers", "results", "verification_summary"}
|
||||
assert required_keys.issubset(result.keys()), (
|
||||
f"缺少必要字段: {required_keys - result.keys()}"
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
flow_name=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_results_length_equals_executed_tasks(self, pipeline, task_results):
|
||||
def test_results_length_equals_executed_tasks(self, flow_name, task_results):
|
||||
"""results 列表长度等于 task_executor.run_tasks 返回的任务数。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
@@ -254,7 +254,7 @@ class TestProperty7FlowSummaryCompleteness:
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
@@ -264,11 +264,11 @@ class TestProperty7FlowSummaryCompleteness:
|
||||
)
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
flow_name=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_flow_and_layers_match_input(self, pipeline, task_results):
|
||||
def test_flow_and_layers_match_input(self, flow_name, task_results):
|
||||
"""返回的 flow 标识和 layers 字段与输入一致。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
@@ -276,20 +276,20 @@ class TestProperty7FlowSummaryCompleteness:
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
assert result["pipeline"] == pipeline
|
||||
assert result["layers"] == FlowRunner.FLOW_LAYERS[pipeline]
|
||||
assert result["flow"] == flow_name
|
||||
assert result["layers"] == FlowRunner.FLOW_LAYERS[flow_name]
|
||||
|
||||
@given(
|
||||
pipeline=flow_name_st,
|
||||
flow_name=flow_name_st,
|
||||
task_results=task_results_st,
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_increment_only_has_no_verification(self, pipeline, task_results):
|
||||
def test_increment_only_has_no_verification(self, flow_name, task_results):
|
||||
"""increment_only 模式下 verification_summary 应为 None。"""
|
||||
executor = MagicMock()
|
||||
executor.run_tasks.return_value = task_results
|
||||
@@ -297,7 +297,7 @@ class TestProperty7FlowSummaryCompleteness:
|
||||
|
||||
with patch(_TASK_LOGGER_PATH):
|
||||
result = runner.run(
|
||||
pipeline=pipeline,
|
||||
flow=flow_name,
|
||||
processing_mode="increment_only",
|
||||
data_source="offline",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,219 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DwdLoadTask 返回值格式属性测试 — 使用 hypothesis 验证返回值的通用正确性属性。
|
||||
|
||||
纯单元测试,使用 MagicMock 构造 DwdLoadTask 实例,不涉及真实数据库连接。
|
||||
|
||||
# Feature: etl-aggregation-fix, Property 1: DwdLoadTask 返回值格式一致性
|
||||
# **Validates: Requirements 1.1**
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
|
||||
|
||||
# ── 辅助:构造最小可用 DwdLoadTask 实例 ─────────────────────
|
||||
|
||||
def _make_task() -> DwdLoadTask:
|
||||
config = MagicMock()
|
||||
config.get = MagicMock(side_effect=lambda key, default=None: default)
|
||||
|
||||
fake_conn = MagicMock()
|
||||
fake_cursor = MagicMock()
|
||||
fake_cursor.__enter__ = MagicMock(return_value=fake_cursor)
|
||||
fake_cursor.__exit__ = MagicMock(return_value=False)
|
||||
fake_conn.cursor.return_value = fake_cursor
|
||||
|
||||
db = MagicMock()
|
||||
db.conn = fake_conn
|
||||
api = MagicMock()
|
||||
logger = logging.getLogger("test_return_format_pbt")
|
||||
return DwdLoadTask(config, db, api, logger)
|
||||
|
||||
|
||||
# ── 生成策略 ──────────────────────────────────────────────────
|
||||
|
||||
# 生成 TABLE_MAP:1~8 张表,每张表随机决定是否失败
|
||||
_table_name_st = st.text(
|
||||
alphabet=st.sampled_from("abcdefghijklmnopqrstuvwxyz_"),
|
||||
min_size=3, max_size=12,
|
||||
)
|
||||
|
||||
# 生成一组 (dwd_table, ods_table, should_fail) 三元组
|
||||
_table_entry_st = st.tuples(
|
||||
_table_name_st, # dwd 表名后缀
|
||||
_table_name_st, # ods 表名后缀
|
||||
st.booleans(), # 是否模拟失败
|
||||
)
|
||||
|
||||
_table_map_st = st.lists(_table_entry_st, min_size=0, max_size=8, unique_by=lambda t: t[0])
|
||||
|
||||
|
||||
# ── Property 1: DwdLoadTask 返回值格式一致性 ─────────────────
|
||||
# Feature: etl-aggregation-fix, Property 1: DwdLoadTask 返回值格式一致性
|
||||
# **Validates: Requirements 1.1**
|
||||
#
|
||||
# 对于任意 DwdLoadTask.load() 的执行结果,返回字典中:
|
||||
# - `errors` 键的值应为 int 类型
|
||||
# - `error_details` 键的值应为 list 类型
|
||||
# - `errors` == len(`error_details`)
|
||||
|
||||
|
||||
class TestProperty1DwdLoadReturnFormat:
|
||||
"""Property 1: DwdLoadTask 返回值格式一致性。"""
|
||||
|
||||
@given(table_entries=_table_map_st)
|
||||
@settings(max_examples=200)
|
||||
def test_errors_is_int_and_equals_len_error_details(self, table_entries):
|
||||
"""对于任意表映射和失败模式,errors 始终为 int 且等于 len(error_details)。"""
|
||||
task = _make_task()
|
||||
|
||||
# 构造 TABLE_MAP 和失败集合
|
||||
table_map = {}
|
||||
fail_set = set()
|
||||
for dwd_suffix, ods_suffix, should_fail in table_entries:
|
||||
dwd_name = f"dwd.dim_{dwd_suffix}"
|
||||
ods_name = f"ods.{ods_suffix}"
|
||||
table_map[dwd_name] = ods_name
|
||||
if should_fail:
|
||||
fail_set.add(dwd_name)
|
||||
|
||||
# _get_columns:失败表抛异常,成功表返回空列(跳过,不计入 summary 也不计入 errors)
|
||||
def fake_get_columns(cur, table):
|
||||
if table in fail_set:
|
||||
raise RuntimeError(f"模拟失败: {table}")
|
||||
return [] # 空列 → 跳过
|
||||
|
||||
task._get_columns = fake_get_columns
|
||||
|
||||
with patch.object(DwdLoadTask, "TABLE_MAP", table_map):
|
||||
ctx = SimpleNamespace(
|
||||
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
|
||||
)
|
||||
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
|
||||
|
||||
# 核心断言:返回值格式一致性
|
||||
assert isinstance(result["errors"], int), (
|
||||
f"errors 应为 int,实际为 {type(result['errors']).__name__}"
|
||||
)
|
||||
assert isinstance(result["error_details"], list), (
|
||||
f"error_details 应为 list,实际为 {type(result['error_details']).__name__}"
|
||||
)
|
||||
assert result["errors"] == len(result["error_details"]), (
|
||||
f"errors={result['errors']} != len(error_details)={len(result['error_details'])}"
|
||||
)
|
||||
# errors 应等于预期失败数
|
||||
assert result["errors"] == len(fail_set), (
|
||||
f"errors={result['errors']} != 预期失败数={len(fail_set)}"
|
||||
)
|
||||
assert "tables" in result
|
||||
|
||||
|
||||
# ── Property 2: _accumulate_counts 类型安全累加 ─────────────
|
||||
# Feature: etl-aggregation-fix, Property 2: _accumulate_counts 类型安全累加
|
||||
# **Validates: Requirements 1.2**
|
||||
#
|
||||
# 对于任意包含 int、float、list 类型值的计数字典,
|
||||
# _accumulate_counts() 应将 int/float 直接累加,将 list 转为 len() 后累加,
|
||||
# 且不抛出异常。
|
||||
|
||||
from tasks.base_task import BaseTask
|
||||
|
||||
# ── 生成策略:计数字典值 ──────────────────────────────────────
|
||||
|
||||
# 合法计数值:int / float / list(任意元素)
|
||||
_count_value_st = st.one_of(
|
||||
st.integers(min_value=0, max_value=10_000),
|
||||
st.floats(min_value=0.0, max_value=10_000.0, allow_nan=False, allow_infinity=False),
|
||||
st.lists(st.integers(), min_size=0, max_size=50),
|
||||
)
|
||||
|
||||
# 计数字典:1~10 个键,值为 int/float/list
|
||||
_count_key_st = st.text(
|
||||
alphabet=st.sampled_from("abcdefghijklmnopqrstuvwxyz_"),
|
||||
min_size=2, max_size=10,
|
||||
)
|
||||
|
||||
_count_dict_st = st.dictionaries(
|
||||
keys=_count_key_st,
|
||||
values=_count_value_st,
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
)
|
||||
|
||||
|
||||
def _numeric_value(v) -> int | float:
|
||||
"""将 int/float/list 统一转为数值(list → len)。"""
|
||||
if isinstance(v, (int, float)):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
return len(v)
|
||||
return 0
|
||||
|
||||
|
||||
class TestProperty2AccumulateCountsTypeSafe:
|
||||
"""Property 2: _accumulate_counts 类型安全累加。"""
|
||||
|
||||
@given(dicts=st.lists(_count_dict_st, min_size=1, max_size=5))
|
||||
@settings(max_examples=200)
|
||||
def test_accumulate_never_raises_and_sums_correctly(self, dicts):
|
||||
"""对于任意 int/float/list 值的计数字典序列,累加不抛异常且数值正确。"""
|
||||
total: dict = {}
|
||||
for d in dicts:
|
||||
BaseTask._accumulate_counts(total, d)
|
||||
|
||||
# 手动计算期望值
|
||||
expected: dict = {}
|
||||
for d in dicts:
|
||||
for k, v in d.items():
|
||||
nv = _numeric_value(v)
|
||||
if isinstance(nv, (int, float)):
|
||||
expected[k] = (expected.get(k) or 0) + nv
|
||||
else:
|
||||
expected.setdefault(k, v)
|
||||
|
||||
# 断言:所有可累加键的值应与手动计算一致
|
||||
for k, ev in expected.items():
|
||||
assert k in total, f"键 '{k}' 应存在于累加结果中"
|
||||
if isinstance(ev, (int, float)):
|
||||
assert abs(total[k] - ev) < 1e-9, (
|
||||
f"键 '{k}': 累加结果={total[k]}, 期望={ev}"
|
||||
)
|
||||
|
||||
@given(current=_count_dict_st)
|
||||
@settings(max_examples=200)
|
||||
def test_single_dict_values_match_numeric_conversion(self, current):
|
||||
"""单次累加:int/float 保持原值,list 转为 len()。"""
|
||||
total: dict = {}
|
||||
BaseTask._accumulate_counts(total, current)
|
||||
|
||||
for k, v in current.items():
|
||||
if isinstance(v, (int, float)):
|
||||
assert abs(total[k] - v) < 1e-9, (
|
||||
f"键 '{k}': int/float 应直接累加,结果={total[k]}, 原值={v}"
|
||||
)
|
||||
elif isinstance(v, list):
|
||||
assert total[k] == len(v), (
|
||||
f"键 '{k}': list 应转为 len()={len(v)}, 实际={total[k]}"
|
||||
)
|
||||
|
||||
@given(current=_count_dict_st)
|
||||
@settings(max_examples=100)
|
||||
def test_none_total_key_initializes_correctly(self, current):
|
||||
"""当 total 中键不存在时(get 返回 None),or 0 应正确初始化。"""
|
||||
total: dict = {}
|
||||
result = BaseTask._accumulate_counts(total, current)
|
||||
# 返回值应与 total 是同一对象
|
||||
assert result is total
|
||||
# 所有数值键应 >= 0
|
||||
for k, v in result.items():
|
||||
if isinstance(v, (int, float)):
|
||||
assert v >= 0, f"键 '{k}' 的值不应为负: {v}"
|
||||
129
apps/etl/connectors/feiqiu/tests/unit/test_staff_info.py
Normal file
129
apps/etl/connectors/feiqiu/tests/unit/test_staff_info.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""员工维度表(staff_info)单元测试:验证 P1、P2 正确性属性。"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
|
||||
|
||||
from tasks.ods.ods_tasks import (
|
||||
ODS_TASK_SPECS,
|
||||
ODS_TASK_CLASSES,
|
||||
ENABLED_ODS_CODES,
|
||||
SnapshotMode,
|
||||
)
|
||||
from api.client import DEFAULT_LIST_KEYS
|
||||
from tasks.dwd.dwd_load_task import DwdLoadTask
|
||||
from .task_test_utils import create_test_config, get_db_operations, FakeAPIClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P1:ODS 任务规格完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_staff_spec():
|
||||
"""从 ODS_TASK_SPECS 中提取 ODS_STAFF_INFO 规格。"""
|
||||
for spec in ODS_TASK_SPECS:
|
||||
if spec.code == "ODS_STAFF_INFO":
|
||||
return spec
|
||||
raise AssertionError("ODS_STAFF_INFO 未在 ODS_TASK_SPECS 中注册")
|
||||
|
||||
|
||||
class TestP1OdsStaffInfoSpec:
|
||||
"""P1:ODS 任务规格完整性验证。"""
|
||||
|
||||
def test_code(self):
|
||||
assert _get_staff_spec().code == "ODS_STAFF_INFO"
|
||||
|
||||
def test_table_name(self):
|
||||
assert _get_staff_spec().table_name == "ods.staff_info_master"
|
||||
|
||||
def test_endpoint(self):
|
||||
assert _get_staff_spec().endpoint == "/PersonnelManagement/SearchSystemStaffInfo"
|
||||
|
||||
def test_list_key(self):
|
||||
assert _get_staff_spec().list_key == "staffProfiles"
|
||||
|
||||
def test_snapshot_mode(self):
|
||||
assert _get_staff_spec().snapshot_mode == SnapshotMode.FULL_TABLE
|
||||
|
||||
def test_requires_window_false(self):
|
||||
assert _get_staff_spec().requires_window is False
|
||||
|
||||
def test_time_fields_none(self):
|
||||
assert _get_staff_spec().time_fields is None
|
||||
|
||||
def test_staff_profiles_in_default_list_keys(self):
|
||||
assert "staffProfiles" in DEFAULT_LIST_KEYS
|
||||
|
||||
def test_ods_staff_info_in_enabled_codes(self):
|
||||
assert "ODS_STAFF_INFO" in ENABLED_ODS_CODES
|
||||
|
||||
def test_task_class_registered(self):
|
||||
assert "ODS_STAFF_INFO" in ODS_TASK_CLASSES
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P2:DWD 映射完整性
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestP2DwdStaffMapping:
|
||||
"""P2:DWD 映射完整性验证。"""
|
||||
|
||||
def test_dim_staff_table_map(self):
|
||||
assert DwdLoadTask.TABLE_MAP["dwd.dim_staff"] == "ods.staff_info_master"
|
||||
|
||||
def test_dim_staff_ex_table_map(self):
|
||||
assert DwdLoadTask.TABLE_MAP["dwd.dim_staff_ex"] == "ods.staff_info_master"
|
||||
|
||||
def test_dim_staff_has_staff_id_mapping(self):
|
||||
mappings = DwdLoadTask.FACT_MAPPINGS["dwd.dim_staff"]
|
||||
pk_map = [(dst, src) for dst, src, _ in mappings if dst == "staff_id"]
|
||||
assert len(pk_map) == 1
|
||||
assert pk_map[0][1] == "id"
|
||||
|
||||
def test_dim_staff_ex_has_staff_id_mapping(self):
|
||||
mappings = DwdLoadTask.FACT_MAPPINGS["dwd.dim_staff_ex"]
|
||||
pk_map = [(dst, src) for dst, src, _ in mappings if dst == "staff_id"]
|
||||
assert len(pk_map) == 1
|
||||
assert pk_map[0][1] == "id"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ODS 落地功能测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_staff_info_ingest(tmp_path):
|
||||
"""验证 ODS_STAFF_INFO 任务能正确落地员工数据。"""
|
||||
config = create_test_config("ONLINE", tmp_path / "archive", tmp_path / "temp")
|
||||
sample = [
|
||||
{
|
||||
"id": 3020236636900101,
|
||||
"staff_name": "葛芃",
|
||||
"mobile": "13811638071",
|
||||
"job": "店长",
|
||||
"staff_identity": 2,
|
||||
"status": 1,
|
||||
"leave_status": 0,
|
||||
"site_id": 2790685415443269,
|
||||
"tenant_id": 2790683160709957,
|
||||
}
|
||||
]
|
||||
api = FakeAPIClient({"/PersonnelManagement/SearchSystemStaffInfo": sample})
|
||||
task_cls = ODS_TASK_CLASSES["ODS_STAFF_INFO"]
|
||||
|
||||
with get_db_operations() as db_ops:
|
||||
task = task_cls(config, db_ops, api, logging.getLogger("test_staff_info"))
|
||||
result = task.execute()
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["counts"]["fetched"] == 1
|
||||
row = db_ops.upserts[0]["rows"][0]
|
||||
assert row["id"] == 3020236636900101
|
||||
assert row["record_index"] == 0
|
||||
assert '"staff_name": "葛芃"' in row["payload"] or '"staff_name": "\\u845b\\u82c3"' in row["payload"]
|
||||
209
apps/etl/connectors/feiqiu/tests/unit/test_timer.py
Normal file
209
apps/etl/connectors/feiqiu/tests/unit/test_timer.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""EtlTimer 单元测试"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pytest
|
||||
|
||||
from utils.timer import EtlTimer, StepRecord, _fmt_ms
|
||||
|
||||
|
||||
# ── _fmt_ms 格式化 ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFmtMs:
|
||||
def test_sub_second(self):
|
||||
assert _fmt_ms(123.4) == "123.4ms"
|
||||
|
||||
def test_seconds(self):
|
||||
assert _fmt_ms(2500) == "2.50s"
|
||||
|
||||
def test_minutes(self):
|
||||
result = _fmt_ms(90_000) # 90 秒
|
||||
assert result.startswith("1m")
|
||||
|
||||
def test_hours(self):
|
||||
result = _fmt_ms(3_700_000) # ~61.7 分钟
|
||||
assert result.startswith("1h")
|
||||
|
||||
|
||||
# ── StepRecord ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStepRecord:
|
||||
def test_elapsed_seconds(self):
|
||||
rec = StepRecord(name="test", start_time=None, elapsed_ms=1500.0)
|
||||
assert rec.elapsed_seconds == 1.5
|
||||
|
||||
def test_to_dict_without_end(self):
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(ZoneInfo("Asia/Shanghai"))
|
||||
rec = StepRecord(name="s1", start_time=now, elapsed_ms=100.0)
|
||||
d = rec.to_dict()
|
||||
assert d["name"] == "s1"
|
||||
assert d["end_time"] is None
|
||||
assert d["elapsed_ms"] == 100.0
|
||||
assert d["children"] == []
|
||||
|
||||
def test_to_dict_with_children(self):
|
||||
from datetime import datetime
|
||||
|
||||
tz = ZoneInfo("Asia/Shanghai")
|
||||
now = datetime.now(tz)
|
||||
parent = StepRecord(name="p", start_time=now, elapsed_ms=200.0)
|
||||
child = StepRecord(name="c", start_time=now, end_time=now, elapsed_ms=50.0)
|
||||
parent.children.append(child)
|
||||
d = parent.to_dict()
|
||||
assert len(d["children"]) == 1
|
||||
assert d["children"][0]["name"] == "c"
|
||||
|
||||
|
||||
# ── EtlTimer 核心流程 ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEtlTimer:
|
||||
def test_start_stop_step(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("STEP_A")
|
||||
time.sleep(0.01) # 确保有可测量的耗时
|
||||
rec = timer.stop_step("STEP_A")
|
||||
|
||||
assert rec.name == "STEP_A"
|
||||
assert rec.end_time is not None
|
||||
assert rec.elapsed_ms > 0
|
||||
|
||||
def test_sub_steps(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("PARENT")
|
||||
timer.start_sub_step("PARENT", "child_1")
|
||||
time.sleep(0.01)
|
||||
timer.stop_sub_step("PARENT", "child_1")
|
||||
timer.start_sub_step("PARENT", "child_2")
|
||||
timer.stop_sub_step("PARENT", "child_2")
|
||||
timer.stop_step("PARENT")
|
||||
|
||||
parent = timer.get_step("PARENT")
|
||||
assert parent is not None
|
||||
assert len(parent.children) == 2
|
||||
assert parent.children[0].name == "child_1"
|
||||
assert parent.children[0].elapsed_ms > 0
|
||||
|
||||
def test_stop_unknown_step_raises(self):
|
||||
timer = EtlTimer()
|
||||
with pytest.raises(KeyError, match="未找到步骤"):
|
||||
timer.stop_step("NONEXISTENT")
|
||||
|
||||
def test_start_sub_step_unknown_parent_raises(self):
|
||||
timer = EtlTimer()
|
||||
with pytest.raises(KeyError, match="未找到父步骤"):
|
||||
timer.start_sub_step("NONEXISTENT", "child")
|
||||
|
||||
def test_stop_sub_step_unknown_raises(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("P")
|
||||
with pytest.raises(KeyError, match="未找到子步骤"):
|
||||
timer.stop_sub_step("P", "no_such_child")
|
||||
|
||||
def test_multiple_steps(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
for name in ["ODS_LOAD", "DWD_LOAD", "DWS_AGG"]:
|
||||
timer.start_step(name)
|
||||
timer.stop_step(name)
|
||||
|
||||
assert len(timer.steps) == 3
|
||||
assert timer.steps[0].name == "ODS_LOAD"
|
||||
assert timer.steps[2].name == "DWS_AGG"
|
||||
|
||||
def test_to_dict(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("S1")
|
||||
timer.stop_step("S1")
|
||||
report = timer.finish(write_report=False)
|
||||
|
||||
d = timer.to_dict()
|
||||
assert d["overall_start"] is not None
|
||||
assert d["overall_end"] is not None
|
||||
assert d["overall_elapsed_ms"] >= 0
|
||||
assert len(d["steps"]) == 1
|
||||
|
||||
def test_overall_elapsed(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
time.sleep(0.02)
|
||||
timer.finish(write_report=False)
|
||||
assert timer.overall_elapsed_ms >= 15 # 至少 15ms(留余量)
|
||||
|
||||
def test_finish_returns_markdown(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("TEST_STEP")
|
||||
timer.stop_step("TEST_STEP")
|
||||
md = timer.finish(write_report=False)
|
||||
|
||||
assert "# ETL 执行计时报告" in md
|
||||
assert "TEST_STEP" in md
|
||||
assert "步骤汇总" in md
|
||||
|
||||
def test_markdown_contains_sub_steps(self):
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("MAIN")
|
||||
timer.start_sub_step("MAIN", "sub_a")
|
||||
timer.stop_sub_step("MAIN", "sub_a")
|
||||
timer.stop_step("MAIN")
|
||||
md = timer.finish(write_report=False)
|
||||
|
||||
assert "步骤详情" in md
|
||||
assert "sub_a" in md
|
||||
|
||||
def test_write_report_requires_env(self):
|
||||
"""ETL_REPORT_ROOT 未设置时应抛出 KeyError"""
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("X")
|
||||
timer.stop_step("X")
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(KeyError, match="ETL_REPORT_ROOT"):
|
||||
timer.finish(write_report=True)
|
||||
|
||||
def test_write_report_creates_file(self, tmp_path: Path):
|
||||
"""设置 ETL_REPORT_ROOT 后应生成 .md 文件"""
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("Y")
|
||||
timer.stop_step("Y")
|
||||
|
||||
with patch.dict(os.environ, {"ETL_REPORT_ROOT": str(tmp_path)}):
|
||||
timer.finish(write_report=True)
|
||||
|
||||
md_files = list(tmp_path.glob("etl_timing_*.md"))
|
||||
assert len(md_files) == 1
|
||||
content = md_files[0].read_text(encoding="utf-8")
|
||||
assert "# ETL 执行计时报告" in content
|
||||
assert "Y" in content
|
||||
|
||||
def test_elapsed_equals_end_minus_start(self):
|
||||
"""Property 7 核心验证:耗时 ≈ 结束时间 - 开始时间"""
|
||||
timer = EtlTimer()
|
||||
timer.start()
|
||||
timer.start_step("VERIFY")
|
||||
time.sleep(0.05)
|
||||
rec = timer.stop_step("VERIFY")
|
||||
|
||||
# 用 datetime 差值计算的毫秒数
|
||||
dt_diff_ms = (rec.end_time - rec.start_time).total_seconds() * 1000
|
||||
# perf_counter 计算的毫秒数
|
||||
# 两者应在合理误差范围内(±50ms,考虑系统调度抖动)
|
||||
assert abs(rec.elapsed_ms - dt_diff_ms) < 50
|
||||
Reference in New Issue
Block a user