# -*- coding: utf-8 -*- """ ML 人工台账导入任务。 设计目标: 1. 人工台账作为 ML 唯一真源; 2. 同一订单支持多助教归因,默认均分; 3. 覆盖策略: - 近 30 天:按 site_id + biz_date 日覆盖; - 超过 30 天:按固定纪元(2026-01-01)切 30 天批次覆盖。 """ from __future__ import annotations import os import uuid from dataclasses import dataclass from datetime import date, datetime, timedelta from decimal import Decimal from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from .base_index_task import BaseIndexTask from ..base_dws_task import TaskContext @dataclass(frozen=True) class ImportScope: """导入覆盖范围定义。""" site_id: int scope_type: str # DAY / P30 start_date: date end_date: date @property def scope_key(self) -> str: if self.scope_type == "DAY": return f"DAY:{self.site_id}:{self.start_date.isoformat()}" return ( f"P30:{self.site_id}:{self.start_date.isoformat()}:{self.end_date.isoformat()}" ) class MlManualImportTask(BaseIndexTask): """导入并拆分 ML 人工台账(订单宽表 + 助教分摊窄表)。""" INDEX_TYPE = "ML" EPOCH_ANCHOR = date(2026, 1, 1) HISTORICAL_BUCKET_DAYS = 30 ASSISTANT_SLOT_COUNT = 5 # Excel 模板字段(按列顺序) TEMPLATE_COLUMNS = [ "site_id", "biz_date", "external_id", "member_id", "pay_time", "order_amount", "currency", "assistant_id_1", "assistant_name_1", "assistant_id_2", "assistant_name_2", "assistant_id_3", "assistant_name_3", "assistant_id_4", "assistant_name_4", "assistant_id_5", "assistant_name_5", "remark", ] def get_task_code(self) -> str: return "DWS_ML_MANUAL_IMPORT" def get_target_table(self) -> str: return "dws_ml_manual_order_source" def get_primary_keys(self) -> List[str]: return ["site_id", "external_id", "import_scope_key", "row_no"] def get_index_type(self) -> str: return self.INDEX_TYPE def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]: """ 执行导入。 说明:该任务按“文件”运行,不依赖时间窗口。调度器会以工具任务方式直接触发。 """ file_path = self._resolve_file_path() if not file_path: raise ValueError( "未找到 ML 台账文件,请通过环境变量 ML_MANUAL_LEDGER_FILE 或配置 run.ml_manual_ledger_file 指定" ) rows = self._read_excel_rows(file_path) if not rows: self.logger.warning("台账文件为空:%s", file_path) return { "status": "SUCCESS", "counts": { "source_rows": 0, "alloc_rows": 0, "deleted_source_rows": 0, "deleted_alloc_rows": 0, "scopes": 0, }, } now = datetime.now(self.tz) today = now.date() import_batch_no = self._build_import_batch_no(now) import_file_name = Path(file_path).name import_user = self._resolve_import_user() source_rows: List[Dict[str, Any]] = [] alloc_rows: List[Dict[str, Any]] = [] scope_set: Dict[Tuple[int, str, date, date], ImportScope] = {} for idx, raw in enumerate(rows, start=2): normalized = self._normalize_row(raw, row_no=idx, file_path=file_path) row_scope = self.resolve_scope( site_id=normalized["site_id"], biz_date=normalized["biz_date"], today=today, ) scope_set[(row_scope.site_id, row_scope.scope_type, row_scope.start_date, row_scope.end_date)] = row_scope source_row = self._build_source_row( normalized=normalized, scope=row_scope, import_batch_no=import_batch_no, import_file_name=import_file_name, import_user=import_user, import_time=now, ) source_rows.append(source_row) alloc_rows.extend( self._build_alloc_rows( normalized=normalized, scope=row_scope, import_batch_no=import_batch_no, import_file_name=import_file_name, import_user=import_user, import_time=now, ) ) scopes = list(scope_set.values()) deleted_source_rows, deleted_alloc_rows = self._delete_by_scopes(scopes) inserted_source = self._insert_source_rows(source_rows) upserted_alloc = self._upsert_alloc_rows(alloc_rows) self.db.conn.commit() self.logger.info( "ML 人工台账导入完成: file=%s source=%d alloc=%d scopes=%d", file_path, inserted_source, upserted_alloc, len(scopes), ) return { "status": "SUCCESS", "counts": { "source_rows": inserted_source, "alloc_rows": upserted_alloc, "deleted_source_rows": deleted_source_rows, "deleted_alloc_rows": deleted_alloc_rows, "scopes": len(scopes), }, } def _resolve_file_path(self) -> Optional[str]: """解析台账文件路径。""" raw_path = ( self.config.get("run.ml_manual_ledger_file") or self.config.get("run.ml_manual_file") or os.getenv("ML_MANUAL_LEDGER_FILE") ) if not raw_path: return None candidate = Path(str(raw_path)).expanduser() if not candidate.is_absolute(): candidate = Path.cwd() / candidate if not candidate.exists(): raise FileNotFoundError(f"台账文件不存在: {candidate}") return str(candidate) def _read_excel_rows(self, file_path: str) -> List[Dict[str, Any]]: """读取 Excel 为行字典列表。""" try: from openpyxl import load_workbook except Exception as exc: # noqa: BLE001 raise RuntimeError( "缺少 openpyxl 依赖,无法读取 Excel,请先安装 openpyxl" ) from exc wb = load_workbook(file_path, data_only=True) ws = wb.active header_row = next(ws.iter_rows(min_row=1, max_row=1, values_only=True), None) if not header_row: return [] headers = [str(col).strip() if col is not None else "" for col in header_row] if not headers: return [] rows: List[Dict[str, Any]] = [] for values in ws.iter_rows(min_row=2, values_only=True): if values is None: continue row_dict = {headers[i]: values[i] for i in range(min(len(headers), len(values)))} if self._is_empty_row(row_dict): continue rows.append(row_dict) return rows @staticmethod def _is_empty_row(row: Dict[str, Any]) -> bool: for value in row.values(): if value is None: continue if isinstance(value, str) and not value.strip(): continue return False return True def _normalize_row( self, raw: Dict[str, Any], row_no: int, file_path: str, ) -> Dict[str, Any]: """规范化单行字段。""" site_id = self._to_int(raw.get("site_id"), fallback=self.config.get("app.store_id")) biz_date = self._to_date(raw.get("biz_date")) pay_time = self._to_datetime(raw.get("pay_time"), fallback_date=biz_date) external_id = str(raw.get("external_id") or "").strip() if not external_id: raise ValueError(f"台账行 {row_no} 缺少 external_id(订单ID): {file_path}") member_id = self._to_int(raw.get("member_id"), fallback=0) order_amount = self._to_decimal(raw.get("order_amount")) currency = str(raw.get("currency") or "CNY").strip().upper() or "CNY" remark = str(raw.get("remark") or "").strip() assistants: List[Tuple[int, str]] = [] for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1): aid = self._to_int(raw.get(f"assistant_id_{idx}"), fallback=None) name = str(raw.get(f"assistant_name_{idx}") or "").strip() if aid is None: continue assistants.append((aid, name)) return { "site_id": site_id, "biz_date": biz_date, "external_id": external_id, "member_id": member_id, "pay_time": pay_time, "order_amount": order_amount, "currency": currency, "assistants": assistants, "remark": remark, "row_no": row_no, } def _build_source_row( self, *, normalized: Dict[str, Any], scope: ImportScope, import_batch_no: str, import_file_name: str, import_user: str, import_time: datetime, ) -> Dict[str, Any]: """构造宽表入库行。""" assistants: Sequence[Tuple[int, str]] = normalized["assistants"] row = { "site_id": normalized["site_id"], "biz_date": normalized["biz_date"], "external_id": normalized["external_id"], "member_id": normalized["member_id"], "pay_time": normalized["pay_time"], "order_amount": normalized["order_amount"], "currency": normalized["currency"], "import_batch_no": import_batch_no, "import_file_name": import_file_name, "import_scope_key": scope.scope_key, "import_time": import_time, "import_user": import_user, "row_no": normalized["row_no"], "remark": normalized["remark"], } for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1): aid, aname = (assistants[idx - 1] if idx - 1 < len(assistants) else (None, None)) row[f"assistant_id_{idx}"] = aid row[f"assistant_name_{idx}"] = aname return row def _build_alloc_rows( self, *, normalized: Dict[str, Any], scope: ImportScope, import_batch_no: str, import_file_name: str, import_user: str, import_time: datetime, ) -> List[Dict[str, Any]]: """构造窄表分摊行。""" assistants: Sequence[Tuple[int, str]] = normalized["assistants"] if not assistants: return [] n = Decimal(str(len(assistants))) share_ratio = Decimal("1") / n rows: List[Dict[str, Any]] = [] for assistant_id, assistant_name in assistants: allocated_amount = normalized["order_amount"] * share_ratio rows.append( { "site_id": normalized["site_id"], "biz_date": normalized["biz_date"], "external_id": normalized["external_id"], "member_id": normalized["member_id"], "pay_time": normalized["pay_time"], "order_amount": normalized["order_amount"], "assistant_id": assistant_id, "assistant_name": assistant_name, "share_ratio": share_ratio, "allocated_amount": allocated_amount, "currency": normalized["currency"], "import_scope_key": scope.scope_key, "import_batch_no": import_batch_no, "import_file_name": import_file_name, "import_time": import_time, "import_user": import_user, } ) return rows @classmethod def resolve_scope(cls, site_id: int, biz_date: date, today: date) -> ImportScope: """按规则解析覆盖范围。""" day_diff = (today - biz_date).days if day_diff <= cls.HISTORICAL_BUCKET_DAYS: return ImportScope( site_id=site_id, scope_type="DAY", start_date=biz_date, end_date=biz_date, ) bucket_start, bucket_end = cls.resolve_p30_bucket(biz_date) return ImportScope( site_id=site_id, scope_type="P30", start_date=bucket_start, end_date=bucket_end, ) @classmethod def resolve_p30_bucket(cls, biz_date: date) -> Tuple[date, date]: """固定纪元 30 天分桶。""" delta_days = (biz_date - cls.EPOCH_ANCHOR).days bucket_index = delta_days // cls.HISTORICAL_BUCKET_DAYS bucket_start = cls.EPOCH_ANCHOR + timedelta(days=bucket_index * cls.HISTORICAL_BUCKET_DAYS) bucket_end = bucket_start + timedelta(days=cls.HISTORICAL_BUCKET_DAYS - 1) return bucket_start, bucket_end def _delete_by_scopes(self, scopes: Iterable[ImportScope]) -> Tuple[int, int]: """按 scope 先删后写,保证整批覆盖。""" deleted_source = 0 deleted_alloc = 0 with self.db.conn.cursor() as cur: for scope in scopes: if scope.scope_type == "DAY": cur.execute( """ DELETE FROM billiards_dws.dws_ml_manual_order_source WHERE site_id = %s AND biz_date = %s """, (scope.site_id, scope.start_date), ) deleted_source += max(cur.rowcount, 0) cur.execute( """ DELETE FROM billiards_dws.dws_ml_manual_order_alloc WHERE site_id = %s AND biz_date = %s """, (scope.site_id, scope.start_date), ) deleted_alloc += max(cur.rowcount, 0) else: cur.execute( """ DELETE FROM billiards_dws.dws_ml_manual_order_source WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s """, (scope.site_id, scope.start_date, scope.end_date), ) deleted_source += max(cur.rowcount, 0) cur.execute( """ DELETE FROM billiards_dws.dws_ml_manual_order_alloc WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s """, (scope.site_id, scope.start_date, scope.end_date), ) deleted_alloc += max(cur.rowcount, 0) return deleted_source, deleted_alloc def _insert_source_rows(self, rows: List[Dict[str, Any]]) -> int: if not rows: return 0 columns = [ "site_id", "biz_date", "external_id", "member_id", "pay_time", "order_amount", "currency", "assistant_id_1", "assistant_name_1", "assistant_id_2", "assistant_name_2", "assistant_id_3", "assistant_name_3", "assistant_id_4", "assistant_name_4", "assistant_id_5", "assistant_name_5", "import_batch_no", "import_file_name", "import_scope_key", "import_time", "import_user", "row_no", "remark", "created_at", "updated_at", ] sql = f""" INSERT INTO billiards_dws.dws_ml_manual_order_source ({", ".join(columns)}) VALUES ({", ".join(["%s"] * len(columns))}) """ inserted = 0 with self.db.conn.cursor() as cur: for row in rows: values = [ row.get("site_id"), row.get("biz_date"), row.get("external_id"), row.get("member_id"), row.get("pay_time"), row.get("order_amount"), row.get("currency"), row.get("assistant_id_1"), row.get("assistant_name_1"), row.get("assistant_id_2"), row.get("assistant_name_2"), row.get("assistant_id_3"), row.get("assistant_name_3"), row.get("assistant_id_4"), row.get("assistant_name_4"), row.get("assistant_id_5"), row.get("assistant_name_5"), row.get("import_batch_no"), row.get("import_file_name"), row.get("import_scope_key"), row.get("import_time"), row.get("import_user"), row.get("row_no"), row.get("remark"), row.get("import_time"), row.get("import_time"), ] cur.execute(sql, values) inserted += max(cur.rowcount, 0) return inserted def _upsert_alloc_rows(self, rows: List[Dict[str, Any]]) -> int: if not rows: return 0 columns = [ "site_id", "biz_date", "external_id", "member_id", "pay_time", "order_amount", "assistant_id", "assistant_name", "share_ratio", "allocated_amount", "currency", "import_scope_key", "import_batch_no", "import_file_name", "import_time", "import_user", "created_at", "updated_at", ] sql = f""" INSERT INTO billiards_dws.dws_ml_manual_order_alloc ({", ".join(columns)}) VALUES ({", ".join(["%s"] * len(columns))}) ON CONFLICT (site_id, external_id, assistant_id) DO UPDATE SET biz_date = EXCLUDED.biz_date, member_id = EXCLUDED.member_id, pay_time = EXCLUDED.pay_time, order_amount = EXCLUDED.order_amount, assistant_name = EXCLUDED.assistant_name, share_ratio = EXCLUDED.share_ratio, allocated_amount = EXCLUDED.allocated_amount, currency = EXCLUDED.currency, import_scope_key = EXCLUDED.import_scope_key, import_batch_no = EXCLUDED.import_batch_no, import_file_name = EXCLUDED.import_file_name, import_time = EXCLUDED.import_time, import_user = EXCLUDED.import_user, updated_at = NOW() """ affected = 0 with self.db.conn.cursor() as cur: for row in rows: values = [ row.get("site_id"), row.get("biz_date"), row.get("external_id"), row.get("member_id"), row.get("pay_time"), row.get("order_amount"), row.get("assistant_id"), row.get("assistant_name"), row.get("share_ratio"), row.get("allocated_amount"), row.get("currency"), row.get("import_scope_key"), row.get("import_batch_no"), row.get("import_file_name"), row.get("import_time"), row.get("import_user"), row.get("import_time"), row.get("import_time"), ] cur.execute(sql, values) affected += max(cur.rowcount, 0) return affected @staticmethod def _to_int(value: Any, fallback: Optional[int] = None) -> Optional[int]: if value is None: return fallback if isinstance(value, str) and not value.strip(): return fallback try: return int(value) except Exception: # noqa: BLE001 return fallback @staticmethod def _to_decimal(value: Any) -> Decimal: if value is None or value == "": return Decimal("0") return Decimal(str(value)) @staticmethod def _to_date(value: Any) -> date: if isinstance(value, datetime): return value.date() if isinstance(value, date): return value if isinstance(value, str): text = value.strip() if not text: raise ValueError("biz_date 不能为空") if len(text) >= 10: return datetime.fromisoformat(text[:10]).date() return datetime.fromisoformat(text).date() raise ValueError(f"无法解析 biz_date: {value}") @staticmethod def _to_datetime(value: Any, fallback_date: date) -> datetime: if isinstance(value, datetime): return value if isinstance(value, date): return datetime.combine(value, datetime.min.time()) if isinstance(value, str): text = value.strip() if text: text = text.replace("/", "-") try: return datetime.fromisoformat(text) except Exception: # noqa: BLE001 if len(text) >= 19: return datetime.strptime(text[:19], "%Y-%m-%d %H:%M:%S") return datetime.fromisoformat(text[:10]) return datetime.combine(fallback_date, datetime.min.time()) @staticmethod def _build_import_batch_no(now: datetime) -> str: return f"MLM_{now.strftime('%Y%m%d%H%M%S')}_{str(uuid.uuid4())[:8]}" @staticmethod def _resolve_import_user() -> str: return ( os.getenv("ETL_OPERATOR") or os.getenv("USERNAME") or os.getenv("USER") or "system" ) __all__ = ["MlManualImportTask", "ImportScope"]