Files
ZQYY.FQ-ETL/tests/unit/test_relation_index_base.py

134 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""关系指数基础能力单测。"""
from __future__ import annotations
import logging
from datetime import date
from typing import Any, Dict, List, Optional
from tasks.dws.index.base_index_task import BaseIndexTask
from tasks.dws.index.ml_manual_import_task import MlManualImportTask
class _DummyConfig:
"""最小配置桩对象。"""
def __init__(self, values: Optional[Dict[str, Any]] = None):
self._values = values or {}
def get(self, key: str, default: Any = None) -> Any:
return self._values.get(key, default)
class _DummyDB:
"""最小数据库桩对象。"""
def __init__(self) -> None:
self.query_calls: List[tuple] = []
def query(self, sql: str, params=None):
self.query_calls.append((sql, params))
index_type = (params or [None])[0]
if index_type == "RS":
return [{"param_name": "lookback_days", "param_value": 60}]
if index_type == "MS":
return [{"param_name": "lookback_days", "param_value": 30}]
return []
class _DummyIndexTask(BaseIndexTask):
"""用于测试 BaseIndexTask 的最小实现。"""
INDEX_TYPE = "RS"
def get_task_code(self) -> str: # pragma: no cover - 测试桩
return "DUMMY_INDEX"
def get_target_table(self) -> str: # pragma: no cover - 测试桩
return "dummy_table"
def get_primary_keys(self) -> List[str]: # pragma: no cover - 测试桩
return ["id"]
def get_index_type(self) -> str:
return self.INDEX_TYPE
def extract(self, context): # pragma: no cover - 测试桩
return []
def load(self, transformed, context): # pragma: no cover - 测试桩
return {}
def test_load_index_parameters_cache_isolated_by_index_type():
"""参数缓存应按 index_type 隔离,避免单任务串参。"""
task = _DummyIndexTask(
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
_DummyDB(),
None,
logging.getLogger("test_index_cache"),
)
rs_first = task.load_index_parameters(index_type="RS")
ms_first = task.load_index_parameters(index_type="MS")
rs_second = task.load_index_parameters(index_type="RS")
assert rs_first["lookback_days"] == 60.0
assert ms_first["lookback_days"] == 30.0
assert rs_second["lookback_days"] == 60.0
# 只应查询两次RS 一次 + MS 一次,第二次 RS 命中缓存
assert len(task.db.query_calls) == 2
def test_batch_normalize_passes_index_type_to_smoothing_chain():
"""batch_normalize_to_display 应把 index_type 传入平滑链路。"""
task = _DummyIndexTask(
_DummyConfig({"app.timezone": "Asia/Shanghai"}),
_DummyDB(),
None,
logging.getLogger("test_index_smoothing"),
)
captured: Dict[str, Any] = {}
def _fake_apply(site_id, current_p5, current_p95, alpha=None, index_type=None):
captured["index_type"] = index_type
return current_p5, current_p95
task._apply_ewma_smoothing = _fake_apply # type: ignore[method-assign]
result = task.batch_normalize_to_display(
raw_scores=[("a", 1.0), ("b", 2.0), ("c", 3.0)],
use_smoothing=True,
site_id=1,
index_type="ML",
)
assert result
assert captured["index_type"] == "ML"
def test_ml_manual_import_scope_day_and_p30_boundary():
"""30天边界内按天覆盖超过30天进入固定纪元30天桶。"""
today = date(2026, 2, 8)
day_scope = MlManualImportTask.resolve_scope(
site_id=1,
biz_date=date(2026, 1, 9), # 距 today 30 天
today=today,
)
assert day_scope.scope_type == "DAY"
assert day_scope.start_date == date(2026, 1, 9)
assert day_scope.end_date == date(2026, 1, 9)
p30_scope = MlManualImportTask.resolve_scope(
site_id=1,
biz_date=date(2026, 1, 8), # 距 today 31 天
today=today,
)
assert p30_scope.scope_type == "P30"
# 固定纪元 2026-01-01第一桶应为 2026-01-01 ~ 2026-01-30
assert p30_scope.start_date == date(2026, 1, 1)
assert p30_scope.end_date == date(2026, 1, 30)