在前后端开发联调前 的提交20260223

This commit is contained in:
Neo
2026-02-23 23:02:20 +08:00
parent 254ccb1e77
commit fafc95e64c
1142 changed files with 10366960 additions and 36957 deletions

View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
"""BaseTask._accumulate_counts() 防御层单元测试。
验证需求 1.2list 类型值转为 len() 后累加。
纯单元测试,不涉及数据库或外部依赖。
"""
from __future__ import annotations
from tasks.base_task import BaseTask
class TestAccumulateCountsListDefense:
"""验证 _accumulate_counts 对 list 类型值的防御处理。"""
def test_list_value_converted_to_len(self):
"""list 值应转为 len() 后累加。"""
total = {}
current = {"errors": [{"table": "dim_a", "error": "boom"}]}
result = BaseTask._accumulate_counts(total, current)
assert result["errors"] == 1
def test_list_value_accumulates_across_calls(self):
"""多次调用时list 的 len() 应正确累加。"""
total = {}
BaseTask._accumulate_counts(total, {"errors": [{"t": "a"}]})
BaseTask._accumulate_counts(total, {"errors": [{"t": "b"}, {"t": "c"}]})
assert total["errors"] == 3
def test_list_and_int_mixed_accumulation(self):
"""先累加 int再累加 list或反过来结果应一致。"""
total = {}
BaseTask._accumulate_counts(total, {"errors": 2})
BaseTask._accumulate_counts(total, {"errors": [{"t": "x"}]})
assert total["errors"] == 3
def test_empty_list_adds_zero(self):
"""空 list 应累加 0。"""
total = {"errors": 5}
BaseTask._accumulate_counts(total, {"errors": []})
assert total["errors"] == 5
def test_int_float_still_work(self):
"""int/float 类型的原有行为不受影响。"""
total = {}
BaseTask._accumulate_counts(total, {"inserted": 10, "rate": 0.5})
BaseTask._accumulate_counts(total, {"inserted": 3, "rate": 0.2})
assert total["inserted"] == 13
assert abs(total["rate"] - 0.7) < 1e-9
def test_non_numeric_non_list_uses_setdefault(self):
"""非数值、非 list 类型使用 setdefault 保留首次值。"""
total = {}
BaseTask._accumulate_counts(total, {"status": "ok"})
BaseTask._accumulate_counts(total, {"status": "fail"})
assert total["status"] == "ok"
def test_none_current_is_safe(self):
"""current 为 None 时不报错。"""
total = {"x": 1}
result = BaseTask._accumulate_counts(total, None)
assert result == {"x": 1}
def test_empty_current_is_safe(self):
"""current 为空字典时不报错。"""
total = {"x": 1}
result = BaseTask._accumulate_counts(total, {})
assert result == {"x": 1}

View File

@@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
"""属性测试废除聚合逻辑正确性is_trash 驱动)
纯单元测试,通过 stub 隔离外部依赖DB/API/Config
验证 _aggregate_by_assistant_date 中 is_trash 分流逻辑的正确性。
# Feature: assistant-abolish-cleanup, Property 1: 废除聚合逻辑正确性
# **Validates: Requirements 3.4, 9.3**
"""
from __future__ import annotations
from datetime import date
from decimal import Decimal
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock
from hypothesis import given, settings
from hypothesis import strategies as st
from tasks.dws.assistant_daily_task import AssistantDailyTask
# ── Stub 构造 ──────────────────────────────────────────────────────────
def _make_stub_task() -> AssistantDailyTask:
"""构造一个最小化的 AssistantDailyTask 实例stub 掉所有外部依赖。"""
# config.get 需要区分 keytimezone 返回字符串,其余返回整数
def _config_get(key, default=None):
if key == "app.timezone":
return "Asia/Shanghai"
if key == "app.tenant_id":
return 1
if key == "app.store_id":
return 1
return default
config = MagicMock()
config.get.side_effect = _config_get
db = MagicMock()
api = MagicMock()
logger = MagicMock()
task = AssistantDailyTask(config, db, api, logger)
# stub SCD2 等级查询——返回固定等级
task.get_assistant_level_asof = MagicMock(
return_value={"level_code": 10, "level_name": "初级"}
)
# stub 课程类型查询——所有 skill 都返回 BASE
from tasks.dws.base_dws_task import CourseType, ConfigCache
task.get_course_type = MagicMock(return_value=CourseType.BASE)
# stub load_config_cacheget_course_type 已被 mock此处仅防止真实 DB 调用)
task.load_config_cache = MagicMock(return_value=ConfigCache(
performance_tiers=[],
level_prices=[],
bonus_rules=[],
area_categories={},
skill_types={},
loaded_at=None,
))
return task
# ── Hypothesis 策略 ────────────────────────────────────────────────────
# 固定的助教 ID 和服务日期,确保所有记录聚合到同一行
_FIXED_ASSISTANT_ID = 10001
_FIXED_SERVICE_DATE = date(2025, 6, 15)
_FIXED_SITE_ID = 1
# 单条服务记录策略
_service_record_st = st.fixed_dictionaries({
"assistant_id": st.just(_FIXED_ASSISTANT_ID),
"service_date": st.just(_FIXED_SERVICE_DATE),
"assistant_nickname": st.just("测试助教"),
"assistant_level": st.just(10),
"assistant_service_id": st.integers(min_value=1, max_value=999999),
"skill_id": st.just(100),
"member_id": st.integers(min_value=1, max_value=9999),
"table_id": st.integers(min_value=1, max_value=50),
"income_seconds": st.integers(min_value=0, max_value=36000),
"ledger_amount": st.integers(min_value=0, max_value=10000),
"is_trash": st.integers(min_value=0, max_value=1),
})
# 服务记录列表策略1~30 条)
_service_records_st = st.lists(_service_record_st, min_size=1, max_size=30)
# ── 属性测试 ───────────────────────────────────────────────────────────
class TestProperty1TrashAggregation:
"""Property 1: 废除聚合逻辑正确性is_trash 驱动)
**Validates: Requirements 3.4, 9.3**
"""
@given(records=_service_records_st)
@settings(max_examples=200)
def test_trashed_seconds_equals_sum_of_trash_income(self, records: List[Dict[str, Any]]):
"""trashed_seconds 应等于所有 is_trash=1 记录的 income_seconds 之和"""
task = _make_stub_task()
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
# 手动计算期望值
expected = sum(
task.safe_int(r.get("income_seconds", 0))
for r in records
if bool(r.get("is_trash", 0))
)
assert len(result) == 1
assert result[0]["trashed_seconds"] == expected
@given(records=_service_records_st)
@settings(max_examples=200)
def test_trashed_count_equals_trash_record_count(self, records: List[Dict[str, Any]]):
"""trashed_count 应等于所有 is_trash=1 记录的数量"""
task = _make_stub_task()
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
expected = sum(1 for r in records if bool(r.get("is_trash", 0)))
assert len(result) == 1
assert result[0]["trashed_count"] == expected
@given(records=_service_records_st)
@settings(max_examples=200)
def test_total_service_count_equals_non_trash_count(self, records: List[Dict[str, Any]]):
"""total_service_count 应等于所有 is_trash=0 记录的数量"""
task = _make_stub_task()
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
expected = sum(1 for r in records if not bool(r.get("is_trash", 0)))
assert len(result) == 1
assert result[0]["total_service_count"] == expected
@given(records=_service_records_st)
@settings(max_examples=200)
def test_total_seconds_equals_sum_of_non_trash_income(self, records: List[Dict[str, Any]]):
"""total_seconds 应等于所有 is_trash=0 记录的 income_seconds 之和"""
task = _make_stub_task()
result = task._aggregate_by_assistant_date(records, _FIXED_SITE_ID)
expected = sum(
task.safe_int(r.get("income_seconds", 0))
for r in records
if not bool(r.get("is_trash", 0))
)
assert len(result) == 1
assert result[0]["total_seconds"] == expected

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@
"""CLI 参数解析单元测试
验证 --data-source 新参数、--pipeline-flow 弃用映射、
--pipeline + --tasks 同时使用、以及 build_cli_overrides 集成行为。
以及 build_cli_overrides 集成行为。
需求: 3.1, 3.3, 3.5
"""
@@ -80,7 +80,7 @@ class TestBuildCliOverrides:
"""构造最小 Namespace未指定的参数设为 None/False"""
defaults = dict(
store_id=None, tasks=None, dry_run=False,
flow=None, pipeline_deprecated=None,
flow=None,
processing_mode="increment_only",
fetch_before_verify=False, verify_tables=None,
window_split="none", lookback_hours=24, overlap_seconds=3600,
@@ -96,6 +96,7 @@ class TestBuildCliOverrides:
data_source=None, pipeline_flow=None,
fetch_root=None, ingest_source=None, write_pretty_json=False,
idle_start=None, idle_end=None, allow_empty_advance=False,
force_full=False,
)
defaults.update(kwargs)
return Namespace(**defaults)
@@ -121,19 +122,3 @@ class TestBuildCliOverrides:
assert overrides["run"]["data_source"] == "hybrid"
# ---------------------------------------------------------------------------
# 4. --pipeline + --tasks 同时使用
# ---------------------------------------------------------------------------
class TestPipelineAndTasks:
"""--pipeline + --tasks 同时使用时的行为"""
def test_pipeline_and_tasks_both_parsed(self):
"""--pipeline弃用别名和 --tasks 可同时解析"""
with patch("sys.argv", [
"cli",
"--pipeline", "api_full",
"--tasks", "ODS_MEMBER,ODS_ORDER",
]):
args = parse_args()
assert args.pipeline_deprecated == "api_full"
assert args.tasks == "ODS_MEMBER,ODS_ORDER"

View File

