Files
Neo-ZQYY/apps/backend/tests/test_auth_router.py

168 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
认证路由单元测试。
覆盖:登录成功/失败、刷新令牌、账号禁用等场景。
通过 mock 数据库连接避免依赖真实 PostgreSQL。
"""
import os
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from app.auth.jwt import (
create_refresh_token,
decode_access_token,
decode_refresh_token,
hash_password,
)
from app.main import app
from app.routers.auth import router
# 注册路由到 app测试时确保路由已挂载
if router not in [r for r in app.routes]:
app.include_router(router)
client = TestClient(app)
# 测试用固定数据
_TEST_PASSWORD = "correct_password"
_TEST_HASH = hash_password(_TEST_PASSWORD)
_TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active
_DISABLED_USER_ROW = (2, _TEST_HASH, 200, False)
def _mock_db_returning(row):
"""构造一个 mock get_connectioncursor.fetchone() 返回指定行。"""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.fetchone.return_value = row
mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn
# ---------------------------------------------------------------------------
# POST /api/auth/login
# ---------------------------------------------------------------------------
class TestLogin:
@patch("app.routers.auth.get_connection")
def test_login_success(self, mock_get_conn):
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "admin", "password": _TEST_PASSWORD},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
# 验证 access_token payload 包含正确的 user_id 和 site_id
payload = decode_access_token(data["access_token"])
assert payload["sub"] == "1"
assert payload["site_id"] == 100
@patch("app.routers.auth.get_connection")
def test_login_user_not_found(self, mock_get_conn):
"""用户不存在时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(None)
resp = client.post(
"/api/auth/login",
json={"username": "nonexistent", "password": "whatever"},
)
assert resp.status_code == 401
assert "用户名或密码错误" in resp.json()["detail"]
@patch("app.routers.auth.get_connection")
def test_login_wrong_password(self, mock_get_conn):
"""密码错误时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "admin", "password": "wrong_password"},
)
assert resp.status_code == 401
assert "用户名或密码错误" in resp.json()["detail"]
@patch("app.routers.auth.get_connection")
def test_login_disabled_account(self, mock_get_conn):
"""账号已禁用时返回 401。"""
mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW)
resp = client.post(
"/api/auth/login",
json={"username": "disabled_user", "password": _TEST_PASSWORD},
)
assert resp.status_code == 401
assert "禁用" in resp.json()["detail"]
def test_login_missing_username(self):
"""缺少 username 字段时返回 422。"""
resp = client.post("/api/auth/login", json={"password": "test"})
assert resp.status_code == 422
def test_login_empty_password(self):
"""空密码时返回 422。"""
resp = client.post(
"/api/auth/login", json={"username": "admin", "password": ""}
)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/auth/refresh
# ---------------------------------------------------------------------------
class TestRefresh:
def test_refresh_success(self):
"""有效的 refresh_token 换取新的 access_token。"""
refresh = create_refresh_token(user_id=5, site_id=50)
resp = client.post(
"/api/auth/refresh", json={"refresh_token": refresh}
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
# refresh_token 原样返回
assert data["refresh_token"] == refresh
assert data["token_type"] == "bearer"
# 新 access_token 包含正确信息
payload = decode_access_token(data["access_token"])
assert payload["sub"] == "5"
assert payload["site_id"] == 50
def test_refresh_with_invalid_token(self):
"""无效令牌返回 401。"""
resp = client.post(
"/api/auth/refresh", json={"refresh_token": "garbage.token.here"}
)
assert resp.status_code == 401
assert "无效的刷新令牌" in resp.json()["detail"]
def test_refresh_with_access_token_rejected(self):
"""用 access_token 做刷新应被拒绝。"""
from app.auth.jwt import create_access_token
access = create_access_token(user_id=1, site_id=1)
resp = client.post(
"/api/auth/refresh", json={"refresh_token": access}
)
assert resp.status_code == 401
def test_refresh_missing_token(self):
"""缺少 refresh_token 字段时返回 422。"""
resp = client.post("/api/auth/refresh", json={})
assert resp.status_code == 422