Files
feiqiu-ETL/etl_billiards/scripts/rebuild_ods_from_json.py
2025-11-30 07:19:05 +08:00

259 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
从本地 JSON 示例目录重建 billiards_ods.* 表,并导入样例数据。
用法:
PYTHONPATH=. python -m etl_billiards.scripts.rebuild_ods_from_json [--dsn ...] [--json-dir ...] [--include ...] [--drop-schema-first]
依赖环境变量:
PG_DSN PostgreSQL 连接串(必填)
PG_CONNECT_TIMEOUT 可选,秒,默认 10
JSON_DOC_DIR 可选JSON 目录,默认 C:\\dev\\LLTQ\\export\\test-json-doc
ODS_INCLUDE_FILES 可选,逗号分隔文件名(不含 .json
ODS_DROP_SCHEMA_FIRST 可选true/false默认 true
"""
from __future__ import annotations
import argparse
import os
import re
import sys
import json
from pathlib import Path
from typing import Iterable, List, Tuple
import psycopg2
from psycopg2 import sql
from psycopg2.extras import Json, execute_values
DEFAULT_JSON_DIR = r"C:\dev\LLTQ\export\test-json-doc"
SPECIAL_LIST_PATHS: dict[str, tuple[str, ...]] = {
"assistant_accounts_master": ("data", "assistantInfos"),
"assistant_cancellation_records": ("data", "abolitionAssistants"),
"assistant_service_records": ("data", "orderAssistantDetails"),
"goods_stock_movements": ("data", "queryDeliveryRecordsList"),
"goods_stock_summary": ("data",),
"group_buy_packages": ("data", "packageCouponList"),
"group_buy_redemption_records": ("data", "siteTableUseDetailsList"),
"member_balance_changes": ("data", "tenantMemberCardLogs"),
"member_profiles": ("data", "tenantMemberInfos"),
"member_stored_value_cards": ("data", "tenantMemberCards"),
"recharge_settlements": ("data", "settleList"),
"settlement_records": ("data", "settleList"),
"site_tables_master": ("data", "siteTables"),
"stock_goods_category_tree": ("data", "goodsCategoryList"),
"store_goods_master": ("data", "orderGoodsList"),
"store_goods_sales_records": ("data", "orderGoodsLedgers"),
"table_fee_discount_records": ("data", "taiFeeAdjustInfos"),
"table_fee_transactions": ("data", "siteTableUseDetailsList"),
"tenant_goods_master": ("data", "tenantGoodsList"),
}
def sanitize_identifier(name: str) -> str:
"""将任意字符串转为可用的 SQL identifier小写、非字母数字转下划线"""
cleaned = re.sub(r"[^0-9a-zA-Z_]", "_", name.strip())
if not cleaned:
cleaned = "col"
if cleaned[0].isdigit():
cleaned = f"_{cleaned}"
return cleaned.lower()
def _extract_list_via_path(node, path: tuple[str, ...]):
cur = node
for key in path:
if isinstance(cur, dict):
cur = cur.get(key)
else:
return []
return cur if isinstance(cur, list) else []
def load_records(payload, list_path: tuple[str, ...] | None = None) -> list:
"""
尝试从 JSON 结构中提取记录列表:
- 直接是 list -> 返回
- dict 中 data 是 list -> 返回
- dict 中 data 是 dict取第一个 list 字段
- dict 中任意值是 list -> 返回
- 其余情况,包装为单条记录
"""
if list_path:
if isinstance(payload, list):
merged: list = []
for item in payload:
merged.extend(_extract_list_via_path(item, list_path))
if merged:
return merged
elif isinstance(payload, dict):
lst = _extract_list_via_path(payload, list_path)
if lst:
return lst
if isinstance(payload, list):
return payload
if isinstance(payload, dict):
data_node = payload.get("data")
if isinstance(data_node, list):
return data_node
if isinstance(data_node, dict):
for v in data_node.values():
if isinstance(v, list):
return v
for v in payload.values():
if isinstance(v, list):
return v
return [payload]
def collect_columns(records: Iterable[dict]) -> List[str]:
"""汇总所有顶层键,作为表字段;仅处理 dict 记录。"""
cols: set[str] = set()
for rec in records:
if isinstance(rec, dict):
cols.update(rec.keys())
return sorted(cols)
def create_table(cur, schema: str, table: str, columns: List[Tuple[str, str]]):
"""
创建表:字段全部 jsonb外加 source_file、record_index、payload、ingested_at。
columns: [(col_name, original_key)]
"""
fields = [sql.SQL("{} jsonb").format(sql.Identifier(col)) for col, _ in columns]
constraint_name = f"uq_{table}_source_record"
ddl = sql.SQL(
"CREATE TABLE IF NOT EXISTS {schema}.{table} ("
"source_file text,"
"record_index integer,"
"{cols},"
"payload jsonb,"
"ingested_at timestamptz default now(),"
"CONSTRAINT {constraint} UNIQUE (source_file, record_index)"
");"
).format(
schema=sql.Identifier(schema),
table=sql.Identifier(table),
cols=sql.SQL(",").join(fields),
constraint=sql.Identifier(constraint_name),
)
cur.execute(ddl)
def insert_records(cur, schema: str, table: str, columns: List[Tuple[str, str]], records: list, source_file: str):
"""批量插入记录。"""
col_idents = [sql.Identifier(col) for col, _ in columns]
col_names = [col for col, _ in columns]
orig_keys = [orig for _, orig in columns]
all_cols = [sql.Identifier("source_file"), sql.Identifier("record_index")] + col_idents + [
sql.Identifier("payload")
]
rows = []
for idx, rec in enumerate(records):
if not isinstance(rec, dict):
rec = {"value": rec}
row_values = [source_file, idx]
for key in orig_keys:
row_values.append(Json(rec.get(key)))
row_values.append(Json(rec))
rows.append(row_values)
insert_sql = sql.SQL("INSERT INTO {}.{} ({}) VALUES %s ON CONFLICT DO NOTHING").format(
sql.Identifier(schema),
sql.Identifier(table),
sql.SQL(",").join(all_cols),
)
execute_values(cur, insert_sql, rows, page_size=500)
def rebuild(schema: str = "billiards_ods", data_dir: str | Path = DEFAULT_JSON_DIR):
parser = argparse.ArgumentParser(description="重建 billiards_ods.* 表并导入 JSON 样例")
parser.add_argument("--dsn", dest="dsn", help="PostgreSQL DSN默认读取环境变量 PG_DSN")
parser.add_argument("--json-dir", dest="json_dir", help=f"JSON 目录,默认 {DEFAULT_JSON_DIR}")
parser.add_argument(
"--include",
dest="include_files",
help="限定导入的文件名(逗号分隔,不含 .json默认全部",
)
parser.add_argument(
"--drop-schema-first",
dest="drop_schema_first",
action="store_true",
help="先删除并重建 schema默认 true",
)
parser.add_argument(
"--no-drop-schema-first",
dest="drop_schema_first",
action="store_false",
help="保留现有 schema仅按冲突去重导入",
)
parser.set_defaults(drop_schema_first=None)
args = parser.parse_args()
dsn = args.dsn or os.environ.get("PG_DSN")
if not dsn:
print("缺少参数/环境变量 PG_DSN无法连接数据库。")
sys.exit(1)
timeout = max(1, min(int(os.environ.get("PG_CONNECT_TIMEOUT", 10)), 60))
env_drop = os.environ.get("ODS_DROP_SCHEMA_FIRST") or os.environ.get("DROP_SCHEMA_FIRST")
drop_schema_first = (
args.drop_schema_first
if args.drop_schema_first is not None
else str(env_drop or "true").lower() in ("1", "true", "yes")
)
include_files_env = args.include_files or os.environ.get("ODS_INCLUDE_FILES") or os.environ.get("INCLUDE_FILES")
include_files = set()
if include_files_env:
include_files = {p.strip().lower() for p in include_files_env.split(",") if p.strip()}
base_dir = Path(args.json_dir or data_dir or DEFAULT_JSON_DIR)
if not base_dir.exists():
print(f"JSON 目录不存在: {base_dir}")
sys.exit(1)
conn = psycopg2.connect(dsn, connect_timeout=timeout)
conn.autocommit = False
cur = conn.cursor()
if drop_schema_first:
print(f"Dropping schema {schema} ...")
cur.execute(sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE;").format(sql.Identifier(schema)))
cur.execute(sql.SQL("CREATE SCHEMA {};").format(sql.Identifier(schema)))
else:
cur.execute(
sql.SQL("SELECT schema_name FROM information_schema.schemata WHERE schema_name=%s"),
(schema,),
)
if not cur.fetchone():
cur.execute(sql.SQL("CREATE SCHEMA {};").format(sql.Identifier(schema)))
json_files = sorted(base_dir.glob("*.json"))
for path in json_files:
stem_lower = path.stem.lower()
if include_files and stem_lower not in include_files:
continue
print(f"Processing {path.name} ...")
payload = json.loads(path.read_text(encoding="utf-8"))
list_path = SPECIAL_LIST_PATHS.get(stem_lower)
records = load_records(payload, list_path=list_path)
columns_raw = collect_columns(records)
columns = [(sanitize_identifier(c), c) for c in columns_raw]
table_name = sanitize_identifier(path.stem)
create_table(cur, schema, table_name, columns)
if records:
insert_records(cur, schema, table_name, columns, records, path.name)
print(f" -> rows: {len(records)}, columns: {len(columns)}")
conn.commit()
cur.close()
conn.close()
print("Rebuild done.")
if __name__ == "__main__":
rebuild()