@@ -0,0 +1,353 @@
# -*- coding: utf-8 -*-
"""
数据一致性检查器单元测试
测试 quality/consistency_checker.py 的核心纯函数逻辑,
不依赖数据库连接。
"""
import json
import pytest
from pathlib import Path
from quality.consistency_checker import (
FieldCheckResult,
TableCheckResult,
ConsistencyReport,
check_api_vs_ods_fields,
check_ods_vs_dwd_mappings,
extract_api_fields_from_json,
generate_markdown_report,
_validate_ods_expression,
_extract_records,
)
# ---------------------------------------------------------------------------
# extract_api_fields_from_json
# ---------------------------------------------------------------------------
class TestExtractApiFields:
"""API JSON 字段提取"""
def test_flat_list(self, tmp_path: Path):
"""直接列表格式"""
data = [{"id": 1, "name": "test", "amount": 100}]
p = tmp_path / "test.json"
p.write_text(json.dumps(data), encoding="utf-8")
fields = extract_api_fields_from_json(p)
assert fields == {"id", "name", "amount"}
def test_nested_data_key(self, tmp_path: Path):
"""{"data": [...]} 格式"""
data = {"data": [{"id": 1, "foo": "bar"}]}
p = tmp_path / "test.json"
p.write_text(json.dumps(data), encoding="utf-8")
fields = extract_api_fields_from_json(p)
assert fields == {"id", "foo"}
def test_nested_data_with_list_key(self, tmp_path: Path):
"""{"data": {"settleList": [...]}} 格式"""
data = {"data": {"settleList": [{"id": 1, "amount": 50}]}}
p = tmp_path / "test.json"
p.write_text(json.dumps(data), encoding="utf-8")
fields = extract_api_fields_from_json(p)
assert fields == {"id", "amount"}
def test_nonexistent_file(self, tmp_path: Path):
"""文件不存在返回 None"""
p = tmp_path / "nonexistent.json"
assert extract_api_fields_from_json(p) is None
def test_invalid_json(self, tmp_path: Path):
"""无效 JSON 返回 None"""
p = tmp_path / "bad.json"
p.write_text("not json", encoding="utf-8")
assert extract_api_fields_from_json(p) is None
def test_empty_list(self, tmp_path: Path):
"""空列表返回 None"""
p = tmp_path / "empty.json"
p.write_text("[]", encoding="utf-8")
assert extract_api_fields_from_json(p) is None
def test_merges_multiple_records(self, tmp_path: Path):
"""合并多条记录的字段"""
data = [{"id": 1, "a": 1}, {"id": 2, "b": 2}]
p = tmp_path / "test.json"
p.write_text(json.dumps(data), encoding="utf-8")
fields = extract_api_fields_from_json(p)
assert fields == {"id", "a", "b"}
# ---------------------------------------------------------------------------
# check_api_vs_ods_fields
# ---------------------------------------------------------------------------
class TestCheckApiVsOds:
"""API vs ODS 字段完整性检查"""
def test_all_fields_present(self):
"""所有 API 字段在 ODS 中都存在"""
api = {"id", "name", "amount"}
ods = {"id", "name", "amount", "fetched_at", "content_hash"}
result = check_api_vs_ods_fields(api, ods)
assert result.passed is True
assert result.missing_fields == 0
assert result.passed_fields == 3
def test_missing_fields(self):
"""部分 API 字段在 ODS 中缺失"""
api = {"id", "name", "extra_field"}
ods = {"id", "name", "fetched_at"}
result = check_api_vs_ods_fields(api, ods)
assert result.passed is False
assert result.missing_fields == 1
missing = [f for f in result.field_results if f.status == "missing"]
assert len(missing) == 1
assert missing[0].field_name == "extra_field"
def test_ods_meta_columns_excluded(self):
"""ODS 元数据列不参与对比"""
api = {"id"}
ods = {"id", "payload", "fetched_at", "content_hash", "source_file"}
result = check_api_vs_ods_fields(api, ods)
assert result.passed is True
assert result.total_fields == 1
def test_case_insensitive(self):
"""大小写不敏感匹配"""
api = {"Id", "Name"}
ods = {"id", "name", "fetched_at"}
result = check_api_vs_ods_fields(api, ods)
assert result.passed is True
def test_empty_api_fields(self):
"""空 API 字段集"""
result = check_api_vs_ods_fields(set(), {"id", "fetched_at"})
assert result.passed is True
assert result.total_fields == 0
# ---------------------------------------------------------------------------
# check_ods_vs_dwd_mappings
# ---------------------------------------------------------------------------
class TestCheckOdsVsDwd:
"""ODS vs DWD 映射正确性检查"""
def test_all_mapped_explicitly(self):
"""所有 DWD 列都有显式映射"""
dwd_cols = {"order_id", "amount", "scd2_is_current", "scd2_version"}
ods_cols = {"id", "total_amount"}
mappings = [
("order_id", "id", None),
("amount", "total_amount", None),
]
result = check_ods_vs_dwd_mappings(
"dwd.test_table", "ods.test_table",
dwd_cols, ods_cols, mappings,
)
assert result.passed is True
# SCD2 列不参与检查
assert result.total_fields == 2
def test_auto_mapping(self):
"""同名列自动映射"""
dwd_cols = {"id", "name", "amount"}
ods_cols = {"id", "name", "amount"}
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, None,
)
assert result.passed is True
assert result.passed_fields == 3
def test_missing_mapping(self):
"""DWD 列无映射源"""
dwd_cols = {"id", "orphan_col"}
ods_cols = {"id"}
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, None,
)
assert result.passed is False
assert result.missing_fields == 1
missing = [f for f in result.field_results if f.status == "missing"]
assert missing[0].field_name == "orphan_col"
def test_json_expression_mapping(self):
"""JSON 路径表达式映射"""
dwd_cols = {"shop_name"}
ods_cols = {"siteprofile"}
mappings = [("shop_name", "siteprofile->>'shop_name'", None)]
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, mappings,
)
assert result.passed is True
def test_sql_expression_mapping(self):
"""SQL 表达式映射CASE WHEN 等)"""
dwd_cols = {"is_leaf"}
ods_cols = {"categoryboxes"}
mappings = [
("is_leaf",
"CASE WHEN categoryboxes IS NULL OR jsonb_array_length(categoryboxes)=0 THEN 1 ELSE 0 END",
None),
]
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, mappings,
)
assert result.passed is True
def test_scd2_columns_excluded(self):
"""SCD2 列不参与检查"""
dwd_cols = {"id", "scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"}
ods_cols = {"id"}
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, None,
)
assert result.total_fields == 1
assert result.passed is True
def test_invalid_ods_reference(self):
"""显式映射引用了不存在的 ODS 列"""
dwd_cols = {"bad_col"}
ods_cols = {"id"}
mappings = [("bad_col", "nonexistent_col", None)]
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, mappings,
)
assert result.passed is False
assert result.mismatch_fields == 1
def test_null_expression(self):
"""NULL 字面量映射"""
dwd_cols = {"null_col"}
ods_cols = {"id"}
mappings = [("null_col", "NULL", None)]
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, mappings,
)
assert result.passed is True
def test_mixed_auto_and_explicit(self):
"""混合自动映射和显式映射"""
dwd_cols = {"id", "name", "mapped_col", "scd2_is_current"}
ods_cols = {"id", "name", "source_col"}
mappings = [("mapped_col", "source_col", None)]
result = check_ods_vs_dwd_mappings(
"dwd.test", "ods.test",
dwd_cols, ods_cols, mappings,
)
assert result.passed is True
assert result.passed_fields == 3 # id, name, mapped_col
# ---------------------------------------------------------------------------
# _validate_ods_expression
# ---------------------------------------------------------------------------
class TestValidateOdsExpression:
"""ODS 表达式验证"""
def test_simple_column(self):
assert _validate_ods_expression("name", {"name", "id"}) is True
def test_missing_column(self):
assert _validate_ods_expression("missing", {"name", "id"}) is False
def test_quoted_column(self):
assert _validate_ods_expression('"siteGoodsId"', {"sitegoodsid"}) is True
def test_json_path(self):
assert _validate_ods_expression("siteprofile->>'shop_name'", {"siteprofile"}) is True
def test_json_path_missing_base(self):
assert _validate_ods_expression("missing_col->>'field'", {"siteprofile"}) is False
def test_null_literal(self):
assert _validate_ods_expression("NULL", set()) is True
def test_case_expression(self):
assert _validate_ods_expression("CASE WHEN x=1 THEN 'a' ELSE 'b' END", set()) is True
def test_coalesce_expression(self):
assert _validate_ods_expression("COALESCE(a, b)", set()) is True
# ---------------------------------------------------------------------------
# _extract_records
# ---------------------------------------------------------------------------
class TestExtractRecords:
"""API 响应记录提取"""
def test_direct_list(self):
assert _extract_records([{"a": 1}]) == [{"a": 1}]
def test_data_list(self):
assert _extract_records({"data": [{"a": 1}]}) == [{"a": 1}]
def test_nested_list_key(self):
result = _extract_records({"data": {"items": [{"a": 1}]}})
assert result == [{"a": 1}]
def test_empty_data(self):
assert _extract_records({}) == []
def test_none_input(self):
assert _extract_records(None) == []
# ---------------------------------------------------------------------------
# generate_markdown_report
# ---------------------------------------------------------------------------
class TestGenerateMarkdownReport:
"""Markdown 报告生成"""
def test_empty_report(self):
report = ConsistencyReport(generated_at="2026-01-01T00:00:00")
md = generate_markdown_report(report)
assert "数据一致性黑盒测试报告" in md
assert "全部通过" in md
def test_failed_report(self):
report = ConsistencyReport(
generated_at="2026-01-01T00:00:00",
ods_vs_dwd_results=[
TableCheckResult(
table_name="dwd.test",
check_type="ods_vs_dwd",
passed=False,
total_fields=3,
passed_fields=2,
missing_fields=1,
field_results=[
FieldCheckResult("col_a", "pass", "自动映射"),
FieldCheckResult("col_b", "pass", "显式映射"),
FieldCheckResult("col_c", "missing", "无映射源"),
],
),
],
)
md = generate_markdown_report(report)
assert "存在异常" in md
assert "col_c" in md
assert "missing" in md
def test_report_contains_summary(self):
report = ConsistencyReport(
generated_at="2026-01-01T00:00:00",
api_vs_ods_results=[
TableCheckResult("ods.t1", "api_vs_ods", passed=True, total_fields=5, passed_fields=5),
TableCheckResult("ods.t2", "api_vs_ods", passed=False, total_fields=3, passed_fields=2, missing_fields=1),
],
)
md = generate_markdown_report(report)
assert "1/2 张表通过" in md

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""配置层属性测试 — 验证 AppConfig 的深度合并、store_id 验证、DSN 组装、点号路径 get。
Feature: etl-pipeline-debug
Feature: etl-flow-debug
使用 hypothesis 对 AppConfig 的 4 个核心正确性属性进行属性测试。
"""
from __future__ import annotations

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""DWD/DWS 层属性测试 — 验证 DwdLoadTask 和 BaseTask 的核心正确性属性。
Feature: etl-pipeline-debug
Feature: etl-flow-debug
使用 hypothesis 对 DWD 列映射完整性、only_tables 过滤和 DWS 分段累加进行属性测试。
"""
from __future__ import annotations
@@ -58,7 +58,7 @@ def _get_ods_table_names() -> set[str]:
# ===========================================================================
# Property 6: DWD FACT_MAPPINGS 列映射完整性
# Feature: etl-pipeline-debug, Property 6: DWD FACT_MAPPINGS 列映射完整性
# Feature: etl-flow-debug, Property 6: DWD FACT_MAPPINGS 列映射完整性
# Validates: Requirements 2.4
#
# 验证策略:静态检查 FACT_MAPPINGS 中每个映射条目,当 ods_expr 是简单列名时,
@@ -122,7 +122,7 @@ def test_property6_fact_mappings_column_integrity(idx):
# ===========================================================================
# Property 7: DWD only_tables 过滤
# Feature: etl-pipeline-debug, Property 7: DWD only_tables 过滤
# Feature: etl-flow-debug, Property 7: DWD only_tables 过滤
# Validates: Requirements 2.6
#
# 验证策略:模拟 DwdLoadTask.load() 中的 only_tables 过滤逻辑,
@@ -209,7 +209,7 @@ def test_property7_dwd_only_tables_filter(only_tables_cfg):
# ===========================================================================
# Property 8: DWS 分段累加一致性
# Feature: etl-pipeline-debug, Property 8: DWS 分段累加一致性
# Feature: etl-flow-debug, Property 8: DWS 分段累加一致性
# Validates: Requirements 3.3
#
# 验证策略:直接测试 BaseTask._accumulate_counts 静态方法,

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""ODS 层属性测试 — 验证 BaseOdsTask 的核心正确性属性。
Feature: etl-pipeline-debug
Feature: etl-flow-debug
使用 hypothesis 对 ODS 任务的关键行为进行属性测试。
"""
from __future__ import annotations
@@ -166,7 +166,7 @@ _COLUMNS_WITH_IS_DELETE = [
# ===========================================================================
# Property 1: ODS 任务提取记录数一致性
# Feature: etl-pipeline-debug, Property 1: ODS 任务提取记录数一致性
# Feature: etl-flow-debug, Property 1: ODS 任务提取记录数一致性
# Validates: Requirements 1.1, 1.2
# ===========================================================================
@@ -207,7 +207,7 @@ def test_property1_ods_record_count_consistency(tmp_path, records):
# ===========================================================================
# Property 2: ODS 冲突处理策略正确性
# Feature: etl-pipeline-debug, Property 2: ODS 冲突处理策略正确性
# Feature: etl-flow-debug, Property 2: ODS 冲突处理策略正确性
# Validates: Requirements 1.3
# ===========================================================================
@@ -256,7 +256,7 @@ def test_property2_ods_conflict_mode_sql(tmp_path, conflict_mode):
# ===========================================================================
# Property 3: ODS 跳过缺失主键记录
# Feature: etl-pipeline-debug, Property 3: ODS 跳过缺失主键记录
# Feature: etl-flow-debug, Property 3: ODS 跳过缺失主键记录
# Validates: Requirements 1.4
# ===========================================================================
@@ -304,7 +304,7 @@ def test_property3_ods_skip_missing_pk(tmp_path, valid_records, missing_pk_recor
# ===========================================================================
# Property 4: ODS content_hash 去重
# Feature: etl-pipeline-debug, Property 4: ODS content_hash 去重
# Feature: etl-flow-debug, Property 4: ODS content_hash 去重
# Validates: Requirements 1.5
# ===========================================================================
@@ -337,7 +337,7 @@ def test_property4_ods_content_hash_deterministic(tmp_path, record):
# ===========================================================================
# Property 5: ODS 快照删除标记INSERT 语义)
# Feature: etl-pipeline-debug, Property 5: ODS 快照删除标记
# Feature: etl-flow-debug, Property 5: ODS 快照删除标记
# Validates: Requirements 1.7, 7.1, 7.4
# ===========================================================================

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""编排层属性测试 — 验证 FlowRunner、TaskExecutor、CLI 的核心正确性属性。
Feature: etl-pipeline-debug
Feature: etl-flow-debug
使用 hypothesis 对 Flow 层解析、无效 Flow 拒绝、工具类任务跳过游标、
CLI data_source 解析进行属性测试。
"""
@@ -58,7 +58,7 @@ def test_property9_pipeline_runner_flow_layer_resolution(flow_name: str):
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
# 直接验证 FLOW_LAYERS 字典查找——这是 FlowRunner.run() 内部
# 解析层列表的唯一路径:`layers = self.FLOW_LAYERS[pipeline]`
# 解析层列表的唯一路径:`layers = self.FLOW_LAYERS[flow]`
assert flow_name in FlowRunner.FLOW_LAYERS
assert FlowRunner.FLOW_LAYERS[flow_name] == expected_layers
@@ -97,7 +97,7 @@ def test_property10_pipeline_runner_rejects_invalid_flow(invalid_flow: str):
避免构造完整的 FlowRunner 实例(需要真实 DB/API 连接)。
"""
# FlowRunner.run() 的第一行就是:
# if pipeline not in self.FLOW_LAYERS: raise ValueError(...)
# if flow not in self.FLOW_LAYERS: raise ValueError(...)
# 我们直接验证这个守卫条件
assert invalid_flow not in FlowRunner.FLOW_LAYERS
@@ -116,7 +116,7 @@ def test_property10_pipeline_runner_rejects_invalid_flow(invalid_flow: str):
)
with pytest.raises(ValueError, match="无效的 Flow 名称"):
runner.run(pipeline=invalid_flow)
runner.run(flow=invalid_flow)
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-
"""任务 3.3: 验证 DwdLoadTask 自动列映射包含 birthday。
验证点:
1. _get_columns() 从 information_schema 读取 DWD 表列名时birthday 被自动包含
2. _is_row_changed() 的 SCD2 变化检测自动包含 birthday因为 birthday 不在 SCD_COLS 中)
3. _merge_dim_scd2() 构建 SELECT 表达式时birthday 作为同名列被自动映射
需求: 4.2, 4.3
"""
from __future__ import annotations
import os
import sys
from datetime import datetime, date
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
from zoneinfo import ZoneInfo
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
from tasks.dwd.dwd_load_task import DwdLoadTask
def _make_bare_task() -> DwdLoadTask:
"""创建一个最小化的 DwdLoadTask 实例,仅设置 _is_row_changed 所需的属性。"""
task = DwdLoadTask.__new__(DwdLoadTask)
task.tz = ZoneInfo("Asia/Shanghai")
return task
# ---------------------------------------------------------------------------
# dim_member 表列定义(迁移 C1 后的状态,含 birthday
# ---------------------------------------------------------------------------
_DIM_MEMBER_COLUMNS = [
{"column_name": "member_id"},
{"column_name": "scd2_version"},
{"column_name": "nickname"},
{"column_name": "mobile"},
{"column_name": "birthday"},
{"column_name": "register_site_id"},
{"column_name": "scd2_start_time"},
{"column_name": "scd2_end_time"},
{"column_name": "scd2_is_current"},
]
_ODS_MEMBER_COLUMNS = [
{"column_name": "id"},
{"column_name": "nickname"},
{"column_name": "mobile"},
{"column_name": "birthday"},
{"column_name": "fetched_at"},
{"column_name": "payload"},
]
# ---------------------------------------------------------------------------
# 测试
# ---------------------------------------------------------------------------
class TestDwdBirthdayColumnMapping:
"""验证 DwdLoadTask 的自动列映射和 SCD2 变化检测包含 birthday。"""
def test_get_columns_includes_birthday(self):
"""_get_columns() 从 information_schema 读取列名时birthday 应在返回列表中。
验证机制_get_columns() 执行 SELECT column_name FROM information_schema.columns
并返回所有列名的小写列表。当 dim_member 表包含 birthday 列时,
返回值中应包含 'birthday'
"""
# 直接验证 _get_columns 的逻辑:它返回 information_schema 查询结果的 column_name
# 模拟 information_schema 返回包含 birthday 的列定义
columns = [row["column_name"].lower() for row in _DIM_MEMBER_COLUMNS]
assert "birthday" in columns, (
"dim_member 列列表应包含 birthday"
f"实际列: {columns}"
)
# 同时验证 SCD2 元数据列也在列表中(它们会被 _is_row_changed 跳过)
for scd_col in ("scd2_start_time", "scd2_end_time", "scd2_is_current", "scd2_version"):
assert scd_col in columns, f"dim_member 应包含 SCD2 列 {scd_col}"
def test_birthday_not_in_scd_cols(self):
"""birthday 不在 SCD_COLS 集合中,因此 SCD2 变化检测会自动包含它。
_is_row_changed() 遍历 dwd_cols跳过 SCD_COLS 中的列,
对其余列逐一比较。birthday 不在 SCD_COLS 中,所以会参与比较。
"""
assert "birthday" not in DwdLoadTask.SCD_COLS, (
"birthday 不应在 SCD_COLS 中,否则 SCD2 变化检测会跳过它。"
f"当前 SCD_COLS: {DwdLoadTask.SCD_COLS}"
)
def test_is_row_changed_detects_birthday_change(self):
"""当 birthday 值变化时_is_row_changed() 应返回 True。
验证需求 4.3: SCD2 将 birthday 作为变化检测字段之一。
"""
task = _make_bare_task()
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
current = {
"member_id": 100,
"nickname": "张三",
"mobile": "13800138000",
"birthday": date(1990, 5, 15),
"register_site_id": 1001,
"scd2_version": 1,
"scd2_start_time": datetime(2025, 1, 1),
"scd2_end_time": datetime(9999, 12, 31),
"scd2_is_current": 1,
}
# birthday 变化1990-05-15 → 1991-06-20
incoming = {
"member_id": 100,
"nickname": "张三",
"mobile": "13800138000",
"birthday": date(1991, 6, 20),
"register_site_id": 1001,
}
changed = task._is_row_changed(current, incoming, dwd_cols)
assert changed is True, (
"birthday 值变化时_is_row_changed 应返回 True"
)
def test_is_row_changed_no_change_when_birthday_same(self):
"""当 birthday 值不变时其他字段也不变_is_row_changed() 应返回 False。"""
task = _make_bare_task()
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
current = {
"member_id": 100,
"nickname": "张三",
"mobile": "13800138000",
"birthday": date(1990, 5, 15),
"register_site_id": 1001,
"scd2_version": 1,
"scd2_start_time": datetime(2025, 1, 1),
"scd2_end_time": datetime(9999, 12, 31),
"scd2_is_current": 1,
}
incoming = {
"member_id": 100,
"nickname": "张三",
"mobile": "13800138000",
"birthday": date(1990, 5, 15),
"register_site_id": 1001,
}
changed = task._is_row_changed(current, incoming, dwd_cols)
assert changed is False, (
"所有字段(含 birthday不变时_is_row_changed 应返回 False"
)
def test_is_row_changed_birthday_null_to_value(self):
"""birthday 从 NULL 变为有值时,应检测到变化。"""
task = _make_bare_task()
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
current = {
"member_id": 100,
"nickname": "张三",
"mobile": "13800138000",
"birthday": None,
"register_site_id": 1001,
"scd2_version": 1,
"scd2_start_time": datetime(2025, 1, 1),
"scd2_end_time": datetime(9999, 12, 31),
"scd2_is_current": 1,
}
incoming = {
"member_id": 100,
"nickname": "张三",
"mobile": "13800138000",
"birthday": date(1990, 5, 15),
"register_site_id": 1001,
}
changed = task._is_row_changed(current, incoming, dwd_cols)
assert changed is True, (
"birthday 从 None 变为有值时_is_row_changed 应返回 True"
)
def test_birthday_in_scd2_select_expressions(self):
"""验证 _merge_dim_scd2 构建 SELECT 表达式时birthday 作为同名列被自动映射。
机制_merge_dim_scd2 遍历 dwd_cols对于不在 SCD_COLS 中且不在 mapping 中
但在 ods_set 中的列,会生成 '"birthday" AS "birthday"' 表达式。
"""
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
ods_cols = [row["column_name"] for row in _ODS_MEMBER_COLUMNS]
ods_set = {c.lower() for c in ods_cols}
# 模拟 _merge_dim_scd2 中构建 SELECT 表达式的逻辑
scd_cols = DwdLoadTask.SCD_COLS
mapping = {} # dim_member 没有显式 FACT_MAPPINGS
selected_cols = []
for col in dwd_cols:
lc = col.lower()
if lc in scd_cols:
continue
if lc in mapping:
selected_cols.append(lc)
elif lc in ods_set:
selected_cols.append(lc)
assert "birthday" in selected_cols, (
"birthday 应出现在 SCD2 SELECT 表达式的列列表中,"
f"实际选中列: {selected_cols}"
)
def test_scd2_change_detection_columns_include_birthday(self):
"""验证 SCD2 变化检测遍历的列集合包含 birthday。
_is_row_changed 遍历 dwd_cols跳过 SCD_COLS。
最终参与比较的列应包含 birthday。
"""
dwd_cols = [row["column_name"] for row in _DIM_MEMBER_COLUMNS]
scd_cols = DwdLoadTask.SCD_COLS
# 计算实际参与变化检测的列
change_detection_cols = [
col for col in dwd_cols if col.lower() not in scd_cols
]
assert "birthday" in change_detection_cols, (
"SCD2 变化检测列应包含 birthday"
f"实际检测列: {change_detection_cols}"
)
# 验证 SCD2 元数据列被正确排除
for scd_col in scd_cols:
assert scd_col not in change_detection_cols, (
f"SCD2 元数据列 {scd_col} 不应参与变化检测"
)

View File

@@ -0,0 +1,144 @@
# -*- coding: utf-8 -*-
"""DwdLoadTask.load() 返回值格式验证 — 单元测试。
验证需求 1.1load() 返回 errors: int, error_details: list[dict]。
使用 FakeDB/FakeAPI不涉及真实数据库连接。
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from tasks.dwd.dwd_load_task import DwdLoadTask
def _make_task() -> DwdLoadTask:
"""构造最小可用的 DwdLoadTask 实例。"""
config = MagicMock()
config.get = MagicMock(side_effect=lambda key, default=None: default)
# 构造 FakeConnectioncursor() 接受任意 kwargs
fake_conn = MagicMock()
fake_cursor = MagicMock()
fake_cursor.__enter__ = MagicMock(return_value=fake_cursor)
fake_cursor.__exit__ = MagicMock(return_value=False)
fake_conn.cursor.return_value = fake_cursor
db = MagicMock()
db.conn = fake_conn
api = MagicMock()
logger = logging.getLogger("test_dwd_return_format")
return DwdLoadTask(config, db, api, logger)
class TestLoadReturnFormat:
"""验证 load() 返回值中 errors 为 int、error_details 为 list[dict]。"""
def test_no_errors_returns_zero_and_empty_list(self):
"""无错误时errors=0, error_details=[]。"""
task = _make_task()
# 让 TABLE_MAP 为空load() 直接返回空结果
with patch.object(DwdLoadTask, "TABLE_MAP", {}):
ctx = SimpleNamespace(
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
assert isinstance(result["errors"], int), (
f"errors 应为 int实际为 {type(result['errors'])}"
)
assert result["errors"] == 0
assert isinstance(result["error_details"], list)
assert result["error_details"] == []
assert "tables" in result
def test_with_errors_returns_count_and_details(self):
"""有错误时errors=len(error_details), error_details 包含错误字典。"""
task = _make_task()
# 模拟一张表装载失败_get_columns 抛异常
fake_cursor = task.db.conn.cursor.return_value.__enter__.return_value
fake_cursor.fetchall.return_value = []
def fake_get_columns(cur, table):
if "dim_test" in table:
raise RuntimeError("模拟装载失败")
return []
task._get_columns = fake_get_columns
with patch.object(
DwdLoadTask, "TABLE_MAP", {"dwd.dim_test": "ods.test_source"}
):
ctx = SimpleNamespace(
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
assert isinstance(result["errors"], int)
assert result["errors"] == 1
assert isinstance(result["error_details"], list)
assert len(result["error_details"]) == 1
assert result["error_details"][0]["table"] == "dwd.dim_test"
assert "模拟装载失败" in result["error_details"][0]["error"]
def test_errors_equals_len_of_error_details(self):
"""errors 值始终等于 error_details 列表长度。"""
task = _make_task()
# 模拟两张表都失败
def fake_get_columns(cur, table):
raise RuntimeError(f"失败: {table}")
task._get_columns = fake_get_columns
with patch.object(
DwdLoadTask,
"TABLE_MAP",
{"dwd.dim_a": "ods.a", "dwd.dim_b": "ods.b"},
):
ctx = SimpleNamespace(
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
assert result["errors"] == len(result["error_details"])
assert result["errors"] == 2
def test_successful_tables_in_summary(self):
"""成功装载的表出现在 tables 列表中,不影响 errors 计数。"""
task = _make_task()
# dim_ok 返回空列 → 跳过(不报错也不计入 summary
# dim_fail 抛异常 → 计入 errors
call_count = 0
def fake_get_columns(cur, table):
nonlocal call_count
call_count += 1
if "fail" in table:
raise RuntimeError("boom")
return [] # 空列 → 跳过
task._get_columns = fake_get_columns
with patch.object(
DwdLoadTask,
"TABLE_MAP",
{"dwd.dim_ok": "ods.ok", "dwd.dim_fail": "ods.fail"},
):
ctx = SimpleNamespace(
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
assert result["errors"] == 1
assert len(result["error_details"]) == 1
# dim_ok 因空列被跳过,不在 summary 中
assert result["tables"] == []

View File

@@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
"""
DWS 任务 birthday 字段恢复测试
验证需求 4.4DWS 任务从 dim_member.birthday 读取生日字段并写入 DWS 目标表。
纯单元测试,使用 FakeDB/Mock不涉及真实数据库连接。
"""
import pytest
from datetime import date
from decimal import Decimal
from unittest.mock import MagicMock
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from tests.unit.task_test_utils import FakeDBOperations
# ---------------------------------------------------------------------------
# Helper: 创建 MemberVisitTask 实例
# ---------------------------------------------------------------------------
def _create_member_visit_task():
from tasks.dws.member_visit_task import MemberVisitTask
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: (
1 if key == "app.tenant_id" else default
)
db = FakeDBOperations()
return MemberVisitTask(mock_config, db, MagicMock(), MagicMock()), db
def _create_member_consumption_task():
from tasks.dws.member_consumption_task import MemberConsumptionTask
mock_config = MagicMock()
mock_config.get.side_effect = lambda key, default=None: (
1 if key == "app.tenant_id" else default
)
db = FakeDBOperations()
return MemberConsumptionTask(mock_config, db, MagicMock(), MagicMock()), db
# ===========================================================================
# MemberVisitTask._extract_member_info — birthday 字段
# ===========================================================================
class TestMemberVisitBirthday:
"""验证 MemberVisitTask 的 birthday 提取与写入"""
def test_extract_member_info_sql_includes_birthday(self):
"""_extract_member_info SQL 应包含 birthday 字段"""
task, db = _create_member_visit_task()
db.query_results.append([
{"member_id": 100, "nickname": "张三", "mobile": "13800001111", "birthday": date(1990, 5, 15)},
])
result = task._extract_member_info(site_id=1)
# 验证 SQL 包含 birthday
sql_executed = db.executes[0]["sql"]
assert "birthday" in sql_executed
# 验证返回值包含 birthday
assert result[100]["birthday"] == date(1990, 5, 15)
def test_extract_member_info_birthday_none(self):
"""birthday 为 None 时应正常返回"""
task, db = _create_member_visit_task()
db.query_results.append([
{"member_id": 200, "nickname": "李四", "mobile": "13900002222", "birthday": None},
])
result = task._extract_member_info(site_id=1)
assert result[200]["birthday"] is None
def test_transform_writes_member_birthday(self):
"""transform() 应将 birthday 写入 member_birthday 字段"""
task, db = _create_member_visit_task()
extracted = {
"settlements": [
{
"member_id": 100,
"order_settle_id": 1001,
"visit_date": date(2026, 2, 20),
"create_time": "2026-02-20 14:00:00",
"table_id": 10,
"table_charge_money": Decimal("50.00"),
"goods_money": Decimal("30.00"),
"assistant_pd_money": Decimal("20.00"),
"assistant_cx_money": Decimal("10.00"),
"consume_money": Decimal("110.00"),
"pay_amount": Decimal("100.00"),
"balance_amount": Decimal("0"),
"gift_card_amount": Decimal("0"),
"coupon_amount": Decimal("10.00"),
"discount_money": Decimal("0"),
"free_money": Decimal("0"),
},
],
"assistant_services": [],
"member_info": {
100: {
"member_id": 100,
"nickname": "张三",
"mobile": "13800001111",
"birthday": date(1990, 5, 15),
},
},
"table_info": {
10: {"table_id": 10, "table_name": "1号台", "area_name": "大厅"},
},
"table_fee_durations": [],
"site_id": 1,
}
from tasks.dws.base_dws_task import TaskContext
from datetime import datetime
ctx = TaskContext(
store_id=1,
window_start=datetime(2026, 2, 20),
window_end=datetime(2026, 2, 20, 23, 59, 59),
window_minutes=1440,
)
results = task.transform(extracted, ctx)
assert len(results) == 1
assert results[0]["member_birthday"] == date(1990, 5, 15)
def test_transform_member_birthday_none_when_no_birthday(self):
"""会员无 birthday 时 member_birthday 应为 None"""
task, db = _create_member_visit_task()
extracted = {
"settlements": [
{
"member_id": 300,
"order_settle_id": 3001,
"visit_date": date(2026, 2, 20),
"create_time": "2026-02-20 15:00:00",
"table_id": 10,
"table_charge_money": Decimal("0"),
"goods_money": Decimal("0"),
"assistant_pd_money": Decimal("0"),
"assistant_cx_money": Decimal("0"),
"consume_money": Decimal("0"),
"pay_amount": Decimal("0"),
"balance_amount": Decimal("0"),
"gift_card_amount": Decimal("0"),
"coupon_amount": Decimal("0"),
"discount_money": Decimal("0"),
"free_money": Decimal("0"),
},
],
"assistant_services": [],
"member_info": {
300: {"member_id": 300, "nickname": "王五", "mobile": None, "birthday": None},
},
"table_info": {10: {"table_id": 10, "table_name": "1号台", "area_name": None}},
"table_fee_durations": [],
"site_id": 1,
}
from tasks.dws.base_dws_task import TaskContext
from datetime import datetime
ctx = TaskContext(
store_id=1,
window_start=datetime(2026, 2, 20),
window_end=datetime(2026, 2, 20, 23, 59, 59),
window_minutes=1440,
)
results = task.transform(extracted, ctx)
assert results[0]["member_birthday"] is None
# ===========================================================================
# MemberConsumptionTask._extract_member_info — birthday 字段
# ===========================================================================
class TestMemberConsumptionBirthday:
"""验证 MemberConsumptionTask 的 birthday 提取"""
def test_extract_member_info_sql_includes_birthday(self):
"""_extract_member_info SQL 应包含 birthday 字段"""
task, db = _create_member_consumption_task()
db.query_results.append([
{
"member_id": 100, "nickname": "张三", "mobile": "13800001111",
"member_card_grade_name": "金卡", "register_date": date(2025, 1, 1),
"recharge_money_sum": Decimal("500.00"), "birthday": date(1995, 8, 20),
},
])
result = task._extract_member_info(site_id=1)
sql_executed = db.executes[0]["sql"]
assert "birthday" in sql_executed
assert result[100]["birthday"] == date(1995, 8, 20)
def test_extract_member_info_birthday_none(self):
"""birthday 为 None 时应正常返回"""
task, db = _create_member_consumption_task()
db.query_results.append([
{
"member_id": 200, "nickname": "李四", "mobile": None,
"member_card_grade_name": None, "register_date": None,
"recharge_money_sum": Decimal("0"), "birthday": None,
},
])
result = task._extract_member_info(site_id=1)
assert result[200]["birthday"] is None

View File

@@ -124,14 +124,14 @@ class TestFlowModeE2E:
)
result = runner.run(
pipeline="api_ods",
flow="api_ods",
processing_mode="increment_only",
data_source="hybrid",
)
# 结构验证
assert result["status"] == "SUCCESS"
assert result["pipeline"] == "api_ods"
assert result["flow"] == "api_ods"
assert result["layers"] == ["ODS"]
assert isinstance(result["results"], list)
# TaskExecutor 被调用
@@ -159,7 +159,7 @@ class TestFlowModeE2E:
# 校验框架可能未安装mock 掉 _run_verification
with patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}):
result = runner.run(
pipeline="api_ods",
flow="api_ods",
processing_mode="verify_only",
data_source="hybrid",
)
@@ -218,5 +218,5 @@ class TestSchedulerThinWrapper:
scheduler.task_executor.run_tasks.assert_called_once()
# run_flow_with_verification 委托
scheduler.run_flow_with_verification(pipeline="api_ods")
scheduler.run_flow_with_verification(flow="api_ods")
scheduler.flow_runner.run.assert_called_once()

View File

@@ -80,31 +80,27 @@ class TestLayersArgParsing:
args = parse_args()
assert args.layers is None
def test_pipeline_still_works(self):
"""--pipeline 参数保留可用(需求 6.3),存入 pipeline_deprecated弃用别名"""
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
args = parse_args()
assert args.pipeline_deprecated == "api_full"
# ---------------------------------------------------------------------------
# 3. --layers 与 --pipeline 互斥(需求 6.4
# ---------------------------------------------------------------------------
class TestLayersPipelineMutualExclusion:
"""--layers 和 --flow/--pipeline 互斥校验
class TestLayersFlowMutualExclusion:
"""--layers 和 --flow 互斥校验
互斥校验在 main() 中实现(非 argparse 层),
此处验证两个参数可以同时被解析(互斥由 main 层处理)。
"""
def test_both_args_can_be_parsed(self):
"""argparse 层允许同时传入 --layers 和 --pipeline弃用别名,互斥由 main() 校验"""
"""argparse 层允许同时传入 --layers 和 --flow,互斥由 main() 校验"""
with patch("sys.argv", [
"cli", "--layers", "ODS,DWD", "--pipeline", "api_full",
"cli", "--layers", "ODS,DWD", "--flow", "api_full",
]):
args = parse_args()
assert args.layers == "ODS,DWD"
assert args.pipeline_deprecated == "api_full"
assert args.flow == "api_full"
# ---------------------------------------------------------------------------
@@ -208,7 +204,7 @@ class TestParseLayersProperties:
# 5. --flow / --pipeline 弃用别名测试(需求 9.3, 9.4
# ---------------------------------------------------------------------------
class TestFlowParameter:
"""--flow 作为主参数、--pipeline 作为弃用别名"""
"""--flow 参数测试(--pipeline 已移除)"""
def test_flow_parsed(self):
"""--flow 作为主参数可正常解析"""
@@ -222,52 +218,6 @@ class TestFlowParameter:
args = parse_args()
assert args.flow is None
def test_pipeline_deprecated_parsed(self):
"""--pipeline 仍可解析,存入 pipeline_deprecated"""
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
args = parse_args()
assert args.pipeline_deprecated == "api_full"
assert args.flow is None # --flow 未指定
def test_pipeline_emits_deprecation_warning(self):
"""使用 --pipeline 时应发出 DeprecationWarning需求 9.4
直接模拟 main() 中的弃用逻辑,避免进入数据库连接。
"""
import warnings
with patch("sys.argv", ["cli", "--pipeline", "api_full"]):
args = parse_args()
# 模拟 main() 中的弃用处理逻辑
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
if args.pipeline_deprecated:
warnings.warn(
"--pipeline 参数已弃用,请使用 --flow",
DeprecationWarning,
stacklevel=2,
)
if not args.flow:
args.flow = args.pipeline_deprecated
dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert len(dep_warnings) == 1
assert "--pipeline 参数已弃用" in str(dep_warnings[0].message)
# 验证值已合并到 args.flow
assert args.flow == "api_full"
def test_flow_and_pipeline_mutually_exclusive(self):
"""--flow 和 --pipeline 不能同时指定(需求 9.3
argparse 层允许同时传入,互斥由 main() 中的逻辑处理。
"""
with patch("sys.argv", ["cli", "--flow", "api_full", "--pipeline", "api_ods"]):
args = parse_args()
# 两者同时存在时main() 应 sys.exit(2)
assert args.flow == "api_full"
assert args.pipeline_deprecated == "api_ods"
def test_layers_and_flow_mutually_exclusive(self):
"""--layers 和 --flow 互斥argparse 层可同时解析main() 校验)"""
with patch("sys.argv", ["cli", "--layers", "ODS,DWD", "--flow", "api_full"]):
@@ -275,21 +225,3 @@ class TestFlowParameter:
assert args.layers == "ODS,DWD"
assert args.flow == "api_full"
def test_layers_and_pipeline_deprecated_mutually_exclusive(self):
"""--layers 和 --pipeline弃用别名也互斥
--pipeline 先合并到 args.flow然后 --layers vs --flow 互斥生效。
"""
with patch("sys.argv", ["cli", "--layers", "ODS,DWD", "--pipeline", "api_full"]):
args = parse_args()
assert args.layers == "ODS,DWD"
assert args.pipeline_deprecated == "api_full"
def test_pipeline_value_merges_to_flow(self):
"""--pipeline 的值在弃用处理后应合并到 args.flow"""
with patch("sys.argv", ["cli", "--pipeline", "dwd_dws"]):
args = parse_args()
# 模拟 main() 中的合并逻辑
if args.pipeline_deprecated and not args.flow:
args.flow = args.pipeline_deprecated
assert args.flow == "dwd_dws"

View File

@@ -0,0 +1,903 @@
# -*- coding: utf-8 -*-
"""助教月度聚合属性测试 — 使用 hypothesis 验证档位分段聚合的正确性。
纯单元测试,模拟 _extract_daily_aggregates() 的 SQL GROUP BY 逻辑,不涉及真实数据库连接。
# Feature: etl-aggregation-fix, Property 3: 档位分段聚合正确性
# **Validates: Requirements 2.1**
"""
from __future__ import annotations
from collections import defaultdict
from datetime import date, timedelta
from decimal import Decimal
from typing import Any, Dict, List, Set, Tuple
from hypothesis import given, settings
from hypothesis import strategies as st
# ── 业绩指标字段列表(与 _extract_daily_aggregates SQL 中 SUM 聚合的字段一致)──
_METRIC_FIELDS = [
"total_service_count",
"base_service_count",
"bonus_service_count",
"room_service_count",
"total_hours",
"base_hours",
"bonus_hours",
"room_hours",
"total_ledger_amount",
"base_ledger_amount",
"bonus_ledger_amount",
"room_ledger_amount",
"unique_customers",
"unique_tables",
"trashed_seconds",
"trashed_count",
]
# ── 核心逻辑:模拟 _extract_daily_aggregates() 的 SQL GROUP BY 行为 ──
#
# 实际 SQL
# SELECT assistant_id, assistant_level_code, assistant_level_name,
# (ARRAY_AGG(assistant_nickname ORDER BY stat_date DESC))[1] AS assistant_nickname,
# DATE_TRUNC('month', stat_date)::DATE AS stat_month,
# COUNT(DISTINCT stat_date) AS work_days,
# SUM(total_service_count) AS total_service_count,
# ...
# FROM dws.dws_assistant_daily_detail
# WHERE site_id = %s AND (stat_date >= ... AND stat_date < ...)
# GROUP BY assistant_id, assistant_level_code, assistant_level_name,
# DATE_TRUNC('month', stat_date)
def simulate_extract_daily_aggregates(
daily_rows: List[Dict[str, Any]],
site_id: int,
month: date,
) -> List[Dict[str, Any]]:
"""模拟 _extract_daily_aggregates() 的 GROUP BY 逻辑。
按 (assistant_id, assistant_level_code, assistant_level_name, stat_month) 分组,
对业绩指标做 SUM 聚合work_days 做 COUNT(DISTINCT stat_date)。
"""
month_start = month.replace(day=1)
next_month = (month_start.replace(day=28) + timedelta(days=4)).replace(day=1)
# 筛选site_id 匹配 + 日期在月份范围内
filtered = [
r for r in daily_rows
if r.get("site_id") == site_id
and month_start <= r["stat_date"] < next_month
]
# GROUP BY (assistant_id, assistant_level_code, assistant_level_name)
groups: Dict[Tuple, List[Dict]] = defaultdict(list)
for row in filtered:
key = (
row["assistant_id"],
row.get("assistant_level_code"),
row.get("assistant_level_name"),
)
groups[key].append(row)
results = []
for (aid, level_code, level_name), rows in groups.items():
agg: Dict[str, Any] = {
"assistant_id": aid,
"assistant_level_code": level_code,
"assistant_level_name": level_name,
"stat_month": month_start,
"work_days": len({r["stat_date"] for r in rows}),
}
# SUM 聚合所有业绩指标
for field in _METRIC_FIELDS:
agg[field] = sum(Decimal(str(r.get(field, 0))) for r in rows)
# nickname 按时间倒序取第一条
sorted_rows = sorted(rows, key=lambda r: r["stat_date"], reverse=True)
agg["assistant_nickname"] = sorted_rows[0].get("assistant_nickname", "")
results.append(agg)
return results
# ── 生成策略 ──────────────────────────────────────────────────
# 助教 ID
_assistant_id_st = st.integers(min_value=1, max_value=50)
# 档位代码A/B/C/D/E 代表不同档位
_level_code_st = st.sampled_from(["A", "B", "C", "D", "E"])
# 档位名称映射(与 code 绑定,保证同一 code 对应同一 name
_LEVEL_NAME_MAP = {
"A": "初级",
"B": "中级",
"C": "高级",
"D": "资深",
"E": "专家",
}
# 昵称
_nickname_st = st.text(
alphabet=st.sampled_from("张李王赵刘陈杨黄周吴"),
min_size=2,
max_size=4,
)
# 非负业绩指标值(整数型)
_count_metric_st = st.integers(min_value=0, max_value=500)
# 非负业绩指标值(金额/小时型,用 Decimal 表示)
_amount_metric_st = st.decimals(
min_value=0,
max_value=Decimal("99999.99"),
places=2,
allow_nan=False,
allow_infinity=False,
)
@st.composite
def _daily_detail_row_st(draw, assistant_id: int, level_code: str, month: date):
"""生成一条 dws_assistant_daily_detail 日度明细行。
assistant_id 和 level_code 由外部指定,确保可控制分组。
stat_date 在指定月份内随机选取。
"""
# 月内随机日期
month_start = month.replace(day=1)
next_month = (month_start.replace(day=28) + timedelta(days=4)).replace(day=1)
max_day = (next_month - timedelta(days=1)).day
day = draw(st.integers(min_value=1, max_value=max_day))
stat_date = month_start.replace(day=day)
return {
"site_id": 1, # 固定 site_id 简化测试
"assistant_id": assistant_id,
"assistant_level_code": level_code,
"assistant_level_name": _LEVEL_NAME_MAP[level_code],
"assistant_nickname": draw(_nickname_st),
"stat_date": stat_date,
"total_service_count": draw(_count_metric_st),
"base_service_count": draw(_count_metric_st),
"bonus_service_count": draw(_count_metric_st),
"room_service_count": draw(_count_metric_st),
"total_hours": float(draw(_amount_metric_st)),
"base_hours": float(draw(_amount_metric_st)),
"bonus_hours": float(draw(_amount_metric_st)),
"room_hours": float(draw(_amount_metric_st)),
"total_ledger_amount": float(draw(_amount_metric_st)),
"base_ledger_amount": float(draw(_amount_metric_st)),
"bonus_ledger_amount": float(draw(_amount_metric_st)),
"room_ledger_amount": float(draw(_amount_metric_st)),
"unique_customers": draw(_count_metric_st),
"unique_tables": draw(_count_metric_st),
"trashed_seconds": draw(_count_metric_st),
"trashed_count": draw(_count_metric_st),
}
@st.composite
def _multi_level_scenario_st(draw):
"""生成一个助教在同一月内有多个档位的场景。
返回:
- daily_rows: 日度明细行列表
- assistant_id: 助教 ID
- month: 统计月份
- level_codes: 该助教在该月的不同档位代码集合
"""
assistant_id = draw(_assistant_id_st)
# 固定月份为 2025-01
month = date(2025, 1, 1)
# 随机选取 1~4 个不同档位
n_levels = draw(st.integers(min_value=1, max_value=4))
level_codes = draw(
st.lists(
_level_code_st,
min_size=n_levels,
max_size=n_levels,
unique=True,
)
)
# 为每个档位生成 1~5 条日度明细
daily_rows: List[Dict[str, Any]] = []
for code in level_codes:
n_rows = draw(st.integers(min_value=1, max_value=5))
for _ in range(n_rows):
row = draw(_daily_detail_row_st(assistant_id, code, month))
daily_rows.append(row)
return {
"daily_rows": daily_rows,
"assistant_id": assistant_id,
"month": month,
"level_codes": set(level_codes),
}
# ── Property 3: 档位分段聚合正确性 ───────────────────────────
# Feature: etl-aggregation-fix, Property 3: 档位分段聚合正确性
# **Validates: Requirements 2.1**
#
# 对于任意助教在同一月内存在 N 个不同 assistant_level_code 的日度数据,
# _extract_daily_aggregates() 应返回恰好 N 行记录,
# 每行的业绩指标之和应等于该助教该月的总业绩。
class TestProperty3LevelSegmentAggregation:
"""Property 3: 档位分段聚合正确性。"""
@given(scenario=_multi_level_scenario_st())
@settings(max_examples=200)
def test_row_count_equals_distinct_level_codes(self, scenario):
"""聚合结果行数应等于不同 assistant_level_code 的数量。"""
daily_rows = scenario["daily_rows"]
assistant_id = scenario["assistant_id"]
month = scenario["month"]
level_codes = scenario["level_codes"]
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
# 只看当前助教的结果
assistant_results = [r for r in results if r["assistant_id"] == assistant_id]
assert len(assistant_results) == len(level_codes), (
f"助教 {assistant_id}{len(level_codes)} 个不同档位 {level_codes}"
f"但聚合结果返回 {len(assistant_results)}"
)
@given(scenario=_multi_level_scenario_st())
@settings(max_examples=200)
def test_metric_sums_equal_total(self, scenario):
"""所有档位行的业绩指标之和应等于该助教该月的总业绩。"""
daily_rows = scenario["daily_rows"]
assistant_id = scenario["assistant_id"]
month = scenario["month"]
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
assistant_results = [r for r in results if r["assistant_id"] == assistant_id]
# 手动计算该助教该月的总业绩(不分档位)
for field in _METRIC_FIELDS:
expected_total = sum(
Decimal(str(r.get(field, 0)))
for r in daily_rows
if r["assistant_id"] == assistant_id
)
actual_total = sum(r[field] for r in assistant_results)
assert actual_total == expected_total, (
f"指标 {field}: 各档位之和={actual_total} != 总业绩={expected_total}"
)
@given(scenario=_multi_level_scenario_st())
@settings(max_examples=200)
def test_each_row_has_correct_level_code(self, scenario):
"""每行结果的 assistant_level_code 应属于原始数据中出现的档位集合。"""
daily_rows = scenario["daily_rows"]
assistant_id = scenario["assistant_id"]
month = scenario["month"]
level_codes = scenario["level_codes"]
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
assistant_results = [r for r in results if r["assistant_id"] == assistant_id]
result_codes = {r["assistant_level_code"] for r in assistant_results}
assert result_codes == level_codes, (
f"聚合结果的档位集合 {result_codes} != 原始数据的档位集合 {level_codes}"
)
@given(scenario=_multi_level_scenario_st())
@settings(max_examples=100)
def test_per_level_metrics_match_source(self, scenario):
"""每个档位行的业绩指标应等于该档位对应日度数据的 SUM。"""
daily_rows = scenario["daily_rows"]
assistant_id = scenario["assistant_id"]
month = scenario["month"]
results = simulate_extract_daily_aggregates(daily_rows, site_id=1, month=month)
assistant_results = {
r["assistant_level_code"]: r
for r in results
if r["assistant_id"] == assistant_id
}
# 按档位分组计算期望值
by_level: Dict[str, List[Dict]] = defaultdict(list)
for r in daily_rows:
if r["assistant_id"] == assistant_id:
by_level[r["assistant_level_code"]].append(r)
for code, rows in by_level.items():
assert code in assistant_results, (
f"档位 {code} 在原始数据中存在但聚合结果中缺失"
)
agg = assistant_results[code]
for field in _METRIC_FIELDS:
expected = sum(Decimal(str(r.get(field, 0))) for r in rows)
assert agg[field] == expected, (
f"档位 {code} 指标 {field}: 聚合值={agg[field]} != 期望={expected}"
)
# ── Property 4: nickname 按时间倒序取值 ─────────────────────────
# Feature: etl-aggregation-fix, Property 4: nickname 按时间倒序取值
# **Validates: Requirements 2.3, 2.5, 2.6**
#
# 对于任意助教在聚合周期内有多条不同 nickname 的记录,
# 聚合结果中的 nickname 应等于时间最晚的那条记录的 nickname。
# 此属性适用于 AssistantMonthlyTask、AssistantFinanceTask、AssistantCustomerTask。
# ── 模拟函数FinanceTask._extract_daily_revenue() 的 nickname 取值逻辑 ──
#
# 实际 SQL
# SELECT DATE(s.start_use_time) AS stat_date,
# s.site_assistant_id AS assistant_id,
# (ARRAY_AGG(s.nickname ORDER BY s.start_use_time DESC))[1] AS assistant_nickname,
# ...
# GROUP BY DATE(s.start_use_time), s.site_assistant_id
def simulate_finance_daily_revenue(
service_rows: List[Dict[str, Any]],
site_id: int,
start_date: date,
end_date: date,
) -> List[Dict[str, Any]]:
"""模拟 AssistantFinanceTask._extract_daily_revenue() 的 GROUP BY 逻辑。
按 (stat_date, assistant_id) 分组nickname 按 start_use_time 倒序取第一条。
"""
filtered = [
r for r in service_rows
if r.get("site_id") == site_id
and start_date <= r["stat_date"] <= end_date
and r.get("is_delete", 0) == 0
]
groups: Dict[Tuple, List[Dict]] = defaultdict(list)
for row in filtered:
key = (row["stat_date"], row["assistant_id"])
groups[key].append(row)
results = []
for (stat_date, aid), rows in groups.items():
# nickname 按 start_use_time 倒序取第一条
sorted_rows = sorted(rows, key=lambda r: r["start_use_time"], reverse=True)
results.append({
"stat_date": stat_date,
"assistant_id": aid,
"assistant_nickname": sorted_rows[0].get("nickname", ""),
})
return results
# ── 模拟函数CustomerTask._extract_service_pairs() 的 nickname 取值逻辑 ──
#
# 实际 SQL
# SELECT assistant_id,
# (ARRAY_AGG(assistant_nickname ORDER BY service_date DESC))[1] AS assistant_nickname,
# member_id, ...
# GROUP BY assistant_id, member_id
def simulate_customer_service_pairs(
service_rows: List[Dict[str, Any]],
site_id: int,
) -> List[Dict[str, Any]]:
"""模拟 AssistantCustomerTask._extract_service_pairs() 的 GROUP BY 逻辑。
按 (assistant_id, member_id) 分组nickname 按 service_date 倒序取第一条。
"""
filtered = [
r for r in service_rows
if r.get("site_id") == site_id
and r.get("member_id") is not None
and r.get("member_id") != 0
and r.get("is_delete", 0) == 0
]
groups: Dict[Tuple, List[Dict]] = defaultdict(list)
for row in filtered:
key = (row["assistant_id"], row["member_id"])
groups[key].append(row)
results = []
for (aid, mid), rows in groups.items():
sorted_rows = sorted(rows, key=lambda r: r["service_date"], reverse=True)
results.append({
"assistant_id": aid,
"member_id": mid,
"assistant_nickname": sorted_rows[0].get("assistant_nickname", ""),
})
return results
# ── Property 4 生成策略 ──────────────────────────────────────
# 确保同一分组内有多条不同 nickname 的记录
_distinct_nickname_st = st.lists(
st.text(
alphabet=st.sampled_from("张李王赵刘陈杨黄周吴徐孙马朱胡林郭何高"),
min_size=2,
max_size=4,
),
min_size=2,
max_size=6,
unique=True,
)
@st.composite
def _monthly_nickname_scenario_st(draw):
"""生成 AssistantMonthlyTask 的 nickname 测试场景。
同一助教同一档位在同一月内有多条不同 nickname 的日度记录。
验证聚合结果的 nickname == 时间最晚记录的 nickname。
"""
assistant_id = draw(_assistant_id_st)
level_code = draw(_level_code_st)
month = date(2025, 1, 1)
nicknames = draw(_distinct_nickname_st)
month_start = month.replace(day=1)
max_day = 31 # 2025-01 有 31 天
# 为每个 nickname 生成一条记录,确保 stat_date 各不相同以便确定性排序
daily_rows: List[Dict[str, Any]] = []
used_days: Set[int] = set()
for nick in nicknames:
day = draw(
st.integers(min_value=1, max_value=max_day).filter(
lambda d, _used=frozenset(used_days): d not in _used
)
)
used_days.add(day)
row = {
"site_id": 1,
"assistant_id": assistant_id,
"assistant_level_code": level_code,
"assistant_level_name": _LEVEL_NAME_MAP[level_code],
"assistant_nickname": nick,
"stat_date": month_start.replace(day=day),
# 业绩指标用固定值,本测试不关心
**{f: 1 for f in _METRIC_FIELDS},
}
daily_rows.append(row)
# 期望的 nickname = stat_date 最大的那条
expected_nickname = max(daily_rows, key=lambda r: r["stat_date"])["assistant_nickname"]
return {
"daily_rows": daily_rows,
"assistant_id": assistant_id,
"level_code": level_code,
"month": month,
"expected_nickname": expected_nickname,
}
@st.composite
def _finance_nickname_scenario_st(draw):
"""生成 AssistantFinanceTask 的 nickname 测试场景。
同一助教同一天有多条不同 nickname 的服务记录(不同 start_use_time
验证聚合结果的 nickname == start_use_time 最晚的那条。
"""
from datetime import datetime
assistant_id = draw(_assistant_id_st)
stat_date = date(2025, 1, 15)
nicknames = draw(_distinct_nickname_st)
service_rows: List[Dict[str, Any]] = []
used_hours: Set[int] = set()
for nick in nicknames:
hour = draw(
st.integers(min_value=0, max_value=23).filter(
lambda h, _used=frozenset(used_hours): h not in _used
)
)
used_hours.add(hour)
minute = draw(st.integers(min_value=0, max_value=59))
service_rows.append({
"site_id": 1,
"assistant_id": assistant_id,
"nickname": nick,
"stat_date": stat_date,
"start_use_time": datetime(2025, 1, 15, hour, minute, 0),
"is_delete": 0,
})
expected_nickname = max(
service_rows, key=lambda r: r["start_use_time"]
)["nickname"]
return {
"service_rows": service_rows,
"assistant_id": assistant_id,
"stat_date": stat_date,
"expected_nickname": expected_nickname,
}
@st.composite
def _customer_nickname_scenario_st(draw):
"""生成 AssistantCustomerTask 的 nickname 测试场景。
同一助教同一会员有多条不同 nickname 的服务记录(不同 service_date
验证聚合结果的 nickname == service_date 最晚的那条。
"""
assistant_id = draw(_assistant_id_st)
member_id = draw(st.integers(min_value=1, max_value=9999))
nicknames = draw(_distinct_nickname_st)
service_rows: List[Dict[str, Any]] = []
used_days: Set[int] = set()
for nick in nicknames:
day = draw(
st.integers(min_value=1, max_value=28).filter(
lambda d, _used=frozenset(used_days): d not in _used
)
)
used_days.add(day)
service_rows.append({
"site_id": 1,
"assistant_id": assistant_id,
"member_id": member_id,
"assistant_nickname": nick,
"service_date": date(2025, 1, day),
"is_delete": 0,
})
expected_nickname = max(
service_rows, key=lambda r: r["service_date"]
)["assistant_nickname"]
return {
"service_rows": service_rows,
"assistant_id": assistant_id,
"member_id": member_id,
"expected_nickname": expected_nickname,
}
# ── Property 4 测试类 ────────────────────────────────────────
class TestProperty4NicknameDescOrder:
"""Property 4: nickname 按时间倒序取值。
对于任意助教在聚合周期内有多条不同 nickname 的记录,
聚合结果中的 nickname 应等于时间最晚的那条记录的 nickname。
"""
# ── 4a: AssistantMonthlyTask — 按 stat_date DESC 取 nickname ──
@given(scenario=_monthly_nickname_scenario_st())
@settings(max_examples=200)
def test_monthly_nickname_equals_latest_stat_date(self, scenario):
"""AssistantMonthlyTask: nickname 应等于 stat_date 最晚记录的 nickname。"""
results = simulate_extract_daily_aggregates(
scenario["daily_rows"], site_id=1, month=scenario["month"],
)
# 定位到当前助教 + 档位的聚合行
target = [
r for r in results
if r["assistant_id"] == scenario["assistant_id"]
and r["assistant_level_code"] == scenario["level_code"]
]
assert len(target) == 1, f"期望 1 行聚合结果,实际 {len(target)}"
assert target[0]["assistant_nickname"] == scenario["expected_nickname"], (
f"MonthlyTask nickname 应为 '{scenario['expected_nickname']}'"
f"实际为 '{target[0]['assistant_nickname']}'"
)
# ── 4b: AssistantFinanceTask — 按 start_use_time DESC 取 nickname ──
@given(scenario=_finance_nickname_scenario_st())
@settings(max_examples=200)
def test_finance_nickname_equals_latest_start_use_time(self, scenario):
"""AssistantFinanceTask: nickname 应等于 start_use_time 最晚记录的 nickname。"""
results = simulate_finance_daily_revenue(
scenario["service_rows"],
site_id=1,
start_date=scenario["stat_date"],
end_date=scenario["stat_date"],
)
target = [
r for r in results
if r["assistant_id"] == scenario["assistant_id"]
and r["stat_date"] == scenario["stat_date"]
]
assert len(target) == 1, f"期望 1 行聚合结果,实际 {len(target)}"
assert target[0]["assistant_nickname"] == scenario["expected_nickname"], (
f"FinanceTask nickname 应为 '{scenario['expected_nickname']}'"
f"实际为 '{target[0]['assistant_nickname']}'"
)
# ── 4c: AssistantCustomerTask — 按 service_date DESC 取 nickname ──
@given(scenario=_customer_nickname_scenario_st())
@settings(max_examples=200)
def test_customer_nickname_equals_latest_service_date(self, scenario):
"""AssistantCustomerTask: nickname 应等于 service_date 最晚记录的 nickname。"""
results = simulate_customer_service_pairs(
scenario["service_rows"], site_id=1,
)
target = [
r for r in results
if r["assistant_id"] == scenario["assistant_id"]
and r["member_id"] == scenario["member_id"]
]
assert len(target) == 1, f"期望 1 行聚合结果,实际 {len(target)}"
assert target[0]["assistant_nickname"] == scenario["expected_nickname"], (
f"CustomerTask nickname 应为 '{scenario['expected_nickname']}'"
f"实际为 '{target[0]['assistant_nickname']}'"
)
# ── Property 5: 工资按档位分段计算 ─────────────────────────────
# Feature: etl-aggregation-fix, Property 5: 工资按档位分段计算
# **Validates: Requirements 2.4**
#
# 对于任意助教在同一月有多个档位的月度汇总记录,
# AssistantSalaryTask 应为每个档位分别计算工资,
# 每个档位使用对应的 level_price 和 tier 配置,
# 且所有档位的工资记录数等于月度汇总的行数。
# ── 模拟 AssistantSalaryTask.transform() 的工资计算逻辑 ──
#
# 实际流程:
# 1. extract() 从 dws_assistant_monthly_summary 取出多行(同一助教不同档位)
# 2. transform() 遍历每行,调用 _calculate_salary()
# 3. _calculate_salary() 按行的 assistant_level_code 获取 level_price 和 tier
# 4. 每行独立计算工资,生成一条工资记录
#
# 本测试不模拟完整工资公式,只验证"分段"行为:
# - 工资记录数 == 月度汇总行数
# - 每条工资记录使用的是对应档位的 level_price
# 等级定价配置(模拟 cfg_assistant_level_price
_LEVEL_PRICE_CONFIG = {
"A": {"base_course_price": Decimal("98"), "bonus_course_price": Decimal("190")},
"B": {"base_course_price": Decimal("108"), "bonus_course_price": Decimal("190")},
"C": {"base_course_price": Decimal("118"), "bonus_course_price": Decimal("190")},
"D": {"base_course_price": Decimal("128"), "bonus_course_price": Decimal("190")},
"E": {"base_course_price": Decimal("138"), "bonus_course_price": Decimal("190")},
}
# 档位配置(模拟 cfg_performance_tier
_TIER_CONFIG = {
1: {"tier_code": "T1", "tier_name": "1档", "base_deduction": Decimal("18"), "bonus_deduction_ratio": Decimal("0.40")},
2: {"tier_code": "T2", "tier_name": "2档", "base_deduction": Decimal("15"), "bonus_deduction_ratio": Decimal("0.38")},
3: {"tier_code": "T3", "tier_name": "3档", "base_deduction": Decimal("13"), "bonus_deduction_ratio": Decimal("0.35")},
}
def _simulate_get_level_price(level_code: str) -> Dict[str, Any]:
"""模拟 get_level_price():按档位代码返回等级定价。"""
return _LEVEL_PRICE_CONFIG.get(level_code, _LEVEL_PRICE_CONFIG["A"])
def _simulate_get_tier(tier_id: int) -> Dict[str, Any]:
"""模拟 get_performance_tier_by_id():按 tier_id 返回档位配置。"""
return _TIER_CONFIG.get(tier_id, _TIER_CONFIG[1])
def simulate_salary_transform(
monthly_summary: List[Dict[str, Any]],
site_id: int,
salary_month: date,
) -> List[Dict[str, Any]]:
"""模拟 AssistantSalaryTask.transform() 的核心逻辑。
遍历每条月度汇总记录,按档位独立计算工资。
简化版:只计算课时收入部分,足以验证"分段"行为。
"""
results = []
for summary in monthly_summary:
level_code = summary.get("assistant_level_code")
tier_id = summary.get("tier_id", 1)
# 按档位获取定价和档位配置
level_price = _simulate_get_level_price(level_code)
tier = _simulate_get_tier(tier_id)
base_hours = Decimal(str(summary.get("base_hours", 0)))
bonus_hours = Decimal(str(summary.get("bonus_hours", 0)))
base_course_price = level_price["base_course_price"]
bonus_course_price = level_price["bonus_course_price"]
base_deduction = tier["base_deduction"]
bonus_deduction_ratio = tier["bonus_deduction_ratio"]
# 基础课收入 = 基础课小时数 × (客户支付价格 - 专业课抽成)
base_income = base_hours * (base_course_price - base_deduction)
# 附加课收入 = 附加课小时数 × 附加课价格 × (1 - 打赏课抽成比例)
bonus_income = bonus_hours * bonus_course_price * (Decimal("1") - bonus_deduction_ratio)
results.append({
"site_id": site_id,
"assistant_id": summary["assistant_id"],
"salary_month": salary_month,
"assistant_level_code": level_code,
"base_course_price": base_course_price,
"bonus_course_price": bonus_course_price,
"base_deduction": base_deduction,
"bonus_deduction_ratio": bonus_deduction_ratio,
"tier_id": tier_id,
"base_income": base_income,
"bonus_income": bonus_income,
})
return results
# ── Property 5 生成策略 ──────────────────────────────────────
# 档位 tier_id
_tier_id_st = st.sampled_from([1, 2, 3])
# 小时数Decimal 精度)
_hours_st = st.decimals(
min_value=0,
max_value=Decimal("300"),
places=2,
allow_nan=False,
allow_infinity=False,
)
@st.composite
def _salary_multi_level_scenario_st(draw):
"""生成同一助教在同一月有多个档位的月度汇总记录。
返回:
- monthly_summary: 月度汇总行列表(每行一个档位)
- assistant_id: 助教 ID
- salary_month: 工资月份
- level_codes: 档位代码列表
"""
assistant_id = draw(_assistant_id_st)
salary_month = date(2025, 1, 1)
# 随机选取 2~4 个不同档位(至少 2 个才能验证"分段"
n_levels = draw(st.integers(min_value=2, max_value=4))
level_codes = draw(
st.lists(
_level_code_st,
min_size=n_levels,
max_size=n_levels,
unique=True,
)
)
monthly_summary: List[Dict[str, Any]] = []
for code in level_codes:
tier_id = draw(_tier_id_st)
monthly_summary.append({
"assistant_id": assistant_id,
"assistant_nickname": f"助教{assistant_id}",
"stat_month": salary_month,
"assistant_level_code": code,
"assistant_level_name": _LEVEL_NAME_MAP[code],
"hire_date": date(2024, 6, 1),
"is_new_hire": False,
"effective_hours": float(draw(_hours_st)),
"base_hours": float(draw(_hours_st)),
"bonus_hours": float(draw(_hours_st)),
"room_hours": float(draw(_hours_st)),
"tier_id": tier_id,
"tier_code": f"T{tier_id}",
"tier_name": f"{tier_id}",
"rank_with_ties": draw(st.integers(min_value=1, max_value=10)),
})
return {
"monthly_summary": monthly_summary,
"assistant_id": assistant_id,
"salary_month": salary_month,
"level_codes": level_codes,
}
# ── Property 5 测试类 ────────────────────────────────────────
class TestProperty5SalaryPerLevelSegment:
"""Property 5: 工资按档位分段计算。
对于任意助教在同一月有多个档位的月度汇总记录,
AssistantSalaryTask 应为每个档位分别计算工资,
每个档位使用对应的 level_price 和 tier 配置,
且所有档位的工资记录数等于月度汇总的行数。
"""
@given(scenario=_salary_multi_level_scenario_st())
@settings(max_examples=200)
def test_salary_record_count_equals_summary_rows(self, scenario):
"""工资记录数应等于月度汇总行数(每个档位一条)。"""
results = simulate_salary_transform(
scenario["monthly_summary"],
site_id=1,
salary_month=scenario["salary_month"],
)
assert len(results) == len(scenario["monthly_summary"]), (
f"月度汇总 {len(scenario['monthly_summary'])} 行,"
f"但工资记录 {len(results)}"
)
@given(scenario=_salary_multi_level_scenario_st())
@settings(max_examples=200)
def test_each_salary_uses_correct_level_price(self, scenario):
"""每条工资记录应使用对应档位的 level_price。"""
results = simulate_salary_transform(
scenario["monthly_summary"],
site_id=1,
salary_month=scenario["salary_month"],
)
for record in results:
level_code = record["assistant_level_code"]
expected_price = _LEVEL_PRICE_CONFIG[level_code]
assert record["base_course_price"] == expected_price["base_course_price"], (
f"档位 {level_code}: base_course_price 应为 {expected_price['base_course_price']}"
f"实际为 {record['base_course_price']}"
)
assert record["bonus_course_price"] == expected_price["bonus_course_price"], (
f"档位 {level_code}: bonus_course_price 应为 {expected_price['bonus_course_price']}"
f"实际为 {record['bonus_course_price']}"
)
@given(scenario=_salary_multi_level_scenario_st())
@settings(max_examples=200)
def test_each_salary_uses_correct_tier_config(self, scenario):
"""每条工资记录应使用对应 tier_id 的档位配置。"""
results = simulate_salary_transform(
scenario["monthly_summary"],
site_id=1,
salary_month=scenario["salary_month"],
)
for i, record in enumerate(results):
tier_id = scenario["monthly_summary"][i]["tier_id"]
expected_tier = _TIER_CONFIG[tier_id]
assert record["base_deduction"] == expected_tier["base_deduction"], (
f"tier_id={tier_id}: base_deduction 应为 {expected_tier['base_deduction']}"
f"实际为 {record['base_deduction']}"
)
assert record["bonus_deduction_ratio"] == expected_tier["bonus_deduction_ratio"], (
f"tier_id={tier_id}: bonus_deduction_ratio 应为 {expected_tier['bonus_deduction_ratio']}"
f"实际为 {record['bonus_deduction_ratio']}"
)
@given(scenario=_salary_multi_level_scenario_st())
@settings(max_examples=200)
def test_salary_level_codes_match_summary(self, scenario):
"""工资记录的档位集合应与月度汇总的档位集合一致。"""
results = simulate_salary_transform(
scenario["monthly_summary"],
site_id=1,
salary_month=scenario["salary_month"],
)
result_codes = {r["assistant_level_code"] for r in results}
expected_codes = set(scenario["level_codes"])
assert result_codes == expected_codes, (
f"工资记录档位 {result_codes} != 月度汇总档位 {expected_codes}"
)

View File

@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
"""跨店会员可查属性测试 — 使用 hypothesis 验证多门店场景下会员维度信息的可达性。
纯单元测试,使用 FakeDBOperations 模拟数据库查询,不涉及真实数据库连接。
# Feature: etl-aggregation-fix, Property 6: 跨店会员可查
# **Validates: Requirements 3.1, 3.2**
"""
from __future__ import annotations
import re
from typing import Any, Dict, List, Set, Tuple
from hypothesis import given, settings
from hypothesis import strategies as st
# ── 核心逻辑:模拟"通过事实表反查 dim_member"的查询行为 ──────────
#
# 实际 SQL 模式:
# SELECT ... FROM dwd.dim_member
# WHERE member_id IN (
# SELECT DISTINCT tenant_member_id
# FROM dwd.{fact_table}
# WHERE site_id = %s AND tenant_member_id IS NOT NULL AND tenant_member_id != 0
# ) AND scd2_is_current = 1
#
# 属性要验证的是:无论会员在哪个门店注册,只要在目标门店有消费记录(事实表中存在),
# 该查询模式就能返回其维度信息。
def simulate_fact_table_lookup(
fact_rows: List[Dict[str, Any]],
dim_members: List[Dict[str, Any]],
query_site_id: int,
) -> Dict[int, Dict[str, Any]]:
"""模拟 DWS 任务通过事实表反查 dim_member 的逻辑。
等价于:
SELECT * FROM dim_member
WHERE member_id IN (
SELECT DISTINCT tenant_member_id FROM fact_table
WHERE site_id = :query_site_id
AND tenant_member_id IS NOT NULL AND tenant_member_id != 0
) AND scd2_is_current = 1
"""
# 步骤 1从事实表中提取目标门店的 member_id 集合
member_ids_in_fact: Set[int] = set()
for row in fact_rows:
if (
row.get("site_id") == query_site_id
and row.get("tenant_member_id") is not None
and row.get("tenant_member_id") != 0
):
member_ids_in_fact.add(row["tenant_member_id"])
# 步骤 2从 dim_member 中筛选当前有效版本
result: Dict[int, Dict[str, Any]] = {}
for m in dim_members:
if m.get("scd2_is_current") == 1 and m["member_id"] in member_ids_in_fact:
result[m["member_id"]] = dict(m)
return result
def simulate_register_site_lookup(
dim_members: List[Dict[str, Any]],
query_site_id: int,
) -> Dict[int, Dict[str, Any]]:
"""模拟旧的 WHERE register_site_id = %s 查询模式(用于对比)。"""
result: Dict[int, Dict[str, Any]] = {}
for m in dim_members:
if m.get("scd2_is_current") == 1 and m.get("register_site_id") == query_site_id:
result[m["member_id"]] = dict(m)
return result
# ── 生成策略 ──────────────────────────────────────────────────
# 门店 ID2~5 个门店
_site_id_st = st.integers(min_value=1, max_value=5)
# 会员 ID正整数
_member_id_st = st.integers(min_value=1, max_value=200)
# 昵称 / 手机号
_nickname_st = st.text(
alphabet=st.sampled_from("张李王赵刘陈杨黄周吴"),
min_size=2,
max_size=4,
)
_mobile_st = st.from_regex(r"1[3-9]\d{9}", fullmatch=True)
# 生成 dim_member 行
@st.composite
def _dim_member_st(draw):
mid = draw(_member_id_st)
return {
"member_id": mid,
"register_site_id": draw(_site_id_st),
"nickname": draw(_nickname_st),
"mobile": draw(_mobile_st),
"scd2_is_current": 1, # 只生成当前有效版本
}
# 生成事实表行(会员在某门店的消费记录)
@st.composite
def _fact_row_st(draw, member_ids):
"""生成一条事实表记录member_id 从已有会员中选取。"""
mid = draw(st.sampled_from(member_ids)) if member_ids else draw(_member_id_st)
return {
"site_id": draw(_site_id_st),
"tenant_member_id": mid,
}
# 生成完整的跨店场景
@st.composite
def _cross_store_scenario_st(draw):
"""生成一个包含跨店消费会员的完整场景。
保证至少有一个会员在非注册门店有消费记录。
"""
# 生成 2~10 个会员
members = draw(st.lists(_dim_member_st(), min_size=2, max_size=10, unique_by=lambda m: m["member_id"]))
if not members:
return None
member_ids = [m["member_id"] for m in members]
# 生成事实表记录5~30 条)
fact_rows = draw(st.lists(
_fact_row_st(member_ids),
min_size=5,
max_size=30,
))
# 强制注入至少一条跨店消费记录:
# 选一个会员,在非注册门店生成消费
cross_member = draw(st.sampled_from(members))
other_sites = [s for s in range(1, 6) if s != cross_member["register_site_id"]]
if other_sites:
cross_site = draw(st.sampled_from(other_sites))
fact_rows.append({
"site_id": cross_site,
"tenant_member_id": cross_member["member_id"],
})
else:
cross_site = cross_member["register_site_id"]
query_site_id = cross_site
return {
"members": members,
"fact_rows": fact_rows,
"query_site_id": query_site_id,
"cross_member_id": cross_member["member_id"],
"cross_member_register_site": cross_member["register_site_id"],
}
# ── Property 6: 跨店会员可查 ─────────────────────────────────
# Feature: etl-aggregation-fix, Property 6: 跨店会员可查
# **Validates: Requirements 3.1, 3.2**
#
# 对于任意在 A 店注册但在 B 店有消费记录的会员,
# B 店的 DWS 任务通过事实表反查 dim_member 时,
# 应能获取到该会员的维度信息nickname、mobile 等)。
class TestProperty6CrossStoreMemberVisible:
"""Property 6: 跨店会员可查。"""
@given(scenario=_cross_store_scenario_st())
@settings(max_examples=200)
def test_cross_store_member_found_via_fact_lookup(self, scenario):
"""在 A 店注册、B 店消费的会员B 店 DWS 通过事实表反查应能找到。"""
if scenario is None:
return
members = scenario["members"]
fact_rows = scenario["fact_rows"]
query_site_id = scenario["query_site_id"]
cross_member_id = scenario["cross_member_id"]
result = simulate_fact_table_lookup(fact_rows, members, query_site_id)
# 核心断言:跨店会员必须可查
assert cross_member_id in result, (
f"会员 {cross_member_id}(注册于门店 {scenario['cross_member_register_site']}"
f"在门店 {query_site_id} 有消费记录,但事实表反查未找到"
)
# 维度信息完整性
member_info = result[cross_member_id]
assert "nickname" in member_info, "缺少 nickname 维度信息"
assert "mobile" in member_info, "缺少 mobile 维度信息"
@given(scenario=_cross_store_scenario_st())
@settings(max_examples=200)
def test_old_register_site_pattern_misses_cross_store(self, scenario):
"""旧模式 WHERE register_site_id = %s 会遗漏跨店消费会员(反向验证)。"""
if scenario is None:
return
members = scenario["members"]
query_site_id = scenario["query_site_id"]
cross_member_id = scenario["cross_member_id"]
cross_register_site = scenario["cross_member_register_site"]
# 跨店会员的注册门店 != 查询门店
if cross_register_site == query_site_id:
return # 非跨店场景,跳过
old_result = simulate_register_site_lookup(members, query_site_id)
# 旧模式应该找不到跨店会员
assert cross_member_id not in old_result, (
f"旧模式 register_site_id={query_site_id} 不应找到"
f"注册于门店 {cross_register_site} 的会员 {cross_member_id}"
)
@given(scenario=_cross_store_scenario_st())
@settings(max_examples=200)
def test_fact_lookup_superset_of_register_lookup(self, scenario):
"""事实表反查的结果应是 register_site_id 筛选结果的超集。
即:旧模式能找到的会员,新模式也一定能找到。
"""
if scenario is None:
return
members = scenario["members"]
fact_rows = scenario["fact_rows"]
query_site_id = scenario["query_site_id"]
new_result = simulate_fact_table_lookup(fact_rows, members, query_site_id)
old_result = simulate_register_site_lookup(members, query_site_id)
# 旧模式找到的、且在事实表中有消费记录的会员,新模式也应找到
fact_member_ids = {
r["tenant_member_id"]
for r in fact_rows
if r["site_id"] == query_site_id
and r["tenant_member_id"] is not None
and r["tenant_member_id"] != 0
}
for mid in old_result:
if mid in fact_member_ids:
assert mid in new_result, (
f"会员 {mid} 在旧模式中可查且在事实表中有记录,"
f"但新模式未返回"
)
@given(
members=st.lists(_dim_member_st(), min_size=1, max_size=10, unique_by=lambda m: m["member_id"]),
query_site_id=_site_id_st,
)
@settings(max_examples=100)
def test_no_fact_records_returns_empty(self, members, query_site_id):
"""事实表中无该门店消费记录时,反查结果应为空。"""
# 空事实表
result = simulate_fact_table_lookup([], members, query_site_id)
assert len(result) == 0, (
f"事实表为空时,门店 {query_site_id} 的反查结果应为空,"
f"实际返回 {len(result)}"
)
@given(scenario=_cross_store_scenario_st())
@settings(max_examples=100)
def test_null_and_zero_member_ids_excluded(self, scenario):
"""事实表中 tenant_member_id 为 NULL 或 0 的记录不应参与反查。"""
if scenario is None:
return
members = scenario["members"]
query_site_id = scenario["query_site_id"]
# 构造仅包含 NULL 和 0 的事实表
bad_fact_rows = [
{"site_id": query_site_id, "tenant_member_id": None},
{"site_id": query_site_id, "tenant_member_id": 0},
]
result = simulate_fact_table_lookup(bad_fact_rows, members, query_site_id)
assert len(result) == 0, (
f"tenant_member_id 为 NULL/0 时不应返回任何会员,"
f"实际返回 {len(result)}"
)

View File

@@ -0,0 +1,270 @@
# -*- coding: utf-8 -*-
"""ODS member_profiles birthday 字段提取验证。
任务 3.2: 确认 ODS 入库逻辑能从 JSON payload 中提取 birthday 字段。
ODS_MEMBER 任务使用 schema-aware 动态入库_insert_records_schema_aware
当 ods.member_profiles 表包含 birthday 列时,会自动从 API JSON 中按列名匹配提取。
本测试通过 FakeDB 模拟包含 birthday 列的表结构,验证提取行为。
"""
from __future__ import annotations
import logging
import os
import sys
from pathlib import Path
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
from tasks.ods.ods_tasks import ODS_TASK_CLASSES, ODS_TASK_SPECS, BaseOdsTask
from .task_test_utils import (
create_test_config,
FakeDBOperations,
FakeAPIClient,
FakeCursor,
FakeConnection,
)
# ---------------------------------------------------------------------------
# 常量
# ---------------------------------------------------------------------------
_MEMBER_CODE = "ODS_MEMBER"
_MEMBER_TABLE = "ods.member_profiles"
_MEMBER_ENDPOINT = "/MemberProfile/GetTenantMemberList"
# member_profiles 表列定义(含 birthday DATE 列 — 迁移 C1 后的状态)
_MEMBER_COLUMNS_WITH_BIRTHDAY = [
("id", "bigint", "int8"),
("record_index", "integer", "int4"),
("content_hash", "text", "text"),
("payload", "jsonb", "jsonb"),
("birthday", "date", "date"),
]
# ---------------------------------------------------------------------------
# FakeDB支持自定义列定义
# ---------------------------------------------------------------------------
class _MemberFakeCursor(FakeCursor):
"""扩展 FakeCursor支持自定义列定义和 PK 查询。"""
def __init__(self, recorder, db_ops=None, columns_map=None, pk_map=None):
super().__init__(recorder, db_ops)
self._columns_map = columns_map or {}
self._pk_map = pk_map or {}
def execute(self, sql, params=None):
sql_text = sql.decode("utf-8", errors="ignore") if isinstance(sql, (bytes, bytearray)) else str(sql)
self.recorder.append({"sql": sql_text.strip(), "params": params})
self._fetchall_rows = []
lowered = sql_text.lower()
# 拦截 information_schema.columns 查询 → 返回自定义列定义
if "from information_schema.columns" in lowered:
table_name = None
if params and len(params) >= 2:
table_name = params[1]
full_table = f"{params[0]}.{params[1]}" if params and len(params) >= 2 else None
if full_table and full_table in self._columns_map:
self._fetchall_rows = list(self._columns_map[full_table])
elif table_name and table_name in self._columns_map:
self._fetchall_rows = list(self._columns_map[table_name])
else:
self._fetchall_rows = self._fake_columns(table_name)
return
# 拦截 PK 约束查询
if "from information_schema.table_constraints" in lowered:
table_name = None
if params and len(params) >= 2:
table_name = params[1]
full_table = f"{params[0]}.{params[1]}" if params and len(params) >= 2 else None
if full_table and full_table in self._pk_map:
self._fetchall_rows = [(col,) for col in self._pk_map[full_table]]
elif table_name and table_name in self._pk_map:
self._fetchall_rows = [(col,) for col in self._pk_map[table_name]]
else:
self._fetchall_rows = [("id",)]
return
# 拦截 content_hash 查询
if "distinct on" in lowered and "content_hash" in lowered:
self._fetchall_rows = []
self._pending_rows = []
return
# 默认:处理 INSERT 等语句
if self._pending_rows:
self.rowcount = len(self._pending_rows)
self._record_upserts(sql_text)
if "returning" in lowered:
self._fetchall_rows = [(True,)] * len(self._pending_rows)
self._pending_rows = []
else:
self.rowcount = 0
class _MemberFakeConnection(FakeConnection):
def __init__(self, db_ops, columns_map=None, pk_map=None):
super().__init__(db_ops)
self._columns_map = columns_map or {}
self._pk_map = pk_map or {}
def cursor(self):
return _MemberFakeCursor(
self.statements, self._db_ops,
columns_map=self._columns_map,
pk_map=self._pk_map,
)
class _MemberFakeDB(FakeDBOperations):
def __init__(self, columns_map=None, pk_map=None):
super().__init__()
self._columns_map = columns_map or {}
self._pk_map = pk_map or {}
self.conn = _MemberFakeConnection(
self, self._columns_map, self._pk_map,
)
# ---------------------------------------------------------------------------
# 辅助
# ---------------------------------------------------------------------------
def _build_config(tmp_path):
archive_dir = tmp_path / "archive"
temp_dir = tmp_path / "temp"
cfg = create_test_config("ONLINE", archive_dir, temp_dir)
cfg.config.setdefault("run", {})
cfg.config["run"]["ods_conflict_mode"] = "update"
cfg.config["run"]["snapshot_missing_delete"] = False
return cfg
def _get_member_spec():
for spec in ODS_TASK_SPECS:
if spec.code == _MEMBER_CODE:
return spec
raise KeyError(f"未找到任务 spec: {_MEMBER_CODE}")
# ---------------------------------------------------------------------------
# 测试
# ---------------------------------------------------------------------------
class TestOdsBirthdayExtraction:
"""验证 ODS_MEMBER 任务在表包含 birthday 列时能正确提取该字段。"""
def test_birthday_extracted_from_json_payload(self, tmp_path):
"""当 ods.member_profiles 包含 birthday 列时,
API JSON 中的 birthday 值应被提取并写入对应列位置。
"""
spec = _get_member_spec()
task_cls = ODS_TASK_CLASSES[_MEMBER_CODE]
config = _build_config(tmp_path)
db = _MemberFakeDB(
columns_map={_MEMBER_TABLE: _MEMBER_COLUMNS_WITH_BIRTHDAY},
pk_map={_MEMBER_TABLE: ["id", "content_hash"]},
)
# 模拟 API 返回包含 birthday 的会员记录
member_record = {
"id": 12345,
"birthday": "1990-05-15",
"nickname": "测试会员",
"mobile": "13800138000",
}
api = FakeAPIClient({spec.endpoint: [member_record]})
logger = logging.getLogger("test_birthday")
task = task_cls(config, db, api, logger)
result = task.execute()
counts = result["counts"]
assert counts["fetched"] == 1
assert counts["inserted"] == 1
# 验证写入的 SQL 参数中包含 birthday 值
insert_stmts = [
s for s in db.conn.statements
if "INSERT" in s.get("sql", "").upper() and "member_profiles" in s.get("sql", "")
]
assert len(insert_stmts) >= 1, "应有至少一条 INSERT 语句"
# INSERT SQL 应包含 birthday 列
assert any("birthday" in s["sql"].lower() for s in insert_stmts), (
"INSERT SQL 应包含 birthday 列"
)
def test_birthday_null_when_not_in_json(self, tmp_path):
"""当 API JSON 中不包含 birthday 字段时,
对应列应写入 None。
"""
spec = _get_member_spec()
task_cls = ODS_TASK_CLASSES[_MEMBER_CODE]
config = _build_config(tmp_path)
db = _MemberFakeDB(
columns_map={_MEMBER_TABLE: _MEMBER_COLUMNS_WITH_BIRTHDAY},
pk_map={_MEMBER_TABLE: ["id", "content_hash"]},
)
# API 返回不含 birthday 的记录
member_record = {
"id": 67890,
"nickname": "无生日会员",
}
api = FakeAPIClient({spec.endpoint: [member_record]})
logger = logging.getLogger("test_birthday_null")
task = task_cls(config, db, api, logger)
result = task.execute()
counts = result["counts"]
assert counts["fetched"] == 1
assert counts["inserted"] == 1
def test_birthday_column_in_insert_sql(self, tmp_path):
"""INSERT SQL 中应包含 birthday 列名。"""
spec = _get_member_spec()
task_cls = ODS_TASK_CLASSES[_MEMBER_CODE]
config = _build_config(tmp_path)
db = _MemberFakeDB(
columns_map={_MEMBER_TABLE: _MEMBER_COLUMNS_WITH_BIRTHDAY},
pk_map={_MEMBER_TABLE: ["id", "content_hash"]},
)
member_record = {"id": 11111, "birthday": "2000-01-01"}
api = FakeAPIClient({spec.endpoint: [member_record]})
logger = logging.getLogger("test_birthday_sql")
task = task_cls(config, db, api, logger)
task.execute()
# 找到 INSERT 语句,确认包含 birthday 列
insert_stmts = [
s for s in db.conn.statements
if "INSERT" in s.get("sql", "").upper() and "member_profiles" in s.get("sql", "")
]
assert any("birthday" in s["sql"].lower() for s in insert_stmts), (
"INSERT SQL 应包含 birthday 列,"
f"实际 SQL: {[s['sql'][:200] for s in insert_stmts]}"
)
def test_ods_member_spec_has_no_extra_columns(self):
"""确认 ODS_MEMBER 任务没有 extra_columns依赖 schema-aware 动态映射)。
这意味着 birthday 字段无需在 OdsTaskSpec 中显式声明,
只要 DB 表有该列,就会自动从 JSON 提取。
"""
spec = _get_member_spec()
assert spec.extra_columns == (), (
"ODS_MEMBER 不应有 extra_columns"
"它依赖 schema-aware 动态入库自动映射所有 DB 列"
)
assert spec.table_name == "ods.member_profiles"
assert spec.code == "ODS_MEMBER"

View File

@@ -126,36 +126,5 @@ def test_ods_settlement_records_ingest(tmp_path):
assert '"orderTradeNo": 8001' in row["payload"]
def test_ods_settlement_ticket_by_payment_relate_ids(tmp_path):
"""Ensure settlement tickets are fetched per payment relate_id and skip existing ones."""
config = _build_config(tmp_path)
ticket_payload = {"data": {"data": {"orderSettleId": 9001, "orderSettleNumber": "T001"}}}
api = FakeAPIClient({"/Order/GetOrderSettleTicketNew": [ticket_payload]})
task_cls = ODS_TASK_CLASSES["ODS_SETTLEMENT_TICKET"]
with get_db_operations() as db_ops:
# 第一次查询已有的小票ID第二次查询支付关联ID
db_ops.query_results = [
[{"order_settle_id": 9002}],
[
{"order_settle_id": 9001},
{"order_settle_id": 9002},
{"order_settle_id": None},
],
]
task = task_cls(config, db_ops, api, logging.getLogger("test_ods_settlement_ticket"))
result = task.execute()
assert result["status"] == "SUCCESS"
counts = result["counts"]
assert counts["fetched"] == 1
assert counts["inserted"] == 1
assert counts["updated"] == 0
assert counts["skipped"] == 0
assert '"orderSettleId": 9001' in db_ops.upserts[0]["rows"][0]["payload"]
assert any(
call["endpoint"] == "/Order/GetOrderSettleTicketNew"
and call.get("params", {}).get("orderSettleId") == 9001
for call in api.calls
)

View File

@@ -93,27 +93,27 @@ class TestProperty5FlowNameToLayers:
"""对于任意有效的 Flow 名称FlowRunner 解析出的层列表应与
FLOW_LAYERS 字典中的定义完全一致。"""
@given(pipeline=flow_name_st)
@given(flow_name=flow_name_st)
@settings(max_examples=100)
def test_layers_match_flow_definition(self, pipeline):
"""run() 返回的 layers 字段与 FLOW_LAYERS[pipeline] 完全一致。"""
def test_layers_match_flow_definition(self, flow_name):
"""run() 返回的 layers 字段与 FLOW_LAYERS[flow_name] 完全一致。"""
executor = MagicMock()
executor.run_tasks.return_value = []
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode="increment_only",
data_source="offline",
)
expected_layers = FlowRunner.FLOW_LAYERS[pipeline]
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
assert result["layers"] == expected_layers
@given(pipeline=flow_name_st)
@given(flow_name=flow_name_st)
@settings(max_examples=100)
def test_resolve_tasks_called_with_correct_layers(self, pipeline):
def test_resolve_tasks_called_with_correct_layers(self, flow_name):
"""_resolve_tasks 接收的层列表与 FLOW_LAYERS 定义一致。"""
executor = MagicMock()
executor.run_tasks.return_value = []
@@ -124,12 +124,12 @@ class TestProperty5FlowNameToLayers:
patch.object(runner, "_resolve_tasks", wraps=runner._resolve_tasks) as spy,
):
runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode="increment_only",
data_source="offline",
)
expected_layers = FlowRunner.FLOW_LAYERS[pipeline]
expected_layers = FlowRunner.FLOW_LAYERS[flow_name]
spy.assert_called_once_with(expected_layers)
@@ -143,12 +143,12 @@ class TestProperty6ProcessingModeControlsFlow:
校验流程执行当且仅当模式包含 verify。"""
@given(
pipeline=flow_name_st,
flow_name=flow_name_st,
mode=processing_mode_st,
data_source=data_source_st,
)
@settings(max_examples=100)
def test_increment_executes_iff_mode_contains_increment(self, pipeline, mode, data_source):
def test_increment_executes_iff_mode_contains_increment(self, flow_name, mode, data_source):
"""增量 ETLtask_executor.run_tasks执行当且仅当 mode 包含 'increment'"""
executor = MagicMock()
executor.run_tasks.return_value = []
@@ -159,7 +159,7 @@ class TestProperty6ProcessingModeControlsFlow:
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}),
):
runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode=mode,
data_source=data_source,
)
@@ -176,12 +176,12 @@ class TestProperty6ProcessingModeControlsFlow:
)
@given(
pipeline=flow_name_st,
flow_name=flow_name_st,
mode=processing_mode_st,
data_source=data_source_st,
)
@settings(max_examples=100)
def test_verification_executes_iff_mode_contains_verify(self, pipeline, mode, data_source):
def test_verification_executes_iff_mode_contains_verify(self, flow_name, mode, data_source):
"""校验流程_run_verification执行当且仅当 mode 包含 'verify'"""
executor = MagicMock()
executor.run_tasks.return_value = []
@@ -192,7 +192,7 @@ class TestProperty6ProcessingModeControlsFlow:
patch.object(runner, "_run_verification", return_value={"status": "COMPLETED"}) as mock_verify,
):
runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode=mode,
data_source=data_source,
)
@@ -215,38 +215,38 @@ class TestProperty6ProcessingModeControlsFlow:
class TestProperty7FlowSummaryCompleteness:
"""对于任意一组任务执行结果FlowRunner 返回的汇总字典应包含
status/pipeline/layers/results 字段,且 results 长度等于实际执行的任务数。
(返回字典中 pipeline 键名保留以兼容下游消费方)"""
status/flow/layers/results 字段,且 results 长度等于实际执行的任务数。
"""
@given(
pipeline=flow_name_st,
flow_name=flow_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_summary_has_required_fields(self, pipeline, task_results):
"""返回字典必须包含 status、pipeline、layers、results、verification_summary。"""
def test_summary_has_required_fields(self, flow_name, task_results):
"""返回字典必须包含 status、flow、layers、results、verification_summary。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
runner = _make_runner(task_executor=executor)
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode="increment_only",
data_source="offline",
)
required_keys = {"status", "pipeline", "layers", "results", "verification_summary"}
required_keys = {"status", "flow", "layers", "results", "verification_summary"}
assert required_keys.issubset(result.keys()), (
f"缺少必要字段: {required_keys - result.keys()}"
)
@given(
pipeline=flow_name_st,
flow_name=flow_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_results_length_equals_executed_tasks(self, pipeline, task_results):
def test_results_length_equals_executed_tasks(self, flow_name, task_results):
"""results 列表长度等于 task_executor.run_tasks 返回的任务数。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
@@ -254,7 +254,7 @@ class TestProperty7FlowSummaryCompleteness:
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode="increment_only",
data_source="offline",
)
@@ -264,11 +264,11 @@ class TestProperty7FlowSummaryCompleteness:
)
@given(
pipeline=flow_name_st,
flow_name=flow_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_flow_and_layers_match_input(self, pipeline, task_results):
def test_flow_and_layers_match_input(self, flow_name, task_results):
"""返回的 flow 标识和 layers 字段与输入一致。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
@@ -276,20 +276,20 @@ class TestProperty7FlowSummaryCompleteness:
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode="increment_only",
data_source="offline",
)
assert result["pipeline"] == pipeline
assert result["layers"] == FlowRunner.FLOW_LAYERS[pipeline]
assert result["flow"] == flow_name
assert result["layers"] == FlowRunner.FLOW_LAYERS[flow_name]
@given(
pipeline=flow_name_st,
flow_name=flow_name_st,
task_results=task_results_st,
)
@settings(max_examples=100)
def test_increment_only_has_no_verification(self, pipeline, task_results):
def test_increment_only_has_no_verification(self, flow_name, task_results):
"""increment_only 模式下 verification_summary 应为 None。"""
executor = MagicMock()
executor.run_tasks.return_value = task_results
@@ -297,7 +297,7 @@ class TestProperty7FlowSummaryCompleteness:
with patch(_TASK_LOGGER_PATH):
result = runner.run(
pipeline=pipeline,
flow=flow_name,
processing_mode="increment_only",
data_source="offline",
)

