427 lines
15 KiB
Python
427 lines
15 KiB
Python
"""DDL 解析器和对比逻辑的单元测试。
|
||
|
||
测试范围:
|
||
- DDL 解析器正确提取表名、字段、类型、约束
|
||
- 类型标准化逻辑
|
||
- 差异检测逻辑识别各类差异
|
||
- 边界情况:空文件、COMMENT 含特殊字符
|
||
"""
|
||
|
||
import pytest
|
||
|
||
from scripts.compare_ddl_db import (
|
||
ColumnDef,
|
||
DiffKind,
|
||
SchemaDiff,
|
||
TableDef,
|
||
compare_tables,
|
||
normalize_type,
|
||
parse_ddl,
|
||
)
|
||
|
||
|
||
# =========================================================================
|
||
# normalize_type 测试
|
||
# =========================================================================
|
||
|
||
class TestNormalizeType:
|
||
"""类型标准化测试。"""
|
||
|
||
@pytest.mark.parametrize("raw,expected", [
|
||
("BIGINT", "bigint"),
|
||
("INT8", "bigint"),
|
||
("INTEGER", "integer"),
|
||
("INT", "integer"),
|
||
("INT4", "integer"),
|
||
("SMALLINT", "smallint"),
|
||
("INT2", "smallint"),
|
||
("BOOLEAN", "boolean"),
|
||
("BOOL", "boolean"),
|
||
("TEXT", "text"),
|
||
("JSONB", "jsonb"),
|
||
("JSON", "json"),
|
||
("DATE", "date"),
|
||
("BYTEA", "bytea"),
|
||
("UUID", "uuid"),
|
||
])
|
||
def test_simple_types(self, raw, expected):
|
||
assert normalize_type(raw) == expected
|
||
|
||
@pytest.mark.parametrize("raw,expected", [
|
||
("NUMERIC(18,2)", "numeric(18,2)"),
|
||
("NUMERIC(10,6)", "numeric(10,6)"),
|
||
("DECIMAL(5,2)", "numeric(5,2)"),
|
||
("NUMERIC(10)", "numeric(10)"),
|
||
("NUMERIC", "numeric"),
|
||
])
|
||
def test_numeric_types(self, raw, expected):
|
||
assert normalize_type(raw) == expected
|
||
|
||
@pytest.mark.parametrize("raw,expected", [
|
||
("VARCHAR(50)", "varchar(50)"),
|
||
("CHARACTER VARYING(100)", "varchar(100)"),
|
||
("VARCHAR", "varchar"),
|
||
("CHAR(1)", "char(1)"),
|
||
("CHARACTER(10)", "char(10)"),
|
||
])
|
||
def test_string_types(self, raw, expected):
|
||
assert normalize_type(raw) == expected
|
||
|
||
@pytest.mark.parametrize("raw,expected", [
|
||
("TIMESTAMP", "timestamp"),
|
||
("TIMESTAMP WITHOUT TIME ZONE", "timestamp"),
|
||
("TIMESTAMPTZ", "timestamptz"),
|
||
("TIMESTAMP WITH TIME ZONE", "timestamptz"),
|
||
])
|
||
def test_timestamp_types(self, raw, expected):
|
||
assert normalize_type(raw) == expected
|
||
|
||
@pytest.mark.parametrize("raw,expected", [
|
||
("BIGSERIAL", "bigint"),
|
||
("SERIAL", "integer"),
|
||
("SMALLSERIAL", "smallint"),
|
||
])
|
||
def test_serial_types(self, raw, expected):
|
||
"""serial 家族应映射到底层整数类型。"""
|
||
assert normalize_type(raw) == expected
|
||
|
||
def test_case_insensitive(self):
|
||
assert normalize_type("bigint") == normalize_type("BIGINT")
|
||
assert normalize_type("Numeric(18,2)") == normalize_type("NUMERIC(18,2)")
|
||
|
||
|
||
|
||
# =========================================================================
|
||
# parse_ddl 测试
|
||
# =========================================================================
|
||
|
||
class TestParseDdl:
|
||
"""DDL 解析器测试。"""
|
||
|
||
def test_basic_create_table(self):
|
||
"""基本 CREATE TABLE 解析。"""
|
||
sql = """
|
||
CREATE TABLE IF NOT EXISTS myschema.users (
|
||
id BIGINT NOT NULL,
|
||
name TEXT,
|
||
age INTEGER,
|
||
PRIMARY KEY (id)
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="myschema")
|
||
assert "users" in tables
|
||
t = tables["users"]
|
||
assert len(t.columns) == 3
|
||
assert t.pk_columns == ["id"]
|
||
assert t.columns["id"].data_type == "bigint"
|
||
assert t.columns["id"].nullable is False
|
||
assert t.columns["name"].data_type == "text"
|
||
assert t.columns["name"].nullable is True
|
||
assert t.columns["age"].data_type == "integer"
|
||
|
||
def test_inline_primary_key(self):
|
||
"""内联 PRIMARY KEY 约束。"""
|
||
sql = """
|
||
CREATE TABLE test_schema.items (
|
||
item_id BIGSERIAL PRIMARY KEY,
|
||
label TEXT NOT NULL
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="test_schema")
|
||
t = tables["items"]
|
||
assert t.columns["item_id"].is_pk is True
|
||
# BIGSERIAL → bigint
|
||
assert t.columns["item_id"].data_type == "bigint"
|
||
assert t.columns["item_id"].nullable is False
|
||
assert t.columns["label"].nullable is False
|
||
|
||
def test_composite_primary_key(self):
|
||
"""复合主键。"""
|
||
sql = """
|
||
CREATE TABLE IF NOT EXISTS billiards_ods.member_profiles (
|
||
id BIGINT,
|
||
content_hash TEXT NOT NULL,
|
||
name TEXT,
|
||
PRIMARY KEY (id, content_hash)
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||
t = tables["member_profiles"]
|
||
assert t.pk_columns == ["id", "content_hash"]
|
||
assert t.columns["id"].is_pk is True
|
||
assert t.columns["content_hash"].is_pk is True
|
||
# PK 隐含 NOT NULL
|
||
assert t.columns["id"].nullable is False
|
||
|
||
def test_various_data_types(self):
|
||
"""各种 PostgreSQL 数据类型。"""
|
||
sql = """
|
||
CREATE TABLE s.t (
|
||
a BIGINT,
|
||
b VARCHAR(50),
|
||
c NUMERIC(18,2),
|
||
d TIMESTAMP,
|
||
e TIMESTAMPTZ DEFAULT now(),
|
||
f BOOLEAN DEFAULT TRUE,
|
||
g JSONB NOT NULL,
|
||
h TEXT,
|
||
i INTEGER
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="s")
|
||
t = tables["t"]
|
||
assert t.columns["a"].data_type == "bigint"
|
||
assert t.columns["b"].data_type == "varchar(50)"
|
||
assert t.columns["c"].data_type == "numeric(18,2)"
|
||
assert t.columns["d"].data_type == "timestamp"
|
||
assert t.columns["e"].data_type == "timestamptz"
|
||
assert t.columns["f"].data_type == "boolean"
|
||
assert t.columns["g"].data_type == "jsonb"
|
||
assert t.columns["g"].nullable is False
|
||
assert t.columns["h"].data_type == "text"
|
||
assert t.columns["i"].data_type == "integer"
|
||
|
||
def test_without_schema_prefix(self):
|
||
"""无 schema 前缀的 CREATE TABLE(如 DWD DDL 中 SET search_path 后)。"""
|
||
sql = """
|
||
SET search_path TO billiards_dwd;
|
||
CREATE TABLE IF NOT EXISTS dim_site (
|
||
site_id BIGINT,
|
||
shop_name TEXT,
|
||
PRIMARY KEY (site_id)
|
||
);
|
||
"""
|
||
# target_schema 指定时,无前缀的表也应被接受
|
||
tables = parse_ddl(sql, target_schema="billiards_dwd")
|
||
assert "dim_site" in tables
|
||
|
||
def test_schema_filter(self):
|
||
"""schema 过滤:只保留目标 schema 的表。"""
|
||
sql = """
|
||
CREATE TABLE schema_a.t1 (id BIGINT);
|
||
CREATE TABLE schema_b.t2 (id BIGINT);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="schema_a")
|
||
assert "t1" in tables
|
||
assert "t2" not in tables
|
||
|
||
def test_empty_ddl(self):
|
||
"""空 DDL 文件应返回空字典。"""
|
||
tables = parse_ddl("", target_schema="any")
|
||
assert tables == {}
|
||
|
||
def test_comments_ignored(self):
|
||
"""SQL 注释不影响解析。"""
|
||
sql = """
|
||
-- 这是注释
|
||
/* 块注释 */
|
||
CREATE TABLE s.t (
|
||
id BIGINT, -- 行内注释
|
||
name TEXT
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="s")
|
||
assert "t" in tables
|
||
assert len(tables["t"].columns) == 2
|
||
|
||
def test_comment_on_statements_ignored(self):
|
||
"""COMMENT ON 语句不影响表解析。"""
|
||
sql = """
|
||
CREATE TABLE billiards_ods.test_table (
|
||
id BIGINT NOT NULL,
|
||
name TEXT,
|
||
PRIMARY KEY (id)
|
||
);
|
||
COMMENT ON TABLE billiards_ods.test_table IS '测试表:含特殊字符 ''引号'' 和 (括号)';
|
||
COMMENT ON COLUMN billiards_ods.test_table.id IS '【说明】主键 ID。【示例】12345。';
|
||
COMMENT ON COLUMN billiards_ods.test_table.name IS '【说明】名称,含 ''单引号'' 和 "双引号"。';
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||
assert "test_table" in tables
|
||
assert len(tables["test_table"].columns) == 2
|
||
|
||
def test_drop_then_create(self):
|
||
"""DROP TABLE 后 CREATE TABLE 应正常解析。"""
|
||
sql = """
|
||
DROP TABLE IF EXISTS billiards_dws.cfg_test CASCADE;
|
||
CREATE TABLE billiards_dws.cfg_test (
|
||
id SERIAL PRIMARY KEY,
|
||
value TEXT
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="billiards_dws")
|
||
assert "cfg_test" in tables
|
||
assert tables["cfg_test"].columns["id"].data_type == "integer"
|
||
|
||
def test_default_values_parsed(self):
|
||
"""DEFAULT 值不影响类型和约束解析。"""
|
||
sql = """
|
||
CREATE TABLE s.t (
|
||
enabled BOOLEAN DEFAULT TRUE,
|
||
created_at TIMESTAMPTZ DEFAULT now(),
|
||
count INTEGER DEFAULT 0 NOT NULL,
|
||
label VARCHAR(20) NOT NULL
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="s")
|
||
t = tables["t"]
|
||
assert t.columns["enabled"].data_type == "boolean"
|
||
assert t.columns["enabled"].nullable is True
|
||
assert t.columns["created_at"].data_type == "timestamptz"
|
||
assert t.columns["count"].data_type == "integer"
|
||
assert t.columns["count"].nullable is False
|
||
assert t.columns["label"].data_type == "varchar(20)"
|
||
assert t.columns["label"].nullable is False
|
||
|
||
def test_constraint_lines_skipped(self):
|
||
"""表级约束行(CONSTRAINT、UNIQUE、FOREIGN KEY)应被跳过。"""
|
||
sql = """
|
||
CREATE TABLE etl_admin.etl_task (
|
||
task_id BIGSERIAL PRIMARY KEY,
|
||
task_code TEXT NOT NULL,
|
||
store_id BIGINT NOT NULL,
|
||
UNIQUE (task_code, store_id)
|
||
);
|
||
"""
|
||
tables = parse_ddl(sql, target_schema="etl_admin")
|
||
t = tables["etl_task"]
|
||
assert len(t.columns) == 3
|
||
assert "task_id" in t.columns
|
||
assert "task_code" in t.columns
|
||
assert "store_id" in t.columns
|
||
|
||
def test_real_ods_ddl_parseable(self):
|
||
"""验证实际 ODS DDL 文件可被解析。"""
|
||
from pathlib import Path
|
||
ddl_path = Path("database/schema_ODS_doc.sql")
|
||
if not ddl_path.exists():
|
||
pytest.skip("DDL 文件不存在")
|
||
sql = ddl_path.read_text(encoding="utf-8")
|
||
tables = parse_ddl(sql, target_schema="billiards_ods")
|
||
# 至少应有 20+ 张表
|
||
assert len(tables) >= 20
|
||
# 每张表都应有字段
|
||
for tbl in tables.values():
|
||
assert len(tbl.columns) > 0
|
||
|
||
|
||
|
||
# =========================================================================
|
||
# compare_tables 测试
|
||
# =========================================================================
|
||
|
||
class TestCompareTables:
|
||
"""差异检测逻辑测试。"""
|
||
|
||
def _make_table(self, name: str, columns: dict[str, tuple[str, bool]],
|
||
pk: list[str] | None = None) -> TableDef:
|
||
"""辅助方法:快速构建 TableDef。
|
||
|
||
columns: {col_name: (data_type, nullable)}
|
||
"""
|
||
cols = {}
|
||
for col_name, (dtype, nullable) in columns.items():
|
||
cols[col_name] = ColumnDef(
|
||
name=col_name,
|
||
data_type=dtype,
|
||
nullable=nullable,
|
||
is_pk=col_name in (pk or []),
|
||
)
|
||
return TableDef(name=name, columns=cols, pk_columns=pk or [])
|
||
|
||
def test_no_diff(self):
|
||
"""完全一致时应返回空列表。"""
|
||
t = self._make_table("t", {"id": ("bigint", False), "name": ("text", True)}, pk=["id"])
|
||
diffs = compare_tables({"t": t}, {"t": t})
|
||
assert diffs == []
|
||
|
||
def test_missing_table(self):
|
||
"""数据库有但 DDL 没有 → MISSING_TABLE。"""
|
||
ddl = {}
|
||
db = {"extra": self._make_table("extra", {"id": ("bigint", False)})}
|
||
diffs = compare_tables(ddl, db)
|
||
assert len(diffs) == 1
|
||
assert diffs[0].kind == DiffKind.MISSING_TABLE
|
||
assert diffs[0].table == "extra"
|
||
|
||
def test_extra_table(self):
|
||
"""DDL 有但数据库没有 → EXTRA_TABLE。"""
|
||
ddl = {"orphan": self._make_table("orphan", {"id": ("bigint", False)})}
|
||
db = {}
|
||
diffs = compare_tables(ddl, db)
|
||
assert len(diffs) == 1
|
||
assert diffs[0].kind == DiffKind.EXTRA_TABLE
|
||
assert diffs[0].table == "orphan"
|
||
|
||
def test_missing_column(self):
|
||
"""数据库有但 DDL 没有的字段 → MISSING_COLUMN。"""
|
||
ddl = {"t": self._make_table("t", {"id": ("bigint", False)})}
|
||
db = {"t": self._make_table("t", {
|
||
"id": ("bigint", False),
|
||
"new_col": ("text", True),
|
||
})}
|
||
diffs = compare_tables(ddl, db)
|
||
assert len(diffs) == 1
|
||
assert diffs[0].kind == DiffKind.MISSING_COLUMN
|
||
assert diffs[0].column == "new_col"
|
||
|
||
def test_extra_column(self):
|
||
"""DDL 有但数据库没有的字段 → EXTRA_COLUMN。"""
|
||
ddl = {"t": self._make_table("t", {
|
||
"id": ("bigint", False),
|
||
"old_col": ("text", True),
|
||
})}
|
||
db = {"t": self._make_table("t", {"id": ("bigint", False)})}
|
||
diffs = compare_tables(ddl, db)
|
||
assert len(diffs) == 1
|
||
assert diffs[0].kind == DiffKind.EXTRA_COLUMN
|
||
assert diffs[0].column == "old_col"
|
||
|
||
def test_type_mismatch(self):
|
||
"""字段类型不一致 → TYPE_MISMATCH。"""
|
||
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
|
||
db = {"t": self._make_table("t", {"val": ("varchar(100)", True)})}
|
||
diffs = compare_tables(ddl, db)
|
||
assert len(diffs) == 1
|
||
assert diffs[0].kind == DiffKind.TYPE_MISMATCH
|
||
assert diffs[0].ddl_value == "text"
|
||
assert diffs[0].db_value == "varchar(100)"
|
||
|
||
def test_nullable_mismatch(self):
|
||
"""可空约束不一致 → NULLABLE_MISMATCH。"""
|
||
ddl = {"t": self._make_table("t", {"val": ("text", True)})}
|
||
db = {"t": self._make_table("t", {"val": ("text", False)})}
|
||
diffs = compare_tables(ddl, db)
|
||
assert len(diffs) == 1
|
||
assert diffs[0].kind == DiffKind.NULLABLE_MISMATCH
|
||
assert diffs[0].ddl_value == "NULL"
|
||
assert diffs[0].db_value == "NOT NULL"
|
||
|
||
def test_multiple_diffs(self):
|
||
"""多种差异同时存在。"""
|
||
ddl = {
|
||
"t1": self._make_table("t1", {
|
||
"id": ("bigint", False),
|
||
"extra": ("text", True),
|
||
}),
|
||
"ddl_only": self._make_table("ddl_only", {"x": ("integer", True)}),
|
||
}
|
||
db = {
|
||
"t1": self._make_table("t1", {
|
||
"id": ("integer", False), # TYPE_MISMATCH
|
||
"missing": ("text", True), # MISSING_COLUMN
|
||
}),
|
||
"db_only": self._make_table("db_only", {"y": ("text", True)}),
|
||
}
|
||
diffs = compare_tables(ddl, db)
|
||
kinds = {d.kind for d in diffs}
|
||
assert DiffKind.MISSING_TABLE in kinds # db_only
|
||
assert DiffKind.EXTRA_TABLE in kinds # ddl_only
|
||
assert DiffKind.MISSING_COLUMN in kinds # t1.missing
|
||
assert DiffKind.EXTRA_COLUMN in kinds # t1.extra
|
||
assert DiffKind.TYPE_MISMATCH in kinds # t1.id
|
||
|
||
def test_empty_both(self):
|
||
"""两边都为空时应返回空列表。"""
|
||
assert compare_tables({}, {}) == []
|