322 lines
11 KiB
Python
322 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""数据库查看器路由单元测试
|
||
|
||
覆盖 4 个端点:
|
||
- GET /api/db/schemas
|
||
- GET /api/db/schemas/{name}/tables
|
||
- GET /api/db/tables/{schema}/{table}/columns
|
||
- POST /api/db/query
|
||
|
||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||
"""
|
||
|
||
import os
|
||
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from psycopg2 import errors as pg_errors
|
||
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.main import app
|
||
|
||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||
|
||
|
||
def _override_auth():
|
||
return _TEST_USER
|
||
|
||
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
client = TestClient(app)
|
||
|
||
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
|
||
|
||
|
||
def _make_mock_conn(rows, description=None):
|
||
"""构造 mock 数据库连接,cursor 返回指定行和列描述。"""
|
||
mock_conn = MagicMock()
|
||
mock_cur = MagicMock()
|
||
mock_cur.fetchall.return_value = rows
|
||
mock_cur.fetchmany.return_value = rows
|
||
mock_cur.description = description
|
||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
return mock_conn, mock_cur
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/db/schemas
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestListSchemas:
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_returns_schema_list(self, mock_get_conn):
|
||
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/schemas")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data) == 3
|
||
assert data[0]["name"] == "dwd"
|
||
assert data[2]["name"] == "ods"
|
||
|
||
# 验证 site_id 传递
|
||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||
conn.close.assert_called_once()
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_empty_schemas(self, mock_get_conn):
|
||
conn, cur = _make_mock_conn([])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/schemas")
|
||
assert resp.status_code == 200
|
||
assert resp.json() == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/db/schemas/{name}/tables
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestListTables:
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_returns_tables_with_row_count(self, mock_get_conn):
|
||
conn, cur = _make_mock_conn([
|
||
("dim_member", 1500),
|
||
("fact_order", 32000),
|
||
])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/schemas/dwd/tables")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data) == 2
|
||
assert data[0]["name"] == "dim_member"
|
||
assert data[0]["row_count"] == 1500
|
||
assert data[1]["name"] == "fact_order"
|
||
assert data[1]["row_count"] == 32000
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_null_row_count(self, mock_get_conn):
|
||
"""pg_stat_user_tables 可能没有统计信息,row_count 为 None。"""
|
||
conn, cur = _make_mock_conn([("new_table", None)])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/schemas/ods/tables")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data[0]["row_count"] is None
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_empty_schema(self, mock_get_conn):
|
||
conn, cur = _make_mock_conn([])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/schemas/empty_schema/tables")
|
||
assert resp.status_code == 200
|
||
assert resp.json() == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/db/tables/{schema}/{table}/columns
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestListColumns:
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_returns_column_definitions(self, mock_get_conn):
|
||
conn, cur = _make_mock_conn([
|
||
("id", "bigint", "NO", None),
|
||
("name", "character varying", "YES", None),
|
||
("created_at", "timestamp with time zone", "NO", "now()"),
|
||
])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/tables/dwd/dim_member/columns")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data) == 3
|
||
|
||
assert data[0]["name"] == "id"
|
||
assert data[0]["data_type"] == "bigint"
|
||
assert data[0]["is_nullable"] is False
|
||
assert data[0]["column_default"] is None
|
||
|
||
assert data[1]["is_nullable"] is True
|
||
assert data[2]["column_default"] == "now()"
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_empty_table(self, mock_get_conn):
|
||
conn, cur = _make_mock_conn([])
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
|
||
assert resp.status_code == 200
|
||
assert resp.json() == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/db/query
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestExecuteQuery:
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_successful_select(self, mock_get_conn):
|
||
description = [("id",), ("name",)]
|
||
conn, cur = _make_mock_conn(
|
||
[(1, "Alice"), (2, "Bob")],
|
||
description=description,
|
||
)
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["columns"] == ["id", "name"]
|
||
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
|
||
assert data["row_count"] == 2
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_empty_result(self, mock_get_conn):
|
||
description = [("id",)]
|
||
conn, cur = _make_mock_conn([], description=description)
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["columns"] == ["id"]
|
||
assert data["rows"] == []
|
||
assert data["row_count"] == 0
|
||
|
||
# ── 写操作拦截 ──
|
||
|
||
@pytest.mark.parametrize("keyword", [
|
||
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
|
||
"insert", "update", "delete", "drop", "truncate",
|
||
"Insert", "Update", "Delete", "Drop", "Truncate",
|
||
])
|
||
def test_blocks_write_operations(self, keyword):
|
||
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
|
||
assert resp.status_code == 400
|
||
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
|
||
|
||
def test_blocks_mixed_case_write(self):
|
||
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
|
||
assert resp.status_code == 400
|
||
|
||
def test_blocks_write_in_subquery(self):
|
||
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
|
||
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
|
||
assert resp.status_code == 400
|
||
|
||
# ── 空 SQL ──
|
||
|
||
def test_empty_sql(self):
|
||
resp = client.post("/api/db/query", json={"sql": ""})
|
||
assert resp.status_code == 400
|
||
|
||
def test_whitespace_only_sql(self):
|
||
resp = client.post("/api/db/query", json={"sql": " "})
|
||
assert resp.status_code == 400
|
||
|
||
# ── SQL 语法错误 ──
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_sql_syntax_error(self, mock_get_conn):
|
||
conn = MagicMock()
|
||
mock_cur = MagicMock()
|
||
# 第一次 execute 设置 timeout 成功,第二次抛异常
|
||
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
|
||
mock_cur.description = None
|
||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
|
||
assert resp.status_code == 400
|
||
assert "SQL 执行错误" in resp.json()["detail"]
|
||
|
||
# ── 查询超时 ──
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_query_timeout(self, mock_get_conn):
|
||
conn = MagicMock()
|
||
mock_cur = MagicMock()
|
||
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
|
||
mock_cur.description = None
|
||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
|
||
assert resp.status_code == 408
|
||
assert "超时" in resp.json()["detail"]
|
||
|
||
# ── 行数限制验证 ──
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_row_limit(self, mock_get_conn):
|
||
"""验证 fetchmany 被调用时传入 1000 行限制。"""
|
||
description = [("id",)]
|
||
conn, cur = _make_mock_conn(
|
||
[(i,) for i in range(1000)],
|
||
description=description,
|
||
)
|
||
mock_get_conn.return_value = conn
|
||
|
||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
|
||
assert resp.status_code == 200
|
||
# 验证 fetchmany 被调用时传入了 1000
|
||
cur.fetchmany.assert_called_once_with(1000)
|
||
|
||
# ── 超时设置验证 ──
|
||
|
||
@patch(_MOCK_CONN)
|
||
def test_sets_statement_timeout(self, mock_get_conn):
|
||
"""验证查询前设置了 statement_timeout。"""
|
||
description = [("id",)]
|
||
conn, cur = _make_mock_conn([(1,)], description=description)
|
||
mock_get_conn.return_value = conn
|
||
|
||
client.post("/api/db/query", json={"sql": "SELECT 1"})
|
||
|
||
# 第一次 execute 应该是设置超时
|
||
first_call = cur.execute.call_args_list[0]
|
||
assert "statement_timeout" in first_call[0][0]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 认证测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDbViewerAuth:
|
||
|
||
def test_requires_auth(self):
|
||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||
original = app.dependency_overrides.pop(get_current_user, None)
|
||
try:
|
||
endpoints = [
|
||
("GET", "/api/db/schemas"),
|
||
("GET", "/api/db/schemas/dwd/tables"),
|
||
("GET", "/api/db/tables/dwd/dim_member/columns"),
|
||
("POST", "/api/db/query"),
|
||
]
|
||
for method, url in endpoints:
|
||
if method == "POST":
|
||
resp = client.request(method, url, json={"sql": "SELECT 1"})
|
||
else:
|
||
resp = client.request(method, url)
|
||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||
finally:
|
||
if original:
|
||
app.dependency_overrides[get_current_user] = original
|