125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
权限中间件 —— 基于 FastAPI 依赖注入的权限检查。
|
||
|
||
提供两个依赖工厂:
|
||
- require_permission(*codes):检查用户 status=approved 且拥有指定权限
|
||
- require_approved():仅检查用户 status=approved(不检查具体权限)
|
||
|
||
用法:
|
||
@router.get("/finance")
|
||
async def get_finance(
|
||
user: CurrentUser = Depends(require_permission("view_board_finance"))
|
||
):
|
||
...
|
||
|
||
@router.get("/tasks")
|
||
async def get_tasks(
|
||
user: CurrentUser = Depends(require_approved())
|
||
):
|
||
...
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
|
||
from fastapi import Depends, HTTPException, status
|
||
|
||
from app.auth.dependencies import CurrentUser, get_current_user
|
||
from app.database import get_connection
|
||
from app.services.role import get_user_permissions
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _get_user_status(user_id: int) -> str | None:
|
||
"""从数据库查询用户当前 status。返回 None 表示用户不存在。"""
|
||
conn = get_connection()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT status FROM auth.users WHERE id = %s",
|
||
(user_id,),
|
||
)
|
||
row = cur.fetchone()
|
||
finally:
|
||
conn.close()
|
||
return row[0] if row else None
|
||
|
||
|
||
def require_permission(*permission_codes: str):
|
||
"""
|
||
权限依赖工厂:要求用户 status=approved 且拥有全部指定权限。
|
||
|
||
流程:
|
||
1. 通过 get_current_user 从 JWT 提取 user_id + site_id
|
||
2. 查询 auth.users.status —— 非 approved 则 403
|
||
3. 查询 user_site_roles + role_permissions 获取权限列表
|
||
4. 检查所需权限是否全部在列表中 —— 缺失则 403
|
||
5. 返回 CurrentUser 对象
|
||
"""
|
||
|
||
async def _dependency(
|
||
user: CurrentUser = Depends(get_current_user),
|
||
) -> CurrentUser:
|
||
# 查询数据库中的实时 status
|
||
db_status = _get_user_status(user.user_id)
|
||
if db_status is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户不存在",
|
||
)
|
||
if db_status != "approved":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户未通过审核,无法访问此资源",
|
||
)
|
||
|
||
# 检查具体权限
|
||
if permission_codes:
|
||
user_perms = await get_user_permissions(user.user_id, user.site_id)
|
||
missing = set(permission_codes) - set(user_perms)
|
||
if missing:
|
||
logger.warning(
|
||
"用户 %s 在 site_id=%s 下缺少权限: %s",
|
||
user.user_id,
|
||
user.site_id,
|
||
missing,
|
||
)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="权限不足",
|
||
)
|
||
|
||
return user
|
||
|
||
return _dependency
|
||
|
||
|
||
def require_approved():
|
||
"""
|
||
审核状态依赖工厂:仅检查用户 status=approved,不检查具体权限。
|
||
|
||
用于通用的已认证端点,只需确认用户已通过审核即可访问。
|
||
"""
|
||
|
||
async def _dependency(
|
||
user: CurrentUser = Depends(get_current_user),
|
||
) -> CurrentUser:
|
||
db_status = _get_user_status(user.user_id)
|
||
if db_status is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户不存在",
|
||
)
|
||
if db_status != "approved":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户未通过审核,无法访问此资源",
|
||
)
|
||
|
||
return user
|
||
|
||
return _dependency
|