164 lines
5.8 KiB
Python
164 lines
5.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""修复迁移中因部分导入导致的重复键问题:先 TRUNCATE 再重新 COPY。"""
|
|
import sys
|
|
import io
|
|
import psycopg2
|
|
|
|
if sys.platform == "win32":
|
|
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
|
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
|
|
|
DB_HOST = "100.64.0.4"
|
|
DB_PORT = 5432
|
|
DB_USER = "local-Python"
|
|
DB_PASS = "Neo-local-1991125"
|
|
OLD_DB = "LLZQ-test"
|
|
NEW_DB = "etl_feiqiu"
|
|
|
|
SCHEMA_MAP = {
|
|
"billiards_ods": "ods",
|
|
"billiards_dwd": "dwd",
|
|
"billiards_dws": "dws",
|
|
"etl_admin": "meta",
|
|
}
|
|
|
|
def get_columns(conn, schema, table):
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT column_name FROM information_schema.columns
|
|
WHERE table_schema = %s AND table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (schema, table))
|
|
return [r[0] for r in cur.fetchall()]
|
|
|
|
def count_rows(conn, schema, table):
|
|
with conn.cursor() as cur:
|
|
cur.execute(f'SELECT COUNT(*) FROM "{schema}"."{table}"')
|
|
return cur.fetchone()[0]
|
|
|
|
def main():
|
|
src = psycopg2.connect(host=DB_HOST, port=DB_PORT, dbname=OLD_DB, user=DB_USER, password=DB_PASS,
|
|
options="-c client_encoding=UTF8")
|
|
dst = psycopg2.connect(host=DB_HOST, port=DB_PORT, dbname=NEW_DB, user=DB_USER, password=DB_PASS,
|
|
options="-c client_encoding=UTF8")
|
|
|
|
mismatched = []
|
|
for old_s, new_s in SCHEMA_MAP.items():
|
|
with src.cursor() as cur:
|
|
cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = %s ORDER BY tablename", (old_s,))
|
|
tables = [r[0] for r in cur.fetchall()]
|
|
for t in tables:
|
|
s_cnt = count_rows(src, old_s, t)
|
|
dst_cols = get_columns(dst, new_s, t)
|
|
if not dst_cols:
|
|
continue
|
|
d_cnt = count_rows(dst, new_s, t)
|
|
if s_cnt > 0 and d_cnt != s_cnt:
|
|
mismatched.append((old_s, new_s, t, s_cnt, d_cnt))
|
|
|
|
if not mismatched:
|
|
print("所有表行数一致,无需修复。")
|
|
# 继续检查索引和 ANALYZE
|
|
else:
|
|
print(f"发现 {len(mismatched)} 个不一致表:")
|
|
for old_s, new_s, t, s_cnt, d_cnt in mismatched:
|
|
print(f" {old_s}.{t}: 源={s_cnt} 目标={d_cnt}")
|
|
|
|
for old_s, new_s, t, s_cnt, d_cnt in mismatched:
|
|
print(f"\n修复 {new_s}.{t} ...")
|
|
src_cols = get_columns(src, old_s, t)
|
|
dst_cols = get_columns(dst, new_s, t)
|
|
common = [c for c in dst_cols if c in src_cols]
|
|
cols_sql = ", ".join(f'"{c}"' for c in common)
|
|
|
|
# TRUNCATE
|
|
with dst.cursor() as cur:
|
|
cur.execute(f'TRUNCATE "{new_s}"."{t}" CASCADE')
|
|
dst.commit()
|
|
print(f" TRUNCATE 完成")
|
|
|
|
# COPY
|
|
buf = io.BytesIO()
|
|
with src.cursor() as cur:
|
|
cur.copy_expert(
|
|
f'COPY (SELECT {cols_sql} FROM "{old_s}"."{t}") TO STDOUT WITH (FORMAT binary)', buf)
|
|
buf.seek(0)
|
|
with dst.cursor() as cur:
|
|
cur.copy_expert(
|
|
f'COPY "{new_s}"."{t}" ({cols_sql}) FROM STDIN WITH (FORMAT binary)', buf)
|
|
dst.commit()
|
|
|
|
final = count_rows(dst, new_s, t)
|
|
status = "OK" if final == s_cnt else "MISMATCH"
|
|
print(f" 导入完成: {final} 行 ({status})")
|
|
|
|
# 迁移索引
|
|
print("\n迁移索引...")
|
|
idx_total = 0
|
|
for old_s, new_s in SCHEMA_MAP.items():
|
|
with src.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT indexname, indexdef FROM pg_indexes
|
|
WHERE schemaname = %s
|
|
AND indexname NOT IN (
|
|
SELECT conname FROM pg_constraint
|
|
WHERE connamespace = (SELECT oid FROM pg_namespace WHERE nspname = %s))
|
|
ORDER BY indexname
|
|
""", (old_s, old_s))
|
|
indexes = cur.fetchall()
|
|
|
|
created = 0
|
|
for idx_name, idx_def in indexes:
|
|
new_def = idx_def.replace(f'"{old_s}"', f'"{new_s}"').replace(f'{old_s}.', f'{new_s}.')
|
|
new_def = new_def.replace("CREATE INDEX", "CREATE INDEX IF NOT EXISTS", 1)
|
|
new_def = new_def.replace("CREATE UNIQUE INDEX", "CREATE UNIQUE INDEX IF NOT EXISTS", 1)
|
|
try:
|
|
with dst.cursor() as cur:
|
|
cur.execute(new_def)
|
|
dst.commit()
|
|
created += 1
|
|
except Exception as e:
|
|
dst.rollback()
|
|
print(f" 索引失败 {idx_name}: {e}")
|
|
idx_total += created
|
|
print(f" {old_s} -> {new_s}: {created}/{len(indexes)} 索引")
|
|
|
|
# ANALYZE
|
|
print("\n执行 ANALYZE...")
|
|
dst.autocommit = True
|
|
with dst.cursor() as cur:
|
|
for new_s in SCHEMA_MAP.values():
|
|
tables = get_columns(dst, new_s, "") # dummy
|
|
cur.execute(f"ANALYZE")
|
|
print("ANALYZE 完成")
|
|
|
|
# 最终验证
|
|
print("\n最终验证:")
|
|
all_ok = True
|
|
for old_s, new_s in SCHEMA_MAP.items():
|
|
with src.cursor() as cur:
|
|
cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = %s ORDER BY tablename", (old_s,))
|
|
tables = [r[0] for r in cur.fetchall()]
|
|
for t in tables:
|
|
s_cnt = count_rows(src, old_s, t)
|
|
if s_cnt == 0:
|
|
continue
|
|
dst_cols = get_columns(dst, new_s, t)
|
|
if not dst_cols:
|
|
print(f" MISS {new_s}.{t}: 目标表不存在")
|
|
all_ok = False
|
|
continue
|
|
d_cnt = count_rows(dst, new_s, t)
|
|
if d_cnt != s_cnt:
|
|
print(f" FAIL {new_s}.{t}: 源={s_cnt} 目标={d_cnt}")
|
|
all_ok = False
|
|
|
|
if all_ok:
|
|
print(" 全部一致 OK")
|
|
|
|
src.close()
|
|
dst.close()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|