824 lines
26 KiB
Python
824 lines
26 KiB
Python
#!/usr/bin/env python3
|
||
"""DDL 与数据库实际表结构对比脚本。
|
||
|
||
# AI_CHANGELOG [2026-02-13] 修复列名以 UNIQUE/CHECK 开头被误判为约束行的 bug;新增 CREATE VIEW 解析支持(视图仅检查存在性)
|
||
|
||
解析 database/schema_*.sql 中的 CREATE TABLE 语句,
|
||
查询 information_schema.columns 获取数据库实际结构,
|
||
逐表逐字段对比并输出差异报告。
|
||
|
||
用法:
|
||
python scripts/compare_ddl_db.py --pg-dsn "postgresql://..." --schema ods --ddl-path database/schema_ODS_doc.sql
|
||
python scripts/compare_ddl_db.py --schema dwd --ddl-path database/schema_dwd_doc.sql # 从 .env 读取 PG_DSN
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import os
|
||
import re
|
||
import sys
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
|
||
class DiffKind(str, Enum):
|
||
"""差异分类枚举。"""
|
||
MISSING_TABLE = "MISSING_TABLE" # DDL 缺表(数据库有,DDL 没有)
|
||
EXTRA_TABLE = "EXTRA_TABLE" # DDL 多表(DDL 有,数据库没有)
|
||
MISSING_COLUMN = "MISSING_COLUMN" # DDL 缺字段
|
||
EXTRA_COLUMN = "EXTRA_COLUMN" # DDL 多字段
|
||
TYPE_MISMATCH = "TYPE_MISMATCH" # 字段类型不一致
|
||
NULLABLE_MISMATCH = "NULLABLE_MISMATCH" # 可空约束不一致
|
||
|
||
|
||
@dataclass
|
||
class SchemaDiff:
|
||
"""单条差异记录。"""
|
||
kind: DiffKind
|
||
table: str
|
||
column: Optional[str] = None
|
||
ddl_value: Optional[str] = None
|
||
db_value: Optional[str] = None
|
||
|
||
def __str__(self) -> str:
|
||
parts = [f"[{self.kind.value}] {self.table}"]
|
||
if self.column:
|
||
parts.append(f".{self.column}")
|
||
if self.ddl_value is not None or self.db_value is not None:
|
||
parts.append(f" DDL={self.ddl_value} DB={self.db_value}")
|
||
return "".join(parts)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DDL 列定义
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class ColumnDef:
|
||
"""从 DDL 解析出的单个字段定义。"""
|
||
name: str
|
||
data_type: str # 标准化后的类型字符串
|
||
nullable: bool = True
|
||
is_pk: bool = False
|
||
default: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class TableDef:
|
||
"""从 DDL 解析出的单张表定义。"""
|
||
name: str # 不含 schema 前缀的表名(小写)
|
||
columns: dict[str, ColumnDef] = field(default_factory=dict)
|
||
pk_columns: list[str] = field(default_factory=list)
|
||
is_view: bool = False # 视图标记,跳过列级对比
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 类型标准化:将 DDL 类型和 information_schema 类型映射到统一表示
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# PostgreSQL information_schema.data_type → 简写映射
|
||
_PG_TYPE_MAP: dict[str, str] = {
|
||
"bigint": "bigint",
|
||
"integer": "integer",
|
||
"smallint": "smallint",
|
||
"boolean": "boolean",
|
||
"text": "text",
|
||
"jsonb": "jsonb",
|
||
"json": "json",
|
||
"date": "date",
|
||
"bytea": "bytea",
|
||
"double precision": "double precision",
|
||
"real": "real",
|
||
"uuid": "uuid",
|
||
"timestamp without time zone": "timestamp",
|
||
"timestamp with time zone": "timestamptz",
|
||
"time without time zone": "time",
|
||
"time with time zone": "timetz",
|
||
"character varying": "varchar",
|
||
"character": "char",
|
||
"ARRAY": "array",
|
||
"USER-DEFINED": "user-defined",
|
||
}
|
||
|
||
|
||
def normalize_type(raw: str) -> str:
|
||
"""将 DDL 或 information_schema 中的类型字符串标准化为可比较的形式。
|
||
|
||
规则:
|
||
- 全部小写
|
||
- BIGINT / INT8 → bigint
|
||
- INTEGER / INT / INT4 → integer
|
||
- SMALLINT / INT2 → smallint
|
||
- BOOLEAN / BOOL → boolean
|
||
- VARCHAR(n) / CHARACTER VARYING(n) → varchar(n)
|
||
- CHAR(n) / CHARACTER(n) → char(n)
|
||
- NUMERIC(p,s) / DECIMAL(p,s) → numeric(p,s)
|
||
- SERIAL → integer(serial 本质是 integer + sequence)
|
||
- BIGSERIAL → bigint
|
||
- TIMESTAMP → timestamp
|
||
- TIMESTAMPTZ / TIMESTAMP WITH TIME ZONE → timestamptz
|
||
- TEXT → text
|
||
- JSONB → jsonb
|
||
"""
|
||
t = raw.strip().lower()
|
||
|
||
# 去掉多余空格
|
||
t = re.sub(r"\s+", " ", t)
|
||
|
||
# serial 家族 → 底层整数类型
|
||
if t == "bigserial":
|
||
return "bigint"
|
||
if t in ("serial", "serial4"):
|
||
return "integer"
|
||
if t == "smallserial":
|
||
return "smallint"
|
||
|
||
# 带精度的 numeric / decimal
|
||
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\s*,\s*(\d+)\)", t)
|
||
if m:
|
||
return f"numeric({m.group(1)},{m.group(2)})"
|
||
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\)", t)
|
||
if m:
|
||
return f"numeric({m.group(1)})"
|
||
if t in ("numeric", "decimal"):
|
||
return "numeric"
|
||
|
||
# varchar / character varying
|
||
m = re.match(r"(?:varchar|character varying)\s*\((\d+)\)", t)
|
||
if m:
|
||
return f"varchar({m.group(1)})"
|
||
if t in ("varchar", "character varying"):
|
||
return "varchar"
|
||
|
||
# char / character
|
||
m = re.match(r"(?:char|character)\s*\((\d+)\)", t)
|
||
if m:
|
||
return f"char({m.group(1)})"
|
||
if t in ("char", "character"):
|
||
return "char(1)"
|
||
|
||
# timestamp 家族
|
||
if t in ("timestamptz", "timestamp with time zone"):
|
||
return "timestamptz"
|
||
if t in ("timestamp", "timestamp without time zone"):
|
||
return "timestamp"
|
||
|
||
# 整数别名
|
||
if t in ("int8", "bigint"):
|
||
return "bigint"
|
||
if t in ("int", "int4", "integer"):
|
||
return "integer"
|
||
if t in ("int2", "smallint"):
|
||
return "smallint"
|
||
|
||
# 布尔
|
||
if t in ("bool", "boolean"):
|
||
return "boolean"
|
||
|
||
# information_schema 映射
|
||
if t in _PG_TYPE_MAP:
|
||
return _PG_TYPE_MAP[t]
|
||
|
||
return t
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DDL 解析器
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
|
||
_CREATE_TABLE_RE = re.compile(
|
||
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
|
||
r"(?:(\w+)\.)?(\w+)\s*\(",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# 匹配 DROP TABLE [IF EXISTS] [schema.]table_name [CASCADE];
|
||
_DROP_TABLE_RE = re.compile(
|
||
r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:\w+\.)?(\w+)",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
# 匹配 CREATE [OR REPLACE] VIEW [schema.]view_name AS SELECT ...
|
||
_CREATE_VIEW_RE = re.compile(
|
||
r"CREATE\s+(?:OR\s+REPLACE\s+)?VIEW\s+"
|
||
r"(?:(\w+)\.)?(\w+)\s+AS\s+",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
|
||
def _strip_sql_comments(sql: str) -> str:
|
||
"""移除 SQL 单行注释(-- ...)和块注释(/* ... */)。"""
|
||
# 块注释
|
||
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
|
||
# 单行注释
|
||
sql = re.sub(r"--[^\n]*", "", sql)
|
||
return sql
|
||
|
||
|
||
def _find_matching_paren(text: str, start: int) -> int:
|
||
"""从 start 位置(应为 '(')开始,找到匹配的 ')' 位置。
|
||
|
||
处理嵌套括号和字符串字面量中的括号。
|
||
"""
|
||
depth = 0
|
||
in_string = False
|
||
string_char = ""
|
||
i = start
|
||
while i < len(text):
|
||
ch = text[i]
|
||
if in_string:
|
||
if ch == string_char:
|
||
# 检查转义
|
||
if i + 1 < len(text) and text[i + 1] == string_char:
|
||
i += 2
|
||
continue
|
||
in_string = False
|
||
else:
|
||
if ch in ("'", '"'):
|
||
in_string = True
|
||
string_char = ch
|
||
elif ch == "(":
|
||
depth += 1
|
||
elif ch == ")":
|
||
depth -= 1
|
||
if depth == 0:
|
||
return i
|
||
i += 1
|
||
return -1
|
||
|
||
|
||
def _parse_column_line(line: str) -> Optional[ColumnDef]:
|
||
"""解析单行字段定义,返回 ColumnDef 或 None(如果是约束行)。"""
|
||
line = line.strip().rstrip(",")
|
||
if not line:
|
||
return None
|
||
|
||
upper = line.upper()
|
||
# 跳过表级约束行
|
||
# 注意:需要区分约束行(如 "UNIQUE (...)")和以约束关键字开头的列名
|
||
# (如 "unique_customers INTEGER"、"check_status INT")
|
||
# 约束行的关键字后面紧跟空格+左括号或直接左括号,而列名后面跟下划线或字母
|
||
if re.match(
|
||
r"(?:PRIMARY\s+KEY|UNIQUE|CHECK|FOREIGN\s+KEY|EXCLUDE)"
|
||
r"(?:\s*\(|\s+(?![\w]))",
|
||
upper,
|
||
) or upper.startswith("CONSTRAINT"):
|
||
return None
|
||
|
||
# 字段名 类型 [约束...]
|
||
# 字段名可能被双引号包裹
|
||
m = re.match(r'(?:"([^"]+)"|(\w+))\s+(.+)', line)
|
||
if not m:
|
||
return None
|
||
|
||
col_name = (m.group(1) or m.group(2)).lower()
|
||
rest = m.group(3).strip()
|
||
|
||
# 提取类型:取到第一个(位置最靠前的)已知约束关键字或行尾
|
||
# 类型可能包含括号,如 NUMERIC(18,2)、VARCHAR(50)
|
||
type_end_keywords = [
|
||
"NOT NULL", "NULL", "DEFAULT", "PRIMARY KEY", "UNIQUE",
|
||
"REFERENCES", "CHECK", "CONSTRAINT", "GENERATED",
|
||
]
|
||
type_str = rest
|
||
constraint_part = ""
|
||
# 找所有关键字中位置最靠前的
|
||
best_idx = len(rest)
|
||
for kw in type_end_keywords:
|
||
idx = rest.upper().find(kw)
|
||
if idx > 0 and idx < best_idx:
|
||
candidate = rest[:idx].strip()
|
||
if candidate:
|
||
best_idx = idx
|
||
if best_idx < len(rest):
|
||
type_str = rest[:best_idx].strip()
|
||
constraint_part = rest[best_idx:]
|
||
|
||
# 去掉类型末尾的逗号
|
||
type_str = type_str.rstrip(",").strip()
|
||
|
||
nullable = True
|
||
if "NOT NULL" in constraint_part.upper():
|
||
nullable = False
|
||
|
||
is_pk = "PRIMARY KEY" in constraint_part.upper()
|
||
|
||
# 提取 DEFAULT 值
|
||
default_val = None
|
||
dm = re.search(r"DEFAULT\s+(.+?)(?:\s+(?:NOT\s+NULL|NULL|PRIMARY|UNIQUE|REFERENCES|CHECK|CONSTRAINT|,|$))",
|
||
constraint_part, re.IGNORECASE)
|
||
if dm:
|
||
default_val = dm.group(1).strip().rstrip(",")
|
||
|
||
return ColumnDef(
|
||
name=col_name,
|
||
data_type=normalize_type(type_str),
|
||
nullable=nullable,
|
||
is_pk=is_pk,
|
||
default=default_val,
|
||
)
|
||
|
||
|
||
def _extract_pk_from_body(body: str) -> list[str]:
|
||
"""从 CREATE TABLE 体中提取表级 PRIMARY KEY 约束的列名列表。"""
|
||
# PRIMARY KEY (col1, col2, ...)
|
||
# 也可能是 CONSTRAINT xxx PRIMARY KEY (col1, col2)
|
||
m = re.search(r"PRIMARY\s+KEY\s*\(([^)]+)\)", body, re.IGNORECASE)
|
||
if not m:
|
||
return []
|
||
cols_str = m.group(1)
|
||
return [c.strip().strip('"').lower() for c in cols_str.split(",")]
|
||
|
||
|
||
def parse_ddl(sql_text: str, target_schema: Optional[str] = None) -> dict[str, TableDef]:
|
||
"""解析 DDL 文本,提取所有 CREATE TABLE 定义。
|
||
|
||
Args:
|
||
sql_text: 完整的 SQL DDL 文本
|
||
target_schema: 如果指定,只保留该 schema 下的表(或无 schema 前缀的表)
|
||
|
||
Returns:
|
||
{表名(小写): TableDef} 字典
|
||
"""
|
||
# 先收集被 DROP 的表名,后续 CREATE 会覆盖
|
||
cleaned = _strip_sql_comments(sql_text)
|
||
|
||
tables: dict[str, TableDef] = {}
|
||
|
||
# 逐个匹配 CREATE TABLE
|
||
for m in _CREATE_TABLE_RE.finditer(cleaned):
|
||
schema_part = m.group(1)
|
||
table_name = m.group(2).lower()
|
||
|
||
# schema 过滤
|
||
if target_schema:
|
||
ts = target_schema.lower()
|
||
if schema_part and schema_part.lower() != ts:
|
||
continue
|
||
# 无 schema 前缀的表也接受(DWD DDL 中 SET search_path 后不带前缀)
|
||
|
||
# 找到 CREATE TABLE ... ( 的左括号位置
|
||
paren_start = m.end() - 1 # m.end() 指向 '(' 后一位
|
||
paren_end = _find_matching_paren(cleaned, paren_start)
|
||
if paren_end < 0:
|
||
continue
|
||
|
||
body = cleaned[paren_start + 1: paren_end]
|
||
|
||
# 按行解析字段
|
||
table_def = TableDef(name=table_name)
|
||
|
||
# 提取表级 PRIMARY KEY
|
||
pk_cols = _extract_pk_from_body(body)
|
||
|
||
# 逐行解析
|
||
for raw_line in body.split("\n"):
|
||
col = _parse_column_line(raw_line)
|
||
if col:
|
||
table_def.columns[col.name] = col
|
||
|
||
# 合并表级 PK 信息
|
||
if pk_cols:
|
||
table_def.pk_columns = pk_cols
|
||
for pk_col in pk_cols:
|
||
if pk_col in table_def.columns:
|
||
table_def.columns[pk_col].is_pk = True
|
||
# PK 隐含 NOT NULL
|
||
table_def.columns[pk_col].nullable = False
|
||
|
||
# 合并内联 PK
|
||
inline_pk = [c.name for c in table_def.columns.values() if c.is_pk]
|
||
if inline_pk and not table_def.pk_columns:
|
||
table_def.pk_columns = inline_pk
|
||
for pk_col in inline_pk:
|
||
table_def.columns[pk_col].nullable = False
|
||
|
||
tables[table_name] = table_def
|
||
|
||
# 解析 CREATE VIEW,仅标记视图存在(列信息由数据库侧提供)
|
||
for m in _CREATE_VIEW_RE.finditer(cleaned):
|
||
schema_part = m.group(1)
|
||
view_name = m.group(2).lower()
|
||
|
||
if target_schema:
|
||
ts = target_schema.lower()
|
||
if schema_part and schema_part.lower() != ts:
|
||
continue
|
||
|
||
if view_name not in tables:
|
||
# 视图仅标记存在,不解析列(列由底层表决定)
|
||
tables[view_name] = TableDef(name=view_name)
|
||
# 标记为视图,跳过列级对比
|
||
tables[view_name].is_view = True
|
||
|
||
return tables
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 数据库 schema 读取
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class DbColumnInfo:
|
||
"""从 information_schema 查询到的字段信息。"""
|
||
name: str
|
||
data_type: str # 标准化后
|
||
nullable: bool
|
||
is_pk: bool = False
|
||
|
||
|
||
def fetch_db_schema(pg_dsn: str, schema_name: str) -> dict[str, TableDef]:
|
||
"""从数据库 information_schema 查询指定 schema 的所有表和字段。
|
||
|
||
Returns:
|
||
{表名(小写): TableDef} 字典
|
||
"""
|
||
import psycopg2
|
||
|
||
conn = psycopg2.connect(pg_dsn)
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 检查 schema 是否存在
|
||
cur.execute(
|
||
"SELECT 1 FROM information_schema.schemata WHERE schema_name = %s",
|
||
(schema_name,),
|
||
)
|
||
if not cur.fetchone():
|
||
print(f"⚠ schema '{schema_name}' 在数据库中不存在,跳过", file=sys.stderr)
|
||
return {}
|
||
|
||
# 查询所有列信息
|
||
cur.execute("""
|
||
SELECT
|
||
c.table_name,
|
||
c.column_name,
|
||
c.data_type,
|
||
c.is_nullable,
|
||
c.character_maximum_length,
|
||
c.numeric_precision,
|
||
c.numeric_scale,
|
||
c.udt_name
|
||
FROM information_schema.columns c
|
||
WHERE c.table_schema = %s
|
||
ORDER BY c.table_name, c.ordinal_position
|
||
""", (schema_name,))
|
||
|
||
rows = cur.fetchall()
|
||
|
||
# 查询主键信息
|
||
cur.execute("""
|
||
SELECT
|
||
tc.table_name,
|
||
kcu.column_name
|
||
FROM information_schema.table_constraints tc
|
||
JOIN information_schema.key_column_usage kcu
|
||
ON tc.constraint_name = kcu.constraint_name
|
||
AND tc.table_schema = kcu.table_schema
|
||
WHERE tc.table_schema = %s
|
||
AND tc.constraint_type = 'PRIMARY KEY'
|
||
ORDER BY tc.table_name, kcu.ordinal_position
|
||
""", (schema_name,))
|
||
|
||
pk_rows = cur.fetchall()
|
||
finally:
|
||
conn.close()
|
||
|
||
# 构建 PK 映射: {table_name: [col1, col2, ...]}
|
||
pk_map: dict[str, list[str]] = {}
|
||
for tbl, col in pk_rows:
|
||
pk_map.setdefault(tbl.lower(), []).append(col.lower())
|
||
|
||
# 构建 TableDef
|
||
tables: dict[str, TableDef] = {}
|
||
for tbl, col_name, data_type, is_nullable, char_max_len, num_prec, num_scale, udt_name in rows:
|
||
tbl_lower = tbl.lower()
|
||
col_lower = col_name.lower()
|
||
|
||
if tbl_lower not in tables:
|
||
tables[tbl_lower] = TableDef(
|
||
name=tbl_lower,
|
||
pk_columns=pk_map.get(tbl_lower, []),
|
||
)
|
||
|
||
# 构建精确类型字符串
|
||
type_str = _build_db_type_string(data_type, char_max_len, num_prec, num_scale, udt_name)
|
||
|
||
is_pk = col_lower in pk_map.get(tbl_lower, [])
|
||
nullable = is_nullable == "YES"
|
||
|
||
tables[tbl_lower].columns[col_lower] = ColumnDef(
|
||
name=col_lower,
|
||
data_type=normalize_type(type_str),
|
||
nullable=nullable,
|
||
is_pk=is_pk,
|
||
)
|
||
|
||
return tables
|
||
|
||
|
||
def _build_db_type_string(
|
||
data_type: str,
|
||
char_max_len: Optional[int],
|
||
num_prec: Optional[int],
|
||
num_scale: Optional[int],
|
||
udt_name: str,
|
||
) -> str:
|
||
"""根据 information_schema 字段构建可比较的类型字符串。"""
|
||
dt = data_type.lower()
|
||
|
||
# character varying → varchar(n)
|
||
if dt == "character varying":
|
||
if char_max_len:
|
||
return f"varchar({char_max_len})"
|
||
return "varchar"
|
||
|
||
# character → char(n)
|
||
if dt == "character":
|
||
if char_max_len:
|
||
return f"char({char_max_len})"
|
||
return "char(1)"
|
||
|
||
# numeric → numeric(p,s)
|
||
if dt == "numeric":
|
||
if num_prec is not None and num_scale is not None:
|
||
return f"numeric({num_prec},{num_scale})"
|
||
if num_prec is not None:
|
||
return f"numeric({num_prec})"
|
||
return "numeric"
|
||
|
||
# USER-DEFINED → 使用 udt_name(如 jsonb, geometry 等)
|
||
if dt == "user-defined":
|
||
return udt_name.lower()
|
||
|
||
# ARRAY → 使用 udt_name 去掉前缀 _
|
||
if dt == "array":
|
||
base = udt_name.lstrip("_").lower()
|
||
return f"{base}[]"
|
||
|
||
return dt
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 对比逻辑
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def compare_tables(
|
||
ddl_tables: dict[str, TableDef],
|
||
db_tables: dict[str, TableDef],
|
||
) -> list[SchemaDiff]:
|
||
"""对比 DDL 定义与数据库实际结构,返回差异列表。
|
||
|
||
差异分类:
|
||
- MISSING_TABLE: 数据库有但 DDL 没有
|
||
- EXTRA_TABLE: DDL 有但数据库没有
|
||
- MISSING_COLUMN: 数据库有但 DDL 没有的字段
|
||
- EXTRA_COLUMN: DDL 有但数据库没有的字段
|
||
- TYPE_MISMATCH: 字段类型不一致
|
||
- NULLABLE_MISMATCH: 可空约束不一致
|
||
"""
|
||
diffs: list[SchemaDiff] = []
|
||
|
||
all_tables = sorted(set(ddl_tables.keys()) | set(db_tables.keys()))
|
||
|
||
for tbl in all_tables:
|
||
in_ddl = tbl in ddl_tables
|
||
in_db = tbl in db_tables
|
||
|
||
if in_db and not in_ddl:
|
||
diffs.append(SchemaDiff(kind=DiffKind.MISSING_TABLE, table=tbl))
|
||
continue
|
||
|
||
if in_ddl and not in_db:
|
||
diffs.append(SchemaDiff(kind=DiffKind.EXTRA_TABLE, table=tbl))
|
||
continue
|
||
|
||
# 两边都有,逐字段对比
|
||
# 视图仅检查存在性,跳过列级对比
|
||
ddl_def = ddl_tables[tbl]
|
||
if getattr(ddl_def, 'is_view', False):
|
||
continue
|
||
|
||
ddl_cols = ddl_def.columns
|
||
db_cols = db_tables[tbl].columns
|
||
all_cols = sorted(set(ddl_cols.keys()) | set(db_cols.keys()))
|
||
|
||
for col in all_cols:
|
||
col_in_ddl = col in ddl_cols
|
||
col_in_db = col in db_cols
|
||
|
||
if col_in_db and not col_in_ddl:
|
||
diffs.append(SchemaDiff(
|
||
kind=DiffKind.MISSING_COLUMN,
|
||
table=tbl,
|
||
column=col,
|
||
db_value=db_cols[col].data_type,
|
||
))
|
||
continue
|
||
|
||
if col_in_ddl and not col_in_db:
|
||
diffs.append(SchemaDiff(
|
||
kind=DiffKind.EXTRA_COLUMN,
|
||
table=tbl,
|
||
column=col,
|
||
ddl_value=ddl_cols[col].data_type,
|
||
))
|
||
continue
|
||
|
||
# 两边都有,比较类型
|
||
ddl_type = ddl_cols[col].data_type
|
||
db_type = db_cols[col].data_type
|
||
# 视图列从 DDL 解析时类型为 unknown,跳过类型比较
|
||
if ddl_type != db_type and ddl_type != "unknown":
|
||
diffs.append(SchemaDiff(
|
||
kind=DiffKind.TYPE_MISMATCH,
|
||
table=tbl,
|
||
column=col,
|
||
ddl_value=ddl_type,
|
||
db_value=db_type,
|
||
))
|
||
|
||
# 比较可空性(视图列跳过)
|
||
ddl_nullable = ddl_cols[col].nullable
|
||
db_nullable = db_cols[col].nullable
|
||
if ddl_nullable != db_nullable and ddl_type != "unknown":
|
||
diffs.append(SchemaDiff(
|
||
kind=DiffKind.NULLABLE_MISMATCH,
|
||
table=tbl,
|
||
column=col,
|
||
ddl_value="NULL" if ddl_nullable else "NOT NULL",
|
||
db_value="NULL" if db_nullable else "NOT NULL",
|
||
))
|
||
|
||
return diffs
|
||
|
||
|
||
def compare_schema(ddl_path: str, schema_name: str, pg_dsn: str) -> list[SchemaDiff]:
|
||
"""对比 DDL 文件与数据库 schema 的完整流程。
|
||
|
||
Args:
|
||
ddl_path: DDL 文件路径
|
||
schema_name: 数据库 schema 名称
|
||
pg_dsn: PostgreSQL 连接字符串
|
||
|
||
Returns:
|
||
差异列表
|
||
"""
|
||
path = Path(ddl_path)
|
||
if not path.exists():
|
||
print(f"✗ DDL 文件不存在: {ddl_path}", file=sys.stderr)
|
||
return []
|
||
|
||
sql_text = path.read_text(encoding="utf-8")
|
||
ddl_tables = parse_ddl(sql_text, target_schema=schema_name)
|
||
|
||
if not ddl_tables:
|
||
print(f"⚠ DDL 文件中未解析到任何表: {ddl_path}", file=sys.stderr)
|
||
|
||
db_tables = fetch_db_schema(pg_dsn, schema_name)
|
||
|
||
return compare_tables(ddl_tables, db_tables)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 报告输出
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def print_report(diffs: list[SchemaDiff], schema_name: str, ddl_path: str) -> None:
|
||
"""按表分组输出差异报告到控制台。"""
|
||
if not diffs:
|
||
print(f"\n✓ {schema_name} ({ddl_path}): 无差异")
|
||
return
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f" 差异报告: {schema_name} ← {ddl_path}")
|
||
print(f" 共 {len(diffs)} 项差异")
|
||
print(f"{'='*60}")
|
||
|
||
# 按表分组
|
||
by_table: dict[str, list[SchemaDiff]] = {}
|
||
for d in diffs:
|
||
by_table.setdefault(d.table, []).append(d)
|
||
|
||
for tbl in sorted(by_table.keys()):
|
||
items = by_table[tbl]
|
||
print(f"\n ▸ {tbl}")
|
||
for d in items:
|
||
icon = {
|
||
DiffKind.MISSING_TABLE: "🔴 DDL 缺表",
|
||
DiffKind.EXTRA_TABLE: "🟡 DDL 多表",
|
||
DiffKind.MISSING_COLUMN: "🔴 DDL 缺字段",
|
||
DiffKind.EXTRA_COLUMN: "🟡 DDL 多字段",
|
||
DiffKind.TYPE_MISMATCH: "🟠 类型不一致",
|
||
DiffKind.NULLABLE_MISMATCH: "🔵 可空不一致",
|
||
}.get(d.kind, d.kind.value)
|
||
|
||
if d.column:
|
||
detail = f" {icon}: {d.column}"
|
||
else:
|
||
detail = f" {icon}"
|
||
|
||
if d.ddl_value is not None or d.db_value is not None:
|
||
detail += f" (DDL={d.ddl_value}, DB={d.db_value})"
|
||
print(detail)
|
||
|
||
print()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CLI 入口
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# 预定义的 schema → DDL 文件映射
|
||
# CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构
|
||
DEFAULT_SCHEMA_MAP: dict[str, str] = {
|
||
"ods": "database/schema_ODS_doc.sql",
|
||
"dwd": "database/schema_dwd_doc.sql",
|
||
"dws": "database/schema_dws.sql",
|
||
"meta": "database/schema_etl_admin.sql",
|
||
}
|
||
|
||
|
||
def main(argv: list[str] | None = None) -> int:
|
||
parser = argparse.ArgumentParser(
|
||
description="对比 DDL 文件与数据库实际表结构",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
# 对比单个 schema
|
||
python scripts/compare_ddl_db.py --schema ods --ddl-path database/schema_ODS_doc.sql
|
||
|
||
# 对比所有预定义 schema(从 .env 读取 PG_DSN)
|
||
python scripts/compare_ddl_db.py --all
|
||
|
||
# 指定连接字符串
|
||
python scripts/compare_ddl_db.py --all --pg-dsn "postgresql://user:pass@host/db"
|
||
""",
|
||
)
|
||
parser.add_argument("--pg-dsn", help="PostgreSQL 连接字符串(默认从 PG_DSN 环境变量或 .env 读取)")
|
||
parser.add_argument("--schema", help="要对比的 schema 名称")
|
||
parser.add_argument("--ddl-path", help="DDL 文件路径")
|
||
parser.add_argument("--all", action="store_true", help="对比所有预定义 schema")
|
||
|
||
args = parser.parse_args(argv)
|
||
|
||
# 加载 .env
|
||
try:
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
except ImportError:
|
||
pass
|
||
|
||
pg_dsn = args.pg_dsn or os.environ.get("PG_DSN")
|
||
if not pg_dsn:
|
||
print("✗ 未提供 PG_DSN,请通过 --pg-dsn 参数或 PG_DSN 环境变量指定", file=sys.stderr)
|
||
return 1
|
||
|
||
# 确定要对比的 schema 列表
|
||
pairs: list[tuple[str, str]] = []
|
||
if args.all:
|
||
for schema, ddl in DEFAULT_SCHEMA_MAP.items():
|
||
pairs.append((schema, ddl))
|
||
elif args.schema and args.ddl_path:
|
||
pairs.append((args.schema, args.ddl_path))
|
||
elif args.schema:
|
||
# 尝试从预定义映射中查找
|
||
ddl = DEFAULT_SCHEMA_MAP.get(args.schema)
|
||
if ddl:
|
||
pairs.append((args.schema, ddl))
|
||
else:
|
||
print(f"✗ 未知 schema '{args.schema}',请通过 --ddl-path 指定 DDL 文件", file=sys.stderr)
|
||
return 1
|
||
else:
|
||
parser.print_help()
|
||
return 1
|
||
|
||
total_diffs = 0
|
||
for schema_name, ddl_path in pairs:
|
||
if not Path(ddl_path).exists():
|
||
print(f"⚠ DDL 文件不存在,跳过: {ddl_path}", file=sys.stderr)
|
||
continue
|
||
|
||
try:
|
||
diffs = compare_schema(ddl_path, schema_name, pg_dsn)
|
||
except Exception as e:
|
||
print(f"✗ 对比 {schema_name} 时出错: {e}", file=sys.stderr)
|
||
continue
|
||
|
||
print_report(diffs, schema_name, ddl_path)
|
||
total_diffs += len(diffs)
|
||
|
||
if total_diffs > 0:
|
||
print(f"共发现 {total_diffs} 项差异")
|
||
return 1
|
||
|
||
print("所有 schema 对比通过,无差异 ✓")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|