# -*- coding: utf-8 -*- """ 从本地 JSON 示例目录重建 billiards_ods.* 表,并导入样例数据。 用法: PYTHONPATH=. python -m etl_billiards.scripts.rebuild_ods_from_json [--dsn ...] [--json-dir ...] [--include ...] [--drop-schema-first] 依赖环境变量: PG_DSN PostgreSQL 连接串(必填) PG_CONNECT_TIMEOUT 可选,秒,默认 10 JSON_DOC_DIR 可选,JSON 目录,默认 C:\\dev\\LLTQ\\export\\test-json-doc ODS_INCLUDE_FILES 可选,逗号分隔文件名(不含 .json) ODS_DROP_SCHEMA_FIRST 可选,true/false,默认 true """ from __future__ import annotations import argparse import os import re import sys import json from pathlib import Path from typing import Iterable, List, Tuple import psycopg2 from psycopg2 import sql from psycopg2.extras import Json, execute_values DEFAULT_JSON_DIR = r"C:\dev\LLTQ\export\test-json-doc" SPECIAL_LIST_PATHS: dict[str, tuple[str, ...]] = { "assistant_accounts_master": ("data", "assistantInfos"), "assistant_cancellation_records": ("data", "abolitionAssistants"), "assistant_service_records": ("data", "orderAssistantDetails"), "goods_stock_movements": ("data", "queryDeliveryRecordsList"), "goods_stock_summary": ("data",), "group_buy_packages": ("data", "packageCouponList"), "group_buy_redemption_records": ("data", "siteTableUseDetailsList"), "member_balance_changes": ("data", "tenantMemberCardLogs"), "member_profiles": ("data", "tenantMemberInfos"), "member_stored_value_cards": ("data", "tenantMemberCards"), "recharge_settlements": ("data", "settleList"), "settlement_records": ("data", "settleList"), "site_tables_master": ("data", "siteTables"), "stock_goods_category_tree": ("data", "goodsCategoryList"), "store_goods_master": ("data", "orderGoodsList"), "store_goods_sales_records": ("data", "orderGoodsLedgers"), "table_fee_discount_records": ("data", "taiFeeAdjustInfos"), "table_fee_transactions": ("data", "siteTableUseDetailsList"), "tenant_goods_master": ("data", "tenantGoodsList"), } def sanitize_identifier(name: str) -> str: """将任意字符串转为可用的 SQL identifier(小写、非字母数字转下划线)。""" cleaned = re.sub(r"[^0-9a-zA-Z_]", "_", name.strip()) if not cleaned: cleaned = "col" if cleaned[0].isdigit(): cleaned = f"_{cleaned}" return cleaned.lower() def _extract_list_via_path(node, path: tuple[str, ...]): cur = node for key in path: if isinstance(cur, dict): cur = cur.get(key) else: return [] return cur if isinstance(cur, list) else [] def load_records(payload, list_path: tuple[str, ...] | None = None) -> list: """ 尝试从 JSON 结构中提取记录列表: - 直接是 list -> 返回 - dict 中 data 是 list -> 返回 - dict 中 data 是 dict,取第一个 list 字段 - dict 中任意值是 list -> 返回 - 其余情况,包装为单条记录 """ if list_path: if isinstance(payload, list): merged: list = [] for item in payload: merged.extend(_extract_list_via_path(item, list_path)) if merged: return merged elif isinstance(payload, dict): lst = _extract_list_via_path(payload, list_path) if lst: return lst if isinstance(payload, list): return payload if isinstance(payload, dict): data_node = payload.get("data") if isinstance(data_node, list): return data_node if isinstance(data_node, dict): for v in data_node.values(): if isinstance(v, list): return v for v in payload.values(): if isinstance(v, list): return v return [payload] def collect_columns(records: Iterable[dict]) -> List[str]: """汇总所有顶层键,作为表字段;仅处理 dict 记录。""" cols: set[str] = set() for rec in records: if isinstance(rec, dict): cols.update(rec.keys()) return sorted(cols) def create_table(cur, schema: str, table: str, columns: List[Tuple[str, str]]): """ 创建表:字段全部 jsonb,外加 source_file、record_index、payload、ingested_at。 columns: [(col_name, original_key)] """ fields = [sql.SQL("{} jsonb").format(sql.Identifier(col)) for col, _ in columns] constraint_name = f"uq_{table}_source_record" ddl = sql.SQL( "CREATE TABLE IF NOT EXISTS {schema}.{table} (" "source_file text," "record_index integer," "{cols}," "payload jsonb," "ingested_at timestamptz default now()," "CONSTRAINT {constraint} UNIQUE (source_file, record_index)" ");" ).format( schema=sql.Identifier(schema), table=sql.Identifier(table), cols=sql.SQL(",").join(fields), constraint=sql.Identifier(constraint_name), ) cur.execute(ddl) def insert_records(cur, schema: str, table: str, columns: List[Tuple[str, str]], records: list, source_file: str): """批量插入记录。""" col_idents = [sql.Identifier(col) for col, _ in columns] col_names = [col for col, _ in columns] orig_keys = [orig for _, orig in columns] all_cols = [sql.Identifier("source_file"), sql.Identifier("record_index")] + col_idents + [ sql.Identifier("payload") ] rows = [] for idx, rec in enumerate(records): if not isinstance(rec, dict): rec = {"value": rec} row_values = [source_file, idx] for key in orig_keys: row_values.append(Json(rec.get(key))) row_values.append(Json(rec)) rows.append(row_values) insert_sql = sql.SQL("INSERT INTO {}.{} ({}) VALUES %s ON CONFLICT DO NOTHING").format( sql.Identifier(schema), sql.Identifier(table), sql.SQL(",").join(all_cols), ) execute_values(cur, insert_sql, rows, page_size=500) def rebuild(schema: str = "billiards_ods", data_dir: str | Path = DEFAULT_JSON_DIR): parser = argparse.ArgumentParser(description="重建 billiards_ods.* 表并导入 JSON 样例") parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN(默认读取环境变量 PG_DSN)") parser.add_argument("--json-dir", dest="json_dir", help=f"JSON 目录,默认 {DEFAULT_JSON_DIR}") parser.add_argument( "--include", dest="include_files", help="限定导入的文件名(逗号分隔,不含 .json),默认全部", ) parser.add_argument( "--drop-schema-first", dest="drop_schema_first", action="store_true", help="先删除并重建 schema(默认 true)", ) parser.add_argument( "--no-drop-schema-first", dest="drop_schema_first", action="store_false", help="保留现有 schema,仅按冲突去重导入", ) parser.set_defaults(drop_schema_first=None) args = parser.parse_args() dsn = args.dsn or os.environ.get("PG_DSN") if not dsn: print("缺少参数/环境变量 PG_DSN,无法连接数据库。") sys.exit(1) timeout = max(1, min(int(os.environ.get("PG_CONNECT_TIMEOUT", 10)), 60)) env_drop = os.environ.get("ODS_DROP_SCHEMA_FIRST") or os.environ.get("DROP_SCHEMA_FIRST") drop_schema_first = ( args.drop_schema_first if args.drop_schema_first is not None else str(env_drop or "true").lower() in ("1", "true", "yes") ) include_files_env = args.include_files or os.environ.get("ODS_INCLUDE_FILES") or os.environ.get("INCLUDE_FILES") include_files = set() if include_files_env: include_files = {p.strip().lower() for p in include_files_env.split(",") if p.strip()} base_dir = Path(args.json_dir or data_dir or DEFAULT_JSON_DIR) if not base_dir.exists(): print(f"JSON 目录不存在: {base_dir}") sys.exit(1) conn = psycopg2.connect(dsn, connect_timeout=timeout) conn.autocommit = False cur = conn.cursor() if drop_schema_first: print(f"Dropping schema {schema} ...") cur.execute(sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE;").format(sql.Identifier(schema))) cur.execute(sql.SQL("CREATE SCHEMA {};").format(sql.Identifier(schema))) else: cur.execute( sql.SQL("SELECT schema_name FROM information_schema.schemata WHERE schema_name=%s"), (schema,), ) if not cur.fetchone(): cur.execute(sql.SQL("CREATE SCHEMA {};").format(sql.Identifier(schema))) json_files = sorted(base_dir.glob("*.json")) for path in json_files: stem_lower = path.stem.lower() if include_files and stem_lower not in include_files: continue print(f"Processing {path.name} ...") payload = json.loads(path.read_text(encoding="utf-8")) list_path = SPECIAL_LIST_PATHS.get(stem_lower) records = load_records(payload, list_path=list_path) columns_raw = collect_columns(records) columns = [(sanitize_identifier(c), c) for c in columns_raw] table_name = sanitize_identifier(path.stem) create_table(cur, schema, table_name, columns) if records: insert_records(cur, schema, table_name, columns, records, path.name) print(f" -> rows: {len(records)}, columns: {len(columns)}") conn.commit() cur.close() conn.close() print("Rebuild done.") if __name__ == "__main__": rebuild()