#!/usr/bin/env python3 """DDL 与数据库实际表结构对比脚本。 # AI_CHANGELOG [2026-02-13] 修复列名以 UNIQUE/CHECK 开头被误判为约束行的 bug;新增 CREATE VIEW 解析支持(视图仅检查存在性) 解析 database/schema_*.sql 中的 CREATE TABLE 语句, 查询 information_schema.columns 获取数据库实际结构, 逐表逐字段对比并输出差异报告。 用法: python scripts/compare_ddl_db.py --pg-dsn "postgresql://..." --schema ods --ddl-path database/schema_ODS_doc.sql python scripts/compare_ddl_db.py --schema dwd --ddl-path database/schema_dwd_doc.sql # 从 .env 读取 PG_DSN """ from __future__ import annotations import argparse import os import re import sys from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import Optional from neozqyy_shared.repo_root import ensure_repo_root ensure_repo_root() class DiffKind(str, Enum): """差异分类枚举。""" MISSING_TABLE = "MISSING_TABLE" # DDL 缺表(数据库有,DDL 没有) EXTRA_TABLE = "EXTRA_TABLE" # DDL 多表(DDL 有,数据库没有) MISSING_COLUMN = "MISSING_COLUMN" # DDL 缺字段 EXTRA_COLUMN = "EXTRA_COLUMN" # DDL 多字段 TYPE_MISMATCH = "TYPE_MISMATCH" # 字段类型不一致 NULLABLE_MISMATCH = "NULLABLE_MISMATCH" # 可空约束不一致 @dataclass class SchemaDiff: """单条差异记录。""" kind: DiffKind table: str column: Optional[str] = None ddl_value: Optional[str] = None db_value: Optional[str] = None def __str__(self) -> str: parts = [f"[{self.kind.value}] {self.table}"] if self.column: parts.append(f".{self.column}") if self.ddl_value is not None or self.db_value is not None: parts.append(f" DDL={self.ddl_value} DB={self.db_value}") return "".join(parts) # --------------------------------------------------------------------------- # DDL 列定义 # --------------------------------------------------------------------------- @dataclass class ColumnDef: """从 DDL 解析出的单个字段定义。""" name: str data_type: str # 标准化后的类型字符串 nullable: bool = True is_pk: bool = False default: Optional[str] = None @dataclass class TableDef: """从 DDL 解析出的单张表定义。""" name: str # 不含 schema 前缀的表名(小写) columns: dict[str, ColumnDef] = field(default_factory=dict) pk_columns: list[str] = field(default_factory=list) is_view: bool = False # 视图标记,跳过列级对比 # --------------------------------------------------------------------------- # 类型标准化:将 DDL 类型和 information_schema 类型映射到统一表示 # --------------------------------------------------------------------------- # PostgreSQL information_schema.data_type → 简写映射 _PG_TYPE_MAP: dict[str, str] = { "bigint": "bigint", "integer": "integer", "smallint": "smallint", "boolean": "boolean", "text": "text", "jsonb": "jsonb", "json": "json", "date": "date", "bytea": "bytea", "double precision": "double precision", "real": "real", "uuid": "uuid", "timestamp without time zone": "timestamp", "timestamp with time zone": "timestamptz", "time without time zone": "time", "time with time zone": "timetz", "character varying": "varchar", "character": "char", "ARRAY": "array", "USER-DEFINED": "user-defined", } def normalize_type(raw: str) -> str: """将 DDL 或 information_schema 中的类型字符串标准化为可比较的形式。 规则: - 全部小写 - BIGINT / INT8 → bigint - INTEGER / INT / INT4 → integer - SMALLINT / INT2 → smallint - BOOLEAN / BOOL → boolean - VARCHAR(n) / CHARACTER VARYING(n) → varchar(n) - CHAR(n) / CHARACTER(n) → char(n) - NUMERIC(p,s) / DECIMAL(p,s) → numeric(p,s) - SERIAL → integer(serial 本质是 integer + sequence) - BIGSERIAL → bigint - TIMESTAMP → timestamp - TIMESTAMPTZ / TIMESTAMP WITH TIME ZONE → timestamptz - TEXT → text - JSONB → jsonb """ t = raw.strip().lower() # 去掉多余空格 t = re.sub(r"\s+", " ", t) # serial 家族 → 底层整数类型 if t == "bigserial": return "bigint" if t in ("serial", "serial4"): return "integer" if t == "smallserial": return "smallint" # 带精度的 numeric / decimal m = re.match(r"(?:numeric|decimal)\s*\((\d+)\s*,\s*(\d+)\)", t) if m: return f"numeric({m.group(1)},{m.group(2)})" m = re.match(r"(?:numeric|decimal)\s*\((\d+)\)", t) if m: return f"numeric({m.group(1)})" if t in ("numeric", "decimal"): return "numeric" # varchar / character varying m = re.match(r"(?:varchar|character varying)\s*\((\d+)\)", t) if m: return f"varchar({m.group(1)})" if t in ("varchar", "character varying"): return "varchar" # char / character m = re.match(r"(?:char|character)\s*\((\d+)\)", t) if m: return f"char({m.group(1)})" if t in ("char", "character"): return "char(1)" # timestamp 家族 if t in ("timestamptz", "timestamp with time zone"): return "timestamptz" if t in ("timestamp", "timestamp without time zone"): return "timestamp" # 整数别名 if t in ("int8", "bigint"): return "bigint" if t in ("int", "int4", "integer"): return "integer" if t in ("int2", "smallint"): return "smallint" # 布尔 if t in ("bool", "boolean"): return "boolean" # information_schema 映射 if t in _PG_TYPE_MAP: return _PG_TYPE_MAP[t] return t # --------------------------------------------------------------------------- # DDL 解析器 # --------------------------------------------------------------------------- # 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name ( _CREATE_TABLE_RE = re.compile( r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?" r"(?:(\w+)\.)?(\w+)\s*\(", re.IGNORECASE, ) # 匹配 DROP TABLE [IF EXISTS] [schema.]table_name [CASCADE]; _DROP_TABLE_RE = re.compile( r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:\w+\.)?(\w+)", re.IGNORECASE, ) # 匹配 CREATE [OR REPLACE] VIEW [schema.]view_name AS SELECT ... _CREATE_VIEW_RE = re.compile( r"CREATE\s+(?:OR\s+REPLACE\s+)?VIEW\s+" r"(?:(\w+)\.)?(\w+)\s+AS\s+", re.IGNORECASE, ) def _strip_sql_comments(sql: str) -> str: """移除 SQL 单行注释(-- ...)和块注释(/* ... */)。""" # 块注释 sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL) # 单行注释 sql = re.sub(r"--[^\n]*", "", sql) return sql def _find_matching_paren(text: str, start: int) -> int: """从 start 位置(应为 '(')开始,找到匹配的 ')' 位置。 处理嵌套括号和字符串字面量中的括号。 """ depth = 0 in_string = False string_char = "" i = start while i < len(text): ch = text[i] if in_string: if ch == string_char: # 检查转义 if i + 1 < len(text) and text[i + 1] == string_char: i += 2 continue in_string = False else: if ch in ("'", '"'): in_string = True string_char = ch elif ch == "(": depth += 1 elif ch == ")": depth -= 1 if depth == 0: return i i += 1 return -1 def _parse_column_line(line: str) -> Optional[ColumnDef]: """解析单行字段定义,返回 ColumnDef 或 None(如果是约束行)。""" line = line.strip().rstrip(",") if not line: return None upper = line.upper() # 跳过表级约束行 # 注意:需要区分约束行(如 "UNIQUE (...)")和以约束关键字开头的列名 # (如 "unique_customers INTEGER"、"check_status INT") # 约束行的关键字后面紧跟空格+左括号或直接左括号,而列名后面跟下划线或字母 if re.match( r"(?:PRIMARY\s+KEY|UNIQUE|CHECK|FOREIGN\s+KEY|EXCLUDE)" r"(?:\s*\(|\s+(?![\w]))", upper, ) or upper.startswith("CONSTRAINT"): return None # 字段名 类型 [约束...] # 字段名可能被双引号包裹 m = re.match(r'(?:"([^"]+)"|(\w+))\s+(.+)', line) if not m: return None col_name = (m.group(1) or m.group(2)).lower() rest = m.group(3).strip() # 提取类型:取到第一个(位置最靠前的)已知约束关键字或行尾 # 类型可能包含括号,如 NUMERIC(18,2)、VARCHAR(50) type_end_keywords = [ "NOT NULL", "NULL", "DEFAULT", "PRIMARY KEY", "UNIQUE", "REFERENCES", "CHECK", "CONSTRAINT", "GENERATED", ] type_str = rest constraint_part = "" # 找所有关键字中位置最靠前的 best_idx = len(rest) for kw in type_end_keywords: idx = rest.upper().find(kw) if idx > 0 and idx < best_idx: candidate = rest[:idx].strip() if candidate: best_idx = idx if best_idx < len(rest): type_str = rest[:best_idx].strip() constraint_part = rest[best_idx:] # 去掉类型末尾的逗号 type_str = type_str.rstrip(",").strip() nullable = True if "NOT NULL" in constraint_part.upper(): nullable = False is_pk = "PRIMARY KEY" in constraint_part.upper() # 提取 DEFAULT 值 default_val = None dm = re.search(r"DEFAULT\s+(.+?)(?:\s+(?:NOT\s+NULL|NULL|PRIMARY|UNIQUE|REFERENCES|CHECK|CONSTRAINT|,|$))", constraint_part, re.IGNORECASE) if dm: default_val = dm.group(1).strip().rstrip(",") return ColumnDef( name=col_name, data_type=normalize_type(type_str), nullable=nullable, is_pk=is_pk, default=default_val, ) def _extract_pk_from_body(body: str) -> list[str]: """从 CREATE TABLE 体中提取表级 PRIMARY KEY 约束的列名列表。""" # PRIMARY KEY (col1, col2, ...) # 也可能是 CONSTRAINT xxx PRIMARY KEY (col1, col2) m = re.search(r"PRIMARY\s+KEY\s*\(([^)]+)\)", body, re.IGNORECASE) if not m: return [] cols_str = m.group(1) return [c.strip().strip('"').lower() for c in cols_str.split(",")] def parse_ddl(sql_text: str, target_schema: Optional[str] = None) -> dict[str, TableDef]: """解析 DDL 文本,提取所有 CREATE TABLE 定义。 Args: sql_text: 完整的 SQL DDL 文本 target_schema: 如果指定,只保留该 schema 下的表(或无 schema 前缀的表) Returns: {表名(小写): TableDef} 字典 """ # 先收集被 DROP 的表名,后续 CREATE 会覆盖 cleaned = _strip_sql_comments(sql_text) tables: dict[str, TableDef] = {} # 逐个匹配 CREATE TABLE for m in _CREATE_TABLE_RE.finditer(cleaned): schema_part = m.group(1) table_name = m.group(2).lower() # schema 过滤 if target_schema: ts = target_schema.lower() if schema_part and schema_part.lower() != ts: continue # 无 schema 前缀的表也接受(DWD DDL 中 SET search_path 后不带前缀) # 找到 CREATE TABLE ... ( 的左括号位置 paren_start = m.end() - 1 # m.end() 指向 '(' 后一位 paren_end = _find_matching_paren(cleaned, paren_start) if paren_end < 0: continue body = cleaned[paren_start + 1: paren_end] # 按行解析字段 table_def = TableDef(name=table_name) # 提取表级 PRIMARY KEY pk_cols = _extract_pk_from_body(body) # 逐行解析 for raw_line in body.split("\n"): col = _parse_column_line(raw_line) if col: table_def.columns[col.name] = col # 合并表级 PK 信息 if pk_cols: table_def.pk_columns = pk_cols for pk_col in pk_cols: if pk_col in table_def.columns: table_def.columns[pk_col].is_pk = True # PK 隐含 NOT NULL table_def.columns[pk_col].nullable = False # 合并内联 PK inline_pk = [c.name for c in table_def.columns.values() if c.is_pk] if inline_pk and not table_def.pk_columns: table_def.pk_columns = inline_pk for pk_col in inline_pk: table_def.columns[pk_col].nullable = False tables[table_name] = table_def # 解析 CREATE VIEW,仅标记视图存在(列信息由数据库侧提供) for m in _CREATE_VIEW_RE.finditer(cleaned): schema_part = m.group(1) view_name = m.group(2).lower() if target_schema: ts = target_schema.lower() if schema_part and schema_part.lower() != ts: continue if view_name not in tables: # 视图仅标记存在,不解析列(列由底层表决定) tables[view_name] = TableDef(name=view_name) # 标记为视图,跳过列级对比 tables[view_name].is_view = True return tables # --------------------------------------------------------------------------- # 数据库 schema 读取 # --------------------------------------------------------------------------- @dataclass class DbColumnInfo: """从 information_schema 查询到的字段信息。""" name: str data_type: str # 标准化后 nullable: bool is_pk: bool = False def fetch_db_schema(pg_dsn: str, schema_name: str) -> dict[str, TableDef]: """从数据库 information_schema 查询指定 schema 的所有表和字段。 Returns: {表名(小写): TableDef} 字典 """ import psycopg2 conn = psycopg2.connect(pg_dsn) try: with conn.cursor() as cur: # 检查 schema 是否存在 cur.execute( "SELECT 1 FROM information_schema.schemata WHERE schema_name = %s", (schema_name,), ) if not cur.fetchone(): print(f"⚠ schema '{schema_name}' 在数据库中不存在,跳过", file=sys.stderr) return {} # 查询所有列信息 cur.execute(""" SELECT c.table_name, c.column_name, c.data_type, c.is_nullable, c.character_maximum_length, c.numeric_precision, c.numeric_scale, c.udt_name FROM information_schema.columns c WHERE c.table_schema = %s ORDER BY c.table_name, c.ordinal_position """, (schema_name,)) rows = cur.fetchall() # 查询主键信息 cur.execute(""" SELECT tc.table_name, kcu.column_name FROM information_schema.table_constraints tc JOIN information_schema.key_column_usage kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema WHERE tc.table_schema = %s AND tc.constraint_type = 'PRIMARY KEY' ORDER BY tc.table_name, kcu.ordinal_position """, (schema_name,)) pk_rows = cur.fetchall() finally: conn.close() # 构建 PK 映射: {table_name: [col1, col2, ...]} pk_map: dict[str, list[str]] = {} for tbl, col in pk_rows: pk_map.setdefault(tbl.lower(), []).append(col.lower()) # 构建 TableDef tables: dict[str, TableDef] = {} for tbl, col_name, data_type, is_nullable, char_max_len, num_prec, num_scale, udt_name in rows: tbl_lower = tbl.lower() col_lower = col_name.lower() if tbl_lower not in tables: tables[tbl_lower] = TableDef( name=tbl_lower, pk_columns=pk_map.get(tbl_lower, []), ) # 构建精确类型字符串 type_str = _build_db_type_string(data_type, char_max_len, num_prec, num_scale, udt_name) is_pk = col_lower in pk_map.get(tbl_lower, []) nullable = is_nullable == "YES" tables[tbl_lower].columns[col_lower] = ColumnDef( name=col_lower, data_type=normalize_type(type_str), nullable=nullable, is_pk=is_pk, ) return tables def _build_db_type_string( data_type: str, char_max_len: Optional[int], num_prec: Optional[int], num_scale: Optional[int], udt_name: str, ) -> str: """根据 information_schema 字段构建可比较的类型字符串。""" dt = data_type.lower() # character varying → varchar(n) if dt == "character varying": if char_max_len: return f"varchar({char_max_len})" return "varchar" # character → char(n) if dt == "character": if char_max_len: return f"char({char_max_len})" return "char(1)" # numeric → numeric(p,s) if dt == "numeric": if num_prec is not None and num_scale is not None: return f"numeric({num_prec},{num_scale})" if num_prec is not None: return f"numeric({num_prec})" return "numeric" # USER-DEFINED → 使用 udt_name(如 jsonb, geometry 等) if dt == "user-defined": return udt_name.lower() # ARRAY → 使用 udt_name 去掉前缀 _ if dt == "array": base = udt_name.lstrip("_").lower() return f"{base}[]" return dt # --------------------------------------------------------------------------- # 对比逻辑 # --------------------------------------------------------------------------- def compare_tables( ddl_tables: dict[str, TableDef], db_tables: dict[str, TableDef], ) -> list[SchemaDiff]: """对比 DDL 定义与数据库实际结构,返回差异列表。 差异分类: - MISSING_TABLE: 数据库有但 DDL 没有 - EXTRA_TABLE: DDL 有但数据库没有 - MISSING_COLUMN: 数据库有但 DDL 没有的字段 - EXTRA_COLUMN: DDL 有但数据库没有的字段 - TYPE_MISMATCH: 字段类型不一致 - NULLABLE_MISMATCH: 可空约束不一致 """ diffs: list[SchemaDiff] = [] all_tables = sorted(set(ddl_tables.keys()) | set(db_tables.keys())) for tbl in all_tables: in_ddl = tbl in ddl_tables in_db = tbl in db_tables if in_db and not in_ddl: diffs.append(SchemaDiff(kind=DiffKind.MISSING_TABLE, table=tbl)) continue if in_ddl and not in_db: diffs.append(SchemaDiff(kind=DiffKind.EXTRA_TABLE, table=tbl)) continue # 两边都有,逐字段对比 # 视图仅检查存在性,跳过列级对比 ddl_def = ddl_tables[tbl] if getattr(ddl_def, 'is_view', False): continue ddl_cols = ddl_def.columns db_cols = db_tables[tbl].columns all_cols = sorted(set(ddl_cols.keys()) | set(db_cols.keys())) for col in all_cols: col_in_ddl = col in ddl_cols col_in_db = col in db_cols if col_in_db and not col_in_ddl: diffs.append(SchemaDiff( kind=DiffKind.MISSING_COLUMN, table=tbl, column=col, db_value=db_cols[col].data_type, )) continue if col_in_ddl and not col_in_db: diffs.append(SchemaDiff( kind=DiffKind.EXTRA_COLUMN, table=tbl, column=col, ddl_value=ddl_cols[col].data_type, )) continue # 两边都有,比较类型 ddl_type = ddl_cols[col].data_type db_type = db_cols[col].data_type # 视图列从 DDL 解析时类型为 unknown,跳过类型比较 if ddl_type != db_type and ddl_type != "unknown": diffs.append(SchemaDiff( kind=DiffKind.TYPE_MISMATCH, table=tbl, column=col, ddl_value=ddl_type, db_value=db_type, )) # 比较可空性(视图列跳过) ddl_nullable = ddl_cols[col].nullable db_nullable = db_cols[col].nullable if ddl_nullable != db_nullable and ddl_type != "unknown": diffs.append(SchemaDiff( kind=DiffKind.NULLABLE_MISMATCH, table=tbl, column=col, ddl_value="NULL" if ddl_nullable else "NOT NULL", db_value="NULL" if db_nullable else "NOT NULL", )) return diffs def compare_schema(ddl_path: str, schema_name: str, pg_dsn: str) -> list[SchemaDiff]: """对比 DDL 文件与数据库 schema 的完整流程。 Args: ddl_path: DDL 文件路径 schema_name: 数据库 schema 名称 pg_dsn: PostgreSQL 连接字符串 Returns: 差异列表 """ path = Path(ddl_path) if not path.exists(): print(f"✗ DDL 文件不存在: {ddl_path}", file=sys.stderr) return [] sql_text = path.read_text(encoding="utf-8") ddl_tables = parse_ddl(sql_text, target_schema=schema_name) if not ddl_tables: print(f"⚠ DDL 文件中未解析到任何表: {ddl_path}", file=sys.stderr) db_tables = fetch_db_schema(pg_dsn, schema_name) return compare_tables(ddl_tables, db_tables) # --------------------------------------------------------------------------- # 报告输出 # --------------------------------------------------------------------------- def print_report(diffs: list[SchemaDiff], schema_name: str, ddl_path: str) -> None: """按表分组输出差异报告到控制台。""" if not diffs: print(f"\n✓ {schema_name} ({ddl_path}): 无差异") return print(f"\n{'='*60}") print(f" 差异报告: {schema_name} ← {ddl_path}") print(f" 共 {len(diffs)} 项差异") print(f"{'='*60}") # 按表分组 by_table: dict[str, list[SchemaDiff]] = {} for d in diffs: by_table.setdefault(d.table, []).append(d) for tbl in sorted(by_table.keys()): items = by_table[tbl] print(f"\n ▸ {tbl}") for d in items: icon = { DiffKind.MISSING_TABLE: "🔴 DDL 缺表", DiffKind.EXTRA_TABLE: "🟡 DDL 多表", DiffKind.MISSING_COLUMN: "🔴 DDL 缺字段", DiffKind.EXTRA_COLUMN: "🟡 DDL 多字段", DiffKind.TYPE_MISMATCH: "🟠 类型不一致", DiffKind.NULLABLE_MISMATCH: "🔵 可空不一致", }.get(d.kind, d.kind.value) if d.column: detail = f" {icon}: {d.column}" else: detail = f" {icon}" if d.ddl_value is not None or d.db_value is not None: detail += f" (DDL={d.ddl_value}, DB={d.db_value})" print(detail) print() # --------------------------------------------------------------------------- # CLI 入口 # --------------------------------------------------------------------------- # 预定义的 schema → DDL 文件映射 # CHANGE 2026-02-15 | 对齐新库 etl_feiqiu 六层架构 DEFAULT_SCHEMA_MAP: dict[str, str] = { "ods": "database/schema_ODS_doc.sql", "dwd": "database/schema_dwd_doc.sql", "dws": "database/schema_dws.sql", "meta": "database/schema_etl_admin.sql", } def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser( description="对比 DDL 文件与数据库实际表结构", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: # 对比单个 schema python scripts/compare_ddl_db.py --schema ods --ddl-path database/schema_ODS_doc.sql # 对比所有预定义 schema(从 .env 读取 PG_DSN) python scripts/compare_ddl_db.py --all # 指定连接字符串 python scripts/compare_ddl_db.py --all --pg-dsn "postgresql://user:pass@host/db" """, ) parser.add_argument("--pg-dsn", help="PostgreSQL 连接字符串(默认从 PG_DSN 环境变量或 .env 读取)") parser.add_argument("--schema", help="要对比的 schema 名称") parser.add_argument("--ddl-path", help="DDL 文件路径") parser.add_argument("--all", action="store_true", help="对比所有预定义 schema") args = parser.parse_args(argv) # 加载 .env try: from dotenv import load_dotenv load_dotenv() except ImportError: pass pg_dsn = args.pg_dsn or os.environ.get("PG_DSN") if not pg_dsn: print("✗ 未提供 PG_DSN,请通过 --pg-dsn 参数或 PG_DSN 环境变量指定", file=sys.stderr) return 1 # 确定要对比的 schema 列表 pairs: list[tuple[str, str]] = [] if args.all: for schema, ddl in DEFAULT_SCHEMA_MAP.items(): pairs.append((schema, ddl)) elif args.schema and args.ddl_path: pairs.append((args.schema, args.ddl_path)) elif args.schema: # 尝试从预定义映射中查找 ddl = DEFAULT_SCHEMA_MAP.get(args.schema) if ddl: pairs.append((args.schema, ddl)) else: print(f"✗ 未知 schema '{args.schema}',请通过 --ddl-path 指定 DDL 文件", file=sys.stderr) return 1 else: parser.print_help() return 1 total_diffs = 0 for schema_name, ddl_path in pairs: if not Path(ddl_path).exists(): print(f"⚠ DDL 文件不存在,跳过: {ddl_path}", file=sys.stderr) continue try: diffs = compare_schema(ddl_path, schema_name, pg_dsn) except Exception as e: print(f"✗ 对比 {schema_name} 时出错: {e}", file=sys.stderr) continue print_report(diffs, schema_name, ddl_path) total_diffs += len(diffs) if total_diffs > 0: print(f"共发现 {total_diffs} 项差异") return 1 print("所有 schema 对比通过,无差异 ✓") return 0 if __name__ == "__main__": sys.exit(main())