在准备环境前提交次全部更改。
This commit is contained in:
321
apps/backend/tests/test_db_viewer_router.py
Normal file
321
apps/backend/tests/test_db_viewer_router.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""数据库查看器路由单元测试
|
||||
|
||||
覆盖 4 个端点:
|
||||
- GET /api/db/schemas
|
||||
- GET /api/db/schemas/{name}/tables
|
||||
- GET /api/db/tables/{schema}/{table}/columns
|
||||
- POST /api/db/query
|
||||
|
||||
通过 mock 绕过数据库连接,专注路由逻辑验证。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from psycopg2 import errors as pg_errors
|
||||
|
||||
from app.auth.dependencies import CurrentUser, get_current_user
|
||||
from app.main import app
|
||||
|
||||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||||
|
||||
|
||||
def _override_auth():
|
||||
return _TEST_USER
|
||||
|
||||
|
||||
app.dependency_overrides[get_current_user] = _override_auth
|
||||
client = TestClient(app)
|
||||
|
||||
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
|
||||
|
||||
|
||||
def _make_mock_conn(rows, description=None):
|
||||
"""构造 mock 数据库连接,cursor 返回指定行和列描述。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.fetchall.return_value = rows
|
||||
mock_cur.fetchmany.return_value = rows
|
||||
mock_cur.description = description
|
||||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn, mock_cur
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListSchemas:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_schema_list(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
assert data[0]["name"] == "dwd"
|
||||
assert data[2]["name"] == "ods"
|
||||
|
||||
# 验证 site_id 传递
|
||||
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
|
||||
conn.close.assert_called_once()
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schemas(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/schemas/{name}/tables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListTables:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_tables_with_row_count(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("dim_member", 1500),
|
||||
("fact_order", 32000),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/dwd/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == "dim_member"
|
||||
assert data[0]["row_count"] == 1500
|
||||
assert data[1]["name"] == "fact_order"
|
||||
assert data[1]["row_count"] == 32000
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_null_row_count(self, mock_get_conn):
|
||||
"""pg_stat_user_tables 可能没有统计信息,row_count 为 None。"""
|
||||
conn, cur = _make_mock_conn([("new_table", None)])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/ods/tables")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data[0]["row_count"] is None
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_schema(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/schemas/empty_schema/tables")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/db/tables/{schema}/{table}/columns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListColumns:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_returns_column_definitions(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([
|
||||
("id", "bigint", "NO", None),
|
||||
("name", "character varying", "YES", None),
|
||||
("created_at", "timestamp with time zone", "NO", "now()"),
|
||||
])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/dim_member/columns")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 3
|
||||
|
||||
assert data[0]["name"] == "id"
|
||||
assert data[0]["data_type"] == "bigint"
|
||||
assert data[0]["is_nullable"] is False
|
||||
assert data[0]["column_default"] is None
|
||||
|
||||
assert data[1]["is_nullable"] is True
|
||||
assert data[2]["column_default"] == "now()"
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_table(self, mock_get_conn):
|
||||
conn, cur = _make_mock_conn([])
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/db/query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExecuteQuery:
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_successful_select(self, mock_get_conn):
|
||||
description = [("id",), ("name",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(1, "Alice"), (2, "Bob")],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id", "name"]
|
||||
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
|
||||
assert data["row_count"] == 2
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_empty_result(self, mock_get_conn):
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["columns"] == ["id"]
|
||||
assert data["rows"] == []
|
||||
assert data["row_count"] == 0
|
||||
|
||||
# ── 写操作拦截 ──
|
||||
|
||||
@pytest.mark.parametrize("keyword", [
|
||||
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
|
||||
"insert", "update", "delete", "drop", "truncate",
|
||||
"Insert", "Update", "Delete", "Drop", "Truncate",
|
||||
])
|
||||
def test_blocks_write_operations(self, keyword):
|
||||
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
|
||||
assert resp.status_code == 400
|
||||
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
|
||||
|
||||
def test_blocks_mixed_case_write(self):
|
||||
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_blocks_write_in_subquery(self):
|
||||
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── 空 SQL ──
|
||||
|
||||
def test_empty_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": ""})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_whitespace_only_sql(self):
|
||||
resp = client.post("/api/db/query", json={"sql": " "})
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── SQL 语法错误 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sql_syntax_error(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
# 第一次 execute 设置 timeout 成功,第二次抛异常
|
||||
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
|
||||
assert resp.status_code == 400
|
||||
assert "SQL 执行错误" in resp.json()["detail"]
|
||||
|
||||
# ── 查询超时 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_query_timeout(self, mock_get_conn):
|
||||
conn = MagicMock()
|
||||
mock_cur = MagicMock()
|
||||
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
|
||||
mock_cur.description = None
|
||||
conn.cursor.return_value.__enter__ = lambda s: mock_cur
|
||||
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
|
||||
assert resp.status_code == 408
|
||||
assert "超时" in resp.json()["detail"]
|
||||
|
||||
# ── 行数限制验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_row_limit(self, mock_get_conn):
|
||||
"""验证 fetchmany 被调用时传入 1000 行限制。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn(
|
||||
[(i,) for i in range(1000)],
|
||||
description=description,
|
||||
)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
|
||||
assert resp.status_code == 200
|
||||
# 验证 fetchmany 被调用时传入了 1000
|
||||
cur.fetchmany.assert_called_once_with(1000)
|
||||
|
||||
# ── 超时设置验证 ──
|
||||
|
||||
@patch(_MOCK_CONN)
|
||||
def test_sets_statement_timeout(self, mock_get_conn):
|
||||
"""验证查询前设置了 statement_timeout。"""
|
||||
description = [("id",)]
|
||||
conn, cur = _make_mock_conn([(1,)], description=description)
|
||||
mock_get_conn.return_value = conn
|
||||
|
||||
client.post("/api/db/query", json={"sql": "SELECT 1"})
|
||||
|
||||
# 第一次 execute 应该是设置超时
|
||||
first_call = cur.execute.call_args_list[0]
|
||||
assert "statement_timeout" in first_call[0][0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 认证测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDbViewerAuth:
|
||||
|
||||
def test_requires_auth(self):
|
||||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||||
original = app.dependency_overrides.pop(get_current_user, None)
|
||||
try:
|
||||
endpoints = [
|
||||
("GET", "/api/db/schemas"),
|
||||
("GET", "/api/db/schemas/dwd/tables"),
|
||||
("GET", "/api/db/tables/dwd/dim_member/columns"),
|
||||
("POST", "/api/db/query"),
|
||||
]
|
||||
for method, url in endpoints:
|
||||
if method == "POST":
|
||||
resp = client.request(method, url, json={"sql": "SELECT 1"})
|
||||
else:
|
||||
resp = client.request(method, url)
|
||||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||||
finally:
|
||||
if original:
|
||||
app.dependency_overrides[get_current_user] = original
|
||||
Reference in New Issue
Block a user