""" 认证路由单元测试。 覆盖:登录成功/失败、刷新令牌、账号禁用等场景。 通过 mock 数据库连接避免依赖真实 PostgreSQL。 """ import os os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests") from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient from app.auth.jwt import ( create_refresh_token, decode_access_token, decode_refresh_token, hash_password, ) from app.main import app from app.routers.auth import router # 注册路由到 app(测试时确保路由已挂载) if router not in [r for r in app.routes]: app.include_router(router) client = TestClient(app) # 测试用固定数据 _TEST_PASSWORD = "correct_password" _TEST_HASH = hash_password(_TEST_PASSWORD) _TEST_USER_ROW = (1, _TEST_HASH, 100, True) # id, password_hash, site_id, is_active _DISABLED_USER_ROW = (2, _TEST_HASH, 200, False) def _mock_db_returning(row): """构造一个 mock get_connection,cursor.fetchone() 返回指定行。""" mock_conn = MagicMock() mock_cursor = MagicMock() mock_cursor.fetchone.return_value = row mock_conn.cursor.return_value.__enter__ = lambda _: mock_cursor mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) return mock_conn # --------------------------------------------------------------------------- # POST /api/auth/login # --------------------------------------------------------------------------- class TestLogin: @patch("app.routers.auth.get_connection") def test_login_success(self, mock_get_conn): mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW) resp = client.post( "/api/auth/login", json={"username": "admin", "password": _TEST_PASSWORD}, ) assert resp.status_code == 200 data = resp.json() assert "access_token" in data assert "refresh_token" in data assert data["token_type"] == "bearer" # 验证 access_token payload 包含正确的 user_id 和 site_id payload = decode_access_token(data["access_token"]) assert payload["sub"] == "1" assert payload["site_id"] == 100 @patch("app.routers.auth.get_connection") def test_login_user_not_found(self, mock_get_conn): """用户不存在时返回 401。""" mock_get_conn.return_value = _mock_db_returning(None) resp = client.post( "/api/auth/login", json={"username": "nonexistent", "password": "whatever"}, ) assert resp.status_code == 401 assert "用户名或密码错误" in resp.json()["detail"] @patch("app.routers.auth.get_connection") def test_login_wrong_password(self, mock_get_conn): """密码错误时返回 401。""" mock_get_conn.return_value = _mock_db_returning(_TEST_USER_ROW) resp = client.post( "/api/auth/login", json={"username": "admin", "password": "wrong_password"}, ) assert resp.status_code == 401 assert "用户名或密码错误" in resp.json()["detail"] @patch("app.routers.auth.get_connection") def test_login_disabled_account(self, mock_get_conn): """账号已禁用时返回 401。""" mock_get_conn.return_value = _mock_db_returning(_DISABLED_USER_ROW) resp = client.post( "/api/auth/login", json={"username": "disabled_user", "password": _TEST_PASSWORD}, ) assert resp.status_code == 401 assert "禁用" in resp.json()["detail"] def test_login_missing_username(self): """缺少 username 字段时返回 422。""" resp = client.post("/api/auth/login", json={"password": "test"}) assert resp.status_code == 422 def test_login_empty_password(self): """空密码时返回 422。""" resp = client.post( "/api/auth/login", json={"username": "admin", "password": ""} ) assert resp.status_code == 422 # --------------------------------------------------------------------------- # POST /api/auth/refresh # --------------------------------------------------------------------------- class TestRefresh: def test_refresh_success(self): """有效的 refresh_token 换取新的 access_token。""" refresh = create_refresh_token(user_id=5, site_id=50) resp = client.post( "/api/auth/refresh", json={"refresh_token": refresh} ) assert resp.status_code == 200 data = resp.json() assert "access_token" in data # refresh_token 原样返回 assert data["refresh_token"] == refresh assert data["token_type"] == "bearer" # 新 access_token 包含正确信息 payload = decode_access_token(data["access_token"]) assert payload["sub"] == "5" assert payload["site_id"] == 50 def test_refresh_with_invalid_token(self): """无效令牌返回 401。""" resp = client.post( "/api/auth/refresh", json={"refresh_token": "garbage.token.here"} ) assert resp.status_code == 401 assert "无效的刷新令牌" in resp.json()["detail"] def test_refresh_with_access_token_rejected(self): """用 access_token 做刷新应被拒绝。""" from app.auth.jwt import create_access_token access = create_access_token(user_id=1, site_id=1) resp = client.post( "/api/auth/refresh", json={"refresh_token": access} ) assert resp.status_code == 401 def test_refresh_missing_token(self): """缺少 refresh_token 字段时返回 422。""" resp = client.post("/api/auth/refresh", json={}) assert resp.status_code == 422