# -*- coding: utf-8 -*- """关系指数基础能力单测。""" from __future__ import annotations import logging from datetime import date from typing import Any, Dict, List, Optional from tasks.dws.index.base_index_task import BaseIndexTask from tasks.dws.index.ml_manual_import_task import MlManualImportTask class _DummyConfig: """最小配置桩对象。""" def __init__(self, values: Optional[Dict[str, Any]] = None): self._values = values or {} def get(self, key: str, default: Any = None) -> Any: return self._values.get(key, default) class _DummyDB: """最小数据库桩对象。""" def __init__(self) -> None: self.query_calls: List[tuple] = [] def query(self, sql: str, params=None): self.query_calls.append((sql, params)) index_type = (params or [None])[0] if index_type == "RS": return [{"param_name": "lookback_days", "param_value": 60}] if index_type == "MS": return [{"param_name": "lookback_days", "param_value": 30}] return [] class _DummyIndexTask(BaseIndexTask): """用于测试 BaseIndexTask 的最小实现。""" INDEX_TYPE = "RS" def get_task_code(self) -> str: # pragma: no cover - 测试桩 return "DUMMY_INDEX" def get_target_table(self) -> str: # pragma: no cover - 测试桩 return "dummy_table" def get_primary_keys(self) -> List[str]: # pragma: no cover - 测试桩 return ["id"] def get_index_type(self) -> str: return self.INDEX_TYPE def extract(self, context): # pragma: no cover - 测试桩 return [] def load(self, transformed, context): # pragma: no cover - 测试桩 return {} def test_load_index_parameters_cache_isolated_by_index_type(): """参数缓存应按 index_type 隔离,避免单任务串参。""" task = _DummyIndexTask( _DummyConfig({"app.timezone": "Asia/Shanghai"}), _DummyDB(), None, logging.getLogger("test_index_cache"), ) rs_first = task.load_index_parameters(index_type="RS") ms_first = task.load_index_parameters(index_type="MS") rs_second = task.load_index_parameters(index_type="RS") assert rs_first["lookback_days"] == 60.0 assert ms_first["lookback_days"] == 30.0 assert rs_second["lookback_days"] == 60.0 # 只应查询两次:RS 一次 + MS 一次,第二次 RS 命中缓存 assert len(task.db.query_calls) == 2 def test_batch_normalize_passes_index_type_to_smoothing_chain(): """batch_normalize_to_display 应把 index_type 传入平滑链路。""" task = _DummyIndexTask( _DummyConfig({"app.timezone": "Asia/Shanghai"}), _DummyDB(), None, logging.getLogger("test_index_smoothing"), ) captured: Dict[str, Any] = {} def _fake_apply(site_id, current_p5, current_p95, alpha=None, index_type=None): captured["index_type"] = index_type return current_p5, current_p95 task._apply_ewma_smoothing = _fake_apply # type: ignore[method-assign] result = task.batch_normalize_to_display( raw_scores=[("a", 1.0), ("b", 2.0), ("c", 3.0)], use_smoothing=True, site_id=1, index_type="ML", ) assert result assert captured["index_type"] == "ML" def test_ml_manual_import_scope_day_and_p30_boundary(): """30天边界内按天覆盖,超过30天进入固定纪元30天桶。""" today = date(2026, 2, 8) day_scope = MlManualImportTask.resolve_scope( site_id=1, biz_date=date(2026, 1, 9), # 距 today 30 天 today=today, ) assert day_scope.scope_type == "DAY" assert day_scope.start_date == date(2026, 1, 9) assert day_scope.end_date == date(2026, 1, 9) p30_scope = MlManualImportTask.resolve_scope( site_id=1, biz_date=date(2026, 1, 8), # 距 today 31 天 today=today, ) assert p30_scope.scope_type == "P30" # 固定纪元 2026-01-01,第一桶应为 2026-01-01 ~ 2026-01-30 assert p30_scope.start_date == date(2026, 1, 1) assert p30_scope.end_date == date(2026, 1, 30)