239 lines
7.7 KiB
Python
239 lines
7.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
同步 API 字段到 ODS 数据库表
|
|
1. 检测 API JSON 字段与 ODS 表列的差异
|
|
2. 生成并执行 DDL 添加缺失列
|
|
3. 忽略 siteProfile 等嵌套对象字段
|
|
"""
|
|
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# 添加项目路径
|
|
project_root = Path(__file__).parent.parent / "etl_billiards"
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from dotenv import load_dotenv
|
|
load_dotenv(project_root / ".env")
|
|
|
|
from database.connection import DatabaseConnection
|
|
|
|
|
|
# 忽略的 siteProfile 相关字段和其他非业务字段
|
|
IGNORED_FIELDS = {
|
|
# siteProfile 内嵌字段
|
|
"siteprofile", "address", "avatar", "business_tel", "customer_service_qrcode",
|
|
"customer_service_wechat", "fixed_pay_qrcode", "full_address", "latitude", "longitude",
|
|
"light_status", "light_token", "light_type", "org_id", "prod_env", "shop_name",
|
|
"shop_status", "site_label", "site_type", "tenant_site_region_id", "wifi_name",
|
|
"wifi_password", "attendance_distance", "attendance_enabled", "auto_light",
|
|
"ewelink_client_id",
|
|
# tableprofile 内嵌字段
|
|
"tableprofile",
|
|
# 已有的系统字段
|
|
"content_hash", "payload", "source_file", "source_endpoint", "fetched_at", "record_index",
|
|
}
|
|
|
|
# API 字段类型推断规则
|
|
def infer_column_type(field_name: str, sample_value=None) -> str:
|
|
"""根据字段名和样本值推断 PostgreSQL 列类型"""
|
|
fn = field_name.lower()
|
|
|
|
# ID 字段
|
|
if fn.endswith("_id") or fn in ("id", "tenant_id", "member_id", "site_id", "table_id",
|
|
"operator_id", "relate_id", "order_id"):
|
|
return "BIGINT"
|
|
|
|
# 金额字段
|
|
if any(x in fn for x in ("_money", "_amount", "_price", "_cost", "_discount", "_balance",
|
|
"_deduct", "_fee", "_charge", "money", "amount", "price")):
|
|
return "NUMERIC(18,2)"
|
|
|
|
# 时间字段
|
|
if any(x in fn for x in ("_time", "time", "_date", "date")) or fn.startswith("create") or fn.startswith("update"):
|
|
return "TIMESTAMP"
|
|
|
|
# 布尔/状态字段
|
|
if fn.startswith("is_") or fn.startswith("can_") or fn.startswith("able_"):
|
|
return "INTEGER"
|
|
|
|
# 数量/计数字段
|
|
if any(x in fn for x in ("_count", "_num", "_seconds", "_minutes", "count", "num", "seconds")):
|
|
return "INTEGER"
|
|
|
|
# 比率/折扣率
|
|
if any(x in fn for x in ("_radio", "_ratio", "_rate")):
|
|
return "NUMERIC(10,4)"
|
|
|
|
# 根据样本值推断
|
|
if sample_value is not None:
|
|
if isinstance(sample_value, bool):
|
|
return "BOOLEAN"
|
|
if isinstance(sample_value, int):
|
|
if sample_value > 2147483647 or sample_value < -2147483648:
|
|
return "BIGINT"
|
|
return "INTEGER"
|
|
if isinstance(sample_value, float):
|
|
return "NUMERIC(18,2)"
|
|
if isinstance(sample_value, (list, dict)):
|
|
return "JSONB"
|
|
|
|
# 默认文本
|
|
return "TEXT"
|
|
|
|
|
|
def get_db_table_columns(db: DatabaseConnection, table_name: str) -> set:
|
|
"""获取数据库表的所有列名"""
|
|
schema, name = table_name.split(".", 1) if "." in table_name else ("public", table_name)
|
|
sql = """
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = %s AND table_name = %s
|
|
"""
|
|
rows = db.query(sql, (schema, name))
|
|
return {r["column_name"].lower() for r in rows}
|
|
|
|
|
|
def get_api_fields_from_comparison(comparison_file: Path) -> dict:
|
|
"""从对比文件获取 API 字段"""
|
|
if not comparison_file.exists():
|
|
return {}
|
|
with open(comparison_file, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def generate_ddl_for_missing_fields(table_name: str, missing_fields: list, api_data: dict = None) -> list:
|
|
"""生成添加缺失列的 DDL"""
|
|
ddl_list = []
|
|
for field in missing_fields:
|
|
# 尝试从 API 数据获取样本值来推断类型
|
|
sample_value = None
|
|
if api_data:
|
|
for record in api_data.get("data", [])[:10]:
|
|
if isinstance(record, dict) and field in record:
|
|
sample_value = record[field]
|
|
break
|
|
|
|
col_type = infer_column_type(field, sample_value)
|
|
ddl = f'ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS "{field}" {col_type};'
|
|
ddl_list.append(ddl)
|
|
|
|
return ddl_list
|
|
|
|
|
|
def main():
|
|
print("=" * 80)
|
|
print("API → ODS 字段同步脚本")
|
|
print("时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
|
print("=" * 80)
|
|
|
|
# 连接数据库
|
|
dsn = os.getenv("PG_DSN")
|
|
if not dsn:
|
|
print("[错误] 未找到 PG_DSN 环境变量")
|
|
return
|
|
|
|
db = DatabaseConnection(dsn)
|
|
|
|
# 加载对比数据
|
|
comparison_file = Path(__file__).parent / "api_ods_comparison.json"
|
|
comparison = get_api_fields_from_comparison(comparison_file)
|
|
|
|
if not comparison:
|
|
print("[错误] 未找到对比文件 api_ods_comparison.json")
|
|
db.close()
|
|
return
|
|
|
|
all_ddl = []
|
|
executed_ddl = []
|
|
failed_ddl = []
|
|
|
|
for task_code, data in comparison.items():
|
|
table_name = data.get("table_name")
|
|
missing = data.get("missing_in_ods", [])
|
|
|
|
if not table_name or not missing:
|
|
continue
|
|
|
|
# 过滤忽略的字段
|
|
filtered_missing = [
|
|
f for f in missing
|
|
if f.lower() not in IGNORED_FIELDS
|
|
]
|
|
|
|
if not filtered_missing:
|
|
continue
|
|
|
|
# 获取数据库当前列
|
|
current_cols = get_db_table_columns(db, table_name)
|
|
|
|
# 二次过滤:排除已存在的列
|
|
truly_missing = [
|
|
f for f in filtered_missing
|
|
if f.lower() not in current_cols
|
|
]
|
|
|
|
if not truly_missing:
|
|
print(f"\n【{task_code}】({table_name})")
|
|
print(f" 所有缺失字段已在数据库中存在,跳过")
|
|
continue
|
|
|
|
print(f"\n【{task_code}】({table_name})")
|
|
print(f" 需要添加 {len(truly_missing)} 列: {', '.join(truly_missing)}")
|
|
|
|
# 生成 DDL
|
|
ddl_list = generate_ddl_for_missing_fields(table_name, truly_missing)
|
|
all_ddl.extend(ddl_list)
|
|
|
|
# 执行 DDL
|
|
for ddl in ddl_list:
|
|
try:
|
|
db.execute(ddl)
|
|
db.commit()
|
|
executed_ddl.append(ddl)
|
|
print(f" [成功] {ddl[:80]}...")
|
|
except Exception as e:
|
|
db.rollback()
|
|
failed_ddl.append((ddl, str(e)))
|
|
print(f" [失败] {ddl[:60]}... - {e}")
|
|
|
|
db.close()
|
|
|
|
# 汇总
|
|
print("\n" + "=" * 80)
|
|
print("执行汇总")
|
|
print("=" * 80)
|
|
print(f"总计生成 DDL: {len(all_ddl)} 条")
|
|
print(f"执行成功: {len(executed_ddl)} 条")
|
|
print(f"执行失败: {len(failed_ddl)} 条")
|
|
|
|
if failed_ddl:
|
|
print("\n失败的 DDL:")
|
|
for ddl, err in failed_ddl:
|
|
print(f" - {ddl}")
|
|
print(f" 错误: {err}")
|
|
|
|
# 保存执行日志
|
|
log_file = Path(__file__).parent / "sync_ods_columns_log.json"
|
|
log = {
|
|
"executed_at": datetime.now().isoformat(),
|
|
"total_ddl": len(all_ddl),
|
|
"success_count": len(executed_ddl),
|
|
"failed_count": len(failed_ddl),
|
|
"executed_ddl": executed_ddl,
|
|
"failed_ddl": [{"ddl": d, "error": e} for d, e in failed_ddl],
|
|
}
|
|
with open(log_file, "w", encoding="utf-8") as f:
|
|
json.dump(log, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"\n执行日志已保存到: {log_file}")
|
|
|
|
return len(failed_ddl) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = main()
|
|
sys.exit(0 if success else 1)
|