259 lines
9.4 KiB
Python
259 lines
9.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
从本地 JSON 示例目录重建 billiards_ods.* 表,并导入样例数据。
|
||
用法:
|
||
PYTHONPATH=. python -m etl_billiards.scripts.rebuild_ods_from_json [--dsn ...] [--json-dir ...] [--include ...] [--drop-schema-first]
|
||
|
||
依赖环境变量:
|
||
PG_DSN PostgreSQL 连接串(必填)
|
||
PG_CONNECT_TIMEOUT 可选,秒,默认 10
|
||
JSON_DOC_DIR 可选,JSON 目录,默认 C:\\dev\\LLTQ\\export\\test-json-doc
|
||
ODS_INCLUDE_FILES 可选,逗号分隔文件名(不含 .json)
|
||
ODS_DROP_SCHEMA_FIRST 可选,true/false,默认 true
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import re
|
||
import sys
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Iterable, List, Tuple
|
||
|
||
import psycopg2
|
||
from psycopg2 import sql
|
||
from psycopg2.extras import Json, execute_values
|
||
|
||
|
||
DEFAULT_JSON_DIR = r"C:\dev\LLTQ\export\test-json-doc"
|
||
SPECIAL_LIST_PATHS: dict[str, tuple[str, ...]] = {
|
||
"assistant_accounts_master": ("data", "assistantInfos"),
|
||
"assistant_cancellation_records": ("data", "abolitionAssistants"),
|
||
"assistant_service_records": ("data", "orderAssistantDetails"),
|
||
"goods_stock_movements": ("data", "queryDeliveryRecordsList"),
|
||
"goods_stock_summary": ("data",),
|
||
"group_buy_packages": ("data", "packageCouponList"),
|
||
"group_buy_redemption_records": ("data", "siteTableUseDetailsList"),
|
||
"member_balance_changes": ("data", "tenantMemberCardLogs"),
|
||
"member_profiles": ("data", "tenantMemberInfos"),
|
||
"member_stored_value_cards": ("data", "tenantMemberCards"),
|
||
"recharge_settlements": ("data", "settleList"),
|
||
"settlement_records": ("data", "settleList"),
|
||
"site_tables_master": ("data", "siteTables"),
|
||
"stock_goods_category_tree": ("data", "goodsCategoryList"),
|
||
"store_goods_master": ("data", "orderGoodsList"),
|
||
"store_goods_sales_records": ("data", "orderGoodsLedgers"),
|
||
"table_fee_discount_records": ("data", "taiFeeAdjustInfos"),
|
||
"table_fee_transactions": ("data", "siteTableUseDetailsList"),
|
||
"tenant_goods_master": ("data", "tenantGoodsList"),
|
||
}
|
||
|
||
|
||
def sanitize_identifier(name: str) -> str:
|
||
"""将任意字符串转为可用的 SQL identifier(小写、非字母数字转下划线)。"""
|
||
cleaned = re.sub(r"[^0-9a-zA-Z_]", "_", name.strip())
|
||
if not cleaned:
|
||
cleaned = "col"
|
||
if cleaned[0].isdigit():
|
||
cleaned = f"_{cleaned}"
|
||
return cleaned.lower()
|
||
|
||
|
||
def _extract_list_via_path(node, path: tuple[str, ...]):
|
||
cur = node
|
||
for key in path:
|
||
if isinstance(cur, dict):
|
||
cur = cur.get(key)
|
||
else:
|
||
return []
|
||
return cur if isinstance(cur, list) else []
|
||
|
||
|
||
def load_records(payload, list_path: tuple[str, ...] | None = None) -> list:
|
||
"""
|
||
尝试从 JSON 结构中提取记录列表:
|
||
- 直接是 list -> 返回
|
||
- dict 中 data 是 list -> 返回
|
||
- dict 中 data 是 dict,取第一个 list 字段
|
||
- dict 中任意值是 list -> 返回
|
||
- 其余情况,包装为单条记录
|
||
"""
|
||
if list_path:
|
||
if isinstance(payload, list):
|
||
merged: list = []
|
||
for item in payload:
|
||
merged.extend(_extract_list_via_path(item, list_path))
|
||
if merged:
|
||
return merged
|
||
elif isinstance(payload, dict):
|
||
lst = _extract_list_via_path(payload, list_path)
|
||
if lst:
|
||
return lst
|
||
|
||
if isinstance(payload, list):
|
||
return payload
|
||
if isinstance(payload, dict):
|
||
data_node = payload.get("data")
|
||
if isinstance(data_node, list):
|
||
return data_node
|
||
if isinstance(data_node, dict):
|
||
for v in data_node.values():
|
||
if isinstance(v, list):
|
||
return v
|
||
for v in payload.values():
|
||
if isinstance(v, list):
|
||
return v
|
||
return [payload]
|
||
|
||
|
||
def collect_columns(records: Iterable[dict]) -> List[str]:
|
||
"""汇总所有顶层键,作为表字段;仅处理 dict 记录。"""
|
||
cols: set[str] = set()
|
||
for rec in records:
|
||
if isinstance(rec, dict):
|
||
cols.update(rec.keys())
|
||
return sorted(cols)
|
||
|
||
|
||
def create_table(cur, schema: str, table: str, columns: List[Tuple[str, str]]):
|
||
"""
|
||
创建表:字段全部 jsonb,外加 source_file、record_index、payload、ingested_at。
|
||
columns: [(col_name, original_key)]
|
||
"""
|
||
fields = [sql.SQL("{} jsonb").format(sql.Identifier(col)) for col, _ in columns]
|
||
constraint_name = f"uq_{table}_source_record"
|
||
ddl = sql.SQL(
|
||
"CREATE TABLE IF NOT EXISTS {schema}.{table} ("
|
||
"source_file text,"
|
||
"record_index integer,"
|
||
"{cols},"
|
||
"payload jsonb,"
|
||
"ingested_at timestamptz default now(),"
|
||
"CONSTRAINT {constraint} UNIQUE (source_file, record_index)"
|
||
");"
|
||
).format(
|
||
schema=sql.Identifier(schema),
|
||
table=sql.Identifier(table),
|
||
cols=sql.SQL(",").join(fields),
|
||
constraint=sql.Identifier(constraint_name),
|
||
)
|
||
cur.execute(ddl)
|
||
|
||
|
||
def insert_records(cur, schema: str, table: str, columns: List[Tuple[str, str]], records: list, source_file: str):
|
||
"""批量插入记录。"""
|
||
col_idents = [sql.Identifier(col) for col, _ in columns]
|
||
col_names = [col for col, _ in columns]
|
||
orig_keys = [orig for _, orig in columns]
|
||
all_cols = [sql.Identifier("source_file"), sql.Identifier("record_index")] + col_idents + [
|
||
sql.Identifier("payload")
|
||
]
|
||
|
||
rows = []
|
||
for idx, rec in enumerate(records):
|
||
if not isinstance(rec, dict):
|
||
rec = {"value": rec}
|
||
row_values = [source_file, idx]
|
||
for key in orig_keys:
|
||
row_values.append(Json(rec.get(key)))
|
||
row_values.append(Json(rec))
|
||
rows.append(row_values)
|
||
|
||
insert_sql = sql.SQL("INSERT INTO {}.{} ({}) VALUES %s ON CONFLICT DO NOTHING").format(
|
||
sql.Identifier(schema),
|
||
sql.Identifier(table),
|
||
sql.SQL(",").join(all_cols),
|
||
)
|
||
execute_values(cur, insert_sql, rows, page_size=500)
|
||
|
||
|
||
def rebuild(schema: str = "billiards_ods", data_dir: str | Path = DEFAULT_JSON_DIR):
|
||
parser = argparse.ArgumentParser(description="重建 billiards_ods.* 表并导入 JSON 样例")
|
||
parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN(默认读取环境变量 PG_DSN)")
|
||
parser.add_argument("--json-dir", dest="json_dir", help=f"JSON 目录,默认 {DEFAULT_JSON_DIR}")
|
||
parser.add_argument(
|
||
"--include",
|
||
dest="include_files",
|
||
help="限定导入的文件名(逗号分隔,不含 .json),默认全部",
|
||
)
|
||
parser.add_argument(
|
||
"--drop-schema-first",
|
||
dest="drop_schema_first",
|
||
action="store_true",
|
||
help="先删除并重建 schema(默认 true)",
|
||
)
|
||
parser.add_argument(
|
||
"--no-drop-schema-first",
|
||
dest="drop_schema_first",
|
||
action="store_false",
|
||
help="保留现有 schema,仅按冲突去重导入",
|
||
)
|
||
parser.set_defaults(drop_schema_first=None)
|
||
args = parser.parse_args()
|
||
|
||
dsn = args.dsn or os.environ.get("PG_DSN")
|
||
if not dsn:
|
||
print("缺少参数/环境变量 PG_DSN,无法连接数据库。")
|
||
sys.exit(1)
|
||
timeout = max(1, min(int(os.environ.get("PG_CONNECT_TIMEOUT", 10)), 60))
|
||
env_drop = os.environ.get("ODS_DROP_SCHEMA_FIRST") or os.environ.get("DROP_SCHEMA_FIRST")
|
||
drop_schema_first = (
|
||
args.drop_schema_first
|
||
if args.drop_schema_first is not None
|
||
else str(env_drop or "true").lower() in ("1", "true", "yes")
|
||
)
|
||
include_files_env = args.include_files or os.environ.get("ODS_INCLUDE_FILES") or os.environ.get("INCLUDE_FILES")
|
||
include_files = set()
|
||
if include_files_env:
|
||
include_files = {p.strip().lower() for p in include_files_env.split(",") if p.strip()}
|
||
|
||
base_dir = Path(args.json_dir or data_dir or DEFAULT_JSON_DIR)
|
||
if not base_dir.exists():
|
||
print(f"JSON 目录不存在: {base_dir}")
|
||
sys.exit(1)
|
||
|
||
conn = psycopg2.connect(dsn, connect_timeout=timeout)
|
||
conn.autocommit = False
|
||
cur = conn.cursor()
|
||
|
||
if drop_schema_first:
|
||
print(f"Dropping schema {schema} ...")
|
||
cur.execute(sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE;").format(sql.Identifier(schema)))
|
||
cur.execute(sql.SQL("CREATE SCHEMA {};").format(sql.Identifier(schema)))
|
||
else:
|
||
cur.execute(
|
||
sql.SQL("SELECT schema_name FROM information_schema.schemata WHERE schema_name=%s"),
|
||
(schema,),
|
||
)
|
||
if not cur.fetchone():
|
||
cur.execute(sql.SQL("CREATE SCHEMA {};").format(sql.Identifier(schema)))
|
||
|
||
json_files = sorted(base_dir.glob("*.json"))
|
||
for path in json_files:
|
||
stem_lower = path.stem.lower()
|
||
if include_files and stem_lower not in include_files:
|
||
continue
|
||
|
||
print(f"Processing {path.name} ...")
|
||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||
list_path = SPECIAL_LIST_PATHS.get(stem_lower)
|
||
records = load_records(payload, list_path=list_path)
|
||
columns_raw = collect_columns(records)
|
||
columns = [(sanitize_identifier(c), c) for c in columns_raw]
|
||
|
||
table_name = sanitize_identifier(path.stem)
|
||
create_table(cur, schema, table_name, columns)
|
||
if records:
|
||
insert_records(cur, schema, table_name, columns, records, path.name)
|
||
print(f" -> rows: {len(records)}, columns: {len(columns)}")
|
||
|
||
conn.commit()
|
||
cur.close()
|
||
conn.close()
|
||
print("Rebuild done.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
rebuild()
|