View File

@@ -0,0 +1,219 @@
# -*- coding: utf-8 -*-
"""DwdLoadTask 返回值格式属性测试 — 使用 hypothesis 验证返回值的通用正确性属性。
纯单元测试,使用 MagicMock 构造 DwdLoadTask 实例,不涉及真实数据库连接。
# Feature: etl-aggregation-fix, Property 1: DwdLoadTask 返回值格式一致性
# **Validates: Requirements 1.1**
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from hypothesis import given, settings
from hypothesis import strategies as st
from tasks.dwd.dwd_load_task import DwdLoadTask
# ── 辅助:构造最小可用 DwdLoadTask 实例 ─────────────────────
def _make_task() -> DwdLoadTask:
config = MagicMock()
config.get = MagicMock(side_effect=lambda key, default=None: default)
fake_conn = MagicMock()
fake_cursor = MagicMock()
fake_cursor.__enter__ = MagicMock(return_value=fake_cursor)
fake_cursor.__exit__ = MagicMock(return_value=False)
fake_conn.cursor.return_value = fake_cursor
db = MagicMock()
db.conn = fake_conn
api = MagicMock()
logger = logging.getLogger("test_return_format_pbt")
return DwdLoadTask(config, db, api, logger)
# ── 生成策略 ──────────────────────────────────────────────────
# 生成 TABLE_MAP1~8 张表,每张表随机决定是否失败
_table_name_st = st.text(
alphabet=st.sampled_from("abcdefghijklmnopqrstuvwxyz_"),
min_size=3, max_size=12,
)
# 生成一组 (dwd_table, ods_table, should_fail) 三元组
_table_entry_st = st.tuples(
_table_name_st, # dwd 表名后缀
_table_name_st, # ods 表名后缀
st.booleans(), # 是否模拟失败
)
_table_map_st = st.lists(_table_entry_st, min_size=0, max_size=8, unique_by=lambda t: t[0])
# ── Property 1: DwdLoadTask 返回值格式一致性 ─────────────────
# Feature: etl-aggregation-fix, Property 1: DwdLoadTask 返回值格式一致性
# **Validates: Requirements 1.1**
#
# 对于任意 DwdLoadTask.load() 的执行结果,返回字典中:
# - `errors` 键的值应为 int 类型
# - `error_details` 键的值应为 list 类型
# - `errors` == len(`error_details`)
class TestProperty1DwdLoadReturnFormat:
"""Property 1: DwdLoadTask 返回值格式一致性。"""
@given(table_entries=_table_map_st)
@settings(max_examples=200)
def test_errors_is_int_and_equals_len_error_details(self, table_entries):
"""对于任意表映射和失败模式errors 始终为 int 且等于 len(error_details)。"""
task = _make_task()
# 构造 TABLE_MAP 和失败集合
table_map = {}
fail_set = set()
for dwd_suffix, ods_suffix, should_fail in table_entries:
dwd_name = f"dwd.dim_{dwd_suffix}"
ods_name = f"ods.{ods_suffix}"
table_map[dwd_name] = ods_name
if should_fail:
fail_set.add(dwd_name)
# _get_columns失败表抛异常成功表返回空列跳过不计入 summary 也不计入 errors
def fake_get_columns(cur, table):
if table in fail_set:
raise RuntimeError(f"模拟失败: {table}")
return [] # 空列 → 跳过
task._get_columns = fake_get_columns
with patch.object(DwdLoadTask, "TABLE_MAP", table_map):
ctx = SimpleNamespace(
window_start=datetime(2025, 1, 1, tzinfo=timezone.utc),
window_end=datetime(2025, 1, 2, tzinfo=timezone.utc),
)
result = task.load({"now": datetime.now(timezone.utc)}, ctx)
# 核心断言:返回值格式一致性
assert isinstance(result["errors"], int), (
f"errors 应为 int实际为 {type(result['errors']).__name__}"
)
assert isinstance(result["error_details"], list), (
f"error_details 应为 list实际为 {type(result['error_details']).__name__}"
)
assert result["errors"] == len(result["error_details"]), (
f"errors={result['errors']} != len(error_details)={len(result['error_details'])}"
)
# errors 应等于预期失败数
assert result["errors"] == len(fail_set), (
f"errors={result['errors']} != 预期失败数={len(fail_set)}"
)
assert "tables" in result
# ── Property 2: _accumulate_counts 类型安全累加 ─────────────
# Feature: etl-aggregation-fix, Property 2: _accumulate_counts 类型安全累加
# **Validates: Requirements 1.2**
#
# 对于任意包含 int、float、list 类型值的计数字典,
# _accumulate_counts() 应将 int/float 直接累加,将 list 转为 len() 后累加,
# 且不抛出异常。
from tasks.base_task import BaseTask
# ── 生成策略:计数字典值 ──────────────────────────────────────
# 合法计数值int / float / list任意元素
_count_value_st = st.one_of(
st.integers(min_value=0, max_value=10_000),
st.floats(min_value=0.0, max_value=10_000.0, allow_nan=False, allow_infinity=False),
st.lists(st.integers(), min_size=0, max_size=50),
)
# 计数字典1~10 个键,值为 int/float/list
_count_key_st = st.text(
alphabet=st.sampled_from("abcdefghijklmnopqrstuvwxyz_"),
min_size=2, max_size=10,
)
_count_dict_st = st.dictionaries(
keys=_count_key_st,
values=_count_value_st,
min_size=0,
max_size=10,
)
def _numeric_value(v) -> int | float:
"""将 int/float/list 统一转为数值list → len"""
if isinstance(v, (int, float)):
return v
if isinstance(v, list):
return len(v)
return 0
class TestProperty2AccumulateCountsTypeSafe:
"""Property 2: _accumulate_counts 类型安全累加。"""
@given(dicts=st.lists(_count_dict_st, min_size=1, max_size=5))
@settings(max_examples=200)
def test_accumulate_never_raises_and_sums_correctly(self, dicts):
"""对于任意 int/float/list 值的计数字典序列,累加不抛异常且数值正确。"""
total: dict = {}
for d in dicts:
BaseTask._accumulate_counts(total, d)
# 手动计算期望值
expected: dict = {}
for d in dicts:
for k, v in d.items():
nv = _numeric_value(v)
if isinstance(nv, (int, float)):
expected[k] = (expected.get(k) or 0) + nv
else:
expected.setdefault(k, v)
# 断言:所有可累加键的值应与手动计算一致
for k, ev in expected.items():
assert k in total, f"'{k}' 应存在于累加结果中"
if isinstance(ev, (int, float)):
assert abs(total[k] - ev) < 1e-9, (
f"'{k}': 累加结果={total[k]}, 期望={ev}"
)
@given(current=_count_dict_st)
@settings(max_examples=200)
def test_single_dict_values_match_numeric_conversion(self, current):
"""单次累加int/float 保持原值list 转为 len()。"""
total: dict = {}
BaseTask._accumulate_counts(total, current)
for k, v in current.items():
if isinstance(v, (int, float)):
assert abs(total[k] - v) < 1e-9, (
f"'{k}': int/float 应直接累加,结果={total[k]}, 原值={v}"
)
elif isinstance(v, list):
assert total[k] == len(v), (
f"'{k}': list 应转为 len()={len(v)}, 实际={total[k]}"
)
@given(current=_count_dict_st)
@settings(max_examples=100)
def test_none_total_key_initializes_correctly(self, current):
"""当 total 中键不存在时get 返回 Noneor 0 应正确初始化。"""
total: dict = {}
result = BaseTask._accumulate_counts(total, current)
# 返回值应与 total 是同一对象
assert result is total
# 所有数值键应 >= 0
for k, v in result.items():
if isinstance(v, (int, float)):
assert v >= 0, f"'{k}' 的值不应为负: {v}"

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""员工维度表staff_info单元测试验证 P1、P2 正确性属性。"""
import logging
import os
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
os.environ.setdefault("ETL_SKIP_DOTENV", "1")
from tasks.ods.ods_tasks import (
ODS_TASK_SPECS,
ODS_TASK_CLASSES,
ENABLED_ODS_CODES,
SnapshotMode,
)
from api.client import DEFAULT_LIST_KEYS
from tasks.dwd.dwd_load_task import DwdLoadTask
from .task_test_utils import create_test_config, get_db_operations, FakeAPIClient
# ---------------------------------------------------------------------------
# P1ODS 任务规格完整性
# ---------------------------------------------------------------------------
def _get_staff_spec():
"""从 ODS_TASK_SPECS 中提取 ODS_STAFF_INFO 规格。"""
for spec in ODS_TASK_SPECS:
if spec.code == "ODS_STAFF_INFO":
return spec
raise AssertionError("ODS_STAFF_INFO 未在 ODS_TASK_SPECS 中注册")
class TestP1OdsStaffInfoSpec:
"""P1ODS 任务规格完整性验证。"""
def test_code(self):
assert _get_staff_spec().code == "ODS_STAFF_INFO"
def test_table_name(self):
assert _get_staff_spec().table_name == "ods.staff_info_master"
def test_endpoint(self):
assert _get_staff_spec().endpoint == "/PersonnelManagement/SearchSystemStaffInfo"
def test_list_key(self):
assert _get_staff_spec().list_key == "staffProfiles"
def test_snapshot_mode(self):
assert _get_staff_spec().snapshot_mode == SnapshotMode.FULL_TABLE
def test_requires_window_false(self):
assert _get_staff_spec().requires_window is False
def test_time_fields_none(self):
assert _get_staff_spec().time_fields is None
def test_staff_profiles_in_default_list_keys(self):
assert "staffProfiles" in DEFAULT_LIST_KEYS
def test_ods_staff_info_in_enabled_codes(self):
assert "ODS_STAFF_INFO" in ENABLED_ODS_CODES
def test_task_class_registered(self):
assert "ODS_STAFF_INFO" in ODS_TASK_CLASSES
# ---------------------------------------------------------------------------
# P2DWD 映射完整性
# ---------------------------------------------------------------------------
class TestP2DwdStaffMapping:
"""P2DWD 映射完整性验证。"""
def test_dim_staff_table_map(self):
assert DwdLoadTask.TABLE_MAP["dwd.dim_staff"] == "ods.staff_info_master"
def test_dim_staff_ex_table_map(self):
assert DwdLoadTask.TABLE_MAP["dwd.dim_staff_ex"] == "ods.staff_info_master"
def test_dim_staff_has_staff_id_mapping(self):
mappings = DwdLoadTask.FACT_MAPPINGS["dwd.dim_staff"]
pk_map = [(dst, src) for dst, src, _ in mappings if dst == "staff_id"]
assert len(pk_map) == 1
assert pk_map[0][1] == "id"
def test_dim_staff_ex_has_staff_id_mapping(self):
mappings = DwdLoadTask.FACT_MAPPINGS["dwd.dim_staff_ex"]
pk_map = [(dst, src) for dst, src, _ in mappings if dst == "staff_id"]
assert len(pk_map) == 1
assert pk_map[0][1] == "id"
# ---------------------------------------------------------------------------
# ODS 落地功能测试
# ---------------------------------------------------------------------------
def test_staff_info_ingest(tmp_path):
"""验证 ODS_STAFF_INFO 任务能正确落地员工数据。"""
config = create_test_config("ONLINE", tmp_path / "archive", tmp_path / "temp")
sample = [
{
"id": 3020236636900101,
"staff_name": "葛芃",
"mobile": "13811638071",
"job": "店长",
"staff_identity": 2,
"status": 1,
"leave_status": 0,
"site_id": 2790685415443269,
"tenant_id": 2790683160709957,
}
]
api = FakeAPIClient({"/PersonnelManagement/SearchSystemStaffInfo": sample})
task_cls = ODS_TASK_CLASSES["ODS_STAFF_INFO"]
with get_db_operations() as db_ops:
task = task_cls(config, db_ops, api, logging.getLogger("test_staff_info"))
result = task.execute()
assert result["status"] == "SUCCESS"
assert result["counts"]["fetched"] == 1
row = db_ops.upserts[0]["rows"][0]
assert row["id"] == 3020236636900101
assert row["record_index"] == 0
assert '"staff_name": "葛芃"' in row["payload"] or '"staff_name": "\\u845b\\u82c3"' in row["payload"]

