# -*- coding: utf-8 -*- """修复迁移中因部分导入导致的重复键问题:先 TRUNCATE 再重新 COPY。""" import sys import io import psycopg2 if sys.platform == "win32": sys.stdout.reconfigure(encoding="utf-8", errors="replace") sys.stderr.reconfigure(encoding="utf-8", errors="replace") DB_HOST = "100.64.0.4" DB_PORT = 5432 DB_USER = "local-Python" DB_PASS = "Neo-local-1991125" OLD_DB = "LLZQ-test" NEW_DB = "etl_feiqiu" SCHEMA_MAP = { "billiards_ods": "ods", "billiards_dwd": "dwd", "billiards_dws": "dws", "etl_admin": "meta", } def get_columns(conn, schema, table): with conn.cursor() as cur: cur.execute(""" SELECT column_name FROM information_schema.columns WHERE table_schema = %s AND table_name = %s ORDER BY ordinal_position """, (schema, table)) return [r[0] for r in cur.fetchall()] def count_rows(conn, schema, table): with conn.cursor() as cur: cur.execute(f'SELECT COUNT(*) FROM "{schema}"."{table}"') return cur.fetchone()[0] def main(): src = psycopg2.connect(host=DB_HOST, port=DB_PORT, dbname=OLD_DB, user=DB_USER, password=DB_PASS, options="-c client_encoding=UTF8") dst = psycopg2.connect(host=DB_HOST, port=DB_PORT, dbname=NEW_DB, user=DB_USER, password=DB_PASS, options="-c client_encoding=UTF8") mismatched = [] for old_s, new_s in SCHEMA_MAP.items(): with src.cursor() as cur: cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = %s ORDER BY tablename", (old_s,)) tables = [r[0] for r in cur.fetchall()] for t in tables: s_cnt = count_rows(src, old_s, t) dst_cols = get_columns(dst, new_s, t) if not dst_cols: continue d_cnt = count_rows(dst, new_s, t) if s_cnt > 0 and d_cnt != s_cnt: mismatched.append((old_s, new_s, t, s_cnt, d_cnt)) if not mismatched: print("所有表行数一致,无需修复。") # 继续检查索引和 ANALYZE else: print(f"发现 {len(mismatched)} 个不一致表:") for old_s, new_s, t, s_cnt, d_cnt in mismatched: print(f" {old_s}.{t}: 源={s_cnt} 目标={d_cnt}") for old_s, new_s, t, s_cnt, d_cnt in mismatched: print(f"\n修复 {new_s}.{t} ...") src_cols = get_columns(src, old_s, t) dst_cols = get_columns(dst, new_s, t) common = [c for c in dst_cols if c in src_cols] cols_sql = ", ".join(f'"{c}"' for c in common) # TRUNCATE with dst.cursor() as cur: cur.execute(f'TRUNCATE "{new_s}"."{t}" CASCADE') dst.commit() print(f" TRUNCATE 完成") # COPY buf = io.BytesIO() with src.cursor() as cur: cur.copy_expert( f'COPY (SELECT {cols_sql} FROM "{old_s}"."{t}") TO STDOUT WITH (FORMAT binary)', buf) buf.seek(0) with dst.cursor() as cur: cur.copy_expert( f'COPY "{new_s}"."{t}" ({cols_sql}) FROM STDIN WITH (FORMAT binary)', buf) dst.commit() final = count_rows(dst, new_s, t) status = "OK" if final == s_cnt else "MISMATCH" print(f" 导入完成: {final} 行 ({status})") # 迁移索引 print("\n迁移索引...") idx_total = 0 for old_s, new_s in SCHEMA_MAP.items(): with src.cursor() as cur: cur.execute(""" SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = %s AND indexname NOT IN ( SELECT conname FROM pg_constraint WHERE connamespace = (SELECT oid FROM pg_namespace WHERE nspname = %s)) ORDER BY indexname """, (old_s, old_s)) indexes = cur.fetchall() created = 0 for idx_name, idx_def in indexes: new_def = idx_def.replace(f'"{old_s}"', f'"{new_s}"').replace(f'{old_s}.', f'{new_s}.') new_def = new_def.replace("CREATE INDEX", "CREATE INDEX IF NOT EXISTS", 1) new_def = new_def.replace("CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1) try: with dst.cursor() as cur: cur.execute(new_def) dst.commit() created += 1 except Exception as e: dst.rollback() print(f" 索引失败 {idx_name}: {e}") idx_total += created print(f" {old_s} -> {new_s}: {created}/{len(indexes)} 索引") # ANALYZE print("\n执行 ANALYZE...") dst.autocommit = True with dst.cursor() as cur: for new_s in SCHEMA_MAP.values(): tables = get_columns(dst, new_s, "") # dummy cur.execute(f"ANALYZE") print("ANALYZE 完成") # 最终验证 print("\n最终验证:") all_ok = True for old_s, new_s in SCHEMA_MAP.items(): with src.cursor() as cur: cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = %s ORDER BY tablename", (old_s,)) tables = [r[0] for r in cur.fetchall()] for t in tables: s_cnt = count_rows(src, old_s, t) if s_cnt == 0: continue dst_cols = get_columns(dst, new_s, t) if not dst_cols: print(f" MISS {new_s}.{t}: 目标表不存在") all_ok = False continue d_cnt = count_rows(dst, new_s, t) if d_cnt != s_cnt: print(f" FAIL {new_s}.{t}: 源={s_cnt} 目标={d_cnt}") all_ok = False if all_ok: print(" 全部一致 OK") src.close() dst.close() if __name__ == "__main__": main()