# -*- coding: utf-8 -*- """ DWS 物化视图刷新任务 说明: - 按 L1/L2/L3/L4 时间分层刷新物化视图 - 默认受 dws.mv.enabled 与 dws.retention.* 配置联动控制 """ from __future__ import annotations import json from typing import Any, Dict, List, Optional from .base_dws_task import BaseDwsTask, TaskContext, TimeLayer class BaseMvRefreshTask(BaseDwsTask): """物化视图刷新基类""" BASE_TABLE: str = "" DATE_COL: str = "" VIEW_PREFIX = "mv_" LAYER_ORDER = [ TimeLayer.LAST_2_DAYS, TimeLayer.LAST_1_MONTH, TimeLayer.LAST_3_MONTHS, TimeLayer.LAST_6_MONTHS, ] LAYER_SUFFIX = { TimeLayer.LAST_2_DAYS: "l1", TimeLayer.LAST_1_MONTH: "l2", TimeLayer.LAST_3_MONTHS: "l3", TimeLayer.LAST_6_MONTHS: "l4", } def get_target_table(self) -> str: return self.BASE_TABLE def get_primary_keys(self) -> List[str]: return [] def extract(self, context: TaskContext) -> Dict[str, Any]: return {"site_id": context.store_id} def transform(self, extracted: Dict[str, Any], context: TaskContext) -> Dict[str, Any]: return extracted def load(self, transformed: Dict[str, Any], context: TaskContext) -> Dict[str, Any]: if not self._is_enabled(): self.logger.info("%s: 未启用物化刷新,跳过", self.get_task_code()) return {"counts": {"refreshed": 0}} layers = self._resolve_layers() refreshed = 0 details = [] for layer in layers: view_name = self._get_view_name(layer) if not view_name: continue if not self._view_exists(view_name): self.logger.warning("%s: 物化视图不存在,跳过 %s", self.get_task_code(), view_name) continue self._refresh_view(view_name) refreshed += 1 details.append({"view": view_name, "layer": layer.value}) self.logger.info("%s: 刷新完成,物化视图数=%d", self.get_task_code(), refreshed) return {"counts": {"refreshed": refreshed}, "extra": {"details": details}} def _is_enabled(self) -> bool: enabled = bool(self.config.get("dws.mv.enabled", False)) if not enabled: return False tables = self._parse_list(self.config.get("dws.mv.tables")) if not tables: tables = self._parse_list(self.config.get("dws.retention.tables")) if tables and self.BASE_TABLE not in tables: return False return True def _resolve_layers(self) -> List[TimeLayer]: # 显式配置优先 configured = self._parse_layers(self.config.get("dws.mv.layers")) if configured: return configured # 表级覆盖:优先 mv.table_layers,其次 retention.table_layers table_layers = self._resolve_layer_map( self.config.get("dws.mv.table_layers") or self.config.get("dws.retention.table_layers") ) layer_name = table_layers.get(self.BASE_TABLE) if layer_name: layer = self._get_layer(layer_name) if layer and layer != TimeLayer.ALL: return self._layers_up_to(layer) # 默认使用 retention.layer retention_layer = self._get_layer(self.config.get("dws.retention.layer")) if retention_layer and retention_layer != TimeLayer.ALL: return self._layers_up_to(retention_layer) return list(self.LAYER_ORDER) def _layers_up_to(self, target: TimeLayer) -> List[TimeLayer]: layers = [] for layer in self.LAYER_ORDER: layers.append(layer) if layer == target: break return layers def _get_view_name(self, layer: TimeLayer) -> Optional[str]: suffix = self.LAYER_SUFFIX.get(layer) if not suffix or not self.BASE_TABLE: return None return f"{self.VIEW_PREFIX}{self.BASE_TABLE}_{suffix}" def _view_exists(self, view_name: str) -> bool: sql = "SELECT to_regclass(%s) AS reg" rows = self.db.query(sql, (f"{self.DWS_SCHEMA}.{view_name}",)) return bool(rows and rows[0].get("reg")) def _refresh_view(self, view_name: str) -> None: concurrently = bool(self.config.get("dws.mv.refresh_concurrently", False)) keyword = "CONCURRENTLY " if concurrently else "" sql = f"REFRESH MATERIALIZED VIEW {keyword}{self.DWS_SCHEMA}.{view_name}" self.db.execute(sql) def _get_layer(self, layer_name: Optional[str]) -> Optional[TimeLayer]: if not layer_name: return None name = str(layer_name).upper() try: return TimeLayer[name] except KeyError: return None def _resolve_layer_map(self, raw: Any) -> Dict[str, str]: if not raw: return {} if isinstance(raw, dict): return {str(k): str(v) for k, v in raw.items()} if isinstance(raw, str): try: parsed = json.loads(raw) if isinstance(parsed, dict): return {str(k): str(v) for k, v in parsed.items()} except json.JSONDecodeError: return {} return {} def _parse_layers(self, raw: Any) -> List[TimeLayer]: if not raw: return [] if isinstance(raw, str): items = [v.strip() for v in raw.split(",") if v.strip()] elif isinstance(raw, (list, tuple, set)): items = [str(v).strip() for v in raw if str(v).strip()] else: return [] layers = [] for item in items: layer = self._get_layer(item) if layer and layer not in layers: layers.append(layer) return layers def _parse_list(self, raw: Any) -> List[str]: if not raw: return [] if isinstance(raw, str): return [v.strip() for v in raw.split(",") if v.strip()] if isinstance(raw, (list, tuple, set)): return [str(v).strip() for v in raw if str(v).strip()] return [] class DwsMvRefreshFinanceDailyTask(BaseMvRefreshTask): BASE_TABLE = "dws_finance_daily_summary" DATE_COL = "stat_date" def get_task_code(self) -> str: return "DWS_MV_REFRESH_FINANCE_DAILY" class DwsMvRefreshAssistantDailyTask(BaseMvRefreshTask): BASE_TABLE = "dws_assistant_daily_detail" DATE_COL = "stat_date" def get_task_code(self) -> str: return "DWS_MV_REFRESH_ASSISTANT_DAILY" __all__ = ["DwsMvRefreshFinanceDailyTask", "DwsMvRefreshAssistantDailyTask"]