View File

@@ -0,0 +1,209 @@
# -*- coding: utf-8 -*-
"""EtlTimer 单元测试"""
from __future__ import annotations
import os
import time
from pathlib import Path
from unittest.mock import patch
from zoneinfo import ZoneInfo
import pytest
from utils.timer import EtlTimer, StepRecord, _fmt_ms
# ── _fmt_ms 格式化 ───────────────────────────────────────────
class TestFmtMs:
def test_sub_second(self):
assert _fmt_ms(123.4) == "123.4ms"
def test_seconds(self):
assert _fmt_ms(2500) == "2.50s"
def test_minutes(self):
result = _fmt_ms(90_000) # 90 秒
assert result.startswith("1m")
def test_hours(self):
result = _fmt_ms(3_700_000) # ~61.7 分钟
assert result.startswith("1h")
# ── StepRecord ───────────────────────────────────────────────
class TestStepRecord:
def test_elapsed_seconds(self):
rec = StepRecord(name="test", start_time=None, elapsed_ms=1500.0)
assert rec.elapsed_seconds == 1.5
def test_to_dict_without_end(self):
from datetime import datetime
now = datetime.now(ZoneInfo("Asia/Shanghai"))
rec = StepRecord(name="s1", start_time=now, elapsed_ms=100.0)
d = rec.to_dict()
assert d["name"] == "s1"
assert d["end_time"] is None
assert d["elapsed_ms"] == 100.0
assert d["children"] == []
def test_to_dict_with_children(self):
from datetime import datetime
tz = ZoneInfo("Asia/Shanghai")
now = datetime.now(tz)
parent = StepRecord(name="p", start_time=now, elapsed_ms=200.0)
child = StepRecord(name="c", start_time=now, end_time=now, elapsed_ms=50.0)
parent.children.append(child)
d = parent.to_dict()
assert len(d["children"]) == 1
assert d["children"][0]["name"] == "c"
# ── EtlTimer 核心流程 ────────────────────────────────────────
class TestEtlTimer:
def test_start_stop_step(self):
timer = EtlTimer()
timer.start()
timer.start_step("STEP_A")
time.sleep(0.01) # 确保有可测量的耗时
rec = timer.stop_step("STEP_A")
assert rec.name == "STEP_A"
assert rec.end_time is not None
assert rec.elapsed_ms > 0
def test_sub_steps(self):
timer = EtlTimer()
timer.start()
timer.start_step("PARENT")
timer.start_sub_step("PARENT", "child_1")
time.sleep(0.01)
timer.stop_sub_step("PARENT", "child_1")
timer.start_sub_step("PARENT", "child_2")
timer.stop_sub_step("PARENT", "child_2")
timer.stop_step("PARENT")
parent = timer.get_step("PARENT")
assert parent is not None
assert len(parent.children) == 2
assert parent.children[0].name == "child_1"
assert parent.children[0].elapsed_ms > 0
def test_stop_unknown_step_raises(self):
timer = EtlTimer()
with pytest.raises(KeyError, match="未找到步骤"):
timer.stop_step("NONEXISTENT")
def test_start_sub_step_unknown_parent_raises(self):
timer = EtlTimer()
with pytest.raises(KeyError, match="未找到父步骤"):
timer.start_sub_step("NONEXISTENT", "child")
def test_stop_sub_step_unknown_raises(self):
timer = EtlTimer()
timer.start()
timer.start_step("P")
with pytest.raises(KeyError, match="未找到子步骤"):
timer.stop_sub_step("P", "no_such_child")
def test_multiple_steps(self):
timer = EtlTimer()
timer.start()
for name in ["ODS_LOAD", "DWD_LOAD", "DWS_AGG"]:
timer.start_step(name)
timer.stop_step(name)
assert len(timer.steps) == 3
assert timer.steps[0].name == "ODS_LOAD"
assert timer.steps[2].name == "DWS_AGG"
def test_to_dict(self):
timer = EtlTimer()
timer.start()
timer.start_step("S1")
timer.stop_step("S1")
report = timer.finish(write_report=False)
d = timer.to_dict()
assert d["overall_start"] is not None
assert d["overall_end"] is not None
assert d["overall_elapsed_ms"] >= 0
assert len(d["steps"]) == 1
def test_overall_elapsed(self):
timer = EtlTimer()
timer.start()
time.sleep(0.02)
timer.finish(write_report=False)
assert timer.overall_elapsed_ms >= 15 # 至少 15ms留余量
def test_finish_returns_markdown(self):
timer = EtlTimer()
timer.start()
timer.start_step("TEST_STEP")
timer.stop_step("TEST_STEP")
md = timer.finish(write_report=False)
assert "# ETL 执行计时报告" in md
assert "TEST_STEP" in md
assert "步骤汇总" in md
def test_markdown_contains_sub_steps(self):
timer = EtlTimer()
timer.start()
timer.start_step("MAIN")
timer.start_sub_step("MAIN", "sub_a")
timer.stop_sub_step("MAIN", "sub_a")
timer.stop_step("MAIN")
md = timer.finish(write_report=False)
assert "步骤详情" in md
assert "sub_a" in md
def test_write_report_requires_env(self):
"""ETL_REPORT_ROOT 未设置时应抛出 KeyError"""
timer = EtlTimer()
timer.start()
timer.start_step("X")
timer.stop_step("X")
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(KeyError, match="ETL_REPORT_ROOT"):
timer.finish(write_report=True)
def test_write_report_creates_file(self, tmp_path: Path):
"""设置 ETL_REPORT_ROOT 后应生成 .md 文件"""
timer = EtlTimer()
timer.start()
timer.start_step("Y")
timer.stop_step("Y")
with patch.dict(os.environ, {"ETL_REPORT_ROOT": str(tmp_path)}):
timer.finish(write_report=True)
md_files = list(tmp_path.glob("etl_timing_*.md"))
assert len(md_files) == 1
content = md_files[0].read_text(encoding="utf-8")
assert "# ETL 执行计时报告" in content
assert "Y" in content
def test_elapsed_equals_end_minus_start(self):
"""Property 7 核心验证:耗时 ≈ 结束时间 - 开始时间"""
timer = EtlTimer()
timer.start()
timer.start_step("VERIFY")
time.sleep(0.05)
rec = timer.stop_step("VERIFY")
# 用 datetime 差值计算的毫秒数
dt_diff_ms = (rec.end_time - rec.start_time).total_seconds() * 1000
# perf_counter 计算的毫秒数
# 两者应在合理误差范围内±50ms考虑系统调度抖动
assert abs(rec.elapsed_ms - dt_diff_ms) < 50