"""DDL 解析器和对比逻辑的单元测试。 测试范围: - DDL 解析器正确提取表名、字段、类型、约束 - 类型标准化逻辑 - 差异检测逻辑识别各类差异 - 边界情况:空文件、COMMENT 含特殊字符 """ import pytest from scripts.compare_ddl_db import ( ColumnDef, DiffKind, SchemaDiff, TableDef, compare_tables, normalize_type, parse_ddl, ) # ========================================================================= # normalize_type 测试 # ========================================================================= class TestNormalizeType: """类型标准化测试。""" @pytest.mark.parametrize("raw,expected", [ ("BIGINT", "bigint"), ("INT8", "bigint"), ("INTEGER", "integer"), ("INT", "integer"), ("INT4", "integer"), ("SMALLINT", "smallint"), ("INT2", "smallint"), ("BOOLEAN", "boolean"), ("BOOL", "boolean"), ("TEXT", "text"), ("JSONB", "jsonb"), ("JSON", "json"), ("DATE", "date"), ("BYTEA", "bytea"), ("UUID", "uuid"), ]) def test_simple_types(self, raw, expected): assert normalize_type(raw) == expected @pytest.mark.parametrize("raw,expected", [ ("NUMERIC(18,2)", "numeric(18,2)"), ("NUMERIC(10,6)", "numeric(10,6)"), ("DECIMAL(5,2)", "numeric(5,2)"), ("NUMERIC(10)", "numeric(10)"), ("NUMERIC", "numeric"), ]) def test_numeric_types(self, raw, expected): assert normalize_type(raw) == expected @pytest.mark.parametrize("raw,expected", [ ("VARCHAR(50)", "varchar(50)"), ("CHARACTER VARYING(100)", "varchar(100)"), ("VARCHAR", "varchar"), ("CHAR(1)", "char(1)"), ("CHARACTER(10)", "char(10)"), ]) def test_string_types(self, raw, expected): assert normalize_type(raw) == expected @pytest.mark.parametrize("raw,expected", [ ("TIMESTAMP", "timestamp"), ("TIMESTAMP WITHOUT TIME ZONE", "timestamp"), ("TIMESTAMPTZ", "timestamptz"), ("TIMESTAMP WITH TIME ZONE", "timestamptz"), ]) def test_timestamp_types(self, raw, expected): assert normalize_type(raw) == expected @pytest.mark.parametrize("raw,expected", [ ("BIGSERIAL", "bigint"), ("SERIAL", "integer"), ("SMALLSERIAL", "smallint"), ]) def test_serial_types(self, raw, expected): """serial 家族应映射到底层整数类型。""" assert normalize_type(raw) == expected def test_case_insensitive(self): assert normalize_type("bigint") == normalize_type("BIGINT") assert normalize_type("Numeric(18,2)") == normalize_type("NUMERIC(18,2)") # ========================================================================= # parse_ddl 测试 # ========================================================================= class TestParseDdl: """DDL 解析器测试。""" def test_basic_create_table(self): """基本 CREATE TABLE 解析。""" sql = """ CREATE TABLE IF NOT EXISTS myschema.users ( id BIGINT NOT NULL, name TEXT, age INTEGER, PRIMARY KEY (id) ); """ tables = parse_ddl(sql, target_schema="myschema") assert "users" in tables t = tables["users"] assert len(t.columns) == 3 assert t.pk_columns == ["id"] assert t.columns["id"].data_type == "bigint" assert t.columns["id"].nullable is False assert t.columns["name"].data_type == "text" assert t.columns["name"].nullable is True assert t.columns["age"].data_type == "integer" def test_inline_primary_key(self): """内联 PRIMARY KEY 约束。""" sql = """ CREATE TABLE test_schema.items ( item_id BIGSERIAL PRIMARY KEY, label TEXT NOT NULL ); """ tables = parse_ddl(sql, target_schema="test_schema") t = tables["items"] assert t.columns["item_id"].is_pk is True # BIGSERIAL → bigint assert t.columns["item_id"].data_type == "bigint" assert t.columns["item_id"].nullable is False assert t.columns["label"].nullable is False def test_composite_primary_key(self): """复合主键。""" sql = """ CREATE TABLE IF NOT EXISTS billiards_ods.member_profiles ( id BIGINT, content_hash TEXT NOT NULL, name TEXT, PRIMARY KEY (id, content_hash) ); """ tables = parse_ddl(sql, target_schema="billiards_ods") t = tables["member_profiles"] assert t.pk_columns == ["id", "content_hash"] assert t.columns["id"].is_pk is True assert t.columns["content_hash"].is_pk is True # PK 隐含 NOT NULL assert t.columns["id"].nullable is False def test_various_data_types(self): """各种 PostgreSQL 数据类型。""" sql = """ CREATE TABLE s.t ( a BIGINT, b VARCHAR(50), c NUMERIC(18,2), d TIMESTAMP, e TIMESTAMPTZ DEFAULT now(), f BOOLEAN DEFAULT TRUE, g JSONB NOT NULL, h TEXT, i INTEGER ); """ tables = parse_ddl(sql, target_schema="s") t = tables["t"] assert t.columns["a"].data_type == "bigint" assert t.columns["b"].data_type == "varchar(50)" assert t.columns["c"].data_type == "numeric(18,2)" assert t.columns["d"].data_type == "timestamp" assert t.columns["e"].data_type == "timestamptz" assert t.columns["f"].data_type == "boolean" assert t.columns["g"].data_type == "jsonb" assert t.columns["g"].nullable is False assert t.columns["h"].data_type == "text" assert t.columns["i"].data_type == "integer" def test_without_schema_prefix(self): """无 schema 前缀的 CREATE TABLE(如 DWD DDL 中 SET search_path 后)。""" sql = """ SET search_path TO billiards_dwd; CREATE TABLE IF NOT EXISTS dim_site ( site_id BIGINT, shop_name TEXT, PRIMARY KEY (site_id) ); """ # target_schema 指定时,无前缀的表也应被接受 tables = parse_ddl(sql, target_schema="billiards_dwd") assert "dim_site" in tables def test_schema_filter(self): """schema 过滤:只保留目标 schema 的表。""" sql = """ CREATE TABLE schema_a.t1 (id BIGINT); CREATE TABLE schema_b.t2 (id BIGINT); """ tables = parse_ddl(sql, target_schema="schema_a") assert "t1" in tables assert "t2" not in tables def test_empty_ddl(self): """空 DDL 文件应返回空字典。""" tables = parse_ddl("", target_schema="any") assert tables == {} def test_comments_ignored(self): """SQL 注释不影响解析。""" sql = """ -- 这是注释 /* 块注释 */ CREATE TABLE s.t ( id BIGINT, -- 行内注释 name TEXT ); """ tables = parse_ddl(sql, target_schema="s") assert "t" in tables assert len(tables["t"].columns) == 2 def test_comment_on_statements_ignored(self): """COMMENT ON 语句不影响表解析。""" sql = """ CREATE TABLE billiards_ods.test_table ( id BIGINT NOT NULL, name TEXT, PRIMARY KEY (id) ); COMMENT ON TABLE billiards_ods.test_table IS '测试表:含特殊字符 ''引号'' 和 (括号)'; COMMENT ON COLUMN billiards_ods.test_table.id IS '【说明】主键 ID。【示例】12345。'; COMMENT ON COLUMN billiards_ods.test_table.name IS '【说明】名称,含 ''单引号'' 和 "双引号"。'; """ tables = parse_ddl(sql, target_schema="billiards_ods") assert "test_table" in tables assert len(tables["test_table"].columns) == 2 def test_drop_then_create(self): """DROP TABLE 后 CREATE TABLE 应正常解析。""" sql = """ DROP TABLE IF EXISTS billiards_dws.cfg_test CASCADE; CREATE TABLE billiards_dws.cfg_test ( id SERIAL PRIMARY KEY, value TEXT ); """ tables = parse_ddl(sql, target_schema="billiards_dws") assert "cfg_test" in tables assert tables["cfg_test"].columns["id"].data_type == "integer" def test_default_values_parsed(self): """DEFAULT 值不影响类型和约束解析。""" sql = """ CREATE TABLE s.t ( enabled BOOLEAN DEFAULT TRUE, created_at TIMESTAMPTZ DEFAULT now(), count INTEGER DEFAULT 0 NOT NULL, label VARCHAR(20) NOT NULL ); """ tables = parse_ddl(sql, target_schema="s") t = tables["t"] assert t.columns["enabled"].data_type == "boolean" assert t.columns["enabled"].nullable is True assert t.columns["created_at"].data_type == "timestamptz" assert t.columns["count"].data_type == "integer" assert t.columns["count"].nullable is False assert t.columns["label"].data_type == "varchar(20)" assert t.columns["label"].nullable is False def test_constraint_lines_skipped(self): """表级约束行(CONSTRAINT、UNIQUE、FOREIGN KEY)应被跳过。""" sql = """ CREATE TABLE etl_admin.etl_task ( task_id BIGSERIAL PRIMARY KEY, task_code TEXT NOT NULL, store_id BIGINT NOT NULL, UNIQUE (task_code, store_id) ); """ tables = parse_ddl(sql, target_schema="etl_admin") t = tables["etl_task"] assert len(t.columns) == 3 assert "task_id" in t.columns assert "task_code" in t.columns assert "store_id" in t.columns def test_real_ods_ddl_parseable(self): """验证实际 ODS DDL 文件可被解析。""" from pathlib import Path ddl_path = Path("database/schema_ODS_doc.sql") if not ddl_path.exists(): pytest.skip("DDL 文件不存在") sql = ddl_path.read_text(encoding="utf-8") tables = parse_ddl(sql, target_schema="billiards_ods") # 至少应有 20+ 张表 assert len(tables) >= 20 # 每张表都应有字段 for tbl in tables.values(): assert len(tbl.columns) > 0 # ========================================================================= # compare_tables 测试 # ========================================================================= class TestCompareTables: """差异检测逻辑测试。""" def _make_table(self, name: str, columns: dict[str, tuple[str, bool]], pk: list[str] | None = None) -> TableDef: """辅助方法:快速构建 TableDef。 columns: {col_name: (data_type, nullable)} """ cols = {} for col_name, (dtype, nullable) in columns.items(): cols[col_name] = ColumnDef( name=col_name, data_type=dtype, nullable=nullable, is_pk=col_name in (pk or []), ) return TableDef(name=name, columns=cols, pk_columns=pk or []) def test_no_diff(self): """完全一致时应返回空列表。""" t = self._make_table("t", {"id": ("bigint", False), "name": ("text", True)}, pk=["id"]) diffs = compare_tables({"t": t}, {"t": t}) assert diffs == [] def test_missing_table(self): """数据库有但 DDL 没有 → MISSING_TABLE。""" ddl = {} db = {"extra": self._make_table("extra", {"id": ("bigint", False)})} diffs = compare_tables(ddl, db) assert len(diffs) == 1 assert diffs[0].kind == DiffKind.MISSING_TABLE assert diffs[0].table == "extra" def test_extra_table(self): """DDL 有但数据库没有 → EXTRA_TABLE。""" ddl = {"orphan": self._make_table("orphan", {"id": ("bigint", False)})} db = {} diffs = compare_tables(ddl, db) assert len(diffs) == 1 assert diffs[0].kind == DiffKind.EXTRA_TABLE assert diffs[0].table == "orphan" def test_missing_column(self): """数据库有但 DDL 没有的字段 → MISSING_COLUMN。""" ddl = {"t": self._make_table("t", {"id": ("bigint", False)})} db = {"t": self._make_table("t", { "id": ("bigint", False), "new_col": ("text", True), })} diffs = compare_tables(ddl, db) assert len(diffs) == 1 assert diffs[0].kind == DiffKind.MISSING_COLUMN assert diffs[0].column == "new_col" def test_extra_column(self): """DDL 有但数据库没有的字段 → EXTRA_COLUMN。""" ddl = {"t": self._make_table("t", { "id": ("bigint", False), "old_col": ("text", True), })} db = {"t": self._make_table("t", {"id": ("bigint", False)})} diffs = compare_tables(ddl, db) assert len(diffs) == 1 assert diffs[0].kind == DiffKind.EXTRA_COLUMN assert diffs[0].column == "old_col" def test_type_mismatch(self): """字段类型不一致 → TYPE_MISMATCH。""" ddl = {"t": self._make_table("t", {"val": ("text", True)})} db = {"t": self._make_table("t", {"val": ("varchar(100)", True)})} diffs = compare_tables(ddl, db) assert len(diffs) == 1 assert diffs[0].kind == DiffKind.TYPE_MISMATCH assert diffs[0].ddl_value == "text" assert diffs[0].db_value == "varchar(100)" def test_nullable_mismatch(self): """可空约束不一致 → NULLABLE_MISMATCH。""" ddl = {"t": self._make_table("t", {"val": ("text", True)})} db = {"t": self._make_table("t", {"val": ("text", False)})} diffs = compare_tables(ddl, db) assert len(diffs) == 1 assert diffs[0].kind == DiffKind.NULLABLE_MISMATCH assert diffs[0].ddl_value == "NULL" assert diffs[0].db_value == "NOT NULL" def test_multiple_diffs(self): """多种差异同时存在。""" ddl = { "t1": self._make_table("t1", { "id": ("bigint", False), "extra": ("text", True), }), "ddl_only": self._make_table("ddl_only", {"x": ("integer", True)}), } db = { "t1": self._make_table("t1", { "id": ("integer", False), # TYPE_MISMATCH "missing": ("text", True), # MISSING_COLUMN }), "db_only": self._make_table("db_only", {"y": ("text", True)}), } diffs = compare_tables(ddl, db) kinds = {d.kind for d in diffs} assert DiffKind.MISSING_TABLE in kinds # db_only assert DiffKind.EXTRA_TABLE in kinds # ddl_only assert DiffKind.MISSING_COLUMN in kinds # t1.missing assert DiffKind.EXTRA_COLUMN in kinds # t1.extra assert DiffKind.TYPE_MISMATCH in kinds # t1.id def test_empty_both(self): """两边都为空时应返回空列表。""" assert compare_tables({}, {}) == []