在准备环境前提交次全部更改。
This commit is contained in:
167
apps/backend/tests/test_auth_router.py
Normal file
167
apps/backend/tests/test_auth_router.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
认证路由单元测试。
|
||||
|
||||
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
|
||||
通过 mock 数据库连接避免依赖真实 PostgreSQL。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.auth.jwt import (
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
hash_password,
|
||||
)
|
||||
from app.main import app
|
||||
from app.routers.auth import router
|
||||
|
||||
# 注册路由到 app(测试时确保路由已挂载)
|
||||
if router not in [r for r in app.routes]:
|
||||
app.include_router(router)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# 测试用固定数据
|
||||
_TEST_PASSWORD = "correct_password"
|
||||
_TEST_HASH = hash_password(_TEST_PASSWORD)
|
||||
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
|
||||
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
|
||||
|
||||
|
||||
def _mock_db_returning(row):
|
||||
"""构造一个 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.fetchone.return_value = row
|
||||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLogin:
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_success(self, mock_get_conn):
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 验证 access_token payload 包含正确的 user_id 和 site_id
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["site_id"] == 100
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_user_not_found(self, mock_get_conn):
|
||||
"""用户不存在时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "nonexistent", "password": "whatever"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_wrong_password(self, mock_get_conn):
|
||||
"""密码错误时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "admin", "password": "wrong_password"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "用户名或密码错误" in resp.json()["detail"]
|
||||
|
||||
@patch("app.routers.auth.get_connection")
|
||||
def test_login_disabled_account(self, mock_get_conn):
|
||||
"""账号已禁用时返回 401。"""
|
||||
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "disabled_user", "password": _TEST_PASSWORD},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "禁用" in resp.json()["detail"]
|
||||
|
||||
def test_login_missing_username(self):
|
||||
"""缺少 username 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/login", json={"password": "test"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_login_empty_password(self):
|
||||
"""空密码时返回 422。"""
|
||||
resp = client.post(
|
||||
"/api/auth/login", json={"username": "admin", "password": ""}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/auth/refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRefresh:
|
||||
def test_refresh_success(self):
|
||||
"""有效的 refresh_token 换取新的 access_token。"""
|
||||
refresh = create_refresh_token(user_id=5, site_id=50)
|
||||
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": refresh}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
# refresh_token 原样返回
|
||||
assert data["refresh_token"] == refresh
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
# 新 access_token 包含正确信息
|
||||
payload = decode_access_token(data["access_token"])
|
||||
assert payload["sub"] == "5"
|
||||
assert payload["site_id"] == 50
|
||||
|
||||
def test_refresh_with_invalid_token(self):
|
||||
"""无效令牌返回 401。"""
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "无效的刷新令牌" in resp.json()["detail"]
|
||||
|
||||
def test_refresh_with_access_token_rejected(self):
|
||||
"""用 access_token 做刷新应被拒绝。"""
|
||||
from app.auth.jwt import create_access_token
|
||||
|
||||
access = create_access_token(user_id=1, site_id=1)
|
||||
resp = client.post(
|
||||
"/api/auth/refresh", json={"refresh_token": access}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_refresh_missing_token(self):
|
||||
"""缺少 refresh_token 字段时返回 422。"""
|
||||
resp = client.post("/api/auth/refresh", json={})
|
||||
assert resp.status_code == 422
|
||||
Reference in New Issue
Block a user