Files
ZQYY.FQ-ETL/tasks/dws/index/ml_manual_import_task.py

624 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
ML 人工台账导入任务。
设计目标:
1. 人工台账作为 ML 唯一真源;
2. 同一订单支持多助教归因,默认均分;
3. 覆盖策略:
- 近 30 天:按 site_id + biz_date 日覆盖;
- 超过 30 天按固定纪元2026-01-01切 30 天批次覆盖。
"""
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from .base_index_task import BaseIndexTask
from ..base_dws_task import TaskContext
@dataclass(frozen=True)
class ImportScope:
"""导入覆盖范围定义。"""
site_id: int
scope_type: str # DAY / P30
start_date: date
end_date: date
@property
def scope_key(self) -> str:
if self.scope_type == "DAY":
return f"DAY:{self.site_id}:{self.start_date.isoformat()}"
return (
f"P30:{self.site_id}:{self.start_date.isoformat()}:{self.end_date.isoformat()}"
)
class MlManualImportTask(BaseIndexTask):
"""导入并拆分 ML 人工台账(订单宽表 + 助教分摊窄表)。"""
INDEX_TYPE = "ML"
EPOCH_ANCHOR = date(2026, 1, 1)
HISTORICAL_BUCKET_DAYS = 30
ASSISTANT_SLOT_COUNT = 5
# Excel 模板字段(按列顺序)
TEMPLATE_COLUMNS = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"currency",
"assistant_id_1",
"assistant_name_1",
"assistant_id_2",
"assistant_name_2",
"assistant_id_3",
"assistant_name_3",
"assistant_id_4",
"assistant_name_4",
"assistant_id_5",
"assistant_name_5",
"remark",
]
def get_task_code(self) -> str:
return "DWS_ML_MANUAL_IMPORT"
def get_target_table(self) -> str:
return "dws_ml_manual_order_source"
def get_primary_keys(self) -> List[str]:
return ["site_id", "external_id", "import_scope_key", "row_no"]
def get_index_type(self) -> str:
return self.INDEX_TYPE
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
"""
执行导入。
说明:该任务按“文件”运行,不依赖时间窗口。调度器会以工具任务方式直接触发。
"""
file_path = self._resolve_file_path()
if not file_path:
raise ValueError(
"未找到 ML 台账文件,请通过环境变量 ML_MANUAL_LEDGER_FILE 或配置 run.ml_manual_ledger_file 指定"
)
rows = self._read_excel_rows(file_path)
if not rows:
self.logger.warning("台账文件为空:%s", file_path)
return {
"status": "SUCCESS",
"counts": {
"source_rows": 0,
"alloc_rows": 0,
"deleted_source_rows": 0,
"deleted_alloc_rows": 0,
"scopes": 0,
},
}
now = datetime.now(self.tz)
today = now.date()
import_batch_no = self._build_import_batch_no(now)
import_file_name = Path(file_path).name
import_user = self._resolve_import_user()
source_rows: List[Dict[str, Any]] = []
alloc_rows: List[Dict[str, Any]] = []
scope_set: Dict[Tuple[int, str, date, date], ImportScope] = {}
for idx, raw in enumerate(rows, start=2):
normalized = self._normalize_row(raw, row_no=idx, file_path=file_path)
row_scope = self.resolve_scope(
site_id=normalized["site_id"],
biz_date=normalized["biz_date"],
today=today,
)
scope_set[(row_scope.site_id, row_scope.scope_type, row_scope.start_date, row_scope.end_date)] = row_scope
source_row = self._build_source_row(
normalized=normalized,
scope=row_scope,
import_batch_no=import_batch_no,
import_file_name=import_file_name,
import_user=import_user,
import_time=now,
)
source_rows.append(source_row)
alloc_rows.extend(
self._build_alloc_rows(
normalized=normalized,
scope=row_scope,
import_batch_no=import_batch_no,
import_file_name=import_file_name,
import_user=import_user,
import_time=now,
)
)
scopes = list(scope_set.values())
deleted_source_rows, deleted_alloc_rows = self._delete_by_scopes(scopes)
inserted_source = self._insert_source_rows(source_rows)
upserted_alloc = self._upsert_alloc_rows(alloc_rows)
self.db.conn.commit()
self.logger.info(
"ML 人工台账导入完成: file=%s source=%d alloc=%d scopes=%d",
file_path,
inserted_source,
upserted_alloc,
len(scopes),
)
return {
"status": "SUCCESS",
"counts": {
"source_rows": inserted_source,
"alloc_rows": upserted_alloc,
"deleted_source_rows": deleted_source_rows,
"deleted_alloc_rows": deleted_alloc_rows,
"scopes": len(scopes),
},
}
def _resolve_file_path(self) -> Optional[str]:
"""解析台账文件路径。"""
raw_path = (
self.config.get("run.ml_manual_ledger_file")
or self.config.get("run.ml_manual_file")
or os.getenv("ML_MANUAL_LEDGER_FILE")
)
if not raw_path:
return None
candidate = Path(str(raw_path)).expanduser()
if not candidate.is_absolute():
candidate = Path.cwd() / candidate
if not candidate.exists():
raise FileNotFoundError(f"台账文件不存在: {candidate}")
return str(candidate)
def _read_excel_rows(self, file_path: str) -> List[Dict[str, Any]]:
"""读取 Excel 为行字典列表。"""
try:
from openpyxl import load_workbook
except Exception as exc: # noqa: BLE001
raise RuntimeError(
"缺少 openpyxl 依赖,无法读取 Excel请先安装 openpyxl"
) from exc
wb = load_workbook(file_path, data_only=True)
ws = wb.active
header_row = next(ws.iter_rows(min_row=1, max_row=1, values_only=True), None)
if not header_row:
return []
headers = [str(col).strip() if col is not None else "" for col in header_row]
if not headers:
return []
rows: List[Dict[str, Any]] = []
for values in ws.iter_rows(min_row=2, values_only=True):
if values is None:
continue
row_dict = {headers[i]: values[i] for i in range(min(len(headers), len(values)))}
if self._is_empty_row(row_dict):
continue
rows.append(row_dict)
return rows
@staticmethod
def _is_empty_row(row: Dict[str, Any]) -> bool:
for value in row.values():
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
return False
return True
def _normalize_row(
self,
raw: Dict[str, Any],
row_no: int,
file_path: str,
) -> Dict[str, Any]:
"""规范化单行字段。"""
site_id = self._to_int(raw.get("site_id"), fallback=self.config.get("app.store_id"))
biz_date = self._to_date(raw.get("biz_date"))
pay_time = self._to_datetime(raw.get("pay_time"), fallback_date=biz_date)
external_id = str(raw.get("external_id") or "").strip()
if not external_id:
raise ValueError(f"台账行 {row_no} 缺少 external_id订单ID: {file_path}")
member_id = self._to_int(raw.get("member_id"), fallback=0)
order_amount = self._to_decimal(raw.get("order_amount"))
currency = str(raw.get("currency") or "CNY").strip().upper() or "CNY"
remark = str(raw.get("remark") or "").strip()
assistants: List[Tuple[int, str]] = []
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
aid = self._to_int(raw.get(f"assistant_id_{idx}"), fallback=None)
name = str(raw.get(f"assistant_name_{idx}") or "").strip()
if aid is None:
continue
assistants.append((aid, name))
return {
"site_id": site_id,
"biz_date": biz_date,
"external_id": external_id,
"member_id": member_id,
"pay_time": pay_time,
"order_amount": order_amount,
"currency": currency,
"assistants": assistants,
"remark": remark,
"row_no": row_no,
}
def _build_source_row(
self,
*,
normalized: Dict[str, Any],
scope: ImportScope,
import_batch_no: str,
import_file_name: str,
import_user: str,
import_time: datetime,
) -> Dict[str, Any]:
"""构造宽表入库行。"""
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
row = {
"site_id": normalized["site_id"],
"biz_date": normalized["biz_date"],
"external_id": normalized["external_id"],
"member_id": normalized["member_id"],
"pay_time": normalized["pay_time"],
"order_amount": normalized["order_amount"],
"currency": normalized["currency"],
"import_batch_no": import_batch_no,
"import_file_name": import_file_name,
"import_scope_key": scope.scope_key,
"import_time": import_time,
"import_user": import_user,
"row_no": normalized["row_no"],
"remark": normalized["remark"],
}
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
aid, aname = (assistants[idx - 1] if idx - 1 < len(assistants) else (None, None))
row[f"assistant_id_{idx}"] = aid
row[f"assistant_name_{idx}"] = aname
return row
def _build_alloc_rows(
self,
*,
normalized: Dict[str, Any],
scope: ImportScope,
import_batch_no: str,
import_file_name: str,
import_user: str,
import_time: datetime,
) -> List[Dict[str, Any]]:
"""构造窄表分摊行。"""
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
if not assistants:
return []
n = Decimal(str(len(assistants)))
share_ratio = Decimal("1") / n
rows: List[Dict[str, Any]] = []
for assistant_id, assistant_name in assistants:
allocated_amount = normalized["order_amount"] * share_ratio
rows.append(
{
"site_id": normalized["site_id"],
"biz_date": normalized["biz_date"],
"external_id": normalized["external_id"],
"member_id": normalized["member_id"],
"pay_time": normalized["pay_time"],
"order_amount": normalized["order_amount"],
"assistant_id": assistant_id,
"assistant_name": assistant_name,
"share_ratio": share_ratio,
"allocated_amount": allocated_amount,
"currency": normalized["currency"],
"import_scope_key": scope.scope_key,
"import_batch_no": import_batch_no,
"import_file_name": import_file_name,
"import_time": import_time,
"import_user": import_user,
}
)
return rows
@classmethod
def resolve_scope(cls, site_id: int, biz_date: date, today: date) -> ImportScope:
"""按规则解析覆盖范围。"""
day_diff = (today - biz_date).days
if day_diff <= cls.HISTORICAL_BUCKET_DAYS:
return ImportScope(
site_id=site_id,
scope_type="DAY",
start_date=biz_date,
end_date=biz_date,
)
bucket_start, bucket_end = cls.resolve_p30_bucket(biz_date)
return ImportScope(
site_id=site_id,
scope_type="P30",
start_date=bucket_start,
end_date=bucket_end,
)
@classmethod
def resolve_p30_bucket(cls, biz_date: date) -> Tuple[date, date]:
"""固定纪元 30 天分桶。"""
delta_days = (biz_date - cls.EPOCH_ANCHOR).days
bucket_index = delta_days // cls.HISTORICAL_BUCKET_DAYS
bucket_start = cls.EPOCH_ANCHOR + timedelta(days=bucket_index * cls.HISTORICAL_BUCKET_DAYS)
bucket_end = bucket_start + timedelta(days=cls.HISTORICAL_BUCKET_DAYS - 1)
return bucket_start, bucket_end
def _delete_by_scopes(self, scopes: Iterable[ImportScope]) -> Tuple[int, int]:
"""按 scope 先删后写,保证整批覆盖。"""
deleted_source = 0
deleted_alloc = 0
with self.db.conn.cursor() as cur:
for scope in scopes:
if scope.scope_type == "DAY":
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_source
WHERE site_id = %s AND biz_date = %s
""",
(scope.site_id, scope.start_date),
)
deleted_source += max(cur.rowcount, 0)
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s AND biz_date = %s
""",
(scope.site_id, scope.start_date),
)
deleted_alloc += max(cur.rowcount, 0)
else:
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_source
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
""",
(scope.site_id, scope.start_date, scope.end_date),
)
deleted_source += max(cur.rowcount, 0)
cur.execute(
"""
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
""",
(scope.site_id, scope.start_date, scope.end_date),
)
deleted_alloc += max(cur.rowcount, 0)
return deleted_source, deleted_alloc
def _insert_source_rows(self, rows: List[Dict[str, Any]]) -> int:
if not rows:
return 0
columns = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"currency",
"assistant_id_1",
"assistant_name_1",
"assistant_id_2",
"assistant_name_2",
"assistant_id_3",
"assistant_name_3",
"assistant_id_4",
"assistant_name_4",
"assistant_id_5",
"assistant_name_5",
"import_batch_no",
"import_file_name",
"import_scope_key",
"import_time",
"import_user",
"row_no",
"remark",
"created_at",
"updated_at",
]
sql = f"""
INSERT INTO billiards_dws.dws_ml_manual_order_source ({", ".join(columns)})
VALUES ({", ".join(["%s"] * len(columns))})
"""
inserted = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [
row.get("site_id"),
row.get("biz_date"),
row.get("external_id"),
row.get("member_id"),
row.get("pay_time"),
row.get("order_amount"),
row.get("currency"),
row.get("assistant_id_1"),
row.get("assistant_name_1"),
row.get("assistant_id_2"),
row.get("assistant_name_2"),
row.get("assistant_id_3"),
row.get("assistant_name_3"),
row.get("assistant_id_4"),
row.get("assistant_name_4"),
row.get("assistant_id_5"),
row.get("assistant_name_5"),
row.get("import_batch_no"),
row.get("import_file_name"),
row.get("import_scope_key"),
row.get("import_time"),
row.get("import_user"),
row.get("row_no"),
row.get("remark"),
row.get("import_time"),
row.get("import_time"),
]
cur.execute(sql, values)
inserted += max(cur.rowcount, 0)
return inserted
def _upsert_alloc_rows(self, rows: List[Dict[str, Any]]) -> int:
if not rows:
return 0
columns = [
"site_id",
"biz_date",
"external_id",
"member_id",
"pay_time",
"order_amount",
"assistant_id",
"assistant_name",
"share_ratio",
"allocated_amount",
"currency",
"import_scope_key",
"import_batch_no",
"import_file_name",
"import_time",
"import_user",
"created_at",
"updated_at",
]
sql = f"""
INSERT INTO billiards_dws.dws_ml_manual_order_alloc ({", ".join(columns)})
VALUES ({", ".join(["%s"] * len(columns))})
ON CONFLICT (site_id, external_id, assistant_id)
DO UPDATE SET
biz_date = EXCLUDED.biz_date,
member_id = EXCLUDED.member_id,
pay_time = EXCLUDED.pay_time,
order_amount = EXCLUDED.order_amount,
assistant_name = EXCLUDED.assistant_name,
share_ratio = EXCLUDED.share_ratio,
allocated_amount = EXCLUDED.allocated_amount,
currency = EXCLUDED.currency,
import_scope_key = EXCLUDED.import_scope_key,
import_batch_no = EXCLUDED.import_batch_no,
import_file_name = EXCLUDED.import_file_name,
import_time = EXCLUDED.import_time,
import_user = EXCLUDED.import_user,
updated_at = NOW()
"""
affected = 0
with self.db.conn.cursor() as cur:
for row in rows:
values = [
row.get("site_id"),
row.get("biz_date"),
row.get("external_id"),
row.get("member_id"),
row.get("pay_time"),
row.get("order_amount"),
row.get("assistant_id"),
row.get("assistant_name"),
row.get("share_ratio"),
row.get("allocated_amount"),
row.get("currency"),
row.get("import_scope_key"),
row.get("import_batch_no"),
row.get("import_file_name"),
row.get("import_time"),
row.get("import_user"),
row.get("import_time"),
row.get("import_time"),
]
cur.execute(sql, values)
affected += max(cur.rowcount, 0)
return affected
@staticmethod
def _to_int(value: Any, fallback: Optional[int] = None) -> Optional[int]:
if value is None:
return fallback
if isinstance(value, str) and not value.strip():
return fallback
try:
return int(value)
except Exception: # noqa: BLE001
return fallback
@staticmethod
def _to_decimal(value: Any) -> Decimal:
if value is None or value == "":
return Decimal("0")
return Decimal(str(value))
@staticmethod
def _to_date(value: Any) -> date:
if isinstance(value, datetime):
return value.date()
if isinstance(value, date):
return value
if isinstance(value, str):
text = value.strip()
if not text:
raise ValueError("biz_date 不能为空")
if len(text) >= 10:
return datetime.fromisoformat(text[:10]).date()
return datetime.fromisoformat(text).date()
raise ValueError(f"无法解析 biz_date: {value}")
@staticmethod
def _to_datetime(value: Any, fallback_date: date) -> datetime:
if isinstance(value, datetime):
return value
if isinstance(value, date):
return datetime.combine(value, datetime.min.time())
if isinstance(value, str):
text = value.strip()
if text:
text = text.replace("/", "-")
try:
return datetime.fromisoformat(text)
except Exception: # noqa: BLE001
if len(text) >= 19:
return datetime.strptime(text[:19], "%Y-%m-%d %H:%M:%S")
return datetime.fromisoformat(text[:10])
return datetime.combine(fallback_date, datetime.min.time())
@staticmethod
def _build_import_batch_no(now: datetime) -> str:
return f"MLM_{now.strftime('%Y%m%d%H%M%S')}_{str(uuid.uuid4())[:8]}"
@staticmethod
def _resolve_import_user() -> str:
return (
os.getenv("ETL_OPERATOR")
or os.getenv("USERNAME")
or os.getenv("USER")
or "system"
)
__all__ = ["MlManualImportTask", "ImportScope"]