""" FastAPI 依赖注入:从 JWT 提取当前用户信息。 用法: @router.get("/protected") async def protected_endpoint(user: CurrentUser = Depends(get_current_user)): print(user.user_id, user.site_id) # 允许 pending 用户(受限令牌)访问 @router.get("/apply") async def apply_endpoint(user: CurrentUser = Depends(get_current_user_or_limited)): if user.limited: ... # 受限逻辑 """ from __future__ import annotations import time from dataclasses import dataclass, field from datetime import datetime from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import JWTError from app.auth.jwt import decode_access_token from app.trace.context import SpanType, TraceSpan, get_current_trace from app.trace.decorators import truncate_token # Bearer token 提取器 _bearer_scheme = HTTPBearer(auto_error=True) # ── 鉴权失败原因分类常量 ── AUTH_EXPIRED = "AUTH_EXPIRED" AUTH_INVALID = "AUTH_INVALID" AUTH_MALFORMED = "AUTH_MALFORMED" AUTH_LIMITED = "AUTH_LIMITED" AUTH_FORBIDDEN = "AUTH_FORBIDDEN" def _record_auth_span( *, token: str, success: bool, user_id: int | None = None, site_id: int | None = None, roles: list[str] | None = None, user_status: str = "", failure_reason: str = "", detail: str = "", duration_ms: float = 0.0, ) -> None: """向当前 TraceContext 添加 AUTH span(无 trace 时静默跳过)。""" ctx = get_current_trace() if ctx is None: return token_prefix = truncate_token(token) if success: desc_zh = f"JWT 鉴权通过:user_id={user_id}, site_id={site_id}, roles={roles}" desc_en = f"JWT auth passed: user_id={user_id}, site_id={site_id}, roles={roles}" result_summary = "approved" else: desc_zh = f"JWT 鉴权失败:{failure_reason} — {detail}" desc_en = f"JWT auth failed: {failure_reason} — {detail}" result_summary = failure_reason extra: dict = {} if failure_reason: extra["failure_reason"] = failure_reason ctx.add_span(TraceSpan( span_type=SpanType.AUTH, module="auth.dependencies", function="get_current_user", description_zh=desc_zh, description_en=desc_en, params={"token_prefix": token_prefix}, result_summary=result_summary, duration_ms=duration_ms, timestamp=datetime.now().isoformat(), extra=extra, )) # 鉴权成功时将 user_id / site_id 写入 TraceContext if success and user_id is not None: ctx.user_id = user_id if site_id is not None: ctx.site_id = site_id def _classify_jwt_error(exc: JWTError) -> str: """根据 JWTError 消息分类失败原因。""" msg = str(exc).lower() if "expired" in msg or "exp" in msg: return AUTH_EXPIRED return AUTH_INVALID @dataclass(frozen=True) class CurrentUser: """从 JWT 解析出的当前用户上下文。""" user_id: int site_id: int = 0 roles: list[str] = field(default_factory=list) status: str = "pending" limited: bool = False async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme), ) -> CurrentUser: """ FastAPI 依赖:从 Authorization header 提取 JWT,验证后返回用户信息。 要求完整令牌(非 limited),失败时抛出 401。 """ token = credentials.credentials start = time.perf_counter() try: payload = decode_access_token(token) except JWTError as exc: elapsed = (time.perf_counter() - start) * 1000 reason = _classify_jwt_error(exc) _record_auth_span( token=token, success=False, failure_reason=reason, detail="无效的令牌", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌", headers={"WWW-Authenticate": "Bearer"}, ) # 受限令牌不允许通过此依赖 if payload.get("limited"): elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=False, failure_reason=AUTH_LIMITED, detail="受限令牌无法访问此端点", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="受限令牌无法访问此端点", headers={"WWW-Authenticate": "Bearer"}, ) user_id_raw = payload.get("sub") site_id = payload.get("site_id") if user_id_raw is None or site_id is None: elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=False, failure_reason=AUTH_MALFORMED, detail="令牌缺少必要字段", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌缺少必要字段", headers={"WWW-Authenticate": "Bearer"}, ) try: user_id = int(user_id_raw) except (TypeError, ValueError): elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=False, failure_reason=AUTH_MALFORMED, detail="令牌中 user_id 格式无效", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌中 user_id 格式无效", headers={"WWW-Authenticate": "Bearer"}, ) roles = payload.get("roles", []) elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=True, user_id=user_id, site_id=site_id, roles=roles, user_status="approved", duration_ms=elapsed, ) return CurrentUser( user_id=user_id, site_id=site_id, roles=roles, status="approved", limited=False, ) async def get_current_user_or_limited( credentials: HTTPAuthorizationCredentials = Depends(_bearer_scheme), ) -> CurrentUser: """ FastAPI 依赖:允许 pending 用户(受限令牌)访问。 - 受限令牌(limited=True):返回 CurrentUser(limited=True, roles=[], status="pending") - 完整令牌:正常返回 CurrentUser """ token = credentials.credentials start = time.perf_counter() try: payload = decode_access_token(token) except JWTError as exc: elapsed = (time.perf_counter() - start) * 1000 reason = _classify_jwt_error(exc) _record_auth_span( token=token, success=False, failure_reason=reason, detail="无效的令牌", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌", headers={"WWW-Authenticate": "Bearer"}, ) user_id_raw = payload.get("sub") if user_id_raw is None: elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=False, failure_reason=AUTH_MALFORMED, detail="令牌缺少必要字段", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌缺少必要字段", headers={"WWW-Authenticate": "Bearer"}, ) try: user_id = int(user_id_raw) except (TypeError, ValueError): elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=False, failure_reason=AUTH_MALFORMED, detail="令牌中 user_id 格式无效", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌中 user_id 格式无效", headers={"WWW-Authenticate": "Bearer"}, ) # 受限令牌:pending 用户 if payload.get("limited"): elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=True, user_id=user_id, site_id=0, roles=[], user_status="pending", duration_ms=elapsed, ) return CurrentUser( user_id=user_id, site_id=0, roles=[], status="pending", limited=True, ) # 完整令牌:要求 site_id site_id = payload.get("site_id") if site_id is None: elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=False, failure_reason=AUTH_MALFORMED, detail="令牌缺少必要字段", duration_ms=elapsed, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="令牌缺少必要字段", headers={"WWW-Authenticate": "Bearer"}, ) roles = payload.get("roles", []) elapsed = (time.perf_counter() - start) * 1000 _record_auth_span( token=token, success=True, user_id=user_id, site_id=site_id, roles=roles, user_status="approved", duration_ms=elapsed, ) return CurrentUser( user_id=user_id, site_id=site_id, roles=roles, status="approved", limited=False, )