# -*- coding: utf-8 -*- """手工示例数据灌入:按 schema_ODS_doc.sql 的表结构写入 ODS。""" from __future__ import annotations import json import os from datetime import datetime from typing import Any, Iterable from psycopg2.extras import Json from .base_task import BaseTask class ManualIngestTask(BaseTask): """本地示例 JSON 灌入 ODS,确保表名/主键/插入列与 schema_ODS_doc.sql 对齐。""" FILE_MAPPING: list[tuple[tuple[str, ...], str]] = [ (("member_profiles",), "billiards_ods.member_profiles"), (("member_balance_changes",), "billiards_ods.member_balance_changes"), (("member_stored_value_cards",), "billiards_ods.member_stored_value_cards"), (("recharge_settlements",), "billiards_ods.recharge_settlements"), (("settlement_records",), "billiards_ods.settlement_records"), (("assistant_cancellation_records",), "billiards_ods.assistant_cancellation_records"), (("assistant_accounts_master",), "billiards_ods.assistant_accounts_master"), (("assistant_service_records",), "billiards_ods.assistant_service_records"), (("site_tables_master",), "billiards_ods.site_tables_master"), (("table_fee_discount_records",), "billiards_ods.table_fee_discount_records"), (("table_fee_transactions",), "billiards_ods.table_fee_transactions"), (("goods_stock_movements",), "billiards_ods.goods_stock_movements"), (("stock_goods_category_tree",), "billiards_ods.stock_goods_category_tree"), (("goods_stock_summary",), "billiards_ods.goods_stock_summary"), (("payment_transactions",), "billiards_ods.payment_transactions"), (("refund_transactions",), "billiards_ods.refund_transactions"), (("platform_coupon_redemption_records",), "billiards_ods.platform_coupon_redemption_records"), (("group_buy_redemption_records",), "billiards_ods.group_buy_redemption_records"), (("group_buy_packages",), "billiards_ods.group_buy_packages"), (("settlement_ticket_details",), "billiards_ods.settlement_ticket_details"), (("store_goods_master",), "billiards_ods.store_goods_master"), (("tenant_goods_master",), "billiards_ods.tenant_goods_master"), (("store_goods_sales_records",), "billiards_ods.store_goods_sales_records"), ] TABLE_SPECS: dict[str, dict[str, Any]] = { "billiards_ods.member_profiles": {"pk": "id"}, "billiards_ods.member_balance_changes": {"pk": "id"}, "billiards_ods.member_stored_value_cards": {"pk": "id"}, "billiards_ods.recharge_settlements": {"pk": "id"}, "billiards_ods.settlement_records": {"pk": "id"}, "billiards_ods.assistant_cancellation_records": {"pk": "id", "json_cols": ["siteProfile"]}, "billiards_ods.assistant_accounts_master": {"pk": "id"}, "billiards_ods.assistant_service_records": {"pk": "id", "json_cols": ["siteProfile"]}, "billiards_ods.site_tables_master": {"pk": "id"}, "billiards_ods.table_fee_discount_records": {"pk": "id", "json_cols": ["siteProfile", "tableProfile"]}, "billiards_ods.table_fee_transactions": {"pk": "id", "json_cols": ["siteProfile"]}, "billiards_ods.goods_stock_movements": {"pk": "siteGoodsStockId"}, "billiards_ods.stock_goods_category_tree": {"pk": "id", "json_cols": ["categoryBoxes"]}, "billiards_ods.goods_stock_summary": {"pk": "siteGoodsId"}, "billiards_ods.payment_transactions": {"pk": "id", "json_cols": ["siteProfile"]}, "billiards_ods.refund_transactions": {"pk": "id", "json_cols": ["siteProfile"]}, "billiards_ods.platform_coupon_redemption_records": {"pk": "id"}, "billiards_ods.tenant_goods_master": {"pk": "id"}, "billiards_ods.group_buy_packages": {"pk": "id"}, "billiards_ods.group_buy_redemption_records": {"pk": "id"}, "billiards_ods.settlement_ticket_details": { "pk": "orderSettleId", "json_cols": ["memberProfile", "orderItem", "tenantMemberCardLogs"], }, "billiards_ods.store_goods_master": {"pk": "id"}, "billiards_ods.store_goods_sales_records": {"pk": "id"}, } def get_task_code(self) -> str: """返回任务编码。""" return "MANUAL_INGEST" def execute(self, cursor_data: dict | None = None) -> dict: """从目录读取 JSON,按表定义批量入库。""" data_dir = ( self.config.get("manual.data_dir") or self.config.get("pipeline.ingest_source_dir") or r"c:\dev\LLTQ\ETL\feiqiu-ETL\etl_billiards\tests\testdata_json" ) if not os.path.exists(data_dir): self.logger.error("Data directory not found: %s", data_dir) return {"status": "error", "message": "Directory not found"} counts = {"fetched": 0, "inserted": 0, "updated": 0, "skipped": 0, "errors": 0} for filename in sorted(os.listdir(data_dir)): if not filename.endswith(".json"): continue filepath = os.path.join(data_dir, filename) try: with open(filepath, "r", encoding="utf-8") as fh: raw_entries = json.load(fh) except Exception: counts["errors"] += 1 self.logger.exception("Failed to read %s", filename) continue entries = raw_entries if isinstance(raw_entries, list) else [raw_entries] records = self._extract_records(entries) if not records: counts["skipped"] += 1 continue target_table = self._match_by_filename(filename) if not target_table: self.logger.warning("No mapping found for file: %s", filename) counts["skipped"] += 1 continue self.logger.info("Ingesting %s into %s", filename, target_table) try: inserted, updated = self._ingest_table(target_table, records, filename) counts["inserted"] += inserted counts["updated"] += updated counts["fetched"] += len(records) except Exception: counts["errors"] += 1 self.logger.exception("Error processing %s", filename) self.db.rollback() continue try: self.db.commit() except Exception: self.db.rollback() raise return {"status": "SUCCESS", "counts": counts} def _match_by_filename(self, filename: str) -> str | None: """根据文件名关键字匹配目标表。""" for keywords, table in self.FILE_MAPPING: if any(keyword and keyword in filename for keyword in keywords): return table return None def _extract_records(self, raw_entries: Iterable[Any]) -> list[dict]: """兼容多层 data/list 包装,抽取记录列表。""" records: list[dict] = [] for entry in raw_entries: if isinstance(entry, dict): preferred = entry if "data" in entry and not any(k not in {"data", "code"} for k in entry.keys()): preferred = entry["data"] data = preferred if isinstance(data, dict): # 特殊处理 settleList(充值、结算记录):展开 data.settleList 下的 settleList,抛弃上层 siteProfile if "settleList" in data: settle_list_val = data.get("settleList") if isinstance(settle_list_val, dict): settle_list_iter = [settle_list_val] elif isinstance(settle_list_val, list): settle_list_iter = settle_list_val else: settle_list_iter = [] handled = False for item in settle_list_iter or []: if not isinstance(item, dict): continue inner = item.get("settleList") merged = dict(inner) if isinstance(inner, dict) else dict(item) # 保留 siteProfile 供后续字段补充,但不落库 site_profile = data.get("siteProfile") if isinstance(site_profile, dict): merged.setdefault("siteProfile", site_profile) records.append(merged) handled = True if handled: continue list_used = False for v in data.values(): if isinstance(v, list) and v and isinstance(v[0], dict): records.extend(v) list_used = True break if list_used: continue if isinstance(data, list) and data and isinstance(data[0], dict): records.extend(data) elif isinstance(data, dict): records.append(data) elif isinstance(entry, list): records.extend([item for item in entry if isinstance(item, dict)]) return records def _get_table_columns(self, table: str) -> list[tuple[str, str, str]]: """查询 information_schema,获取目标表列信息。""" cache = getattr(self, "_table_columns_cache", {}) if table in cache: return cache[table] if "." in table: schema, name = table.split(".", 1) else: schema, name = "public", table sql = """ SELECT column_name, data_type, udt_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """ with self.db.conn.cursor() as cur: cur.execute(sql, (schema, name)) cols = [(r[0], (r[1] or "").lower(), (r[2] or "").lower()) for r in cur.fetchall()] cache[table] = cols self._table_columns_cache = cache return cols def _ingest_table(self, table: str, records: list[dict], source_file: str) -> tuple[int, int]: """构建 INSERT/ON CONFLICT 语句并批量执行。""" spec = self.TABLE_SPECS.get(table) if not spec: raise ValueError(f"No table spec for {table}") pk_col = spec.get("pk") json_cols = set(spec.get("json_cols", [])) json_cols_lower = {c.lower() for c in json_cols} columns_info = self._get_table_columns(table) columns = [c[0] for c in columns_info] db_json_cols_lower = { c[0].lower() for c in columns_info if c[1] in ("json", "jsonb") or c[2] in ("json", "jsonb") } pk_col_db = None if pk_col: pk_col_db = next((c for c in columns if c.lower() == pk_col.lower()), pk_col) placeholders = ", ".join(["%s"] * len(columns)) col_list = ", ".join(f'"{c}"' for c in columns) sql = f'INSERT INTO {table} ({col_list}) VALUES ({placeholders})' if pk_col_db: update_cols = [c for c in columns if c != pk_col_db] set_clause = ", ".join(f'"{c}"=EXCLUDED."{c}"' for c in update_cols) sql += f' ON CONFLICT ("{pk_col_db}") DO UPDATE SET {set_clause}' sql += " RETURNING (xmax = 0) AS inserted" params = [] now = datetime.now() json_dump = lambda v: json.dumps(v, ensure_ascii=False) # noqa: E731 for rec in records: merged_rec = rec if isinstance(rec, dict) else {} data_part = merged_rec.get("data") while isinstance(data_part, dict): merged_rec = {**data_part, **merged_rec} data_part = data_part.get("data") # 针对充值/结算,补齐 siteProfile 中的店铺信息 if table in { "billiards_ods.recharge_settlements", "billiards_ods.settlement_records", }: site_profile = merged_rec.get("siteProfile") or merged_rec.get("site_profile") if isinstance(site_profile, dict): merged_rec.setdefault("tenantid", site_profile.get("tenant_id") or site_profile.get("tenantId")) merged_rec.setdefault("siteid", site_profile.get("id") or site_profile.get("siteId")) merged_rec.setdefault("sitename", site_profile.get("shop_name") or site_profile.get("siteName")) pk_val = self._get_value_case_insensitive(merged_rec, pk_col) if pk_col else None if pk_col and (pk_val is None or pk_val == ""): continue row_vals = [] for col_name, data_type, udt in columns_info: col_lower = col_name.lower() if col_lower == "payload": row_vals.append(Json(rec, dumps=json_dump)) continue if col_lower == "source_file": row_vals.append(source_file) continue if col_lower == "fetched_at": row_vals.append(merged_rec.get(col_name, now)) continue value = self._normalize_scalar(self._get_value_case_insensitive(merged_rec, col_name)) if col_lower in json_cols_lower or col_lower in db_json_cols_lower: row_vals.append(Json(value, dumps=json_dump) if value is not None else None) continue casted = self._cast_value(value, data_type) row_vals.append(casted) params.append(tuple(row_vals)) if not params: return 0, 0 inserted = 0 updated = 0 with self.db.conn.cursor() as cur: for row in params: cur.execute(sql, row) flag = cur.fetchone()[0] if flag: inserted += 1 else: updated += 1 return inserted, updated @staticmethod def _get_value_case_insensitive(record: dict, col: str | None): """忽略大小写获取值,兼容 information_schema 与 JSON 原始字段。""" if record is None or col is None: return None if col in record: return record.get(col) col_lower = col.lower() for k, v in record.items(): if isinstance(k, str) and k.lower() == col_lower: return v return None @staticmethod def _normalize_scalar(value): """将空字符串/空 JSON 规范为 None,避免类型转换错误。""" if value == "" or value == "{}" or value == "[]": return None return value @staticmethod def _cast_value(value, data_type: str): """根据列类型做简单转换,保证批量插入兼容。""" if value is None: return None dt = (data_type or "").lower() if dt in ("integer", "bigint", "smallint"): if isinstance(value, bool): return int(value) try: return int(value) except Exception: return None if dt in ("numeric", "double precision", "real", "decimal"): if isinstance(value, bool): return int(value) try: return float(value) except Exception: return None if dt.startswith("timestamp") or dt in ("date", "time", "interval"): return value if isinstance(value, str) else None return value