# -*- coding: utf-8 -*- """数据库查看器路由单元测试 覆盖 4 个端点: - GET /api/db/schemas - GET /api/db/schemas/{name}/tables - GET /api/db/tables/{schema}/{table}/columns - POST /api/db/query 通过 mock 绕过数据库连接,专注路由逻辑验证。 """ import os os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests") from unittest.mock import patch, MagicMock import pytest from fastapi.testclient import TestClient from psycopg2 import errors as pg_errors from app.auth.dependencies import CurrentUser, get_current_user from app.main import app _TEST_USER = CurrentUser(user_id=1, site_id=100) def _override_auth(): return _TEST_USER app.dependency_overrides[get_current_user] = _override_auth client = TestClient(app) _MOCK_CONN = "app.routers.db_viewer.get_etl_readonly_connection" def _make_mock_conn(rows, description=None): """构造 mock 数据库连接,cursor 返回指定行和列描述。""" mock_conn = MagicMock() mock_cur = MagicMock() mock_cur.fetchall.return_value = rows mock_cur.fetchmany.return_value = rows mock_cur.description = description mock_conn.cursor.return_value.__enter__ = lambda s: mock_cur mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) return mock_conn, mock_cur # --------------------------------------------------------------------------- # GET /api/db/schemas # --------------------------------------------------------------------------- class TestListSchemas: @patch(_MOCK_CONN) def test_returns_schema_list(self, mock_get_conn): conn, cur = _make_mock_conn([("dwd",), ("dws",), ("ods",)]) mock_get_conn.return_value = conn resp = client.get("/api/db/schemas") assert resp.status_code == 200 data = resp.json() assert len(data) == 3 assert data[0]["name"] == "dwd" assert data[2]["name"] == "ods" # 验证 site_id 传递 mock_get_conn.assert_called_once_with(_TEST_USER.site_id) conn.close.assert_called_once() @patch(_MOCK_CONN) def test_empty_schemas(self, mock_get_conn): conn, cur = _make_mock_conn([]) mock_get_conn.return_value = conn resp = client.get("/api/db/schemas") assert resp.status_code == 200 assert resp.json() == [] # --------------------------------------------------------------------------- # GET /api/db/schemas/{name}/tables # --------------------------------------------------------------------------- class TestListTables: @patch(_MOCK_CONN) def test_returns_tables_with_row_count(self, mock_get_conn): conn, cur = _make_mock_conn([ ("dim_member", 1500), ("fact_order", 32000), ]) mock_get_conn.return_value = conn resp = client.get("/api/db/schemas/dwd/tables") assert resp.status_code == 200 data = resp.json() assert len(data) == 2 assert data[0]["name"] == "dim_member" assert data[0]["row_count"] == 1500 assert data[1]["name"] == "fact_order" assert data[1]["row_count"] == 32000 @patch(_MOCK_CONN) def test_null_row_count(self, mock_get_conn): """pg_stat_user_tables 可能没有统计信息,row_count 为 None。""" conn, cur = _make_mock_conn([("new_table", None)]) mock_get_conn.return_value = conn resp = client.get("/api/db/schemas/ods/tables") assert resp.status_code == 200 data = resp.json() assert data[0]["row_count"] is None @patch(_MOCK_CONN) def test_empty_schema(self, mock_get_conn): conn, cur = _make_mock_conn([]) mock_get_conn.return_value = conn resp = client.get("/api/db/schemas/empty_schema/tables") assert resp.status_code == 200 assert resp.json() == [] # --------------------------------------------------------------------------- # GET /api/db/tables/{schema}/{table}/columns # --------------------------------------------------------------------------- class TestListColumns: @patch(_MOCK_CONN) def test_returns_column_definitions(self, mock_get_conn): conn, cur = _make_mock_conn([ ("id", "bigint", "NO", None), ("name", "character varying", "YES", None), ("created_at", "timestamp with time zone", "NO", "now()"), ]) mock_get_conn.return_value = conn resp = client.get("/api/db/tables/dwd/dim_member/columns") assert resp.status_code == 200 data = resp.json() assert len(data) == 3 assert data[0]["name"] == "id" assert data[0]["data_type"] == "bigint" assert data[0]["is_nullable"] is False assert data[0]["column_default"] is None assert data[1]["is_nullable"] is True assert data[2]["column_default"] == "now()" @patch(_MOCK_CONN) def test_empty_table(self, mock_get_conn): conn, cur = _make_mock_conn([]) mock_get_conn.return_value = conn resp = client.get("/api/db/tables/dwd/nonexistent/columns") assert resp.status_code == 200 assert resp.json() == [] # --------------------------------------------------------------------------- # POST /api/db/query # --------------------------------------------------------------------------- class TestExecuteQuery: @patch(_MOCK_CONN) def test_successful_select(self, mock_get_conn): description = [("id",), ("name",)] conn, cur = _make_mock_conn( [(1, "Alice"), (2, "Bob")], description=description, ) mock_get_conn.return_value = conn resp = client.post("/api/db/query", json={"sql": "SELECT id, name FROM users"}) assert resp.status_code == 200 data = resp.json() assert data["columns"] == ["id", "name"] assert data["rows"] == [[1, "Alice"], [2, "Bob"]] assert data["row_count"] == 2 @patch(_MOCK_CONN) def test_empty_result(self, mock_get_conn): description = [("id",)] conn, cur = _make_mock_conn([], description=description) mock_get_conn.return_value = conn resp = client.post("/api/db/query", json={"sql": "SELECT id FROM empty_table"}) assert resp.status_code == 200 data = resp.json() assert data["columns"] == ["id"] assert data["rows"] == [] assert data["row_count"] == 0 # ── 写操作拦截 ── @pytest.mark.parametrize("keyword", [ "INSERT", "UPDATE", "DELETE", "DROP", "TRUNCATE", "insert", "update", "delete", "drop", "truncate", "Insert", "Update", "Delete", "Drop", "Truncate", ]) def test_blocks_write_operations(self, keyword): resp = client.post("/api/db/query", json={"sql": f"{keyword} INTO some_table VALUES (1)"}) assert resp.status_code == 400 assert "只读" in resp.json()["detail"] or "禁止" in resp.json()["detail"] def test_blocks_mixed_case_write(self): resp = client.post("/api/db/query", json={"sql": "DeLeTe FROM users WHERE id = 1"}) assert resp.status_code == 400 def test_blocks_write_in_subquery(self): """写操作关键词出现在 SQL 任意位置都应拦截。""" resp = client.post("/api/db/query", json={"sql": "SELECT * FROM (DELETE FROM users) sub"}) assert resp.status_code == 400 # ── 空 SQL ── def test_empty_sql(self): resp = client.post("/api/db/query", json={"sql": ""}) assert resp.status_code == 400 def test_whitespace_only_sql(self): resp = client.post("/api/db/query", json={"sql": " "}) assert resp.status_code == 400 # ── SQL 语法错误 ── @patch(_MOCK_CONN) def test_sql_syntax_error(self, mock_get_conn): conn = MagicMock() mock_cur = MagicMock() # 第一次 execute 设置 timeout 成功,第二次抛异常 mock_cur.execute.side_effect = [None, Exception("syntax error at or near \"SELEC\"")] mock_cur.description = None conn.cursor.return_value.__enter__ = lambda s: mock_cur conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = conn resp = client.post("/api/db/query", json={"sql": "SELEC * FROM users"}) assert resp.status_code == 400 assert "SQL 执行错误" in resp.json()["detail"] # ── 查询超时 ── @patch(_MOCK_CONN) def test_query_timeout(self, mock_get_conn): conn = MagicMock() mock_cur = MagicMock() mock_cur.execute.side_effect = [None, pg_errors.QueryCanceled()] mock_cur.description = None conn.cursor.return_value.__enter__ = lambda s: mock_cur conn.cursor.return_value.__exit__ = MagicMock(return_value=False) mock_get_conn.return_value = conn resp = client.post("/api/db/query", json={"sql": "SELECT pg_sleep(60)"}) assert resp.status_code == 408 assert "超时" in resp.json()["detail"] # ── 行数限制验证 ── @patch(_MOCK_CONN) def test_row_limit(self, mock_get_conn): """验证 fetchmany 被调用时传入 1000 行限制。""" description = [("id",)] conn, cur = _make_mock_conn( [(i,) for i in range(1000)], description=description, ) mock_get_conn.return_value = conn resp = client.post("/api/db/query", json={"sql": "SELECT id FROM big_table"}) assert resp.status_code == 200 # 验证 fetchmany 被调用时传入了 1000 cur.fetchmany.assert_called_once_with(1000) # ── 超时设置验证 ── @patch(_MOCK_CONN) def test_sets_statement_timeout(self, mock_get_conn): """验证查询前设置了 statement_timeout。""" description = [("id",)] conn, cur = _make_mock_conn([(1,)], description=description) mock_get_conn.return_value = conn client.post("/api/db/query", json={"sql": "SELECT 1"}) # 第一次 execute 应该是设置超时 first_call = cur.execute.call_args_list[0] assert "statement_timeout" in first_call[0][0] # --------------------------------------------------------------------------- # 认证测试 # --------------------------------------------------------------------------- class TestDbViewerAuth: def test_requires_auth(self): """移除 auth override 后,所有端点应返回 401/403。""" original = app.dependency_overrides.pop(get_current_user, None) try: endpoints = [ ("GET", "/api/db/schemas"), ("GET", "/api/db/schemas/dwd/tables"), ("GET", "/api/db/tables/dwd/dim_member/columns"), ("POST", "/api/db/query"), ] for method, url in endpoints: if method == "POST": resp = client.request(method, url, json={"sql": "SELECT 1"}) else: resp = client.request(method, url) assert resp.status_code in (401, 403), f"{method} {url} 应需要认证" finally: if original: app.dependency_overrides[get_current_user] = original