init: 项目初始提交 - NeoZQYY Monorepo 完整代码
This commit is contained in:
822
apps/etl/pipelines/feiqiu/scripts/compare_ddl_db.py
Normal file
822
apps/etl/pipelines/feiqiu/scripts/compare_ddl_db.py
Normal file
@@ -0,0 +1,822 @@
|
||||
#!/usr/bin/env python3
|
||||
"""DDL 与数据库实际表结构对比脚本。
|
||||
|
||||
# AI_CHANGELOG [2026-02-13] 修复列名以 UNIQUE/CHECK 开头被误判为约束行的 bug;新增 CREATE VIEW 解析支持(视图仅检查存在性)
|
||||
|
||||
解析 database/schema_*.sql 中的 CREATE TABLE 语句,
|
||||
查询 information_schema.columns 获取数据库实际结构,
|
||||
逐表逐字段对比并输出差异报告。
|
||||
|
||||
用法:
|
||||
python scripts/compare_ddl_db.py --pg-dsn "postgresql://..." --schema billiards_ods --ddl-path database/schema_ODS_doc.sql
|
||||
python scripts/compare_ddl_db.py --schema billiards_dwd --ddl-path database/schema_dwd_doc.sql # 从 .env 读取 PG_DSN
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DiffKind(str, Enum):
|
||||
"""差异分类枚举。"""
|
||||
MISSING_TABLE = "MISSING_TABLE" # DDL 缺表(数据库有,DDL 没有)
|
||||
EXTRA_TABLE = "EXTRA_TABLE" # DDL 多表(DDL 有,数据库没有)
|
||||
MISSING_COLUMN = "MISSING_COLUMN" # DDL 缺字段
|
||||
EXTRA_COLUMN = "EXTRA_COLUMN" # DDL 多字段
|
||||
TYPE_MISMATCH = "TYPE_MISMATCH" # 字段类型不一致
|
||||
NULLABLE_MISMATCH = "NULLABLE_MISMATCH" # 可空约束不一致
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchemaDiff:
|
||||
"""单条差异记录。"""
|
||||
kind: DiffKind
|
||||
table: str
|
||||
column: Optional[str] = None
|
||||
ddl_value: Optional[str] = None
|
||||
db_value: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [f"[{self.kind.value}] {self.table}"]
|
||||
if self.column:
|
||||
parts.append(f".{self.column}")
|
||||
if self.ddl_value is not None or self.db_value is not None:
|
||||
parts.append(f" DDL={self.ddl_value} DB={self.db_value}")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DDL 列定义
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class ColumnDef:
|
||||
"""从 DDL 解析出的单个字段定义。"""
|
||||
name: str
|
||||
data_type: str # 标准化后的类型字符串
|
||||
nullable: bool = True
|
||||
is_pk: bool = False
|
||||
default: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableDef:
|
||||
"""从 DDL 解析出的单张表定义。"""
|
||||
name: str # 不含 schema 前缀的表名(小写)
|
||||
columns: dict[str, ColumnDef] = field(default_factory=dict)
|
||||
pk_columns: list[str] = field(default_factory=list)
|
||||
is_view: bool = False # 视图标记,跳过列级对比
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 类型标准化:将 DDL 类型和 information_schema 类型映射到统一表示
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# PostgreSQL information_schema.data_type → 简写映射
|
||||
_PG_TYPE_MAP: dict[str, str] = {
|
||||
"bigint": "bigint",
|
||||
"integer": "integer",
|
||||
"smallint": "smallint",
|
||||
"boolean": "boolean",
|
||||
"text": "text",
|
||||
"jsonb": "jsonb",
|
||||
"json": "json",
|
||||
"date": "date",
|
||||
"bytea": "bytea",
|
||||
"double precision": "double precision",
|
||||
"real": "real",
|
||||
"uuid": "uuid",
|
||||
"timestamp without time zone": "timestamp",
|
||||
"timestamp with time zone": "timestamptz",
|
||||
"time without time zone": "time",
|
||||
"time with time zone": "timetz",
|
||||
"character varying": "varchar",
|
||||
"character": "char",
|
||||
"ARRAY": "array",
|
||||
"USER-DEFINED": "user-defined",
|
||||
}
|
||||
|
||||
|
||||
def normalize_type(raw: str) -> str:
|
||||
"""将 DDL 或 information_schema 中的类型字符串标准化为可比较的形式。
|
||||
|
||||
规则:
|
||||
- 全部小写
|
||||
- BIGINT / INT8 → bigint
|
||||
- INTEGER / INT / INT4 → integer
|
||||
- SMALLINT / INT2 → smallint
|
||||
- BOOLEAN / BOOL → boolean
|
||||
- VARCHAR(n) / CHARACTER VARYING(n) → varchar(n)
|
||||
- CHAR(n) / CHARACTER(n) → char(n)
|
||||
- NUMERIC(p,s) / DECIMAL(p,s) → numeric(p,s)
|
||||
- SERIAL → integer(serial 本质是 integer + sequence)
|
||||
- BIGSERIAL → bigint
|
||||
- TIMESTAMP → timestamp
|
||||
- TIMESTAMPTZ / TIMESTAMP WITH TIME ZONE → timestamptz
|
||||
- TEXT → text
|
||||
- JSONB → jsonb
|
||||
"""
|
||||
t = raw.strip().lower()
|
||||
|
||||
# 去掉多余空格
|
||||
t = re.sub(r"\s+", " ", t)
|
||||
|
||||
# serial 家族 → 底层整数类型
|
||||
if t == "bigserial":
|
||||
return "bigint"
|
||||
if t in ("serial", "serial4"):
|
||||
return "integer"
|
||||
if t == "smallserial":
|
||||
return "smallint"
|
||||
|
||||
# 带精度的 numeric / decimal
|
||||
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\s*,\s*(\d+)\)", t)
|
||||
if m:
|
||||
return f"numeric({m.group(1)},{m.group(2)})"
|
||||
m = re.match(r"(?:numeric|decimal)\s*\((\d+)\)", t)
|
||||
if m:
|
||||
return f"numeric({m.group(1)})"
|
||||
if t in ("numeric", "decimal"):
|
||||
return "numeric"
|
||||
|
||||
# varchar / character varying
|
||||
m = re.match(r"(?:varchar|character varying)\s*\((\d+)\)", t)
|
||||
if m:
|
||||
return f"varchar({m.group(1)})"
|
||||
if t in ("varchar", "character varying"):
|
||||
return "varchar"
|
||||
|
||||
# char / character
|
||||
m = re.match(r"(?:char|character)\s*\((\d+)\)", t)
|
||||
if m:
|
||||
return f"char({m.group(1)})"
|
||||
if t in ("char", "character"):
|
||||
return "char(1)"
|
||||
|
||||
# timestamp 家族
|
||||
if t in ("timestamptz", "timestamp with time zone"):
|
||||
return "timestamptz"
|
||||
if t in ("timestamp", "timestamp without time zone"):
|
||||
return "timestamp"
|
||||
|
||||
# 整数别名
|
||||
if t in ("int8", "bigint"):
|
||||
return "bigint"
|
||||
if t in ("int", "int4", "integer"):
|
||||
return "integer"
|
||||
if t in ("int2", "smallint"):
|
||||
return "smallint"
|
||||
|
||||
# 布尔
|
||||
if t in ("bool", "boolean"):
|
||||
return "boolean"
|
||||
|
||||
# information_schema 映射
|
||||
if t in _PG_TYPE_MAP:
|
||||
return _PG_TYPE_MAP[t]
|
||||
|
||||
return t
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DDL 解析器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 匹配 CREATE TABLE [IF NOT EXISTS] [schema.]table_name (
|
||||
_CREATE_TABLE_RE = re.compile(
|
||||
r"CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
|
||||
r"(?:(\w+)\.)?(\w+)\s*\(",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# 匹配 DROP TABLE [IF EXISTS] [schema.]table_name [CASCADE];
|
||||
_DROP_TABLE_RE = re.compile(
|
||||
r"DROP\s+TABLE\s+(?:IF\s+EXISTS\s+)?(?:\w+\.)?(\w+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# 匹配 CREATE [OR REPLACE] VIEW [schema.]view_name AS SELECT ...
|
||||
_CREATE_VIEW_RE = re.compile(
|
||||
r"CREATE\s+(?:OR\s+REPLACE\s+)?VIEW\s+"
|
||||
r"(?:(\w+)\.)?(\w+)\s+AS\s+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _strip_sql_comments(sql: str) -> str:
|
||||
"""移除 SQL 单行注释(-- ...)和块注释(/* ... */)。"""
|
||||
# 块注释
|
||||
sql = re.sub(r"/\*.*?\*/", "", sql, flags=re.DOTALL)
|
||||
# 单行注释
|
||||
sql = re.sub(r"--[^\n]*", "", sql)
|
||||
return sql
|
||||
|
||||
|
||||
def _find_matching_paren(text: str, start: int) -> int:
|
||||
"""从 start 位置(应为 '(')开始,找到匹配的 ')' 位置。
|
||||
|
||||
处理嵌套括号和字符串字面量中的括号。
|
||||
"""
|
||||
depth = 0
|
||||
in_string = False
|
||||
string_char = ""
|
||||
i = start
|
||||
while i < len(text):
|
||||
ch = text[i]
|
||||
if in_string:
|
||||
if ch == string_char:
|
||||
# 检查转义
|
||||
if i + 1 < len(text) and text[i + 1] == string_char:
|
||||
i += 2
|
||||
continue
|
||||
in_string = False
|
||||
else:
|
||||
if ch in ("'", '"'):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
elif ch == "(":
|
||||
depth += 1
|
||||
elif ch == ")":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return i
|
||||
i += 1
|
||||
return -1
|
||||
|
||||
|
||||
def _parse_column_line(line: str) -> Optional[ColumnDef]:
|
||||
"""解析单行字段定义,返回 ColumnDef 或 None(如果是约束行)。"""
|
||||
line = line.strip().rstrip(",")
|
||||
if not line:
|
||||
return None
|
||||
|
||||
upper = line.upper()
|
||||
# 跳过表级约束行
|
||||
# 注意:需要区分约束行(如 "UNIQUE (...)")和以约束关键字开头的列名
|
||||
# (如 "unique_customers INTEGER"、"check_status INT")
|
||||
# 约束行的关键字后面紧跟空格+左括号或直接左括号,而列名后面跟下划线或字母
|
||||
if re.match(
|
||||
r"(?:PRIMARY\s+KEY|UNIQUE|CHECK|FOREIGN\s+KEY|EXCLUDE)"
|
||||
r"(?:\s*\(|\s+(?![\w]))",
|
||||
upper,
|
||||
) or upper.startswith("CONSTRAINT"):
|
||||
return None
|
||||
|
||||
# 字段名 类型 [约束...]
|
||||
# 字段名可能被双引号包裹
|
||||
m = re.match(r'(?:"([^"]+)"|(\w+))\s+(.+)', line)
|
||||
if not m:
|
||||
return None
|
||||
|
||||
col_name = (m.group(1) or m.group(2)).lower()
|
||||
rest = m.group(3).strip()
|
||||
|
||||
# 提取类型:取到第一个(位置最靠前的)已知约束关键字或行尾
|
||||
# 类型可能包含括号,如 NUMERIC(18,2)、VARCHAR(50)
|
||||
type_end_keywords = [
|
||||
"NOT NULL", "NULL", "DEFAULT", "PRIMARY KEY", "UNIQUE",
|
||||
"REFERENCES", "CHECK", "CONSTRAINT", "GENERATED",
|
||||
]
|
||||
type_str = rest
|
||||
constraint_part = ""
|
||||
# 找所有关键字中位置最靠前的
|
||||
best_idx = len(rest)
|
||||
for kw in type_end_keywords:
|
||||
idx = rest.upper().find(kw)
|
||||
if idx > 0 and idx < best_idx:
|
||||
candidate = rest[:idx].strip()
|
||||
if candidate:
|
||||
best_idx = idx
|
||||
if best_idx < len(rest):
|
||||
type_str = rest[:best_idx].strip()
|
||||
constraint_part = rest[best_idx:]
|
||||
|
||||
# 去掉类型末尾的逗号
|
||||
type_str = type_str.rstrip(",").strip()
|
||||
|
||||
nullable = True
|
||||
if "NOT NULL" in constraint_part.upper():
|
||||
nullable = False
|
||||
|
||||
is_pk = "PRIMARY KEY" in constraint_part.upper()
|
||||
|
||||
# 提取 DEFAULT 值
|
||||
default_val = None
|
||||
dm = re.search(r"DEFAULT\s+(.+?)(?:\s+(?:NOT\s+NULL|NULL|PRIMARY|UNIQUE|REFERENCES|CHECK|CONSTRAINT|,|$))",
|
||||
constraint_part, re.IGNORECASE)
|
||||
if dm:
|
||||
default_val = dm.group(1).strip().rstrip(",")
|
||||
|
||||
return ColumnDef(
|
||||
name=col_name,
|
||||
data_type=normalize_type(type_str),
|
||||
nullable=nullable,
|
||||
is_pk=is_pk,
|
||||
default=default_val,
|
||||
)
|
||||
|
||||
|
||||
def _extract_pk_from_body(body: str) -> list[str]:
|
||||
"""从 CREATE TABLE 体中提取表级 PRIMARY KEY 约束的列名列表。"""
|
||||
# PRIMARY KEY (col1, col2, ...)
|
||||
# 也可能是 CONSTRAINT xxx PRIMARY KEY (col1, col2)
|
||||
m = re.search(r"PRIMARY\s+KEY\s*\(([^)]+)\)", body, re.IGNORECASE)
|
||||
if not m:
|
||||
return []
|
||||
cols_str = m.group(1)
|
||||
return [c.strip().strip('"').lower() for c in cols_str.split(",")]
|
||||
|
||||
|
||||
def parse_ddl(sql_text: str, target_schema: Optional[str] = None) -> dict[str, TableDef]:
|
||||
"""解析 DDL 文本,提取所有 CREATE TABLE 定义。
|
||||
|
||||
Args:
|
||||
sql_text: 完整的 SQL DDL 文本
|
||||
target_schema: 如果指定,只保留该 schema 下的表(或无 schema 前缀的表)
|
||||
|
||||
Returns:
|
||||
{表名(小写): TableDef} 字典
|
||||
"""
|
||||
# 先收集被 DROP 的表名,后续 CREATE 会覆盖
|
||||
cleaned = _strip_sql_comments(sql_text)
|
||||
|
||||
tables: dict[str, TableDef] = {}
|
||||
|
||||
# 逐个匹配 CREATE TABLE
|
||||
for m in _CREATE_TABLE_RE.finditer(cleaned):
|
||||
schema_part = m.group(1)
|
||||
table_name = m.group(2).lower()
|
||||
|
||||
# schema 过滤
|
||||
if target_schema:
|
||||
ts = target_schema.lower()
|
||||
if schema_part and schema_part.lower() != ts:
|
||||
continue
|
||||
# 无 schema 前缀的表也接受(DWD DDL 中 SET search_path 后不带前缀)
|
||||
|
||||
# 找到 CREATE TABLE ... ( 的左括号位置
|
||||
paren_start = m.end() - 1 # m.end() 指向 '(' 后一位
|
||||
paren_end = _find_matching_paren(cleaned, paren_start)
|
||||
if paren_end < 0:
|
||||
continue
|
||||
|
||||
body = cleaned[paren_start + 1: paren_end]
|
||||
|
||||
# 按行解析字段
|
||||
table_def = TableDef(name=table_name)
|
||||
|
||||
# 提取表级 PRIMARY KEY
|
||||
pk_cols = _extract_pk_from_body(body)
|
||||
|
||||
# 逐行解析
|
||||
for raw_line in body.split("\n"):
|
||||
col = _parse_column_line(raw_line)
|
||||
if col:
|
||||
table_def.columns[col.name] = col
|
||||
|
||||
# 合并表级 PK 信息
|
||||
if pk_cols:
|
||||
table_def.pk_columns = pk_cols
|
||||
for pk_col in pk_cols:
|
||||
if pk_col in table_def.columns:
|
||||
table_def.columns[pk_col].is_pk = True
|
||||
# PK 隐含 NOT NULL
|
||||
table_def.columns[pk_col].nullable = False
|
||||
|
||||
# 合并内联 PK
|
||||
inline_pk = [c.name for c in table_def.columns.values() if c.is_pk]
|
||||
if inline_pk and not table_def.pk_columns:
|
||||
table_def.pk_columns = inline_pk
|
||||
for pk_col in inline_pk:
|
||||
table_def.columns[pk_col].nullable = False
|
||||
|
||||
tables[table_name] = table_def
|
||||
|
||||
# 解析 CREATE VIEW,仅标记视图存在(列信息由数据库侧提供)
|
||||
for m in _CREATE_VIEW_RE.finditer(cleaned):
|
||||
schema_part = m.group(1)
|
||||
view_name = m.group(2).lower()
|
||||
|
||||
if target_schema:
|
||||
ts = target_schema.lower()
|
||||
if schema_part and schema_part.lower() != ts:
|
||||
continue
|
||||
|
||||
if view_name not in tables:
|
||||
# 视图仅标记存在,不解析列(列由底层表决定)
|
||||
tables[view_name] = TableDef(name=view_name)
|
||||
# 标记为视图,跳过列级对比
|
||||
tables[view_name].is_view = True
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 数据库 schema 读取
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class DbColumnInfo:
|
||||
"""从 information_schema 查询到的字段信息。"""
|
||||
name: str
|
||||
data_type: str # 标准化后
|
||||
nullable: bool
|
||||
is_pk: bool = False
|
||||
|
||||
|
||||
def fetch_db_schema(pg_dsn: str, schema_name: str) -> dict[str, TableDef]:
|
||||
"""从数据库 information_schema 查询指定 schema 的所有表和字段。
|
||||
|
||||
Returns:
|
||||
{表名(小写): TableDef} 字典
|
||||
"""
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(pg_dsn)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
# 检查 schema 是否存在
|
||||
cur.execute(
|
||||
"SELECT 1 FROM information_schema.schemata WHERE schema_name = %s",
|
||||
(schema_name,),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
print(f"⚠ schema '{schema_name}' 在数据库中不存在,跳过", file=sys.stderr)
|
||||
return {}
|
||||
|
||||
# 查询所有列信息
|
||||
cur.execute("""
|
||||
SELECT
|
||||
c.table_name,
|
||||
c.column_name,
|
||||
c.data_type,
|
||||
c.is_nullable,
|
||||
c.character_maximum_length,
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
c.udt_name
|
||||
FROM information_schema.columns c
|
||||
WHERE c.table_schema = %s
|
||||
ORDER BY c.table_name, c.ordinal_position
|
||||
""", (schema_name,))
|
||||
|
||||
rows = cur.fetchall()
|
||||
|
||||
# 查询主键信息
|
||||
cur.execute("""
|
||||
SELECT
|
||||
tc.table_name,
|
||||
kcu.column_name
|
||||
FROM information_schema.table_constraints tc
|
||||
JOIN information_schema.key_column_usage kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.table_schema = %s
|
||||
AND tc.constraint_type = 'PRIMARY KEY'
|
||||
ORDER BY tc.table_name, kcu.ordinal_position
|
||||
""", (schema_name,))
|
||||
|
||||
pk_rows = cur.fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# 构建 PK 映射: {table_name: [col1, col2, ...]}
|
||||
pk_map: dict[str, list[str]] = {}
|
||||
for tbl, col in pk_rows:
|
||||
pk_map.setdefault(tbl.lower(), []).append(col.lower())
|
||||
|
||||
# 构建 TableDef
|
||||
tables: dict[str, TableDef] = {}
|
||||
for tbl, col_name, data_type, is_nullable, char_max_len, num_prec, num_scale, udt_name in rows:
|
||||
tbl_lower = tbl.lower()
|
||||
col_lower = col_name.lower()
|
||||
|
||||
if tbl_lower not in tables:
|
||||
tables[tbl_lower] = TableDef(
|
||||
name=tbl_lower,
|
||||
pk_columns=pk_map.get(tbl_lower, []),
|
||||
)
|
||||
|
||||
# 构建精确类型字符串
|
||||
type_str = _build_db_type_string(data_type, char_max_len, num_prec, num_scale, udt_name)
|
||||
|
||||
is_pk = col_lower in pk_map.get(tbl_lower, [])
|
||||
nullable = is_nullable == "YES"
|
||||
|
||||
tables[tbl_lower].columns[col_lower] = ColumnDef(
|
||||
name=col_lower,
|
||||
data_type=normalize_type(type_str),
|
||||
nullable=nullable,
|
||||
is_pk=is_pk,
|
||||
)
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
def _build_db_type_string(
|
||||
data_type: str,
|
||||
char_max_len: Optional[int],
|
||||
num_prec: Optional[int],
|
||||
num_scale: Optional[int],
|
||||
udt_name: str,
|
||||
) -> str:
|
||||
"""根据 information_schema 字段构建可比较的类型字符串。"""
|
||||
dt = data_type.lower()
|
||||
|
||||
# character varying → varchar(n)
|
||||
if dt == "character varying":
|
||||
if char_max_len:
|
||||
return f"varchar({char_max_len})"
|
||||
return "varchar"
|
||||
|
||||
# character → char(n)
|
||||
if dt == "character":
|
||||
if char_max_len:
|
||||
return f"char({char_max_len})"
|
||||
return "char(1)"
|
||||
|
||||
# numeric → numeric(p,s)
|
||||
if dt == "numeric":
|
||||
if num_prec is not None and num_scale is not None:
|
||||
return f"numeric({num_prec},{num_scale})"
|
||||
if num_prec is not None:
|
||||
return f"numeric({num_prec})"
|
||||
return "numeric"
|
||||
|
||||
# USER-DEFINED → 使用 udt_name(如 jsonb, geometry 等)
|
||||
if dt == "user-defined":
|
||||
return udt_name.lower()
|
||||
|
||||
# ARRAY → 使用 udt_name 去掉前缀 _
|
||||
if dt == "array":
|
||||
base = udt_name.lstrip("_").lower()
|
||||
return f"{base}[]"
|
||||
|
||||
return dt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 对比逻辑
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compare_tables(
|
||||
ddl_tables: dict[str, TableDef],
|
||||
db_tables: dict[str, TableDef],
|
||||
) -> list[SchemaDiff]:
|
||||
"""对比 DDL 定义与数据库实际结构,返回差异列表。
|
||||
|
||||
差异分类:
|
||||
- MISSING_TABLE: 数据库有但 DDL 没有
|
||||
- EXTRA_TABLE: DDL 有但数据库没有
|
||||
- MISSING_COLUMN: 数据库有但 DDL 没有的字段
|
||||
- EXTRA_COLUMN: DDL 有但数据库没有的字段
|
||||
- TYPE_MISMATCH: 字段类型不一致
|
||||
- NULLABLE_MISMATCH: 可空约束不一致
|
||||
"""
|
||||
diffs: list[SchemaDiff] = []
|
||||
|
||||
all_tables = sorted(set(ddl_tables.keys()) | set(db_tables.keys()))
|
||||
|
||||
for tbl in all_tables:
|
||||
in_ddl = tbl in ddl_tables
|
||||
in_db = tbl in db_tables
|
||||
|
||||
if in_db and not in_ddl:
|
||||
diffs.append(SchemaDiff(kind=DiffKind.MISSING_TABLE, table=tbl))
|
||||
continue
|
||||
|
||||
if in_ddl and not in_db:
|
||||
diffs.append(SchemaDiff(kind=DiffKind.EXTRA_TABLE, table=tbl))
|
||||
continue
|
||||
|
||||
# 两边都有,逐字段对比
|
||||
# 视图仅检查存在性,跳过列级对比
|
||||
ddl_def = ddl_tables[tbl]
|
||||
if getattr(ddl_def, 'is_view', False):
|
||||
continue
|
||||
|
||||
ddl_cols = ddl_def.columns
|
||||
db_cols = db_tables[tbl].columns
|
||||
all_cols = sorted(set(ddl_cols.keys()) | set(db_cols.keys()))
|
||||
|
||||
for col in all_cols:
|
||||
col_in_ddl = col in ddl_cols
|
||||
col_in_db = col in db_cols
|
||||
|
||||
if col_in_db and not col_in_ddl:
|
||||
diffs.append(SchemaDiff(
|
||||
kind=DiffKind.MISSING_COLUMN,
|
||||
table=tbl,
|
||||
column=col,
|
||||
db_value=db_cols[col].data_type,
|
||||
))
|
||||
continue
|
||||
|
||||
if col_in_ddl and not col_in_db:
|
||||
diffs.append(SchemaDiff(
|
||||
kind=DiffKind.EXTRA_COLUMN,
|
||||
table=tbl,
|
||||
column=col,
|
||||
ddl_value=ddl_cols[col].data_type,
|
||||
))
|
||||
continue
|
||||
|
||||
# 两边都有,比较类型
|
||||
ddl_type = ddl_cols[col].data_type
|
||||
db_type = db_cols[col].data_type
|
||||
# 视图列从 DDL 解析时类型为 unknown,跳过类型比较
|
||||
if ddl_type != db_type and ddl_type != "unknown":
|
||||
diffs.append(SchemaDiff(
|
||||
kind=DiffKind.TYPE_MISMATCH,
|
||||
table=tbl,
|
||||
column=col,
|
||||
ddl_value=ddl_type,
|
||||
db_value=db_type,
|
||||
))
|
||||
|
||||
# 比较可空性(视图列跳过)
|
||||
ddl_nullable = ddl_cols[col].nullable
|
||||
db_nullable = db_cols[col].nullable
|
||||
if ddl_nullable != db_nullable and ddl_type != "unknown":
|
||||
diffs.append(SchemaDiff(
|
||||
kind=DiffKind.NULLABLE_MISMATCH,
|
||||
table=tbl,
|
||||
column=col,
|
||||
ddl_value="NULL" if ddl_nullable else "NOT NULL",
|
||||
db_value="NULL" if db_nullable else "NOT NULL",
|
||||
))
|
||||
|
||||
return diffs
|
||||
|
||||
|
||||
def compare_schema(ddl_path: str, schema_name: str, pg_dsn: str) -> list[SchemaDiff]:
|
||||
"""对比 DDL 文件与数据库 schema 的完整流程。
|
||||
|
||||
Args:
|
||||
ddl_path: DDL 文件路径
|
||||
schema_name: 数据库 schema 名称
|
||||
pg_dsn: PostgreSQL 连接字符串
|
||||
|
||||
Returns:
|
||||
差异列表
|
||||
"""
|
||||
path = Path(ddl_path)
|
||||
if not path.exists():
|
||||
print(f"✗ DDL 文件不存在: {ddl_path}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
sql_text = path.read_text(encoding="utf-8")
|
||||
ddl_tables = parse_ddl(sql_text, target_schema=schema_name)
|
||||
|
||||
if not ddl_tables:
|
||||
print(f"⚠ DDL 文件中未解析到任何表: {ddl_path}", file=sys.stderr)
|
||||
|
||||
db_tables = fetch_db_schema(pg_dsn, schema_name)
|
||||
|
||||
return compare_tables(ddl_tables, db_tables)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 报告输出
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def print_report(diffs: list[SchemaDiff], schema_name: str, ddl_path: str) -> None:
|
||||
"""按表分组输出差异报告到控制台。"""
|
||||
if not diffs:
|
||||
print(f"\n✓ {schema_name} ({ddl_path}): 无差异")
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" 差异报告: {schema_name} ← {ddl_path}")
|
||||
print(f" 共 {len(diffs)} 项差异")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 按表分组
|
||||
by_table: dict[str, list[SchemaDiff]] = {}
|
||||
for d in diffs:
|
||||
by_table.setdefault(d.table, []).append(d)
|
||||
|
||||
for tbl in sorted(by_table.keys()):
|
||||
items = by_table[tbl]
|
||||
print(f"\n ▸ {tbl}")
|
||||
for d in items:
|
||||
icon = {
|
||||
DiffKind.MISSING_TABLE: "🔴 DDL 缺表",
|
||||
DiffKind.EXTRA_TABLE: "🟡 DDL 多表",
|
||||
DiffKind.MISSING_COLUMN: "🔴 DDL 缺字段",
|
||||
DiffKind.EXTRA_COLUMN: "🟡 DDL 多字段",
|
||||
DiffKind.TYPE_MISMATCH: "🟠 类型不一致",
|
||||
DiffKind.NULLABLE_MISMATCH: "🔵 可空不一致",
|
||||
}.get(d.kind, d.kind.value)
|
||||
|
||||
if d.column:
|
||||
detail = f" {icon}: {d.column}"
|
||||
else:
|
||||
detail = f" {icon}"
|
||||
|
||||
if d.ddl_value is not None or d.db_value is not None:
|
||||
detail += f" (DDL={d.ddl_value}, DB={d.db_value})"
|
||||
print(detail)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI 入口
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 预定义的 schema → DDL 文件映射
|
||||
DEFAULT_SCHEMA_MAP: dict[str, str] = {
|
||||
"billiards_ods": "database/schema_ODS_doc.sql",
|
||||
"billiards_dwd": "database/schema_dwd_doc.sql",
|
||||
"billiards_dws": "database/schema_dws.sql",
|
||||
"etl_admin": "database/schema_etl_admin.sql",
|
||||
}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="对比 DDL 文件与数据库实际表结构",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 对比单个 schema
|
||||
python scripts/compare_ddl_db.py --schema billiards_ods --ddl-path database/schema_ODS_doc.sql
|
||||
|
||||
# 对比所有预定义 schema(从 .env 读取 PG_DSN)
|
||||
python scripts/compare_ddl_db.py --all
|
||||
|
||||
# 指定连接字符串
|
||||
python scripts/compare_ddl_db.py --all --pg-dsn "postgresql://user:pass@host/db"
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--pg-dsn", help="PostgreSQL 连接字符串(默认从 PG_DSN 环境变量或 .env 读取)")
|
||||
parser.add_argument("--schema", help="要对比的 schema 名称")
|
||||
parser.add_argument("--ddl-path", help="DDL 文件路径")
|
||||
parser.add_argument("--all", action="store_true", help="对比所有预定义 schema")
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# 加载 .env
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pg_dsn = args.pg_dsn or os.environ.get("PG_DSN")
|
||||
if not pg_dsn:
|
||||
print("✗ 未提供 PG_DSN,请通过 --pg-dsn 参数或 PG_DSN 环境变量指定", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# 确定要对比的 schema 列表
|
||||
pairs: list[tuple[str, str]] = []
|
||||
if args.all:
|
||||
for schema, ddl in DEFAULT_SCHEMA_MAP.items():
|
||||
pairs.append((schema, ddl))
|
||||
elif args.schema and args.ddl_path:
|
||||
pairs.append((args.schema, args.ddl_path))
|
||||
elif args.schema:
|
||||
# 尝试从预定义映射中查找
|
||||
ddl = DEFAULT_SCHEMA_MAP.get(args.schema)
|
||||
if ddl:
|
||||
pairs.append((args.schema, ddl))
|
||||
else:
|
||||
print(f"✗ 未知 schema '{args.schema}',请通过 --ddl-path 指定 DDL 文件", file=sys.stderr)
|
||||
return 1
|
||||
else:
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
total_diffs = 0
|
||||
for schema_name, ddl_path in pairs:
|
||||
if not Path(ddl_path).exists():
|
||||
print(f"⚠ DDL 文件不存在,跳过: {ddl_path}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
try:
|
||||
diffs = compare_schema(ddl_path, schema_name, pg_dsn)
|
||||
except Exception as e:
|
||||
print(f"✗ 对比 {schema_name} 时出错: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
print_report(diffs, schema_name, ddl_path)
|
||||
total_diffs += len(diffs)
|
||||
|
||||
if total_diffs > 0:
|
||||
print(f"共发现 {total_diffs} 项差异")
|
||||
return 1
|
||||
|
||||
print("所有 schema 对比通过,无差异 ✓")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user