168 lines
5.7 KiB
Python
168 lines
5.7 KiB
Python
"""
|
||
认证路由单元测试。
|
||
|
||
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
|
||
通过 mock 数据库连接避免依赖真实 PostgreSQL。
|
||
"""
|
||
|
||
import os
|
||
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
|
||
from app.auth.jwt import (
|
||
create_refresh_token,
|
||
decode_access_token,
|
||
decode_refresh_token,
|
||
hash_password,
|
||
)
|
||
from app.main import app
|
||
from app.routers.auth import router
|
||
|
||
# 注册路由到 app(测试时确保路由已挂载)
|
||
if router not in [r for r in app.routes]:
|
||
app.include_router(router)
|
||
|
||
client = TestClient(app)
|
||
|
||
# 测试用固定数据
|
||
_TEST_PASSWORD = "correct_password"
|
||
_TEST_HASH = hash_password(_TEST_PASSWORD)
|
||
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
|
||
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
|
||
|
||
|
||
def _mock_db_returning(row):
|
||
"""构造一个 mock get_connection,cursor.fetchone() 返回指定行。"""
|
||
mock_conn = MagicMock()
|
||
mock_cursor = MagicMock()
|
||
mock_cursor.fetchone.return_value = row
|
||
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
|
||
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
|
||
return mock_conn
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/auth/login
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestLogin:
|
||
@patch("app.routers.auth.get_connection")
|
||
def test_login_success(self, mock_get_conn):
|
||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||
|
||
resp = client.post(
|
||
"/api/auth/login",
|
||
json={"username": "admin", "password": _TEST_PASSWORD},
|
||
)
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
assert "refresh_token" in data
|
||
assert data["token_type"] == "bearer"
|
||
|
||
# 验证 access_token payload 包含正确的 user_id 和 site_id
|
||
payload = decode_access_token(data["access_token"])
|
||
assert payload["sub"] == "1"
|
||
assert payload["site_id"] == 100
|
||
|
||
@patch("app.routers.auth.get_connection")
|
||
def test_login_user_not_found(self, mock_get_conn):
|
||
"""用户不存在时返回 401。"""
|
||
mock_get_conn.return_value = _mock_db_returning(None)
|
||
|
||
resp = client.post(
|
||
"/api/auth/login",
|
||
json={"username": "nonexistent", "password": "whatever"},
|
||
)
|
||
assert resp.status_code == 401
|
||
assert "用户名或密码错误" in resp.json()["detail"]
|
||
|
||
@patch("app.routers.auth.get_connection")
|
||
def test_login_wrong_password(self, mock_get_conn):
|
||
"""密码错误时返回 401。"""
|
||
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
|
||
|
||
resp = client.post(
|
||
"/api/auth/login",
|
||
json={"username": "admin", "password": "wrong_password"},
|
||
)
|
||
assert resp.status_code == 401
|
||
assert "用户名或密码错误" in resp.json()["detail"]
|
||
|
||
@patch("app.routers.auth.get_connection")
|
||
def test_login_disabled_account(self, mock_get_conn):
|
||
"""账号已禁用时返回 401。"""
|
||
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
|
||
|
||
resp = client.post(
|
||
"/api/auth/login",
|
||
json={"username": "disabled_user", "password": _TEST_PASSWORD},
|
||
)
|
||
assert resp.status_code == 401
|
||
assert "禁用" in resp.json()["detail"]
|
||
|
||
def test_login_missing_username(self):
|
||
"""缺少 username 字段时返回 422。"""
|
||
resp = client.post("/api/auth/login", json={"password": "test"})
|
||
assert resp.status_code == 422
|
||
|
||
def test_login_empty_password(self):
|
||
"""空密码时返回 422。"""
|
||
resp = client.post(
|
||
"/api/auth/login", json={"username": "admin", "password": ""}
|
||
)
|
||
assert resp.status_code == 422
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/auth/refresh
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestRefresh:
|
||
def test_refresh_success(self):
|
||
"""有效的 refresh_token 换取新的 access_token。"""
|
||
refresh = create_refresh_token(user_id=5, site_id=50)
|
||
|
||
resp = client.post(
|
||
"/api/auth/refresh", json={"refresh_token": refresh}
|
||
)
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert "access_token" in data
|
||
# refresh_token 原样返回
|
||
assert data["refresh_token"] == refresh
|
||
assert data["token_type"] == "bearer"
|
||
|
||
# 新 access_token 包含正确信息
|
||
payload = decode_access_token(data["access_token"])
|
||
assert payload["sub"] == "5"
|
||
assert payload["site_id"] == 50
|
||
|
||
def test_refresh_with_invalid_token(self):
|
||
"""无效令牌返回 401。"""
|
||
resp = client.post(
|
||
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
|
||
)
|
||
assert resp.status_code == 401
|
||
assert "无效的刷新令牌" in resp.json()["detail"]
|
||
|
||
def test_refresh_with_access_token_rejected(self):
|
||
"""用 access_token 做刷新应被拒绝。"""
|
||
from app.auth.jwt import create_access_token
|
||
|
||
access = create_access_token(user_id=1, site_id=1)
|
||
resp = client.post(
|
||
"/api/auth/refresh", json={"refresh_token": access}
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
def test_refresh_missing_token(self):
|
||
"""缺少 refresh_token 字段时返回 422。"""
|
||
resp = client.post("/api/auth/refresh", json={})
|
||
assert resp.status_code == 422
|