初始提交:飞球 ETL 系统全量代码

This commit is contained in:
Neo
2026-02-13 08:05:34 +08:00
commit 3c51f5485d
441 changed files with 117631 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""工具类任务Schema 初始化、手动入库、数据完整性检查等)"""

View File

@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
"""Task: report last successful cursor cutoff times from etl_admin."""
from __future__ import annotations
from typing import Any
from tasks.base_task import BaseTask
class CheckCutoffTask(BaseTask):
"""Report per-task cursor cutoff times (etl_admin.etl_cursor.last_end)."""
def get_task_code(self) -> str:
return "CHECK_CUTOFF"
def execute(self, cursor_data: dict | None = None) -> dict:
store_id = int(self.config.get("app.store_id"))
filter_codes = self.config.get("run.cutoff_task_codes") or None
if isinstance(filter_codes, str):
filter_codes = [c.strip().upper() for c in filter_codes.split(",") if c.strip()]
sql = """
SELECT
t.task_code,
c.last_start,
c.last_end,
c.last_id,
c.last_run_id,
c.updated_at
FROM etl_admin.etl_task t
LEFT JOIN etl_admin.etl_cursor c
ON c.task_id = t.task_id AND c.store_id = t.store_id
WHERE t.store_id = %s
AND t.enabled = TRUE
ORDER BY t.task_code
"""
rows = self.db.query(sql, (store_id,))
if filter_codes:
wanted = {str(c).upper() for c in filter_codes}
rows = [r for r in rows if str(r.get("task_code", "")).upper() in wanted]
def _ts(v: Any) -> str:
return "-" if not v else str(v)
self.logger.info("截止时间检查: 门店ID=%s 启用任务数=%s", store_id, len(rows))
for r in rows:
self.logger.info(
"截止时间检查: %-24s 结束时间=%s 开始时间=%s 运行ID=%s",
str(r.get("task_code") or ""),
_ts(r.get("last_end")),
_ts(r.get("last_start")),
_ts(r.get("last_run_id")),
)
cutoff_candidates = [
r.get("last_end")
for r in rows
if r.get("last_end") is not None and not str(r.get("task_code", "")).upper().startswith("INIT_")
]
cutoff = min(cutoff_candidates) if cutoff_candidates else None
self.logger.info("截止时间检查: 总体截止时间(最小结束时间,排除INIT_*)=%s", _ts(cutoff))
ods_fetched = self._probe_ods_fetched_at(store_id)
if ods_fetched:
non_null = [v["max_fetched_at"] for v in ods_fetched.values() if v.get("max_fetched_at") is not None]
ods_cutoff = min(non_null) if non_null else None
self.logger.info("截止时间检查: ODS截止时间(最小抓取时间)=%s", _ts(ods_cutoff))
worst = sorted(
((k, v.get("max_fetched_at")) for k, v in ods_fetched.items()),
key=lambda kv: (kv[1] is None, kv[1]),
)[:8]
for table, mx in worst:
self.logger.info("截止时间检查: ODS表=%s 最大抓取时间=%s", table, _ts(mx))
dw_checks = self._probe_dw_time_columns()
for name, value in dw_checks.items():
self.logger.info("截止时间检查: %s=%s", name, _ts(value))
return {
"status": "SUCCESS",
"counts": {"fetched": len(rows), "inserted": 0, "updated": 0, "skipped": 0, "errors": 0},
"window": None,
"request_params": {"store_id": store_id, "filter_task_codes": filter_codes or []},
"report": {
"rows": rows,
"overall_cutoff": cutoff,
"ods_fetched_at": ods_fetched,
"dw_max_times": dw_checks,
},
}
def _probe_ods_fetched_at(self, store_id: int) -> dict[str, dict[str, Any]]:
try:
from tasks.dwd.dwd_load_task import DwdLoadTask # local import to avoid circulars
except Exception:
return {}
ods_tables = sorted({str(t) for t in DwdLoadTask.TABLE_MAP.values() if str(t).startswith("billiards_ods.")})
results: dict[str, dict[str, Any]] = {}
for table in ods_tables:
try:
row = self.db.query(f"SELECT MAX(fetched_at) AS mx, COUNT(*) AS cnt FROM {table}")[0]
results[table] = {"max_fetched_at": row.get("mx"), "count": row.get("cnt")}
except Exception as exc: # noqa: BLE001
results[table] = {"max_fetched_at": None, "count": None, "error": str(exc)}
return results
def _probe_dw_time_columns(self) -> dict[str, Any]:
checks: dict[str, Any] = {}
probes = {
"DWD.max_settlement_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_settlement_head",
"DWD.max_payment_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_payment",
"DWD.max_refund_pay_time": "SELECT MAX(pay_time) AS mx FROM billiards_dwd.dwd_refund",
"DWS.max_order_date": "SELECT MAX(order_date) AS mx FROM billiards_dws.dws_order_summary",
"DWS.max_updated_at": "SELECT MAX(updated_at) AS mx FROM billiards_dws.dws_order_summary",
}
for name, sql2 in probes.items():
try:
row = self.db.query(sql2)[0]
checks[name] = row.get("mx")
except Exception as exc: # noqa: BLE001
checks[name] = f"ERROR: {exc}"
return checks

View File

@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
"""Data integrity task that checks API -> ODS -> DWD completeness."""
from __future__ import annotations
from datetime import datetime
from zoneinfo import ZoneInfo
from dateutil import parser as dtparser
from utils.windowing import build_window_segments, calc_window_minutes
from tasks.base_task import BaseTask
from quality.integrity_service import run_history_flow, run_window_flow, write_report
class DataIntegrityTask(BaseTask):
"""Check data completeness across API -> ODS -> DWD."""
def get_task_code(self) -> str:
return "DATA_INTEGRITY_CHECK"
def execute(self, cursor_data: dict | None = None) -> dict:
tz = ZoneInfo(self.config.get("app.timezone", "Asia/Taipei"))
mode = str(self.config.get("integrity.mode", "history") or "history").lower()
include_dimensions = bool(self.config.get("integrity.include_dimensions", False))
task_codes = str(self.config.get("integrity.ods_task_codes", "") or "").strip()
auto_backfill = bool(self.config.get("integrity.auto_backfill", False))
compare_content = self.config.get("integrity.compare_content")
if compare_content is None:
compare_content = True
content_sample_limit = self.config.get("integrity.content_sample_limit")
backfill_mismatch = self.config.get("integrity.backfill_mismatch")
if backfill_mismatch is None:
backfill_mismatch = True
recheck_after_backfill = self.config.get("integrity.recheck_after_backfill")
if recheck_after_backfill is None:
recheck_after_backfill = True
# 当提供 CLI 覆盖参数时,切换到窗口模式。
window_override_start = self.config.get("run.window_override.start")
window_override_end = self.config.get("run.window_override.end")
if window_override_start or window_override_end:
self.logger.info(
"Detected CLI window override. Switching to window mode: %s ~ %s",
window_override_start,
window_override_end,
)
mode = "window"
if mode == "window":
base_start, base_end, _ = self._get_time_window(cursor_data)
segments = build_window_segments(
self.config,
base_start,
base_end,
tz=tz,
override_only=True,
)
if not segments:
segments = [(base_start, base_end)]
total_segments = len(segments)
if total_segments > 1:
self.logger.info("Data integrity check split into %s segments.", total_segments)
report, counts = run_window_flow(
cfg=self.config,
windows=segments,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
do_backfill=bool(auto_backfill),
include_mismatch=bool(backfill_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(self.config.get("api.page_size") or 200),
chunk_size=500,
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
report_path = write_report(report, prefix="data_integrity_window", tz=tz)
report["report_path"] = report_path
return {
"status": "SUCCESS",
"counts": counts,
"window": {
"start": overall_start,
"end": overall_end,
"minutes": calc_window_minutes(overall_start, overall_end),
},
"report_path": report_path,
"backfill_result": report.get("backfill_result"),
}
history_start = str(self.config.get("integrity.history_start", "2025-07-01") or "2025-07-01")
history_end = str(self.config.get("integrity.history_end", "") or "").strip()
start_dt = dtparser.parse(history_start)
if start_dt.tzinfo is None:
start_dt = start_dt.replace(tzinfo=tz)
else:
start_dt = start_dt.astimezone(tz)
end_dt = None
if history_end:
end_dt = dtparser.parse(history_end)
if end_dt.tzinfo is None:
end_dt = end_dt.replace(tzinfo=tz)
else:
end_dt = end_dt.astimezone(tz)
report, counts = run_history_flow(
cfg=self.config,
start_dt=start_dt,
end_dt=end_dt,
include_dimensions=include_dimensions,
task_codes=task_codes,
logger=self.logger,
compare_content=bool(compare_content),
content_sample_limit=content_sample_limit,
do_backfill=bool(auto_backfill),
include_mismatch=bool(backfill_mismatch),
recheck_after_backfill=bool(recheck_after_backfill),
page_size=int(self.config.get("api.page_size") or 200),
chunk_size=500,
)
report_path = write_report(report, prefix="data_integrity_history", tz=tz)
report["report_path"] = report_path
end_dt_used = end_dt
if end_dt_used is None:
end_str = report.get("end")
if end_str:
parsed = dtparser.parse(end_str)
if parsed.tzinfo is None:
end_dt_used = parsed.replace(tzinfo=tz)
else:
end_dt_used = parsed.astimezone(tz)
if end_dt_used is None:
end_dt_used = start_dt
return {
"status": "SUCCESS",
"counts": counts,
"window": {
"start": start_dt,
"end": end_dt_used,
"minutes": int((end_dt_used - start_dt).total_seconds() // 60) if end_dt_used > start_dt else 0,
},
"report_path": report_path,
"backfill_result": report.get("backfill_result"),
}

View File

@@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
"""Build DWS order summary table from DWD fact tables."""
from __future__ import annotations
from datetime import date
from typing import Any
from tasks.base_task import BaseTask, TaskContext
from utils.windowing import build_window_segments, calc_window_minutes
# 原先从 scripts.rebuild.build_dws_order_summary 导入脚本已归档SQL 内联于此
SQL_BUILD_SUMMARY = r"""
WITH base AS (
SELECT
sh.site_id,
sh.order_settle_id,
sh.order_trade_no,
COALESCE(sh.pay_time, sh.create_time)::date AS order_date,
sh.tenant_id,
sh.member_id,
COALESCE(sh.is_bind_member, FALSE) AS member_flag,
(COALESCE(sh.consume_money, 0) = 0 AND COALESCE(sh.pay_amount, 0) > 0) AS recharge_order_flag,
COALESCE(sh.member_discount_amount, 0) AS member_discount_amount,
COALESCE(sh.adjust_amount, 0) AS manual_discount_amount,
COALESCE(sh.pay_amount, 0) AS total_paid_amount,
COALESCE(sh.balance_amount, 0) + COALESCE(sh.recharge_card_amount, 0) + COALESCE(sh.gift_card_amount, 0) AS stored_card_deduct,
COALESCE(sh.coupon_amount, 0) AS total_coupon_deduction,
COALESCE(sh.table_charge_money, 0) AS settle_table_fee_amount,
COALESCE(sh.assistant_pd_money, 0) + COALESCE(sh.assistant_cx_money, 0) AS settle_assistant_service_amount,
COALESCE(sh.real_goods_money, 0) AS settle_goods_amount
FROM billiards_dwd.dwd_settlement_head sh
WHERE (%(site_id)s IS NULL OR sh.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR COALESCE(sh.pay_time, sh.create_time)::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR COALESCE(sh.pay_time, sh.create_time)::date <= %(end_date)s)
),
table_fee AS (
SELECT
site_id,
order_settle_id,
SUM(COALESCE(real_table_charge_money, 0)) AS table_fee_amount
FROM billiards_dwd.dwd_table_fee_log
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_use_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_use_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
assistant_fee AS (
SELECT
site_id,
order_settle_id,
SUM(COALESCE(ledger_amount, 0)) AS assistant_service_amount
FROM billiards_dwd.dwd_assistant_service_log
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR start_use_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR start_use_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
goods_fee AS (
SELECT
site_id,
order_settle_id,
COUNT(*) AS item_count,
SUM(COALESCE(ledger_count, 0)) AS total_item_quantity,
SUM(COALESCE(real_goods_money, 0)) AS goods_amount
FROM billiards_dwd.dwd_store_goods_sale
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR create_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR create_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
group_fee AS (
SELECT
site_id,
order_settle_id,
SUM(COALESCE(ledger_amount, 0)) AS group_amount
FROM billiards_dwd.dwd_groupbuy_redemption
WHERE COALESCE(is_delete, 0) = 0
AND (%(site_id)s IS NULL OR site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR create_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR create_time::date <= %(end_date)s)
GROUP BY site_id, order_settle_id
),
refunds AS (
SELECT
r.site_id,
r.relate_id AS order_settle_id,
SUM(COALESCE(rx.refund_amount, 0)) AS refund_amount
FROM billiards_dwd.dwd_refund r
LEFT JOIN billiards_dwd.dwd_refund_ex rx ON r.refund_id = rx.refund_id
WHERE (%(site_id)s IS NULL OR r.site_id = %(site_id)s)
AND (%(start_date)s IS NULL OR r.pay_time::date >= %(start_date)s)
AND (%(end_date)s IS NULL OR r.pay_time::date <= %(end_date)s)
GROUP BY r.site_id, r.relate_id
)
INSERT INTO billiards_dws.dws_order_summary (
site_id, order_settle_id, order_trade_no, order_date, tenant_id,
member_id, member_flag, recharge_order_flag,
item_count, total_item_quantity,
table_fee_amount, assistant_service_amount, goods_amount, group_amount,
total_coupon_deduction, member_discount_amount, manual_discount_amount,
order_original_amount, order_final_amount,
stored_card_deduct, external_paid_amount, total_paid_amount,
book_table_flow, book_assistant_flow, book_goods_flow, book_group_flow, book_order_flow,
order_effective_consume_cash, order_effective_recharge_cash, order_effective_flow,
refund_amount, net_income, created_at, updated_at
)
SELECT
b.site_id, b.order_settle_id, b.order_trade_no::text, b.order_date, b.tenant_id,
b.member_id, b.member_flag, b.recharge_order_flag,
COALESCE(gf.item_count, 0),
COALESCE(gf.total_item_quantity, 0),
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount),
COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount),
COALESCE(gf.goods_amount, b.settle_goods_amount),
COALESCE(gr.group_amount, 0),
b.total_coupon_deduction, b.member_discount_amount, b.manual_discount_amount,
(b.total_paid_amount + b.total_coupon_deduction + b.member_discount_amount + b.manual_discount_amount),
b.total_paid_amount,
b.stored_card_deduct,
GREATEST(b.total_paid_amount - b.stored_card_deduct, 0),
b.total_paid_amount,
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount),
COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount),
COALESCE(gf.goods_amount, b.settle_goods_amount),
COALESCE(gr.group_amount, 0),
COALESCE(tf.table_fee_amount, b.settle_table_fee_amount)
+ COALESCE(af.assistant_service_amount, b.settle_assistant_service_amount)
+ COALESCE(gf.goods_amount, b.settle_goods_amount)
+ COALESCE(gr.group_amount, 0),
GREATEST(b.total_paid_amount - b.stored_card_deduct, 0),
0,
b.total_paid_amount,
COALESCE(rf.refund_amount, 0),
b.total_paid_amount - COALESCE(rf.refund_amount, 0),
now(), now()
FROM base b
LEFT JOIN table_fee tf ON b.site_id = tf.site_id AND b.order_settle_id = tf.order_settle_id
LEFT JOIN assistant_fee af ON b.site_id = af.site_id AND b.order_settle_id = af.order_settle_id
LEFT JOIN goods_fee gf ON b.site_id = gf.site_id AND b.order_settle_id = gf.order_settle_id
LEFT JOIN group_fee gr ON b.site_id = gr.site_id AND b.order_settle_id = gr.order_settle_id
LEFT JOIN refunds rf ON b.site_id = rf.site_id AND b.order_settle_id = rf.order_settle_id
ON CONFLICT (site_id, order_settle_id) DO UPDATE SET
order_trade_no = EXCLUDED.order_trade_no,
order_date = EXCLUDED.order_date,
tenant_id = EXCLUDED.tenant_id,
member_id = EXCLUDED.member_id,
member_flag = EXCLUDED.member_flag,
recharge_order_flag = EXCLUDED.recharge_order_flag,
item_count = EXCLUDED.item_count,
total_item_quantity = EXCLUDED.total_item_quantity,
table_fee_amount = EXCLUDED.table_fee_amount,
assistant_service_amount = EXCLUDED.assistant_service_amount,
goods_amount = EXCLUDED.goods_amount,
group_amount = EXCLUDED.group_amount,
total_coupon_deduction = EXCLUDED.total_coupon_deduction,
member_discount_amount = EXCLUDED.member_discount_amount,
manual_discount_amount = EXCLUDED.manual_discount_amount,
order_original_amount = EXCLUDED.order_original_amount,
order_final_amount = EXCLUDED.order_final_amount,
stored_card_deduct = EXCLUDED.stored_card_deduct,
external_paid_amount = EXCLUDED.external_paid_amount,
total_paid_amount = EXCLUDED.total_paid_amount,
book_table_flow = EXCLUDED.book_table_flow,
book_assistant_flow = EXCLUDED.book_assistant_flow,
book_goods_flow = EXCLUDED.book_goods_flow,
book_group_flow = EXCLUDED.book_group_flow,
book_order_flow = EXCLUDED.book_order_flow,
order_effective_consume_cash = EXCLUDED.order_effective_consume_cash,
order_effective_recharge_cash = EXCLUDED.order_effective_recharge_cash,
order_effective_flow = EXCLUDED.order_effective_flow,
refund_amount = EXCLUDED.refund_amount,
net_income = EXCLUDED.net_income,
updated_at = now();
"""
class DwsBuildOrderSummaryTask(BaseTask):
"""Recompute/refresh `billiards_dws.dws_order_summary` for a date window."""
def get_task_code(self) -> str:
return "DWS_BUILD_ORDER_SUMMARY"
def execute(self, cursor_data: dict | None = None) -> dict:
base_context = self._build_context(cursor_data)
task_code = self.get_task_code()
segments = build_window_segments(
self.config,
base_context.window_start,
base_context.window_end,
tz=self.tz,
override_only=True,
)
if not segments:
segments = [(base_context.window_start, base_context.window_end)]
total_segments = len(segments)
if total_segments > 1:
self.logger.info("%s: 分段执行 共%s", task_code, total_segments)
total_counts: dict = {}
segment_results: list[dict] = []
request_params_list: list[dict] = []
total_deleted = 0
for idx, (window_start, window_end) in enumerate(segments, start=1):
context = self._build_context_for_window(window_start, window_end, cursor_data)
self.logger.info(
"%s: 开始执行(%s/%s), 窗口[%s ~ %s]",
task_code,
idx,
total_segments,
context.window_start,
context.window_end,
)
try:
extracted = self.extract(context)
transformed = self.transform(extracted, context)
load_result = self.load(transformed, context) or {}
self.db.commit()
except Exception:
self.db.rollback()
self.logger.error("%s: 执行失败", task_code, exc_info=True)
raise
counts = load_result.get("counts") or {}
self._accumulate_counts(total_counts, counts)
extra = load_result.get("extra") or {}
deleted = int(extra.get("deleted") or 0)
total_deleted += deleted
request_params = load_result.get("request_params")
if request_params:
request_params_list.append(request_params)
if total_segments > 1:
segment_results.append(
{
"window": {
"start": context.window_start,
"end": context.window_end,
"minutes": context.window_minutes,
},
"counts": counts,
"extra": extra,
}
)
overall_start = segments[0][0]
overall_end = segments[-1][1]
result = {"status": "SUCCESS", "counts": total_counts}
result["window"] = {
"start": overall_start,
"end": overall_end,
"minutes": calc_window_minutes(overall_start, overall_end),
}
if segment_results:
result["segments"] = segment_results
if request_params_list:
result["request_params"] = request_params_list[0] if len(request_params_list) == 1 else request_params_list
if total_deleted:
result["extra"] = {"deleted": total_deleted}
self.logger.info("%s: 完成, 统计=%s", task_code, total_counts)
return result
def extract(self, context: TaskContext) -> dict[str, Any]:
store_id = int(self.config.get("app.store_id"))
full_refresh = bool(self.config.get("dws.order_summary.full_refresh", False))
site_id = self.config.get("dws.order_summary.site_id", store_id)
if site_id in ("", None, "null", "NULL"):
site_id = None
start_date = self.config.get("dws.order_summary.start_date")
end_date = self.config.get("dws.order_summary.end_date")
if not full_refresh:
if not start_date:
start_date = context.window_start.date()
if not end_date:
end_date = context.window_end.date()
else:
start_date = None
end_date = None
delete_before_insert = bool(self.config.get("dws.order_summary.delete_before_insert", True))
return {
"site_id": site_id,
"start_date": start_date,
"end_date": end_date,
"full_refresh": full_refresh,
"delete_before_insert": delete_before_insert,
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
sql_params = {
"site_id": extracted["site_id"],
"start_date": extracted["start_date"],
"end_date": extracted["end_date"],
}
request_params = {
"site_id": extracted["site_id"],
"start_date": _jsonable_date(extracted["start_date"]),
"end_date": _jsonable_date(extracted["end_date"]),
}
with self.db.conn.cursor() as cur:
cur.execute("SELECT to_regclass('billiards_dws.dws_order_summary') AS reg;")
row = cur.fetchone()
reg = row[0] if row else None
if not reg:
raise RuntimeError("DWS 表不存在:请先运行任务 INIT_DWS_SCHEMA")
deleted = 0
if extracted["delete_before_insert"]:
if extracted["full_refresh"] and extracted["site_id"] is None:
cur.execute("TRUNCATE TABLE billiards_dws.dws_order_summary;")
self.logger.info("DWS订单汇总: 已清空 billiards_dws.dws_order_summary")
else:
delete_sql = "DELETE FROM billiards_dws.dws_order_summary WHERE 1=1"
delete_args: list[Any] = []
if extracted["site_id"] is not None:
delete_sql += " AND site_id = %s"
delete_args.append(extracted["site_id"])
if extracted["start_date"] is not None:
delete_sql += " AND order_date >= %s"
delete_args.append(_as_date(extracted["start_date"]))
if extracted["end_date"] is not None:
delete_sql += " AND order_date <= %s"
delete_args.append(_as_date(extracted["end_date"]))
cur.execute(delete_sql, delete_args)
deleted = cur.rowcount
self.logger.info("DWS订单汇总: 删除=%s 语句=%s", deleted, delete_sql)
cur.execute(SQL_BUILD_SUMMARY, sql_params)
affected = cur.rowcount
return {
"counts": {"fetched": 0, "inserted": affected, "updated": 0, "skipped": 0, "errors": 0},
"request_params": request_params,
"extra": {"deleted": deleted},
}
def _as_date(v: Any) -> date:
if isinstance(v, date):
return v
return date.fromisoformat(str(v))
def _jsonable_date(v: Any):
if v is None:
return None
if isinstance(v, date):
return v.isoformat()
return str(v)

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
"""初始化 DWD Schema执行 schema_dwd_doc.sql可选先 DROP SCHEMA。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class InitDwdSchemaTask(BaseTask):
"""通过调度执行 DWD schema 初始化。"""
def get_task_code(self) -> str:
"""返回任务编码。"""
return "INIT_DWD_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""读取 DWD SQL 文件与参数。"""
base_dir = Path(__file__).resolve().parents[1] / "database"
dwd_path = Path(self.config.get("schema.dwd_file", base_dir / "schema_dwd_doc.sql"))
if not dwd_path.exists():
raise FileNotFoundError(f"未找到 DWD schema 文件: {dwd_path}")
drop_first = self.config.get("dwd.drop_schema_first", False)
return {"dwd_sql": dwd_path.read_text(encoding="utf-8"), "dwd_file": str(dwd_path), "drop_first": drop_first}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
"""可选 DROP schema再执行 DWD DDL。"""
with self.db.conn.cursor() as cur:
if extracted["drop_first"]:
cur.execute("DROP SCHEMA IF EXISTS billiards_dwd CASCADE;")
self.logger.info("已执行 DROP SCHEMA billiards_dwd CASCADE")
self.logger.info("执行 DWD schema 文件: %s", extracted["dwd_file"])
cur.execute(extracted["dwd_sql"])
return {"executed": 1, "files": [extracted["dwd_file"]]}

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""Initialize DWS schema (billiards_dws)."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class InitDwsSchemaTask(BaseTask):
"""Apply DWS schema SQL."""
def get_task_code(self) -> str:
return "INIT_DWS_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
base_dir = Path(__file__).resolve().parents[1] / "database"
dws_path = Path(self.config.get("schema.dws_file", base_dir / "schema_dws.sql"))
if not dws_path.exists():
raise FileNotFoundError(f"未找到 DWS schema 文件: {dws_path}")
drop_first = bool(self.config.get("dws.drop_schema_first", False))
return {"dws_sql": dws_path.read_text(encoding="utf-8"), "dws_file": str(dws_path), "drop_first": drop_first}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
with self.db.conn.cursor() as cur:
if extracted["drop_first"]:
cur.execute("DROP SCHEMA IF EXISTS billiards_dws CASCADE;")
self.logger.info("已执行 DROP SCHEMA billiards_dws CASCADE")
self.logger.info("执行 DWS schema 文件: %s", extracted["dws_file"])
cur.execute(extracted["dws_sql"])
return {"executed": 1, "files": [extracted["dws_file"]]}

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""任务:初始化运行环境,执行 ODS 与 etl_admin 的 DDL并准备日志/导出目录。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class InitOdsSchemaTask(BaseTask):
"""通过调度执行初始化:创建必要目录,执行 ODS 与 etl_admin 的 DDL。"""
def get_task_code(self) -> str:
"""返回任务编码。"""
return "INIT_ODS_SCHEMA"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""读取 SQL 文件路径,收集需创建的目录。"""
base_dir = Path(__file__).resolve().parents[1] / "database"
ods_path = Path(self.config.get("schema.ods_file", base_dir / "schema_ODS_doc.sql"))
admin_path = Path(self.config.get("schema.etl_admin_file", base_dir / "schema_etl_admin.sql"))
if not ods_path.exists():
raise FileNotFoundError(f"找不到 ODS schema 文件: {ods_path}")
if not admin_path.exists():
raise FileNotFoundError(f"找不到 etl_admin schema 文件: {admin_path}")
log_root = Path(self.config.get("io.log_root") or self.config["io"]["log_root"])
export_root = Path(self.config.get("io.export_root") or self.config["io"]["export_root"])
fetch_root = Path(self.config.get("pipeline.fetch_root") or self.config["pipeline"]["fetch_root"])
ingest_dir = Path(self.config.get("pipeline.ingest_source_dir") or fetch_root)
return {
"ods_sql": ods_path.read_text(encoding="utf-8"),
"admin_sql": admin_path.read_text(encoding="utf-8"),
"ods_file": str(ods_path),
"admin_file": str(admin_path),
"dirs": [log_root, export_root, fetch_root, ingest_dir],
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
"""执行 DDL 并创建必要目录。
安全提示:
ODS DDL 文件可能携带头部说明或异常注释,为避免因非 SQL 文本导致执行失败,这里会做一次轻量清洗后再执行。
"""
for d in extracted["dirs"]:
Path(d).mkdir(parents=True, exist_ok=True)
self.logger.info("已确保目录存在: %s", d)
# 处理 ODS SQL去掉头部说明行以及易出错的 COMMENT ON 行(如 CamelCase 未加引号)
ods_sql_raw: str = extracted["ods_sql"]
drop_idx = ods_sql_raw.find("DROP SCHEMA")
if drop_idx > 0:
ods_sql_raw = ods_sql_raw[drop_idx:]
cleaned_lines: list[str] = []
for line in ods_sql_raw.splitlines():
if line.strip().upper().startswith("COMMENT ON "):
continue
cleaned_lines.append(line)
ods_sql = "\n".join(cleaned_lines)
with self.db.conn.cursor() as cur:
self.logger.info("执行 etl_admin schema 文件: %s", extracted["admin_file"])
cur.execute(extracted["admin_sql"])
self.logger.info("执行 ODS schema 文件: %s", extracted["ods_file"])
cur.execute(ods_sql)
return {
"executed": 2,
"files": [extracted["admin_file"], extracted["ods_file"]],
"dirs_prepared": [str(p) for p in extracted["dirs"]],
}

View File

@@ -0,0 +1,463 @@
# -*- coding: utf-8 -*-
"""手工示例数据灌入:按 schema_ODS_doc.sql 的表结构写入 ODS。"""
from __future__ import annotations
import hashlib
import json
import os
from datetime import datetime
from typing import Any, Iterable
from psycopg2.extras import Json, execute_values
from tasks.base_task import BaseTask
class ManualIngestTask(BaseTask):
"""本地示例 JSON 灌入 ODS确保表名/主键/插入列与 schema_ODS_doc.sql 对齐。"""
FILE_MAPPING: list[tuple[tuple[str, ...], str]] = [
(("member_profiles",), "billiards_ods.member_profiles"),
(("member_balance_changes",), "billiards_ods.member_balance_changes"),
(("member_stored_value_cards",), "billiards_ods.member_stored_value_cards"),
(("recharge_settlements",), "billiards_ods.recharge_settlements"),
(("settlement_records",), "billiards_ods.settlement_records"),
(("assistant_cancellation_records",), "billiards_ods.assistant_cancellation_records"),
(("assistant_accounts_master",), "billiards_ods.assistant_accounts_master"),
(("assistant_service_records",), "billiards_ods.assistant_service_records"),
(("site_tables_master",), "billiards_ods.site_tables_master"),
(("table_fee_discount_records",), "billiards_ods.table_fee_discount_records"),
(("table_fee_transactions",), "billiards_ods.table_fee_transactions"),
(("goods_stock_movements",), "billiards_ods.goods_stock_movements"),
(("stock_goods_category_tree",), "billiards_ods.stock_goods_category_tree"),
(("goods_stock_summary",), "billiards_ods.goods_stock_summary"),
(("payment_transactions",), "billiards_ods.payment_transactions"),
(("refund_transactions",), "billiards_ods.refund_transactions"),
(("platform_coupon_redemption_records",), "billiards_ods.platform_coupon_redemption_records"),
(("group_buy_redemption_records",), "billiards_ods.group_buy_redemption_records"),
(("group_buy_packages",), "billiards_ods.group_buy_packages"),
(("settlement_ticket_details",), "billiards_ods.settlement_ticket_details"),
(("store_goods_master",), "billiards_ods.store_goods_master"),
(("tenant_goods_master",), "billiards_ods.tenant_goods_master"),
(("store_goods_sales_records",), "billiards_ods.store_goods_sales_records"),
]
TABLE_SPECS: dict[str, dict[str, Any]] = {
"billiards_ods.member_profiles": {"pk": "id"},
"billiards_ods.member_balance_changes": {"pk": "id"},
"billiards_ods.member_stored_value_cards": {"pk": "id"},
"billiards_ods.recharge_settlements": {"pk": "id"},
"billiards_ods.settlement_records": {"pk": "id"},
"billiards_ods.assistant_cancellation_records": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.assistant_accounts_master": {"pk": "id"},
"billiards_ods.assistant_service_records": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.site_tables_master": {"pk": "id"},
"billiards_ods.table_fee_discount_records": {"pk": "id", "json_cols": ["siteProfile", "tableProfile"]},
"billiards_ods.table_fee_transactions": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.goods_stock_movements": {"pk": "siteGoodsStockId"},
"billiards_ods.stock_goods_category_tree": {"pk": "id", "json_cols": ["categoryBoxes"]},
"billiards_ods.goods_stock_summary": {"pk": "siteGoodsId"},
"billiards_ods.payment_transactions": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.refund_transactions": {"pk": "id", "json_cols": ["siteProfile"]},
"billiards_ods.platform_coupon_redemption_records": {"pk": "id"},
"billiards_ods.tenant_goods_master": {"pk": "id"},
"billiards_ods.group_buy_packages": {"pk": "id"},
"billiards_ods.group_buy_redemption_records": {"pk": "id"},
"billiards_ods.settlement_ticket_details": {
"pk": "orderSettleId",
"json_cols": ["memberProfile", "orderItem", "tenantMemberCardLogs"],
},
"billiards_ods.store_goods_master": {"pk": "id"},
"billiards_ods.store_goods_sales_records": {"pk": "id"},
}
def get_task_code(self) -> str:
"""返回任务编码。"""
return "MANUAL_INGEST"
def execute(self, cursor_data: dict | None = None) -> dict:
"""从目录读取 JSON按表定义批量入库按文件提交事务避免长事务导致连接不稳定"""
data_dir = (
self.config.get("manual.data_dir")
or self.config.get("pipeline.ingest_source_dir")
or os.path.join("tests", "testdata_json")
)
if not os.path.exists(data_dir):
self.logger.error("Data directory not found: %s", data_dir)
return {"status": "error", "message": "Directory not found"}
counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0}
include_files_cfg = self.config.get("manual.include_files") or []
include_files = {str(x).strip().lower() for x in include_files_cfg if str(x).strip()} if include_files_cfg else set()
for filename in sorted(os.listdir(data_dir)):
if not filename.endswith(".json"):
continue
stem = os.path.splitext(filename)[0].lower()
if include_files and stem not in include_files:
continue
filepath = os.path.join(data_dir, filename)
try:
with open(filepath, "r", encoding="utf-8") as fh:
raw_entries = json.load(fh)
except Exception:
counts["errors"] += 1
self.logger.exception("Failed to read %s", filename)
continue
entries = raw_entries if isinstance(raw_entries, list) else [raw_entries]
records = self._extract_records(entries)
if not records:
counts["skipped"] += 1
continue
target_table = self._match_by_filename(filename)
if not target_table:
self.logger.warning("No mapping found for file: %s", filename)
counts["skipped"] += 1
continue
self.logger.info("Ingesting %s into %s", filename, target_table)
try:
inserted, updated, row_errors = self._ingest_table(target_table, records, filename)
counts["inserted"] += inserted
counts["updated"] += updated
counts["fetched"] += len(records)
counts["errors"] += row_errors
# 每个文件一次提交:降低单次事务体积,避免长事务/连接异常导致整体回滚失败。
self.db.commit()
except Exception:
counts["errors"] += 1
self.logger.exception("Error processing %s", filename)
try:
self.db.rollback()
except Exception:
pass
# 若连接已断开,后续文件无法继续,直接抛出让上层处理(重连/重跑)。
if getattr(self.db.conn, "closed", 0):
raise
continue
return {"status": "SUCCESS", "counts": counts}
def _match_by_filename(self, filename: str) -> str | None:
"""根据文件名关键字匹配目标表。"""
for keywords, table in self.FILE_MAPPING:
if any(keyword and keyword in filename for keyword in keywords):
return table
return None
def _extract_records(self, raw_entries: Iterable[Any]) -> list[dict]:
"""兼容多层 data/list 包装,抽取记录列表。"""
records: list[dict] = []
for entry in raw_entries:
if isinstance(entry, dict):
preferred = entry
if "data" in entry and not any(k not in {"data", "code"} for k in entry.keys()):
preferred = entry["data"]
data = preferred
if isinstance(data, dict):
# 特殊处理 settleList充值、结算记录展开 data.settleList 下的 settleList抛弃上层 siteProfile
if "settleList" in data:
settle_list_val = data.get("settleList")
if isinstance(settle_list_val, dict):
settle_list_iter = [settle_list_val]
elif isinstance(settle_list_val, list):
settle_list_iter = settle_list_val
else:
settle_list_iter = []
handled = False
for item in settle_list_iter or []:
if not isinstance(item, dict):
continue
inner = item.get("settleList")
merged = dict(inner) if isinstance(inner, dict) else dict(item)
# 保留 siteProfile 供后续字段补充,但不落库
site_profile = data.get("siteProfile")
if isinstance(site_profile, dict):
merged.setdefault("siteProfile", site_profile)
records.append(merged)
handled = True
if handled:
continue
list_used = False
for v in data.values():
if isinstance(v, list) and v and isinstance(v[0], dict):
records.extend(v)
list_used = True
break
if list_used:
continue
if isinstance(data, list) and data and isinstance(data[0], dict):
records.extend(data)
elif isinstance(data, dict):
records.append(data)
elif isinstance(entry, list):
records.extend([item for item in entry if isinstance(item, dict)])
return records
def _get_table_columns(self, table: str) -> list[tuple[str, str, str]]:
"""查询 information_schema获取目标表列信息。"""
cache = getattr(self, "_table_columns_cache", {})
if table in cache:
return cache[table]
if "." in table:
schema, name = table.split(".", 1)
else:
schema, name = "public", table
sql = """
SELECT column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position
"""
with self.db.conn.cursor() as cur:
cur.execute(sql, (schema, name))
cols = [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()]
cache[table] = cols
self._table_columns_cache = cache
return cols
def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int, int]:
"""
构建 INSERT/ON CONFLICT 语句并批量执行(优先向量化,小批次提交)。
设计目标:
- 控制单条 SQL 体积(避免一次性 VALUES 过大导致服务端 backend 被 OOM/异常终止);
- 发生异常时,可降级逐行并用 SAVEPOINT 跳过异常行;
- 统计口径偏“尽量可跑通”,插入/更新计数为近似值(不强依赖 RETURNING
"""
spec = self.TABLE_SPECS.get(table)
if not spec:
raise ValueError(f"No table spec for {table}")
pk_col = spec.get("pk")
json_cols = set(spec.get("json_cols", []))
json_cols_lower = {c.lower() for c in json_cols}
columns_info = self._get_table_columns(table)
columns = [c[0] for c in columns_info]
db_json_cols_lower = {
c[0].lower() for c in columns_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb")
}
pk_col_db = None
if pk_col:
pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col)
pk_index = None
if pk_col_db:
try:
pk_index = next(i for i, c in enumerate(columns_info) if c[0] == pk_col_db)
except Exception:
pk_index = None
has_content_hash = any(c[0].lower() == "content_hash" for c in columns_info)
col_list = ", ".join(f'"{c}"' for c in columns)
sql_prefix = f"INSERT INTO {table} ({col_list}) VALUES %s"
if pk_col_db:
if has_content_hash:
sql_prefix += f' ON CONFLICT ("{pk_col_db}", "content_hash") DO NOTHING'
else:
update_cols = [c for c in columns if c != pk_col_db]
set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols)
sql_prefix += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}'
params = []
now = datetime.now()
json_dump = lambda v: json.dumps(v, ensure_ascii=False) # noqa: E731
for rec in records:
merged_rec = rec if isinstance(rec, dict) else {}
data_part = merged_rec.get("data")
while isinstance(data_part, dict):
merged_rec = {**data_part, **merged_rec}
data_part = data_part.get("data")
# 针对充值/结算,补齐 siteProfile 中的店铺信息
if table in {
"billiards_ods.recharge_settlements",
"billiards_ods.settlement_records",
}:
site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile")
if isinstance(site_profile, dict):
merged_rec.setdefault("tenantid", site_profile.get("tenant_id") or site_profile.get("tenantId"))
merged_rec.setdefault("siteid", site_profile.get("id") or site_profile.get("siteId"))
merged_rec.setdefault("sitename", site_profile.get("shop_name") or site_profile.get("siteName"))
pk_val = self._get_value_case_insensitive(merged_rec, pk_col) if pk_col else None
if pk_col and (pk_val is None or pk_val == ""):
continue
content_hash = None
if has_content_hash:
# Keep hash semantics aligned with ODS task ingestion:
# fetched_at is ETL metadata and should not create a new content version.
content_hash = self._compute_content_hash(merged_rec, include_fetched_at=False)
row_vals = []
for col_name, data_type, udt in columns_info:
col_lower = col_name.lower()
if col_lower == "payload":
row_vals.append(Json(rec, dumps=json_dump))
continue
if col_lower == "source_file":
row_vals.append(source_file)
continue
if col_lower == "fetched_at":
row_vals.append(merged_rec.get(col_name, now))
continue
if col_lower == "content_hash":
row_vals.append(content_hash)
continue
value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name))
if col_lower in json_cols_lower or col_lower in db_json_cols_lower:
row_vals.append(Json(value, dumps=json_dump) if value is not None else None)
continue
casted = self._cast_value(value, data_type)
row_vals.append(casted)
params.append(tuple(row_vals))
if not params:
return 0, 0, 0
# 先尝试向量化执行(速度快);若失败,再降级逐行并用 SAVEPOINT 跳过异常行。
try:
with self.db.conn.cursor() as cur:
# 分批提交:降低单次事务/单次 SQL 压力,避免服务端异常中断连接。
affected = 0
chunk_size = int(self.config.get("manual.execute_values_page_size", 50) or 50)
chunk_size = max(1, min(chunk_size, 500))
for i in range(0, len(params), chunk_size):
chunk = params[i : i + chunk_size]
execute_values(cur, sql_prefix, chunk, page_size=len(chunk))
if cur.rowcount is not None and cur.rowcount > 0:
affected += int(cur.rowcount)
# 这里无法精确拆分 inserted/updated除非 RETURNING按“受影响行数≈插入”近似返回。
return int(affected), 0, 0
except Exception as exc:
self.logger.warning("批量入库失败准备降级逐行处理table=%s, err=%s", table, exc)
try:
self.db.rollback()
except Exception:
pass
inserted = 0
updated = 0
errors = 0
with self.db.conn.cursor() as cur:
for row in params:
cur.execute("SAVEPOINT sp_manual_ingest_row")
try:
cur.execute(sql_prefix.replace(" VALUES %s", f" VALUES ({', '.join(['%s'] * len(row))})"), row)
inserted += 1
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception as exc: # noqa: BLE001
errors += 1
try:
cur.execute("ROLLBACK TO SAVEPOINT sp_manual_ingest_row")
cur.execute("RELEASE SAVEPOINT sp_manual_ingest_row")
except Exception:
pass
pk_val = None
if pk_index is not None:
try:
pk_val = row[pk_index]
except Exception:
pk_val = None
self.logger.warning("跳过异常行table=%s pk=%s err=%s", table, pk_val, exc)
return inserted, updated, errors
@staticmethod
def _get_value_case_insensitive(record: dict, col: str | None):
"""忽略大小写获取值,兼容 information_schema 与 JSON 原始字段。"""
if record is None or col is None:
return None
if col in record:
return record.get(col)
col_lower = col.lower()
for k, v in record.items():
if isinstance(k, str) and k.lower() == col_lower:
return v
return None
@staticmethod
def _normalize_scalar(value):
"""将空字符串/空 JSON 规范为 None避免类型转换错误。"""
if value == "" or value == "{}" or value == "[]":
return None
return value
@staticmethod
def _cast_value(value, data_type: str):
"""根据列类型做简单转换,保证批量插入兼容。"""
if value is None:
return None
dt = (data_type or "").lower()
if dt in ("integer", "bigint", "smallint"):
if isinstance(value, bool):
return int(value)
try:
return int(value)
except Exception:
return None
if dt in ("numeric", "double precision", "real", "decimal"):
if isinstance(value, bool):
return int(value)
try:
return float(value)
except Exception:
return None
if dt.startswith("timestamp") or dt in ("date", "time", "interval"):
return value if isinstance(value, str) else None
return value
@staticmethod
def _hash_default(value):
if isinstance(value, datetime):
return value.isoformat()
return str(value)
@classmethod
def _sanitize_record_for_hash(cls, record: dict, *, include_fetched_at: bool) -> dict:
exclude = {
"data",
"payload",
"source_file",
"source_endpoint",
"content_hash",
"record_index",
}
if not include_fetched_at:
exclude.add("fetched_at")
def _strip(value):
if isinstance(value, dict):
cleaned = {}
for k, v in value.items():
if isinstance(k, str) and k.lower() in exclude:
continue
cleaned[k] = _strip(v)
return cleaned
if isinstance(value, list):
return [_strip(v) for v in value]
return value
return _strip(record or {})
@classmethod
def _compute_content_hash(cls, record: dict, *, include_fetched_at: bool) -> str:
cleaned = cls._sanitize_record_for_hash(record, include_fetched_at=include_fetched_at)
payload = json.dumps(
cleaned,
ensure_ascii=False,
sort_keys=True,
separators=(",", ":"),
default=cls._hash_default,
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""
DWS配置数据初始化任务
功能说明:
执行 seed_dws_config.sql向配置表插入初始数据
执行前提:
- billiards_dws schema 已创建INIT_DWS_SCHEMA
- 配置表已存在
作者ETL团队
创建日期2026-02-01
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from tasks.base_task import BaseTask, TaskContext
class SeedDwsConfigTask(BaseTask):
"""
DWS配置数据初始化任务
执行 seed_dws_config.sql 文件,向以下配置表插入初始数据:
- cfg_performance_tier: 绩效档位配置
- cfg_assistant_level_price: 助教等级定价
- cfg_bonus_rules: 奖金规则配置
- cfg_area_category: 台区分类映射
- cfg_skill_type: 技能课程类型映射
"""
def get_task_code(self) -> str:
return "SEED_DWS_CONFIG"
def extract(self, context: TaskContext) -> dict[str, Any]:
"""
读取配置数据SQL文件
"""
base_dir = Path(__file__).resolve().parents[1] / "database"
seed_path = Path(self.config.get("schema.seed_dws_file", base_dir / "seed_dws_config.sql"))
if not seed_path.exists():
raise FileNotFoundError(f"未找到 DWS 配置数据文件: {seed_path}")
return {
"seed_sql": seed_path.read_text(encoding="utf-8"),
"seed_file": str(seed_path)
}
def load(self, extracted: dict[str, Any], context: TaskContext) -> dict:
"""
执行配置数据SQL
"""
with self.db.conn.cursor() as cur:
self.logger.info("执行 DWS 配置数据文件: %s", extracted["seed_file"])
cur.execute(extracted["seed_sql"])
self.logger.info("DWS 配置数据初始化完成")
return {"executed": 1, "files": [extracted["seed_file"]]}