在准备环境前提交次全部更改。

This commit is contained in:
Neo
2026-02-19 08:35:13 +08:00
parent ded6dfb9d8
commit 4eac07da47
1387 changed files with 6107191 additions and 33002 deletions

View File

@@ -0,0 +1,321 @@
# -*- coding: utf-8 -*-
"""数据库查看器路由单元测试
覆盖 4 个端点:
- GET /api/db/schemas
- GET /api/db/schemas/{name}/tables
- GET /api/db/tables/{schema}/{table}/columns
- POST /api/db/query
通过 mock 绕过数据库连接,专注路由逻辑验证。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
from psycopg2 import errors as pg_errors
from app.auth.dependencies import CurrentUser, get_current_user
from app.main import app
_TEST_USER = CurrentUser(user_id=1, site_id=100)
def _override_auth():
return _TEST_USER
app.dependency_overrides[get_current_user] = _override_auth
client = TestClient(app)
_MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection"
def _make_mock_conn(rows, description=None):
"""构造 mock 数据库连接cursor 返回指定行和列描述。"""
mock_conn = MagicMock()
mock_cur = MagicMock()
mock_cur.fetchall.return_value = rows
mock_cur.fetchmany.return_value = rows
mock_cur.description = description
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn, mock_cur
# ---------------------------------------------------------------------------
# GET /api/db/schemas
# ---------------------------------------------------------------------------
class TestListSchemas:
@patch(_MOCK_CONN)
def test_returns_schema_list(self, mock_get_conn):
conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "dwd"
assert data[2]["name"] == "ods"
# 验证 site_id 传递
mock_get_conn.assert_called_once_with(_TEST_USER.site_id)
conn.close.assert_called_once()
@patch(_MOCK_CONN)
def test_empty_schemas(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/schemas/{name}/tables
# ---------------------------------------------------------------------------
class TestListTables:
@patch(_MOCK_CONN)
def test_returns_tables_with_row_count(self, mock_get_conn):
conn, cur = _make_mock_conn([
("dim_member", 1500),
("fact_order", 32000),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/dwd/tables")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["name"] == "dim_member"
assert data[0]["row_count"] == 1500
assert data[1]["name"] == "fact_order"
assert data[1]["row_count"] == 32000
@patch(_MOCK_CONN)
def test_null_row_count(self, mock_get_conn):
"""pg_stat_user_tables 可能没有统计信息row_count 为 None。"""
conn, cur = _make_mock_conn([("new_table", None)])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/ods/tables")
assert resp.status_code == 200
data = resp.json()
assert data[0]["row_count"] is None
@patch(_MOCK_CONN)
def test_empty_schema(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/schemas/empty_schema/tables")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# GET /api/db/tables/{schema}/{table}/columns
# ---------------------------------------------------------------------------
class TestListColumns:
@patch(_MOCK_CONN)
def test_returns_column_definitions(self, mock_get_conn):
conn, cur = _make_mock_conn([
("id", "bigint", "NO", None),
("name", "character varying", "YES", None),
("created_at", "timestamp with time zone", "NO", "now()"),
])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/dim_member/columns")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 3
assert data[0]["name"] == "id"
assert data[0]["data_type"] == "bigint"
assert data[0]["is_nullable"] is False
assert data[0]["column_default"] is None
assert data[1]["is_nullable"] is True
assert data[2]["column_default"] == "now()"
@patch(_MOCK_CONN)
def test_empty_table(self, mock_get_conn):
conn, cur = _make_mock_conn([])
mock_get_conn.return_value = conn
resp = client.get("/api/db/tables/dwd/nonexistent/columns")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# POST /api/db/query
# ---------------------------------------------------------------------------
class TestExecuteQuery:
@patch(_MOCK_CONN)
def test_successful_select(self, mock_get_conn):
description = [("id",), ("name",)]
conn, cur = _make_mock_conn(
[(1, "Alice"), (2, "Bob")],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id", "name"]
assert data["rows"] == [[1, "Alice"], [2, "Bob"]]
assert data["row_count"] == 2
@patch(_MOCK_CONN)
def test_empty_result(self, mock_get_conn):
description = [("id",)]
conn, cur = _make_mock_conn([], description=description)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"})
assert resp.status_code == 200
data = resp.json()
assert data["columns"] == ["id"]
assert data["rows"] == []
assert data["row_count"] == 0
# ── 写操作拦截 ──
@pytest.mark.parametrize("keyword", [
"INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE",
"insert", "update", "delete", "drop", "truncate",
"Insert", "Update", "Delete", "Drop", "Truncate",
])
def test_blocks_write_operations(self, keyword):
resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"})
assert resp.status_code == 400
assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"]
def test_blocks_mixed_case_write(self):
resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"})
assert resp.status_code == 400
def test_blocks_write_in_subquery(self):
"""写操作关键词出现在 SQL 任意位置都应拦截。"""
resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"})
assert resp.status_code == 400
# ── 空 SQL ──
def test_empty_sql(self):
resp = client.post("/api/db/query", json={"sql": ""})
assert resp.status_code == 400
def test_whitespace_only_sql(self):
resp = client.post("/api/db/query", json={"sql": " "})
assert resp.status_code == 400
# ── SQL 语法错误 ──
@patch(_MOCK_CONN)
def test_sql_syntax_error(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
# 第一次 execute 设置 timeout 成功,第二次抛异常
mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"})
assert resp.status_code == 400
assert "SQL 执行错误" in resp.json()["detail"]
# ── 查询超时 ──
@patch(_MOCK_CONN)
def test_query_timeout(self, mock_get_conn):
conn = MagicMock()
mock_cur = MagicMock()
mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()]
mock_cur.description = None
conn.cursor.return_value.__enter__ = lambda s: mock_cur
conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"})
assert resp.status_code == 408
assert "超时" in resp.json()["detail"]
# ── 行数限制验证 ──
@patch(_MOCK_CONN)
def test_row_limit(self, mock_get_conn):
"""验证 fetchmany 被调用时传入 1000 行限制。"""
description = [("id",)]
conn, cur = _make_mock_conn(
[(i,) for i in range(1000)],
description=description,
)
mock_get_conn.return_value = conn
resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"})
assert resp.status_code == 200
# 验证 fetchmany 被调用时传入了 1000
cur.fetchmany.assert_called_once_with(1000)
# ── 超时设置验证 ──
@patch(_MOCK_CONN)
def test_sets_statement_timeout(self, mock_get_conn):
"""验证查询前设置了 statement_timeout。"""
description = [("id",)]
conn, cur = _make_mock_conn([(1,)], description=description)
mock_get_conn.return_value = conn
client.post("/api/db/query", json={"sql": "SELECT 1"})
# 第一次 execute 应该是设置超时
first_call = cur.execute.call_args_list[0]
assert "statement_timeout" in first_call[0][0]
# ---------------------------------------------------------------------------
# 认证测试
# ---------------------------------------------------------------------------
class TestDbViewerAuth:
def test_requires_auth(self):
"""移除 auth override 后,所有端点应返回 401/403。"""
original = app.dependency_overrides.pop(get_current_user, None)
try:
endpoints = [
("GET", "/api/db/schemas"),
("GET", "/api/db/schemas/dwd/tables"),
("GET", "/api/db/tables/dwd/dim_member/columns"),
("POST", "/api/db/query"),
]
for method, url in endpoints:
if method == "POST":
resp = client.request(method, url, json={"sql": "SELECT 1"})
else:
resp = client.request(method, url)
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
finally:
if original:
app.dependency_overrides[get_current_user] = original