148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
"""
|
||
JWT 认证模块单元测试。
|
||
|
||
覆盖:令牌生成、验证、过期、类型校验、密码哈希、依赖注入。
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
|
||
import pytest
|
||
from jose import jwt as jose_jwt
|
||
|
||
# 测试前设置 JWT_SECRET_KEY,避免空密钥
|
||
os.environ.setdefault("JWT_SECRET_KEY", "test-secret-key-for-unit-tests")
|
||
|
||
from app.auth.jwt import (
|
||
create_access_token,
|
||
create_refresh_token,
|
||
create_token_pair,
|
||
decode_access_token,
|
||
decode_refresh_token,
|
||
decode_token,
|
||
hash_password,
|
||
verify_password,
|
||
)
|
||
from app import config
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 密码哈希
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestPasswordHashing:
|
||
def test_hash_and_verify(self):
|
||
raw = "my_secure_password"
|
||
hashed = hash_password(raw)
|
||
assert verify_password(raw, hashed)
|
||
|
||
def test_wrong_password_rejected(self):
|
||
hashed = hash_password("correct")
|
||
assert not verify_password("wrong", hashed)
|
||
|
||
def test_hash_is_not_plaintext(self):
|
||
raw = "plaintext123"
|
||
hashed = hash_password(raw)
|
||
assert hashed != raw
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 令牌生成与解码
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestTokenCreation:
|
||
def test_access_token_contains_expected_fields(self):
|
||
token = create_access_token(user_id=1, site_id=100)
|
||
payload = decode_token(token)
|
||
assert payload["sub"] == "1"
|
||
assert payload["site_id"] == 100
|
||
assert payload["type"] == "access"
|
||
assert "exp" in payload
|
||
|
||
def test_refresh_token_contains_expected_fields(self):
|
||
token = create_refresh_token(user_id=2, site_id=200)
|
||
payload = decode_token(token)
|
||
assert payload["sub"] == "2"
|
||
assert payload["site_id"] == 200
|
||
assert payload["type"] == "refresh"
|
||
assert "exp" in payload
|
||
|
||
def test_token_pair_returns_both_tokens(self):
|
||
pair = create_token_pair(user_id=3, site_id=300)
|
||
assert "access_token" in pair
|
||
assert "refresh_token" in pair
|
||
assert pair["token_type"] == "bearer"
|
||
|
||
# 验证两个令牌类型不同
|
||
access_payload = decode_token(pair["access_token"])
|
||
refresh_payload = decode_token(pair["refresh_token"])
|
||
assert access_payload["type"] == "access"
|
||
assert refresh_payload["type"] == "refresh"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 令牌类型校验
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestTokenTypeValidation:
|
||
def test_decode_access_token_rejects_refresh(self):
|
||
"""access 解码器拒绝 refresh 令牌。"""
|
||
token = create_refresh_token(user_id=1, site_id=1)
|
||
with pytest.raises(Exception):
|
||
decode_access_token(token)
|
||
|
||
def test_decode_refresh_token_rejects_access(self):
|
||
"""refresh 解码器拒绝 access 令牌。"""
|
||
token = create_access_token(user_id=1, site_id=1)
|
||
with pytest.raises(Exception):
|
||
decode_refresh_token(token)
|
||
|
||
def test_decode_access_token_accepts_access(self):
|
||
token = create_access_token(user_id=5, site_id=50)
|
||
payload = decode_access_token(token)
|
||
assert payload["sub"] == "5"
|
||
assert payload["site_id"] == 50
|
||
|
||
def test_decode_refresh_token_accepts_refresh(self):
|
||
token = create_refresh_token(user_id=6, site_id=60)
|
||
payload = decode_refresh_token(token)
|
||
assert payload["sub"] == "6"
|
||
assert payload["site_id"] == 60
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 令牌过期
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestTokenExpiry:
|
||
def test_expired_token_rejected(self):
|
||
"""手动构造已过期令牌,验证解码失败。"""
|
||
payload = {
|
||
"sub": "1",
|
||
"site_id": 1,
|
||
"type": "access",
|
||
"exp": int(time.time()) - 10, # 10 秒前过期
|
||
}
|
||
token = jose_jwt.encode(
|
||
payload, config.JWT_SECRET_KEY, algorithm=config.JWT_ALGORITHM
|
||
)
|
||
with pytest.raises(Exception):
|
||
decode_token(token)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 无效令牌
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestInvalidToken:
|
||
def test_garbage_token_rejected(self):
|
||
with pytest.raises(Exception):
|
||
decode_token("not.a.valid.jwt")
|
||
|
||
def test_wrong_secret_rejected(self):
|
||
"""用不同密钥签发的令牌应被拒绝。"""
|
||
payload = {"sub": "1", "site_id": 1, "type": "access", "exp": int(time.time()) + 3600}
|
||
token = jose_jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||
with pytest.raises(Exception):
|
||
decode_token(token)
|