Files
Neo-ZQYY/apps/backend/tests/test_db_viewer_router.py

322 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""数据库查看器路由单元测试
覆盖 4 个端点:
- GET /api/db/schemas
- GET /api/db/schemas/{name}/tables
- GET /api/db/tables/{schema}/{table}/columns
- POST /api/db/query
通过 mock 绕过数据库连接,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from psycopg2 import errors as pg_errors
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
def _make_mock_conn(rows, description=None):
"""构造 mock 数据库连接cursor 返回指定行和列描述。"""
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = rows
mock_cur.fetchmany.return_value = rows
mock_cur.description = description
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cur
# ---------------------------------------------------------------------------
# GET /api/db/schemas
# ---------------------------------------------------------------------------
class TestListSchemas:
@patch(_MOCK_CONN)
def test_returns_schema_list(self, mock_get_conn):
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "dwd"
assert data[2]["name"] == "ods"
# 验证 site_id 传递
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
conn.close.assert_called_once()
@patch(_MOCK_CONN)
def test_empty_schemas(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/schemas/{name}/tables
# ---------------------------------------------------------------------------
class TestListTables:
@patch(_MOCK_CONN)
def test_returns_tables_with_row_count(self, mock_get_conn):
conn, cur = _make_mock_conn([
("dim_member", 1500),
("fact_order", 32000),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/dwd/tables")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["name"] == "dim_member"
assert data[0]["row_count"] == 1500
assert data[1]["name"] == "fact_order"
assert data[1]["row_count"] == 32000
@patch(_MOCK_CONN)
def test_null_row_count(self, mock_get_conn):
"""pg_stat_user_tables 可能没有统计信息row_count 为 None。"""
conn, cur = _make_mock_conn([("new_table", None)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/ods/tables")
assert resp.status_code == 200
data = resp.json()
assert data[0]["row_count"] is None
@patch(_MOCK_CONN)
def test_empty_schema(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/empty_schema/tables")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/tables/{schema}/{table}/columns
# ---------------------------------------------------------------------------
class TestListColumns:
@patch(_MOCK_CONN)
def test_returns_column_definitions(self, mock_get_conn):
conn, cur = _make_mock_conn([
("id", "bigint", "NO", None),
("name", "character varying", "YES", None),
("created_at", "timestamp with time zone", "NO", "now()"),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/dim_member/columns")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "id"
assert data[0]["data_type"] == "bigint"
assert data[0]["is_nullable"] is False
assert data[0]["column_default"] is None
assert data[1]["is_nullable"] is True
assert data[2]["column_default"] == "now()"
@patch(_MOCK_CONN)
def test_empty_table(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# POST /api/db/query
# ---------------------------------------------------------------------------
class TestExecuteQuery:
@patch(_MOCK_CONN)
def test_successful_select(self, mock_get_conn):
description = [("id",), ("name",)]
conn, cur = _make_mock_conn(
[(1, "Alice"), (2, "Bob")],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id", "name"]
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
assert data["row_count"] == 2
@patch(_MOCK_CONN)
def test_empty_result(self, mock_get_conn):
description = [("id",)]
conn, cur = _make_mock_conn([], description=description)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id"]
assert data["rows"] == []
assert data["row_count"] == 0
# ── 写操作拦截 ──
@pytest.mark.parametrize("keyword", [
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
"insert", "update", "delete", "drop", "truncate",
"Insert", "Update", "Delete", "Drop", "Truncate",
])
def test_blocks_write_operations(self, keyword):
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
assert resp.status_code == 400
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
def test_blocks_mixed_case_write(self):
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
assert resp.status_code == 400
def test_blocks_write_in_subquery(self):
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
assert resp.status_code == 400
# ── 空 SQL ──
def test_empty_sql(self):
resp = client.post("/api/db/query", json={"sql": ""})
assert resp.status_code == 400
def test_whitespace_only_sql(self):
resp = client.post("/api/db/query", json={"sql": " "})
assert resp.status_code == 400
# ── SQL 语法错误 ──
@patch(_MOCK_CONN)
def test_sql_syntax_error(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
# 第一次 execute 设置 timeout 成功,第二次抛异常
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
assert resp.status_code == 400
assert "SQL 执行错误" in resp.json()["detail"]
# ── 查询超时 ──
@patch(_MOCK_CONN)
def test_query_timeout(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
assert resp.status_code == 408
assert "超时" in resp.json()["detail"]
# ── 行数限制验证 ──
@patch(_MOCK_CONN)
def test_row_limit(self, mock_get_conn):
"""验证 fetchmany 被调用时传入 1000 行限制。"""
description = [("id",)]
conn, cur = _make_mock_conn(
[(i,) for i in range(1000)],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
assert resp.status_code == 200
# 验证 fetchmany 被调用时传入了 1000
cur.fetchmany.assert_called_once_with(1000)
# ── 超时设置验证 ──
@patch(_MOCK_CONN)
def test_sets_statement_timeout(self, mock_get_conn):
"""验证查询前设置了 statement_timeout。"""
description = [("id",)]
conn, cur = _make_mock_conn([(1,)], description=description)
mock_get_conn.return_value = conn
client.post("/api/db/query", json={"sql": "SELECT 1"})
# 第一次 execute 应该是设置超时
first_call = cur.execute.call_args_list[0]
assert "statement_timeout" in first_call[0][0]
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestDbViewerAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
original = app.dependency_overrides.pop(get_current_user, None)
try:
endpoints = [
("GET", "/api/db/schemas"),
("GET", "/api/db/schemas/dwd/tables"),
("GET", "/api/db/tables/dwd/dim_member/columns"),
("POST", "/api/db/query"),
]
for method, url in endpoints:
if method == "POST":
resp = client.request(method, url, json={"sql": "SELECT 1"})
else:
resp = client.request(method, url)
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original