624 lines
22 KiB
Python
624 lines
22 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
ML 人工台账导入任务。
|
||
|
||
设计目标:
|
||
1. 人工台账作为 ML 唯一真源;
|
||
2. 同一订单支持多助教归因,默认均分;
|
||
3. 覆盖策略:
|
||
- 近 30 天:按 site_id + biz_date 日覆盖;
|
||
- 超过 30 天:按固定纪元(2026-01-01)切 30 天批次覆盖。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import uuid
|
||
from dataclasses import dataclass
|
||
from datetime import date, datetime, timedelta
|
||
from decimal import Decimal
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||
|
||
from .base_index_task import BaseIndexTask
|
||
from ..base_dws_task import TaskContext
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ImportScope:
|
||
"""导入覆盖范围定义。"""
|
||
|
||
site_id: int
|
||
scope_type: str # DAY / P30
|
||
start_date: date
|
||
end_date: date
|
||
|
||
@property
|
||
def scope_key(self) -> str:
|
||
if self.scope_type == "DAY":
|
||
return f"DAY:{self.site_id}:{self.start_date.isoformat()}"
|
||
return (
|
||
f"P30:{self.site_id}:{self.start_date.isoformat()}:{self.end_date.isoformat()}"
|
||
)
|
||
|
||
|
||
class MlManualImportTask(BaseIndexTask):
|
||
"""导入并拆分 ML 人工台账(订单宽表 + 助教分摊窄表)。"""
|
||
|
||
INDEX_TYPE = "ML"
|
||
EPOCH_ANCHOR = date(2026, 1, 1)
|
||
HISTORICAL_BUCKET_DAYS = 30
|
||
ASSISTANT_SLOT_COUNT = 5
|
||
|
||
# Excel 模板字段(按列顺序)
|
||
TEMPLATE_COLUMNS = [
|
||
"site_id",
|
||
"biz_date",
|
||
"external_id",
|
||
"member_id",
|
||
"pay_time",
|
||
"order_amount",
|
||
"currency",
|
||
"assistant_id_1",
|
||
"assistant_name_1",
|
||
"assistant_id_2",
|
||
"assistant_name_2",
|
||
"assistant_id_3",
|
||
"assistant_name_3",
|
||
"assistant_id_4",
|
||
"assistant_name_4",
|
||
"assistant_id_5",
|
||
"assistant_name_5",
|
||
"remark",
|
||
]
|
||
|
||
def get_task_code(self) -> str:
|
||
return "DWS_ML_MANUAL_IMPORT"
|
||
|
||
def get_target_table(self) -> str:
|
||
return "dws_ml_manual_order_source"
|
||
|
||
def get_primary_keys(self) -> List[str]:
|
||
return ["site_id", "external_id", "import_scope_key", "row_no"]
|
||
|
||
def get_index_type(self) -> str:
|
||
return self.INDEX_TYPE
|
||
|
||
def execute(self, context: Optional[TaskContext]) -> Dict[str, Any]:
|
||
"""
|
||
执行导入。
|
||
|
||
说明:该任务按“文件”运行,不依赖时间窗口。调度器会以工具任务方式直接触发。
|
||
"""
|
||
file_path = self._resolve_file_path()
|
||
if not file_path:
|
||
raise ValueError(
|
||
"未找到 ML 台账文件,请通过环境变量 ML_MANUAL_LEDGER_FILE 或配置 run.ml_manual_ledger_file 指定"
|
||
)
|
||
|
||
rows = self._read_excel_rows(file_path)
|
||
if not rows:
|
||
self.logger.warning("台账文件为空:%s", file_path)
|
||
return {
|
||
"status": "SUCCESS",
|
||
"counts": {
|
||
"source_rows": 0,
|
||
"alloc_rows": 0,
|
||
"deleted_source_rows": 0,
|
||
"deleted_alloc_rows": 0,
|
||
"scopes": 0,
|
||
},
|
||
}
|
||
|
||
now = datetime.now(self.tz)
|
||
today = now.date()
|
||
import_batch_no = self._build_import_batch_no(now)
|
||
import_file_name = Path(file_path).name
|
||
import_user = self._resolve_import_user()
|
||
|
||
source_rows: List[Dict[str, Any]] = []
|
||
alloc_rows: List[Dict[str, Any]] = []
|
||
scope_set: Dict[Tuple[int, str, date, date], ImportScope] = {}
|
||
|
||
for idx, raw in enumerate(rows, start=2):
|
||
normalized = self._normalize_row(raw, row_no=idx, file_path=file_path)
|
||
row_scope = self.resolve_scope(
|
||
site_id=normalized["site_id"],
|
||
biz_date=normalized["biz_date"],
|
||
today=today,
|
||
)
|
||
scope_set[(row_scope.site_id, row_scope.scope_type, row_scope.start_date, row_scope.end_date)] = row_scope
|
||
|
||
source_row = self._build_source_row(
|
||
normalized=normalized,
|
||
scope=row_scope,
|
||
import_batch_no=import_batch_no,
|
||
import_file_name=import_file_name,
|
||
import_user=import_user,
|
||
import_time=now,
|
||
)
|
||
source_rows.append(source_row)
|
||
|
||
alloc_rows.extend(
|
||
self._build_alloc_rows(
|
||
normalized=normalized,
|
||
scope=row_scope,
|
||
import_batch_no=import_batch_no,
|
||
import_file_name=import_file_name,
|
||
import_user=import_user,
|
||
import_time=now,
|
||
)
|
||
)
|
||
|
||
scopes = list(scope_set.values())
|
||
deleted_source_rows, deleted_alloc_rows = self._delete_by_scopes(scopes)
|
||
inserted_source = self._insert_source_rows(source_rows)
|
||
upserted_alloc = self._upsert_alloc_rows(alloc_rows)
|
||
|
||
self.db.conn.commit()
|
||
self.logger.info(
|
||
"ML 人工台账导入完成: file=%s source=%d alloc=%d scopes=%d",
|
||
file_path,
|
||
inserted_source,
|
||
upserted_alloc,
|
||
len(scopes),
|
||
)
|
||
return {
|
||
"status": "SUCCESS",
|
||
"counts": {
|
||
"source_rows": inserted_source,
|
||
"alloc_rows": upserted_alloc,
|
||
"deleted_source_rows": deleted_source_rows,
|
||
"deleted_alloc_rows": deleted_alloc_rows,
|
||
"scopes": len(scopes),
|
||
},
|
||
}
|
||
|
||
def _resolve_file_path(self) -> Optional[str]:
|
||
"""解析台账文件路径。"""
|
||
raw_path = (
|
||
self.config.get("run.ml_manual_ledger_file")
|
||
or self.config.get("run.ml_manual_file")
|
||
or os.getenv("ML_MANUAL_LEDGER_FILE")
|
||
)
|
||
if not raw_path:
|
||
return None
|
||
candidate = Path(str(raw_path)).expanduser()
|
||
if not candidate.is_absolute():
|
||
candidate = Path.cwd() / candidate
|
||
if not candidate.exists():
|
||
raise FileNotFoundError(f"台账文件不存在: {candidate}")
|
||
return str(candidate)
|
||
|
||
def _read_excel_rows(self, file_path: str) -> List[Dict[str, Any]]:
|
||
"""读取 Excel 为行字典列表。"""
|
||
try:
|
||
from openpyxl import load_workbook
|
||
except Exception as exc: # noqa: BLE001
|
||
raise RuntimeError(
|
||
"缺少 openpyxl 依赖,无法读取 Excel,请先安装 openpyxl"
|
||
) from exc
|
||
|
||
wb = load_workbook(file_path, data_only=True)
|
||
ws = wb.active
|
||
header_row = next(ws.iter_rows(min_row=1, max_row=1, values_only=True), None)
|
||
if not header_row:
|
||
return []
|
||
|
||
headers = [str(col).strip() if col is not None else "" for col in header_row]
|
||
if not headers:
|
||
return []
|
||
|
||
rows: List[Dict[str, Any]] = []
|
||
for values in ws.iter_rows(min_row=2, values_only=True):
|
||
if values is None:
|
||
continue
|
||
row_dict = {headers[i]: values[i] for i in range(min(len(headers), len(values)))}
|
||
if self._is_empty_row(row_dict):
|
||
continue
|
||
rows.append(row_dict)
|
||
return rows
|
||
|
||
@staticmethod
|
||
def _is_empty_row(row: Dict[str, Any]) -> bool:
|
||
for value in row.values():
|
||
if value is None:
|
||
continue
|
||
if isinstance(value, str) and not value.strip():
|
||
continue
|
||
return False
|
||
return True
|
||
|
||
def _normalize_row(
|
||
self,
|
||
raw: Dict[str, Any],
|
||
row_no: int,
|
||
file_path: str,
|
||
) -> Dict[str, Any]:
|
||
"""规范化单行字段。"""
|
||
site_id = self._to_int(raw.get("site_id"), fallback=self.config.get("app.store_id"))
|
||
biz_date = self._to_date(raw.get("biz_date"))
|
||
pay_time = self._to_datetime(raw.get("pay_time"), fallback_date=biz_date)
|
||
external_id = str(raw.get("external_id") or "").strip()
|
||
if not external_id:
|
||
raise ValueError(f"台账行 {row_no} 缺少 external_id(订单ID): {file_path}")
|
||
|
||
member_id = self._to_int(raw.get("member_id"), fallback=0)
|
||
order_amount = self._to_decimal(raw.get("order_amount"))
|
||
currency = str(raw.get("currency") or "CNY").strip().upper() or "CNY"
|
||
remark = str(raw.get("remark") or "").strip()
|
||
|
||
assistants: List[Tuple[int, str]] = []
|
||
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
|
||
aid = self._to_int(raw.get(f"assistant_id_{idx}"), fallback=None)
|
||
name = str(raw.get(f"assistant_name_{idx}") or "").strip()
|
||
if aid is None:
|
||
continue
|
||
assistants.append((aid, name))
|
||
|
||
return {
|
||
"site_id": site_id,
|
||
"biz_date": biz_date,
|
||
"external_id": external_id,
|
||
"member_id": member_id,
|
||
"pay_time": pay_time,
|
||
"order_amount": order_amount,
|
||
"currency": currency,
|
||
"assistants": assistants,
|
||
"remark": remark,
|
||
"row_no": row_no,
|
||
}
|
||
|
||
def _build_source_row(
|
||
self,
|
||
*,
|
||
normalized: Dict[str, Any],
|
||
scope: ImportScope,
|
||
import_batch_no: str,
|
||
import_file_name: str,
|
||
import_user: str,
|
||
import_time: datetime,
|
||
) -> Dict[str, Any]:
|
||
"""构造宽表入库行。"""
|
||
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
|
||
row = {
|
||
"site_id": normalized["site_id"],
|
||
"biz_date": normalized["biz_date"],
|
||
"external_id": normalized["external_id"],
|
||
"member_id": normalized["member_id"],
|
||
"pay_time": normalized["pay_time"],
|
||
"order_amount": normalized["order_amount"],
|
||
"currency": normalized["currency"],
|
||
"import_batch_no": import_batch_no,
|
||
"import_file_name": import_file_name,
|
||
"import_scope_key": scope.scope_key,
|
||
"import_time": import_time,
|
||
"import_user": import_user,
|
||
"row_no": normalized["row_no"],
|
||
"remark": normalized["remark"],
|
||
}
|
||
for idx in range(1, self.ASSISTANT_SLOT_COUNT + 1):
|
||
aid, aname = (assistants[idx - 1] if idx - 1 < len(assistants) else (None, None))
|
||
row[f"assistant_id_{idx}"] = aid
|
||
row[f"assistant_name_{idx}"] = aname
|
||
return row
|
||
|
||
def _build_alloc_rows(
|
||
self,
|
||
*,
|
||
normalized: Dict[str, Any],
|
||
scope: ImportScope,
|
||
import_batch_no: str,
|
||
import_file_name: str,
|
||
import_user: str,
|
||
import_time: datetime,
|
||
) -> List[Dict[str, Any]]:
|
||
"""构造窄表分摊行。"""
|
||
assistants: Sequence[Tuple[int, str]] = normalized["assistants"]
|
||
if not assistants:
|
||
return []
|
||
|
||
n = Decimal(str(len(assistants)))
|
||
share_ratio = Decimal("1") / n
|
||
rows: List[Dict[str, Any]] = []
|
||
for assistant_id, assistant_name in assistants:
|
||
allocated_amount = normalized["order_amount"] * share_ratio
|
||
rows.append(
|
||
{
|
||
"site_id": normalized["site_id"],
|
||
"biz_date": normalized["biz_date"],
|
||
"external_id": normalized["external_id"],
|
||
"member_id": normalized["member_id"],
|
||
"pay_time": normalized["pay_time"],
|
||
"order_amount": normalized["order_amount"],
|
||
"assistant_id": assistant_id,
|
||
"assistant_name": assistant_name,
|
||
"share_ratio": share_ratio,
|
||
"allocated_amount": allocated_amount,
|
||
"currency": normalized["currency"],
|
||
"import_scope_key": scope.scope_key,
|
||
"import_batch_no": import_batch_no,
|
||
"import_file_name": import_file_name,
|
||
"import_time": import_time,
|
||
"import_user": import_user,
|
||
}
|
||
)
|
||
return rows
|
||
|
||
@classmethod
|
||
def resolve_scope(cls, site_id: int, biz_date: date, today: date) -> ImportScope:
|
||
"""按规则解析覆盖范围。"""
|
||
day_diff = (today - biz_date).days
|
||
if day_diff <= cls.HISTORICAL_BUCKET_DAYS:
|
||
return ImportScope(
|
||
site_id=site_id,
|
||
scope_type="DAY",
|
||
start_date=biz_date,
|
||
end_date=biz_date,
|
||
)
|
||
|
||
bucket_start, bucket_end = cls.resolve_p30_bucket(biz_date)
|
||
return ImportScope(
|
||
site_id=site_id,
|
||
scope_type="P30",
|
||
start_date=bucket_start,
|
||
end_date=bucket_end,
|
||
)
|
||
|
||
@classmethod
|
||
def resolve_p30_bucket(cls, biz_date: date) -> Tuple[date, date]:
|
||
"""固定纪元 30 天分桶。"""
|
||
delta_days = (biz_date - cls.EPOCH_ANCHOR).days
|
||
bucket_index = delta_days // cls.HISTORICAL_BUCKET_DAYS
|
||
bucket_start = cls.EPOCH_ANCHOR + timedelta(days=bucket_index * cls.HISTORICAL_BUCKET_DAYS)
|
||
bucket_end = bucket_start + timedelta(days=cls.HISTORICAL_BUCKET_DAYS - 1)
|
||
return bucket_start, bucket_end
|
||
|
||
def _delete_by_scopes(self, scopes: Iterable[ImportScope]) -> Tuple[int, int]:
|
||
"""按 scope 先删后写,保证整批覆盖。"""
|
||
deleted_source = 0
|
||
deleted_alloc = 0
|
||
with self.db.conn.cursor() as cur:
|
||
for scope in scopes:
|
||
if scope.scope_type == "DAY":
|
||
cur.execute(
|
||
"""
|
||
DELETE FROM billiards_dws.dws_ml_manual_order_source
|
||
WHERE site_id = %s AND biz_date = %s
|
||
""",
|
||
(scope.site_id, scope.start_date),
|
||
)
|
||
deleted_source += max(cur.rowcount, 0)
|
||
cur.execute(
|
||
"""
|
||
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
|
||
WHERE site_id = %s AND biz_date = %s
|
||
""",
|
||
(scope.site_id, scope.start_date),
|
||
)
|
||
deleted_alloc += max(cur.rowcount, 0)
|
||
else:
|
||
cur.execute(
|
||
"""
|
||
DELETE FROM billiards_dws.dws_ml_manual_order_source
|
||
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
|
||
""",
|
||
(scope.site_id, scope.start_date, scope.end_date),
|
||
)
|
||
deleted_source += max(cur.rowcount, 0)
|
||
cur.execute(
|
||
"""
|
||
DELETE FROM billiards_dws.dws_ml_manual_order_alloc
|
||
WHERE site_id = %s AND biz_date >= %s AND biz_date <= %s
|
||
""",
|
||
(scope.site_id, scope.start_date, scope.end_date),
|
||
)
|
||
deleted_alloc += max(cur.rowcount, 0)
|
||
return deleted_source, deleted_alloc
|
||
|
||
def _insert_source_rows(self, rows: List[Dict[str, Any]]) -> int:
|
||
if not rows:
|
||
return 0
|
||
columns = [
|
||
"site_id",
|
||
"biz_date",
|
||
"external_id",
|
||
"member_id",
|
||
"pay_time",
|
||
"order_amount",
|
||
"currency",
|
||
"assistant_id_1",
|
||
"assistant_name_1",
|
||
"assistant_id_2",
|
||
"assistant_name_2",
|
||
"assistant_id_3",
|
||
"assistant_name_3",
|
||
"assistant_id_4",
|
||
"assistant_name_4",
|
||
"assistant_id_5",
|
||
"assistant_name_5",
|
||
"import_batch_no",
|
||
"import_file_name",
|
||
"import_scope_key",
|
||
"import_time",
|
||
"import_user",
|
||
"row_no",
|
||
"remark",
|
||
"created_at",
|
||
"updated_at",
|
||
]
|
||
sql = f"""
|
||
INSERT INTO billiards_dws.dws_ml_manual_order_source ({", ".join(columns)})
|
||
VALUES ({", ".join(["%s"] * len(columns))})
|
||
"""
|
||
inserted = 0
|
||
with self.db.conn.cursor() as cur:
|
||
for row in rows:
|
||
values = [
|
||
row.get("site_id"),
|
||
row.get("biz_date"),
|
||
row.get("external_id"),
|
||
row.get("member_id"),
|
||
row.get("pay_time"),
|
||
row.get("order_amount"),
|
||
row.get("currency"),
|
||
row.get("assistant_id_1"),
|
||
row.get("assistant_name_1"),
|
||
row.get("assistant_id_2"),
|
||
row.get("assistant_name_2"),
|
||
row.get("assistant_id_3"),
|
||
row.get("assistant_name_3"),
|
||
row.get("assistant_id_4"),
|
||
row.get("assistant_name_4"),
|
||
row.get("assistant_id_5"),
|
||
row.get("assistant_name_5"),
|
||
row.get("import_batch_no"),
|
||
row.get("import_file_name"),
|
||
row.get("import_scope_key"),
|
||
row.get("import_time"),
|
||
row.get("import_user"),
|
||
row.get("row_no"),
|
||
row.get("remark"),
|
||
row.get("import_time"),
|
||
row.get("import_time"),
|
||
]
|
||
cur.execute(sql, values)
|
||
inserted += max(cur.rowcount, 0)
|
||
return inserted
|
||
|
||
def _upsert_alloc_rows(self, rows: List[Dict[str, Any]]) -> int:
|
||
if not rows:
|
||
return 0
|
||
columns = [
|
||
"site_id",
|
||
"biz_date",
|
||
"external_id",
|
||
"member_id",
|
||
"pay_time",
|
||
"order_amount",
|
||
"assistant_id",
|
||
"assistant_name",
|
||
"share_ratio",
|
||
"allocated_amount",
|
||
"currency",
|
||
"import_scope_key",
|
||
"import_batch_no",
|
||
"import_file_name",
|
||
"import_time",
|
||
"import_user",
|
||
"created_at",
|
||
"updated_at",
|
||
]
|
||
sql = f"""
|
||
INSERT INTO billiards_dws.dws_ml_manual_order_alloc ({", ".join(columns)})
|
||
VALUES ({", ".join(["%s"] * len(columns))})
|
||
ON CONFLICT (site_id, external_id, assistant_id)
|
||
DO UPDATE SET
|
||
biz_date = EXCLUDED.biz_date,
|
||
member_id = EXCLUDED.member_id,
|
||
pay_time = EXCLUDED.pay_time,
|
||
order_amount = EXCLUDED.order_amount,
|
||
assistant_name = EXCLUDED.assistant_name,
|
||
share_ratio = EXCLUDED.share_ratio,
|
||
allocated_amount = EXCLUDED.allocated_amount,
|
||
currency = EXCLUDED.currency,
|
||
import_scope_key = EXCLUDED.import_scope_key,
|
||
import_batch_no = EXCLUDED.import_batch_no,
|
||
import_file_name = EXCLUDED.import_file_name,
|
||
import_time = EXCLUDED.import_time,
|
||
import_user = EXCLUDED.import_user,
|
||
updated_at = NOW()
|
||
"""
|
||
affected = 0
|
||
with self.db.conn.cursor() as cur:
|
||
for row in rows:
|
||
values = [
|
||
row.get("site_id"),
|
||
row.get("biz_date"),
|
||
row.get("external_id"),
|
||
row.get("member_id"),
|
||
row.get("pay_time"),
|
||
row.get("order_amount"),
|
||
row.get("assistant_id"),
|
||
row.get("assistant_name"),
|
||
row.get("share_ratio"),
|
||
row.get("allocated_amount"),
|
||
row.get("currency"),
|
||
row.get("import_scope_key"),
|
||
row.get("import_batch_no"),
|
||
row.get("import_file_name"),
|
||
row.get("import_time"),
|
||
row.get("import_user"),
|
||
row.get("import_time"),
|
||
row.get("import_time"),
|
||
]
|
||
cur.execute(sql, values)
|
||
affected += max(cur.rowcount, 0)
|
||
return affected
|
||
|
||
@staticmethod
|
||
def _to_int(value: Any, fallback: Optional[int] = None) -> Optional[int]:
|
||
if value is None:
|
||
return fallback
|
||
if isinstance(value, str) and not value.strip():
|
||
return fallback
|
||
try:
|
||
return int(value)
|
||
except Exception: # noqa: BLE001
|
||
return fallback
|
||
|
||
@staticmethod
|
||
def _to_decimal(value: Any) -> Decimal:
|
||
if value is None or value == "":
|
||
return Decimal("0")
|
||
return Decimal(str(value))
|
||
|
||
@staticmethod
|
||
def _to_date(value: Any) -> date:
|
||
if isinstance(value, datetime):
|
||
return value.date()
|
||
if isinstance(value, date):
|
||
return value
|
||
if isinstance(value, str):
|
||
text = value.strip()
|
||
if not text:
|
||
raise ValueError("biz_date 不能为空")
|
||
if len(text) >= 10:
|
||
return datetime.fromisoformat(text[:10]).date()
|
||
return datetime.fromisoformat(text).date()
|
||
raise ValueError(f"无法解析 biz_date: {value}")
|
||
|
||
@staticmethod
|
||
def _to_datetime(value: Any, fallback_date: date) -> datetime:
|
||
if isinstance(value, datetime):
|
||
return value
|
||
if isinstance(value, date):
|
||
return datetime.combine(value, datetime.min.time())
|
||
if isinstance(value, str):
|
||
text = value.strip()
|
||
if text:
|
||
text = text.replace("/", "-")
|
||
try:
|
||
return datetime.fromisoformat(text)
|
||
except Exception: # noqa: BLE001
|
||
if len(text) >= 19:
|
||
return datetime.strptime(text[:19], "%Y-%m-%d %H:%M:%S")
|
||
return datetime.fromisoformat(text[:10])
|
||
return datetime.combine(fallback_date, datetime.min.time())
|
||
|
||
@staticmethod
|
||
def _build_import_batch_no(now: datetime) -> str:
|
||
return f"MLM_{now.strftime('%Y%m%d%H%M%S')}_{str(uuid.uuid4())[:8]}"
|
||
|
||
@staticmethod
|
||
def _resolve_import_user() -> str:
|
||
return (
|
||
os.getenv("ETL_OPERATOR")
|
||
or os.getenv("USERNAME")
|
||
or os.getenv("USER")
|
||
or "system"
|
||
)
|
||
|
||
|
||
__all__ = ["MlManualImportTask", "ImportScope"]
|