197 lines
7.4 KiB
Python
197 lines
7.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
全局响应包装中间件 + 异常处理器。
|
||
|
||
ResponseWrapperMiddleware(ASGI 中间件):
|
||
对 JSON 成功响应(2xx + application/json)自动包装为 { "code": 0, "data": <原始body> }。
|
||
跳过条件:text/event-stream(SSE)、非 application/json、非 2xx 状态码。
|
||
|
||
ExceptionHandler 函数(http_exception_handler / unhandled_exception_handler):
|
||
统一格式化错误响应为 { "code": <status_code>, "message": <detail> }。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from typing import Any
|
||
|
||
from fastapi import Request
|
||
from fastapi.responses import JSONResponse
|
||
from starlette.exceptions import HTTPException
|
||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ResponseWrapperMiddleware:
|
||
"""ASGI 中间件:全局响应包装。
|
||
|
||
拦截 http.response.start / http.response.body,对满足条件的响应
|
||
包装为 { "code": 0, "data": <原始body> }。
|
||
|
||
跳过条件(透传原始响应):
|
||
1. content-type 为 text/event-stream(SSE 端点)
|
||
2. content-type 不包含 application/json(文件下载等)
|
||
3. HTTP 状态码非 2xx(错误响应已由 ExceptionHandler 格式化)
|
||
"""
|
||
|
||
def __init__(self, app: ASGIApp) -> None:
|
||
self.app = app
|
||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
# 仅处理 HTTP 请求,WebSocket / lifespan 等直接透传
|
||
if scope["type"] != "http":
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
# 用于在 send 回调间共享状态
|
||
should_wrap = False
|
||
status_code = 0
|
||
# 缓存 start message,包装时需要修改 content-length
|
||
cached_start: Message | None = None
|
||
# 收集所有 body 分片(more_body 场景)
|
||
body_parts: list[bytes] = []
|
||
|
||
async def send_wrapper(message: Message) -> None:
|
||
nonlocal should_wrap, status_code, cached_start
|
||
|
||
if message["type"] == "http.response.start":
|
||
status_code = message.get("status", 200)
|
||
headers = dict(
|
||
(k.lower(), v)
|
||
for k, v in (
|
||
(k if isinstance(k, bytes) else k.encode(),
|
||
v if isinstance(v, bytes) else v.encode())
|
||
for k, v in message.get("headers", [])
|
||
)
|
||
)
|
||
content_type = headers.get(b"content-type", b"").decode("latin-1").lower()
|
||
|
||
# 判断是否需要包装
|
||
is_2xx = 200 <= status_code < 300
|
||
is_json = "application/json" in content_type
|
||
is_sse = "text/event-stream" in content_type
|
||
|
||
if is_2xx and is_json and not is_sse:
|
||
should_wrap = True
|
||
# 缓存 start message,等 body 完整后再发送(需要更新 content-length)
|
||
cached_start = message
|
||
else:
|
||
# 不包装,直接透传
|
||
should_wrap = False
|
||
await send(message)
|
||
return
|
||
|
||
if message["type"] == "http.response.body":
|
||
if not should_wrap:
|
||
# 不包装,直接透传
|
||
await send(message)
|
||
return
|
||
|
||
# 收集 body 分片
|
||
body_parts.append(message.get("body", b""))
|
||
more_body = message.get("more_body", False)
|
||
|
||
if not more_body:
|
||
# body 完整,执行包装
|
||
original_body = b"".join(body_parts)
|
||
try:
|
||
wrapped = _wrap_success_body(original_body)
|
||
except Exception:
|
||
# 包装失败(如 JSON 解析错误),透传原始响应
|
||
logger.debug(
|
||
"响应包装失败,透传原始响应",
|
||
exc_info=True,
|
||
)
|
||
if cached_start is not None:
|
||
await send(cached_start)
|
||
await send({
|
||
"type": "http.response.body",
|
||
"body": original_body,
|
||
})
|
||
return
|
||
|
||
# 更新 content-length 并发送
|
||
if cached_start is not None:
|
||
new_headers = _update_content_length(
|
||
cached_start.get("headers", []),
|
||
len(wrapped),
|
||
)
|
||
await send({
|
||
"type": "http.response.start",
|
||
"status": cached_start.get("status", 200),
|
||
"headers": new_headers,
|
||
})
|
||
await send({
|
||
"type": "http.response.body",
|
||
"body": wrapped,
|
||
})
|
||
# 如果 more_body=True,继续收集,不发送
|
||
return
|
||
|
||
try:
|
||
await self.app(scope, receive, send_wrapper)
|
||
except Exception:
|
||
# 中间件自身异常不应阻塞请求,但这里是 app 内部异常,
|
||
# 正常情况下由 ExceptionHandler 处理,此处仅做兜底日志
|
||
logger.exception("ResponseWrapperMiddleware 捕获到未处理异常")
|
||
raise
|
||
|
||
|
||
def _wrap_success_body(original_body: bytes) -> bytes:
|
||
"""将原始 JSON body 包装为 { "code": 0, "data": <parsed_body> }。
|
||
|
||
如果原始 body 为空,data 设为 null。
|
||
"""
|
||
if not original_body or original_body.strip() == b"":
|
||
data: Any = None
|
||
else:
|
||
data = json.loads(original_body)
|
||
|
||
wrapped = {"code": 0, "data": data}
|
||
return json.dumps(wrapped, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||
|
||
|
||
def _update_content_length(
|
||
headers: list[tuple[bytes, bytes] | list],
|
||
new_length: int,
|
||
) -> list[list[bytes]]:
|
||
"""替换 headers 中的 content-length 为新值。"""
|
||
new_headers: list[list[bytes]] = []
|
||
found = False
|
||
for pair in headers:
|
||
k = pair[0] if isinstance(pair[0], bytes) else pair[0].encode()
|
||
v = pair[1] if isinstance(pair[1], bytes) else pair[1].encode()
|
||
if k.lower() == b"content-length":
|
||
new_headers.append([k, str(new_length).encode()])
|
||
found = True
|
||
else:
|
||
new_headers.append([k, v])
|
||
if not found:
|
||
new_headers.append([b"content-length", str(new_length).encode()])
|
||
return new_headers
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Exception Handlers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||
"""HTTPException → { code: <status_code>, message: <detail> }"""
|
||
return JSONResponse(
|
||
status_code=exc.status_code,
|
||
content={"code": exc.status_code, "message": exc.detail},
|
||
)
|
||
|
||
|
||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||
"""未捕获异常 → { code: 500, message: "Internal Server Error" }
|
||
完整堆栈写入服务端日志。"""
|
||
logger.exception("未捕获异常: %s", exc)
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"code": 500, "message": "Internal Server Error"},
|
||
)
|