292 lines
9.6 KiB
Python
292 lines
9.6 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""环境配置路由单元测试
|
||
|
||
覆盖 3 个端点:GET / PUT / GET /export
|
||
通过 mock 绕过文件 I/O,专注路由逻辑验证。
|
||
"""
|
||
|
||
import os
|
||
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.main import app
|
||
|
||
_TEST_USER = CurrentUser(user_id=1, site_id=100)
|
||
|
||
|
||
def _override_auth():
|
||
return _TEST_USER
|
||
|
||
|
||
app.dependency_overrides[get_current_user] = _override_auth
|
||
client = TestClient(app)
|
||
|
||
# 模拟 .env 文件内容
|
||
_SAMPLE_ENV = """\
|
||
# 数据库配置
|
||
DB_HOST=localhost
|
||
DB_PORT=5432
|
||
DB_PASSWORD=super_secret_123
|
||
JWT_SECRET_KEY=my-jwt-secret
|
||
|
||
# ETL 配置
|
||
ETL_DB_DSN=postgresql://user:pass@host/db
|
||
TIMEZONE=Asia/Shanghai
|
||
"""
|
||
|
||
_MOCK_ENV_PATH = "app.routers.env_config._ENV_PATH"
|
||
|
||
|
||
def _mock_path(content: str | None = _SAMPLE_ENV, exists: bool = True):
|
||
"""构造 mock Path 对象。"""
|
||
mock = MagicMock()
|
||
mock.exists.return_value = exists
|
||
if content is not None:
|
||
mock.read_text.return_value = content
|
||
return mock
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/env-config
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestGetEnvConfig:
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_returns_entries_with_masked_sensitive(self, mock_path_obj):
|
||
mock_path_obj.__class__ = type(MagicMock())
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||
|
||
resp = client.get("/api/env-config")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
entries = {e["key"]: e["value"] for e in data["entries"]}
|
||
|
||
# 非敏感值原样返回
|
||
assert entries["DB_HOST"] == "localhost"
|
||
assert entries["DB_PORT"] == "5432"
|
||
assert entries["TIMEZONE"] == "Asia/Shanghai"
|
||
|
||
# 敏感值掩码
|
||
assert entries["DB_PASSWORD"] == "****"
|
||
assert entries["JWT_SECRET_KEY"] == "****"
|
||
assert entries["ETL_DB_DSN"] == "****"
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_file_not_found(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = False
|
||
|
||
resp = client.get("/api/env-config")
|
||
assert resp.status_code == 404
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_empty_file(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = ""
|
||
|
||
resp = client.get("/api/env-config")
|
||
assert resp.status_code == 200
|
||
assert resp.json()["entries"] == []
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_comments_and_blank_lines_excluded(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = "# comment\n\nKEY=val\n"
|
||
|
||
resp = client.get("/api/env-config")
|
||
assert resp.status_code == 200
|
||
entries = resp.json()["entries"]
|
||
assert len(entries) == 1
|
||
assert entries[0]["key"] == "KEY"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# PUT /api/env-config
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestUpdateEnvConfig:
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_update_existing_key(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\nDB_PORT=5432\n"
|
||
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [
|
||
{"key": "DB_HOST", "value": "192.168.1.1"},
|
||
{"key": "DB_PORT", "value": "5433"},
|
||
]
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
# 验证写入内容
|
||
written = mock_path_obj.write_text.call_args[0][0]
|
||
assert "DB_HOST=192.168.1.1" in written
|
||
assert "DB_PORT=5433" in written
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_add_new_key(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = "DB_HOST=localhost\n"
|
||
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [
|
||
{"key": "DB_HOST", "value": "localhost"},
|
||
{"key": "NEW_KEY", "value": "new_value"},
|
||
]
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
written = mock_path_obj.write_text.call_args[0][0]
|
||
assert "NEW_KEY=new_value" in written
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_masked_value_preserves_original(self, mock_path_obj):
|
||
"""掩码值(****)不应覆盖原始敏感值。"""
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = "DB_PASSWORD=real_secret\nDB_HOST=localhost\n"
|
||
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [
|
||
{"key": "DB_PASSWORD", "value": "****"},
|
||
{"key": "DB_HOST", "value": "newhost"},
|
||
]
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
written = mock_path_obj.write_text.call_args[0][0]
|
||
# 原始密码应保留
|
||
assert "DB_PASSWORD=real_secret" in written
|
||
assert "DB_HOST=newhost" in written
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_preserves_comments(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = "# 注释行\nDB_HOST=localhost\n\n# 另一个注释\n"
|
||
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [{"key": "DB_HOST", "value": "newhost"}]
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
written = mock_path_obj.write_text.call_args[0][0]
|
||
assert "# 注释行" in written
|
||
assert "# 另一个注释" in written
|
||
|
||
def test_invalid_key_format(self):
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [{"key": "123BAD", "value": "val"}]
|
||
})
|
||
assert resp.status_code == 422
|
||
|
||
def test_empty_key(self):
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [{"key": "", "value": "val"}]
|
||
})
|
||
assert resp.status_code == 422
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_file_not_exists_creates_new(self, mock_path_obj):
|
||
"""文件不存在时,应创建新文件。"""
|
||
mock_path_obj.exists.return_value = False
|
||
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [{"key": "NEW_KEY", "value": "value"}]
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
written = mock_path_obj.write_text.call_args[0][0]
|
||
assert "NEW_KEY=value" in written
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_update_sensitive_with_new_value(self, mock_path_obj):
|
||
"""显式提供新密码时应更新。"""
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = "DB_PASSWORD=old_secret\n"
|
||
|
||
resp = client.put("/api/env-config", json={
|
||
"entries": [{"key": "DB_PASSWORD", "value": "new_secret"}]
|
||
})
|
||
assert resp.status_code == 200
|
||
|
||
written = mock_path_obj.write_text.call_args[0][0]
|
||
assert "DB_PASSWORD=new_secret" in written
|
||
|
||
# 返回值中敏感键仍然掩码
|
||
entries = {e["key"]: e["value"] for e in resp.json()["entries"]}
|
||
assert entries["DB_PASSWORD"] == "****"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/env-config/export
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestExportEnvConfig:
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_export_masks_sensitive(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||
|
||
resp = client.get("/api/env-config/export")
|
||
assert resp.status_code == 200
|
||
assert resp.headers["content-type"].startswith("text/plain")
|
||
assert "attachment" in resp.headers.get("content-disposition", "")
|
||
|
||
content = resp.text
|
||
# 非敏感值保留
|
||
assert "DB_HOST=localhost" in content
|
||
assert "TIMEZONE=Asia/Shanghai" in content
|
||
|
||
# 敏感值掩码
|
||
assert "super_secret_123" not in content
|
||
assert "my-jwt-secret" not in content
|
||
assert "DB_PASSWORD=****" in content
|
||
assert "JWT_SECRET_KEY=****" in content
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_export_preserves_comments(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = True
|
||
mock_path_obj.read_text.return_value = _SAMPLE_ENV
|
||
|
||
content = client.get("/api/env-config/export").text
|
||
assert "# 数据库配置" in content
|
||
assert "# ETL 配置" in content
|
||
|
||
@patch(_MOCK_ENV_PATH)
|
||
def test_export_file_not_found(self, mock_path_obj):
|
||
mock_path_obj.exists.return_value = False
|
||
|
||
resp = client.get("/api/env-config/export")
|
||
assert resp.status_code == 404
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 认证测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestEnvConfigAuth:
|
||
|
||
def test_requires_auth(self):
|
||
"""移除 auth override 后,所有端点应返回 401/403。"""
|
||
# 临时移除 override
|
||
original = app.dependency_overrides.pop(get_current_user, None)
|
||
try:
|
||
for method, url in [
|
||
("GET", "/api/env-config"),
|
||
("PUT", "/api/env-config"),
|
||
("GET", "/api/env-config/export"),
|
||
]:
|
||
resp = client.request(method, url)
|
||
assert resp.status_code in (401, 403), f"{method} {url} 应需要认证"
|
||
finally:
|
||
if original:
|
||
app.dependency_overrides[get_current_user] = original
|