# -*- coding: utf-8 -*- """ 同步 API 字段到 ODS 数据库表 1. 检测 API JSON 字段与 ODS 表列的差异 2. 生成并执行 DDL 添加缺失列 3. 忽略 siteProfile 等嵌套对象字段 """ import json import os import sys from datetime import datetime from pathlib import Path # 添加项目路径 project_root = Path(__file__).parent.parent / "etl_billiards" sys.path.insert(0, str(project_root)) from dotenv import load_dotenv load_dotenv(project_root / ".env") from database.connection import DatabaseConnection # 忽略的 siteProfile 相关字段和其他非业务字段 IGNORED_FIELDS = { # siteProfile 内嵌字段 "siteprofile", "address", "avatar", "business_tel", "customer_service_qrcode", "customer_service_wechat", "fixed_pay_qrcode", "full_address", "latitude", "longitude", "light_status", "light_token", "light_type", "org_id", "prod_env", "shop_name", "shop_status", "site_label", "site_type", "tenant_site_region_id", "wifi_name", "wifi_password", "attendance_distance", "attendance_enabled", "auto_light", "ewelink_client_id", # tableprofile 内嵌字段 "tableprofile", # 已有的系统字段 "content_hash", "payload", "source_file", "source_endpoint", "fetched_at", "record_index", } # API 字段类型推断规则 def infer_column_type(field_name: str, sample_value=None) -> str: """根据字段名和样本值推断 PostgreSQL 列类型""" fn = field_name.lower() # ID 字段 if fn.endswith("_id") or fn in ("id", "tenant_id", "member_id", "site_id", "table_id", "operator_id", "relate_id", "order_id"): return "BIGINT" # 金额字段 if any(x in fn for x in ("_money", "_amount", "_price", "_cost", "_discount", "_balance", "_deduct", "_fee", "_charge", "money", "amount", "price")): return "NUMERIC(18,2)" # 时间字段 if any(x in fn for x in ("_time", "time", "_date", "date")) or fn.startswith("create") or fn.startswith("update"): return "TIMESTAMP" # 布尔/状态字段 if fn.startswith("is_") or fn.startswith("can_") or fn.startswith("able_"): return "INTEGER" # 数量/计数字段 if any(x in fn for x in ("_count", "_num", "_seconds", "_minutes", "count", "num", "seconds")): return "INTEGER" # 比率/折扣率 if any(x in fn for x in ("_radio", "_ratio", "_rate")): return "NUMERIC(10,4)" # 根据样本值推断 if sample_value is not None: if isinstance(sample_value, bool): return "BOOLEAN" if isinstance(sample_value, int): if sample_value > 2147483647 or sample_value < -2147483648: return "BIGINT" return "INTEGER" if isinstance(sample_value, float): return "NUMERIC(18,2)" if isinstance(sample_value, (list, dict)): return "JSONB" # 默认文本 return "TEXT" def get_db_table_columns(db: DatabaseConnection, table_name: str) -> set: """获取数据库表的所有列名""" schema, name = table_name.split(".", 1) if "." in table_name else ("public", table_name) sql = """ SELECT column_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s """ rows = db.query(sql, (schema, name)) return {r["column_name"].lower() for r in rows} def get_api_fields_from_comparison(comparison_file: Path) -> dict: """从对比文件获取 API 字段""" if not comparison_file.exists(): return {} with open(comparison_file, "r", encoding="utf-8") as f: return json.load(f) def generate_ddl_for_missing_fields(table_name: str, missing_fields: list, api_data: dict = None) -> list: """生成添加缺失列的 DDL""" ddl_list = [] for field in missing_fields: # 尝试从 API 数据获取样本值来推断类型 sample_value = None if api_data: for record in api_data.get("data", [])[:10]: if isinstance(record, dict) and field in record: sample_value = record[field] break col_type = infer_column_type(field, sample_value) ddl = f'ALTER TABLE {table_name} ADD COLUMN IF NOT EXISTS "{field}" {col_type};' ddl_list.append(ddl) return ddl_list def main(): print("=" * 80) print("API → ODS 字段同步脚本") print("时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) print("=" * 80) # 连接数据库 dsn = os.getenv("PG_DSN") if not dsn: print("[错误] 未找到 PG_DSN 环境变量") return db = DatabaseConnection(dsn) # 加载对比数据 comparison_file = Path(__file__).parent / "api_ods_comparison.json" comparison = get_api_fields_from_comparison(comparison_file) if not comparison: print("[错误] 未找到对比文件 api_ods_comparison.json") db.close() return all_ddl = [] executed_ddl = [] failed_ddl = [] for task_code, data in comparison.items(): table_name = data.get("table_name") missing = data.get("missing_in_ods", []) if not table_name or not missing: continue # 过滤忽略的字段 filtered_missing = [ f for f in missing if f.lower() not in IGNORED_FIELDS ] if not filtered_missing: continue # 获取数据库当前列 current_cols = get_db_table_columns(db, table_name) # 二次过滤:排除已存在的列 truly_missing = [ f for f in filtered_missing if f.lower() not in current_cols ] if not truly_missing: print(f"\n【{task_code}】({table_name})") print(f" 所有缺失字段已在数据库中存在,跳过") continue print(f"\n【{task_code}】({table_name})") print(f" 需要添加 {len(truly_missing)} 列: {', '.join(truly_missing)}") # 生成 DDL ddl_list = generate_ddl_for_missing_fields(table_name, truly_missing) all_ddl.extend(ddl_list) # 执行 DDL for ddl in ddl_list: try: db.execute(ddl) db.commit() executed_ddl.append(ddl) print(f" [成功] {ddl[:80]}...") except Exception as e: db.rollback() failed_ddl.append((ddl, str(e))) print(f" [失败] {ddl[:60]}... - {e}") db.close() # 汇总 print("\n" + "=" * 80) print("执行汇总") print("=" * 80) print(f"总计生成 DDL: {len(all_ddl)} 条") print(f"执行成功: {len(executed_ddl)} 条") print(f"执行失败: {len(failed_ddl)} 条") if failed_ddl: print("\n失败的 DDL:") for ddl, err in failed_ddl: print(f" - {ddl}") print(f" 错误: {err}") # 保存执行日志 log_file = Path(__file__).parent / "sync_ods_columns_log.json" log = { "executed_at": datetime.now().isoformat(), "total_ddl": len(all_ddl), "success_count": len(executed_ddl), "failed_count": len(failed_ddl), "executed_ddl": executed_ddl, "failed_ddl": [{"ddl": d, "error": e} for d, e in failed_ddl], } with open(log_file, "w", encoding="utf-8") as f: json.dump(log, f, ensure_ascii=False, indent=2) print(f"\n执行日志已保存到: {log_file}") return len(failed_ddl) == 0 if __name__ == "__main__": success = main() sys.exit(0 if success else 1)