95 lines
3.6 KiB
Python
95 lines
3.6 KiB
Python
"""
|
||
数据库连接模块单元测试。
|
||
|
||
覆盖:ETL 只读连接的创建、RLS site_id 设置、只读模式、异常处理。
|
||
"""
|
||
|
||
import os
|
||
from unittest.mock import MagicMock, call, patch
|
||
|
||
import pytest
|
||
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
from app.database import get_etl_readonly_connection
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# get_etl_readonly_connection
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGetEtlReadonlyConnection:
|
||
"""ETL 只读连接:验证连接参数、只读设置、RLS 隔离。"""
|
||
|
||
@patch("app.database.psycopg2.connect")
|
||
def test_sets_readonly_and_site_id(self, mock_connect):
|
||
"""连接后应依次执行 SET read_only 和 SET LOCAL site_id。"""
|
||
mock_conn = MagicMock()
|
||
mock_cursor = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_connect.return_value = mock_conn
|
||
|
||
conn = get_etl_readonly_connection(site_id=42)
|
||
|
||
# 验证 autocommit 被关闭
|
||
assert mock_conn.autocommit is False
|
||
|
||
# 验证执行了两条 SET 语句
|
||
executed = [c.args[0] for c in mock_cursor.execute.call_args_list]
|
||
assert "SET default_transaction_read_only = on" in executed[0]
|
||
assert "SET LOCAL app.current_site_id" in executed[1]
|
||
|
||
# 验证 site_id 参数化传递(防 SQL 注入)
|
||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||
assert site_id_call.args[1] == ("42",)
|
||
|
||
# 验证提交
|
||
mock_conn.commit.assert_called_once()
|
||
assert conn is mock_conn
|
||
|
||
@patch("app.database.psycopg2.connect")
|
||
def test_accepts_string_site_id(self, mock_connect):
|
||
"""site_id 为字符串时也应正常工作。"""
|
||
mock_conn = MagicMock()
|
||
mock_cursor = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_connect.return_value = mock_conn
|
||
|
||
get_etl_readonly_connection(site_id="99")
|
||
|
||
site_id_call = mock_cursor.execute.call_args_list[1]
|
||
assert site_id_call.args[1] == ("99",)
|
||
|
||
@patch("app.database.psycopg2.connect")
|
||
def test_closes_connection_on_setup_error(self, mock_connect):
|
||
"""SET 语句执行失败时应关闭连接并抛出异常。"""
|
||
mock_conn = MagicMock()
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.execute.side_effect = Exception("SET failed")
|
||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_connect.return_value = mock_conn
|
||
|
||
with pytest.raises(Exception, match="SET failed"):
|
||
get_etl_readonly_connection(site_id=1)
|
||
|
||
mock_conn.close.assert_called_once()
|
||
|
||
@patch("app.database.psycopg2.connect")
|
||
def test_uses_etl_config_params(self, mock_connect):
|
||
"""应使用 ETL_DB_* 配置项连接。"""
|
||
mock_conn = MagicMock()
|
||
mock_cursor = MagicMock()
|
||
mock_conn.cursor.return_value.__enter__ = lambda s: mock_cursor
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
mock_connect.return_value = mock_conn
|
||
|
||
get_etl_readonly_connection(site_id=1)
|
||
|
||
connect_kwargs = mock_connect.call_args.kwargs
|
||
# 验证使用了 ETL 数据库名(默认 etl_feiqiu)
|
||
assert connect_kwargs["dbname"] == "etl_feiqiu"
|