Files
Neo-ZQYY/apps/etl/connectors/feiqiu/scripts/compare_ddl_db.py

824 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""DDL 与数据库实际表结构对比脚本。
# AI_CHANGELOG [2026-02-13] 修复列名以 UNIQUE/CHECK 开头被误判为约束行的 bug新增 CREATE VIEW 解析支持(视图仅检查存在性)
解析 database/schema_*.sql 中的 CREATE TABLE 语句,
查询 information_schema.columns 获取数据库实际结构,
逐表逐字段对比并输出差异报告。
用法:
python scripts/compare_ddl_db.py --pg-dsn "postgresql://..." --schema ods --ddl-path database/schema_ODS_doc.sql
python scripts/compare_ddl_db.py --schema dwd --ddl-path database/schema_dwd_doc.sql # 从 .env 读取 PG_DSN
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional
class DiffKind(str, Enum):
"""差异分类枚举。"""
MISSING_TABLE = "MISSING_TABLE" # DDL 缺表数据库有DDL 没有)
EXTRA_TABLE = "EXTRA_TABLE" # DDL 多表DDL 有,数据库没有)
MISSING_COLUMN = "MISSING_COLUMN" # DDL 缺字段
EXTRA_COLUMN = "EXTRA_COLUMN" # DDL 多字段
TYPE_MISMATCH = "TYPE_MISMATCH" # 字段类型不一致
NULLABLE_MISMATCH = "NULLABLE_MISMATCH" # 可空约束不一致
@dataclass
class SchemaDiff:
"""单条差异记录。"""
kind: DiffKind
table: str
column: Optional[str] = None
ddl_value: Optional[str] = None
db_value: Optional[str] = None
def __str__(self) -> str:
parts = [f"[{self.kind.value}] {self.table}"]
if self.column:
parts.append(f".{self.column}")
if self.ddl_value is not None or self.db_value is not None:
parts.append(f" DDL={self.ddl_value} DB={self.db_value}")
return "".join(parts)
# ---------------------------------------------------------------------------
# DDL 列定义
# ---------------------------------------------------------------------------
@dataclass
class ColumnDef:
"""从 DDL 解析出的单个字段定义。"""
name: str
data_type: str # 标准化后的类型字符串
nullable: bool = True
is_pk: bool = False
default: Optional[str] = None
@dataclass
class TableDef:
"""从 DDL 解析出的单张表定义。"""
name: str # 不含 schema 前缀的表名(小写)
columns: dict[str, ColumnDef] = field(default_factory=dict)
pk_columns: list[str] = field(default_factory=list)
is_view: bool = False # 视图标记,跳过列级对比
# ---------------------------------------------------------------------------
# 类型标准化:将 DDL 类型和 information_schema 类型映射到统一表示
# ---------------------------------------------------------------------------
# PostgreSQL information_schema.data_type → 简写映射
_PG_TYPE_MAP: dict[str, str] = {
"bigint": "bigint",
"integer": "integer",
"smallint": "smallint",
"boolean": "boolean",
"text": "text",
"jsonb": "jsonb",
"json": "json",
"date": "date",
"bytea": "bytea",
"double precision": "double precision",
"real": "real",
"uuid": "uuid",
"timestamp without time zone": "timestamp",
"timestamp with time zone": "timestamptz",
"time without time zone": "time",
"time with time zone": "timetz",
"character varying": "varchar",
"character": "char",
"ARRAY": "array",
"USER-DEFINED": "user-defined",
}
def normalize_type(raw: str) -> str:
"""将 DDL 或 information_schema 中的类型字符串标准化为可比较的形式。
规则:
- 全部小写
- BIGINT / INT8 → bigint
- INTEGER / INT / INT4 → integer
- SMALLINT / INT2 → smallint
- BOOLEAN / BOOL → boolean
- VARCHAR(n) / CHARACTER VARYING(n) → varchar(n)
- CHAR(n) / CHARACTER(n) → char(n)
- NUMERIC(p,s) / DECIMAL(p,s) → numeric(p,s)
- SERIAL → integerserial 本质是 integer + sequence
- BIGSERIAL → bigint
- TIMESTAMP → timestamp
- TIMESTAMPTZ / TIMESTAMP WITH TIME ZONE → timestamptz
- TEXT → text
- JSONB → jsonb
"""
t = raw.strip().lower()
# 去掉多余空格
t = re.sub(r"\s+", " ", t)
# serial 家族 → 底层整数类型
if t == "bigserial":
return "bigint"
if t in ("serial", "serial4"):
return "integer"
if t == "smallserial":
return "smallint"
# 带精度的 numeric / decimal
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\s*,\s*(\d+)\)", t)
if m:
return f"numeric({m.group(1)},{m.group(2)})"
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\)", t)
if m:
return f"numeric({m.group(1)})"
if t in ("numeric", "decimal"):
return "numeric"
# varchar / character varying
m = re.match(r"(?:varchar|character varying)\s*\((\d+)\)", t)
if m:
return f"varchar({m.group(1)})"
if t in ("varchar", "character varying"):
return "varchar"
# char / character
m = re.match(r"(?:char|character)\s*\((\d+)\)", t)
if m:
return f"char({m.group(1)})"
if t in ("char", "character"):
return "char(1)"
# timestamp 家族
if t in ("timestamptz", "timestamp with time zone"):
return "timestamptz"
if t in ("timestamp", "timestamp without time zone"):
return "timestamp"
# 整数别名
if t in ("int8", "bigint"):
return "bigint"
if t in ("int", "int4", "integer"):
return "integer"
if t in ("int2", "smallint"):
return "smallint"
# 布尔
if t in ("bool", "boolean"):
return "boolean"
# information_schema 映射
if t in _PG_TYPE_MAP:
return _PG_TYPE_MAP[t]
return t
# ---------------------------------------------------------------------------
# DDL 解析器
# ---------------------------------------------------------------------------
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
_CREATE_TABLE_RE = re.compile(
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
r"(?:(\w+)\.)?(\w+)\s*\(",
re.IGNORECASE,
)
# 匹配 DROP TABLE [IF EXISTS] [schema.]table_name [CASCADE];
_DROP_TABLE_RE = re.compile(
r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:\w+\.)?(\w+)",
re.IGNORECASE,
)
# 匹配 CREATE [OR REPLACE] VIEW [schema.]view_name AS SELECT ...
_CREATE_VIEW_RE = re.compile(
r"CREATE\s+(?:OR\s+REPLACE\s+)?VIEW\s+"
r"(?:(\w+)\.)?(\w+)\s+AS\s+",
re.IGNORECASE,
)
def _strip_sql_comments(sql: str) -> str:
"""移除 SQL 单行注释(-- ...)和块注释(/* ... */)。"""
# 块注释
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
# 单行注释
sql = re.sub(r"--[^\n]*", "", sql)
return sql
def _find_matching_paren(text: str, start: int) -> int:
"""从 start 位置(应为 '(')开始,找到匹配的 ')' 位置。
处理嵌套括号和字符串字面量中的括号。
"""
depth = 0
in_string = False
string_char = ""
i = start
while i < len(text):
ch = text[i]
if in_string:
if ch == string_char:
# 检查转义
if i + 1 < len(text) and text[i + 1] == string_char:
i += 2
continue
in_string = False
else:
if ch in ("'", '"'):
in_string = True
string_char = ch
elif ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth == 0:
return i
i += 1
return -1
def _parse_column_line(line: str) -> Optional[ColumnDef]:
"""解析单行字段定义,返回 ColumnDef 或 None如果是约束行"""
line = line.strip().rstrip(",")
if not line:
return None
upper = line.upper()
# 跳过表级约束行
# 注意:需要区分约束行(如 "UNIQUE (...)")和以约束关键字开头的列名
# (如 "unique_customers INTEGER"、"check_status INT"
# 约束行的关键字后面紧跟空格+左括号或直接左括号,而列名后面跟下划线或字母
if re.match(
r"(?:PRIMARY\s+KEY|UNIQUE|CHECK|FOREIGN\s+KEY|EXCLUDE)"
r"(?:\s*\(|\s+(?![\w]))",
upper,
) or upper.startswith("CONSTRAINT"):
return None
# 字段名 类型 [约束...]
# 字段名可能被双引号包裹
m = re.match(r'(?:"([^"]+)"|(\w+))\s+(.+)', line)
if not m:
return None
col_name = (m.group(1) or m.group(2)).lower()
rest = m.group(3).strip()
# 提取类型:取到第一个(位置最靠前的)已知约束关键字或行尾
# 类型可能包含括号,如 NUMERIC(18,2)、VARCHAR(50)
type_end_keywords = [
"NOT NULL", "NULL", "DEFAULT", "PRIMARY KEY", "UNIQUE",
"REFERENCES", "CHECK", "CONSTRAINT", "GENERATED",
]
type_str = rest
constraint_part = ""
# 找所有关键字中位置最靠前的
best_idx = len(rest)
for kw in type_end_keywords:
idx = rest.upper().find(kw)
if idx > 0 and idx < best_idx:
candidate = rest[:idx].strip()
if candidate:
best_idx = idx
if best_idx < len(rest):
type_str = rest[:best_idx].strip()
constraint_part = rest[best_idx:]
# 去掉类型末尾的逗号
type_str = type_str.rstrip(",").strip()
nullable = True
if "NOT NULL" in constraint_part.upper():
nullable = False
is_pk = "PRIMARY KEY" in constraint_part.upper()
# 提取 DEFAULT 值
default_val = None
dm = re.search(r"DEFAULT\s+(.+?)(?:\s+(?:NOT\s+NULL|NULL|PRIMARY|UNIQUE|REFERENCES|CHECK|CONSTRAINT|,|$))",
constraint_part, re.IGNORECASE)
if dm:
default_val = dm.group(1).strip().rstrip(",")
return ColumnDef(
name=col_name,
data_type=normalize_type(type_str),
nullable=nullable,
is_pk=is_pk,
default=default_val,
)
def _extract_pk_from_body(body: str) -> list[str]:
"""从 CREATE TABLE 体中提取表级 PRIMARY KEY 约束的列名列表。"""
# PRIMARY KEY (col1, col2, ...)
# 也可能是 CONSTRAINT xxx PRIMARY KEY (col1, col2)
m = re.search(r"PRIMARY\s+KEY\s*\(([^)]+)\)", body, re.IGNORECASE)
if not m:
return []
cols_str = m.group(1)
return [c.strip().strip('"').lower() for c in cols_str.split(",")]
def parse_ddl(sql_text: str, target_schema: Optional[str] = None) -> dict[str, TableDef]:
"""解析 DDL 文本,提取所有 CREATE TABLE 定义。
Args:
sql_text: 完整的 SQL DDL 文本
target_schema: 如果指定,只保留该 schema 下的表(或无 schema 前缀的表)
Returns:
{表名(小写): TableDef} 字典
"""
# 先收集被 DROP 的表名,后续 CREATE 会覆盖
cleaned = _strip_sql_comments(sql_text)
tables: dict[str, TableDef] = {}
# 逐个匹配 CREATE TABLE
for m in _CREATE_TABLE_RE.finditer(cleaned):
schema_part = m.group(1)
table_name = m.group(2).lower()
# schema 过滤
if target_schema:
ts = target_schema.lower()
if schema_part and schema_part.lower() != ts:
continue
# 无 schema 前缀的表也接受DWD DDL 中 SET search_path 后不带前缀)
# 找到 CREATE TABLE ... ( 的左括号位置
paren_start = m.end() - 1 # m.end() 指向 '(' 后一位
paren_end = _find_matching_paren(cleaned, paren_start)
if paren_end < 0:
continue
body = cleaned[paren_start + 1: paren_end]
# 按行解析字段
table_def = TableDef(name=table_name)
# 提取表级 PRIMARY KEY
pk_cols = _extract_pk_from_body(body)
# 逐行解析
for raw_line in body.split("\n"):
col = _parse_column_line(raw_line)
if col:
table_def.columns[col.name] = col
# 合并表级 PK 信息
if pk_cols:
table_def.pk_columns = pk_cols
for pk_col in pk_cols:
if pk_col in table_def.columns:
table_def.columns[pk_col].is_pk = True
# PK 隐含 NOT NULL
table_def.columns[pk_col].nullable = False
# 合并内联 PK
inline_pk = [c.name for c in table_def.columns.values() if c.is_pk]
if inline_pk and not table_def.pk_columns:
table_def.pk_columns = inline_pk
for pk_col in inline_pk:
table_def.columns[pk_col].nullable = False
tables[table_name] = table_def
# 解析 CREATE VIEW仅标记视图存在列信息由数据库侧提供
for m in _CREATE_VIEW_RE.finditer(cleaned):
schema_part = m.group(1)
view_name = m.group(2).lower()
if target_schema:
ts = target_schema.lower()
if schema_part and schema_part.lower() != ts:
continue
if view_name not in tables:
# 视图仅标记存在,不解析列(列由底层表决定)
tables[view_name] = TableDef(name=view_name)
# 标记为视图,跳过列级对比
tables[view_name].is_view = True
return tables
# ---------------------------------------------------------------------------
# 数据库 schema 读取
# ---------------------------------------------------------------------------
@dataclass
class DbColumnInfo:
"""从 information_schema 查询到的字段信息。"""
name: str
data_type: str # 标准化后
nullable: bool
is_pk: bool = False
def fetch_db_schema(pg_dsn: str, schema_name: str) -> dict[str, TableDef]:
"""从数据库 information_schema 查询指定 schema 的所有表和字段。
Returns:
{表名(小写): TableDef} 字典
"""
import psycopg2
conn = psycopg2.connect(pg_dsn)
try:
with conn.cursor() as cur:
# 检查 schema 是否存在
cur.execute(
"SELECT 1 FROM information_schema.schemata WHERE schema_name = %s",
(schema_name,),
)
if not cur.fetchone():
print(f"⚠ schema '{schema_name}' 在数据库中不存在,跳过", file=sys.stderr)
return {}
# 查询所有列信息
cur.execute("""
SELECT
c.table_name,
c.column_name,
c.data_type,
c.is_nullable,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
c.udt_name
FROM information_schema.columns c
WHERE c.table_schema = %s
ORDER BY c.table_name, c.ordinal_position
""", (schema_name,))
rows = cur.fetchall()
# 查询主键信息
cur.execute("""
SELECT
tc.table_name,
kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.table_schema = %s
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY tc.table_name, kcu.ordinal_position
""", (schema_name,))
pk_rows = cur.fetchall()
finally:
conn.close()
# 构建 PK 映射: {table_name: [col1, col2, ...]}
pk_map: dict[str, list[str]] = {}
for tbl, col in pk_rows:
pk_map.setdefault(tbl.lower(), []).append(col.lower())
# 构建 TableDef
tables: dict[str, TableDef] = {}
for tbl, col_name, data_type, is_nullable, char_max_len, num_prec, num_scale, udt_name in rows:
tbl_lower = tbl.lower()
col_lower = col_name.lower()
if tbl_lower not in tables:
tables[tbl_lower] = TableDef(
name=tbl_lower,
pk_columns=pk_map.get(tbl_lower, []),
)
# 构建精确类型字符串
type_str = _build_db_type_string(data_type, char_max_len, num_prec, num_scale, udt_name)
is_pk = col_lower in pk_map.get(tbl_lower, [])
nullable = is_nullable == "YES"
tables[tbl_lower].columns[col_lower] = ColumnDef(
name=col_lower,
data_type=normalize_type(type_str),
nullable=nullable,
is_pk=is_pk,
)
return tables
def _build_db_type_string(
data_type: str,
char_max_len: Optional[int],
num_prec: Optional[int],
num_scale: Optional[int],
udt_name: str,
) -> str:
"""根据 information_schema 字段构建可比较的类型字符串。"""
dt = data_type.lower()
# character varying → varchar(n)
if dt == "character varying":
if char_max_len:
return f"varchar({char_max_len})"
return "varchar"
# character → char(n)
if dt == "character":
if char_max_len:
return f"char({char_max_len})"
return "char(1)"
# numeric → numeric(p,s)
if dt == "numeric":
if num_prec is not None and num_scale is not None:
return f"numeric({num_prec},{num_scale})"
if num_prec is not None:
return f"numeric({num_prec})"
return "numeric"
# USER-DEFINED → 使用 udt_name如 jsonb, geometry 等)
if dt == "user-defined":
return udt_name.lower()
# ARRAY → 使用 udt_name 去掉前缀 _
if dt == "array":
base = udt_name.lstrip("_").lower()
return f"{base}[]"
return dt
# ---------------------------------------------------------------------------
# 对比逻辑
# ---------------------------------------------------------------------------
def compare_tables(
ddl_tables: dict[str, TableDef],
db_tables: dict[str, TableDef],
) -> list[SchemaDiff]:
"""对比 DDL 定义与数据库实际结构,返回差异列表。
差异分类:
- MISSING_TABLE: 数据库有但 DDL 没有
- EXTRA_TABLE: DDL 有但数据库没有
- MISSING_COLUMN: 数据库有但 DDL 没有的字段
- EXTRA_COLUMN: DDL 有但数据库没有的字段
- TYPE_MISMATCH: 字段类型不一致
- NULLABLE_MISMATCH: 可空约束不一致
"""
diffs: list[SchemaDiff] = []
all_tables = sorted(set(ddl_tables.keys()) | set(db_tables.keys()))
for tbl in all_tables:
in_ddl = tbl in ddl_tables
in_db = tbl in db_tables
if in_db and not in_ddl:
diffs.append(SchemaDiff(kind=DiffKind.MISSING_TABLE, table=tbl))
continue
if in_ddl and not in_db:
diffs.append(SchemaDiff(kind=DiffKind.EXTRA_TABLE, table=tbl))
continue
# 两边都有,逐字段对比
# 视图仅检查存在性,跳过列级对比
ddl_def = ddl_tables[tbl]
if getattr(ddl_def, 'is_view', False):
continue
ddl_cols = ddl_def.columns
db_cols = db_tables[tbl].columns
all_cols = sorted(set(ddl_cols.keys()) | set(db_cols.keys()))
for col in all_cols:
col_in_ddl = col in ddl_cols
col_in_db = col in db_cols
if col_in_db and not col_in_ddl:
diffs.append(SchemaDiff(
kind=DiffKind.MISSING_COLUMN,
table=tbl,
column=col,
db_value=db_cols[col].data_type,
))
continue
if col_in_ddl and not col_in_db:
diffs.append(SchemaDiff(
kind=DiffKind.EXTRA_COLUMN,
table=tbl,
column=col,
ddl_value=ddl_cols[col].data_type,
))
continue
# 两边都有,比较类型
ddl_type = ddl_cols[col].data_type
db_type = db_cols[col].data_type
# 视图列从 DDL 解析时类型为 unknown跳过类型比较
if ddl_type != db_type and ddl_type != "unknown":
diffs.append(SchemaDiff(
kind=DiffKind.TYPE_MISMATCH,
table=tbl,
column=col,
ddl_value=ddl_type,
db_value=db_type,
))
# 比较可空性(视图列跳过)
ddl_nullable = ddl_cols[col].nullable
db_nullable = db_cols[col].nullable
if ddl_nullable != db_nullable and ddl_type != "unknown":
diffs.append(SchemaDiff(
kind=DiffKind.NULLABLE_MISMATCH,
table=tbl,
column=col,
ddl_value="NULL" if ddl_nullable else "NOT NULL",
db_value="NULL" if db_nullable else "NOT NULL",
))
return diffs
def compare_schema(ddl_path: str, schema_name: str, pg_dsn: str) -> list[SchemaDiff]:
"""对比 DDL 文件与数据库 schema 的完整流程。
Args:
ddl_path: DDL 文件路径
schema_name: 数据库 schema 名称
pg_dsn: PostgreSQL 连接字符串
Returns:
差异列表
"""
path = Path(ddl_path)
if not path.exists():
print(f"✗ DDL 文件不存在: {ddl_path}", file=sys.stderr)
return []
sql_text = path.read_text(encoding="utf-8")
ddl_tables = parse_ddl(sql_text, target_schema=schema_name)
if not ddl_tables:
print(f"⚠ DDL 文件中未解析到任何表: {ddl_path}", file=sys.stderr)
db_tables = fetch_db_schema(pg_dsn, schema_name)
return compare_tables(ddl_tables, db_tables)
# ---------------------------------------------------------------------------
# 报告输出
# ---------------------------------------------------------------------------
def print_report(diffs: list[SchemaDiff], schema_name: str, ddl_path: str) -> None:
"""按表分组输出差异报告到控制台。"""
if not diffs:
print(f"\n{schema_name} ({ddl_path}): 无差异")
return
print(f"\n{'='*60}")
print(f" 差异报告: {schema_name}{ddl_path}")
print(f"{len(diffs)} 项差异")
print(f"{'='*60}")
# 按表分组
by_table: dict[str, list[SchemaDiff]] = {}
for d in diffs:
by_table.setdefault(d.table, []).append(d)
for tbl in sorted(by_table.keys()):
items = by_table[tbl]
print(f"\n{tbl}")
for d in items:
icon = {
DiffKind.MISSING_TABLE: "🔴 DDL 缺表",
DiffKind.EXTRA_TABLE: "🟡 DDL 多表",
DiffKind.MISSING_COLUMN: "🔴 DDL 缺字段",
DiffKind.EXTRA_COLUMN: "🟡 DDL 多字段",
DiffKind.TYPE_MISMATCH: "🟠 类型不一致",
DiffKind.NULLABLE_MISMATCH: "🔵 可空不一致",
}.get(d.kind, d.kind.value)
if d.column:
detail = f" {icon}: {d.column}"
else:
detail = f" {icon}"
if d.ddl_value is not None or d.db_value is not None:
detail += f" (DDL={d.ddl_value}, DB={d.db_value})"
print(detail)
print()
# ---------------------------------------------------------------------------
# CLI 入口
# ---------------------------------------------------------------------------
# 预定义的 schema → DDL 文件映射
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构
DEFAULT_SCHEMA_MAP: dict[str, str] = {
"ods": "database/schema_ODS_doc.sql",
"dwd": "database/schema_dwd_doc.sql",
"dws": "database/schema_dws.sql",
"meta": "database/schema_etl_admin.sql",
}
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="对比 DDL 文件与数据库实际表结构",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 对比单个 schema
python scripts/compare_ddl_db.py --schema ods --ddl-path database/schema_ODS_doc.sql
# 对比所有预定义 schema从 .env 读取 PG_DSN
python scripts/compare_ddl_db.py --all
# 指定连接字符串
python scripts/compare_ddl_db.py --all --pg-dsn "postgresql://user:pass@host/db"
""",
)
parser.add_argument("--pg-dsn", help="PostgreSQL 连接字符串(默认从 PG_DSN 环境变量或 .env 读取)")
parser.add_argument("--schema", help="要对比的 schema 名称")
parser.add_argument("--ddl-path", help="DDL 文件路径")
parser.add_argument("--all", action="store_true", help="对比所有预定义 schema")
args = parser.parse_args(argv)
# 加载 .env
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
pg_dsn = args.pg_dsn or os.environ.get("PG_DSN")
if not pg_dsn:
print("✗ 未提供 PG_DSN请通过 --pg-dsn 参数或 PG_DSN 环境变量指定", file=sys.stderr)
return 1
# 确定要对比的 schema 列表
pairs: list[tuple[str, str]] = []
if args.all:
for schema, ddl in DEFAULT_SCHEMA_MAP.items():
pairs.append((schema, ddl))
elif args.schema and args.ddl_path:
pairs.append((args.schema, args.ddl_path))
elif args.schema:
# 尝试从预定义映射中查找
ddl = DEFAULT_SCHEMA_MAP.get(args.schema)
if ddl:
pairs.append((args.schema, ddl))
else:
print(f"✗ 未知 schema '{args.schema}',请通过 --ddl-path 指定 DDL 文件", file=sys.stderr)
return 1
else:
parser.print_help()
return 1
total_diffs = 0
for schema_name, ddl_path in pairs:
if not Path(ddl_path).exists():
print(f"⚠ DDL 文件不存在,跳过: {ddl_path}", file=sys.stderr)
continue
try:
diffs = compare_schema(ddl_path, schema_name, pg_dsn)
except Exception as e:
print(f"✗ 对比 {schema_name} 时出错: {e}", file=sys.stderr)
continue
print_report(diffs, schema_name, ddl_path)
total_diffs += len(diffs)
if total_diffs > 0:
print(f"共发现 {total_diffs} 项差异")
return 1
print("所有 schema 对比通过,无差异 ✓")
return 0
if __name__ == "__main__":
sys.exit(main())