Files
Neo-ZQYY/apps/backend/app/middleware/response_wrapper.py

197 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
全局响应包装中间件 + 异常处理器。
ResponseWrapperMiddlewareASGI 中间件):
对 JSON 成功响应2xx + application/json自动包装为 { "code": 0, "data": <原始body> }。
跳过条件text/event-streamSSE、非 application/json、非 2xx 状态码。
ExceptionHandler 函数http_exception_handler / unhandled_exception_handler
统一格式化错误响应为 { "code": <status_code>, "message": <detail> }。
"""
from __future__ import annotations
import json
import logging
from typing import Any
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger(__name__)
class ResponseWrapperMiddleware:
"""ASGI 中间件:全局响应包装。
拦截 http.response.start / http.response.body对满足条件的响应
包装为 { "code": 0, "data": <原始body> }。
跳过条件(透传原始响应):
1. content-type 为 text/event-streamSSE 端点)
2. content-type 不包含 application/json文件下载等
3. HTTP 状态码非 2xx错误响应已由 ExceptionHandler 格式化)
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# 仅处理 HTTP 请求WebSocket / lifespan 等直接透传
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# 用于在 send 回调间共享状态
should_wrap = False
status_code = 0
# 缓存 start message包装时需要修改 content-length
cached_start: Message | None = None
# 收集所有 body 分片more_body 场景)
body_parts: list[bytes] = []
async def send_wrapper(message: Message) -> None:
nonlocal should_wrap, status_code, cached_start
if message["type"] == "http.response.start":
status_code = message.get("status", 200)
headers = dict(
(k.lower(), v)
for k, v in (
(k if isinstance(k, bytes) else k.encode(),
v if isinstance(v, bytes) else v.encode())
for k, v in message.get("headers", [])
)
)
content_type = headers.get(b"content-type", b"").decode("latin-1").lower()
# 判断是否需要包装
is_2xx = 200 <= status_code < 300
is_json = "application/json" in content_type
is_sse = "text/event-stream" in content_type
if is_2xx and is_json and not is_sse:
should_wrap = True
# 缓存 start message等 body 完整后再发送(需要更新 content-length
cached_start = message
else:
# 不包装,直接透传
should_wrap = False
await send(message)
return
if message["type"] == "http.response.body":
if not should_wrap:
# 不包装,直接透传
await send(message)
return
# 收集 body 分片
body_parts.append(message.get("body", b""))
more_body = message.get("more_body", False)
if not more_body:
# body 完整,执行包装
original_body = b"".join(body_parts)
try:
wrapped = _wrap_success_body(original_body)
except Exception:
# 包装失败(如 JSON 解析错误),透传原始响应
logger.debug(
"响应包装失败,透传原始响应",
exc_info=True,
)
if cached_start is not None:
await send(cached_start)
await send({
"type": "http.response.body",
"body": original_body,
})
return
# 更新 content-length 并发送
if cached_start is not None:
new_headers = _update_content_length(
cached_start.get("headers", []),
len(wrapped),
)
await send({
"type": "http.response.start",
"status": cached_start.get("status", 200),
"headers": new_headers,
})
await send({
"type": "http.response.body",
"body": wrapped,
})
# 如果 more_body=True继续收集不发送
return
try:
await self.app(scope, receive, send_wrapper)
except Exception:
# 中间件自身异常不应阻塞请求,但这里是 app 内部异常,
# 正常情况下由 ExceptionHandler 处理,此处仅做兜底日志
logger.exception("ResponseWrapperMiddleware 捕获到未处理异常")
raise
def _wrap_success_body(original_body: bytes) -> bytes:
"""将原始 JSON body 包装为 { "code": 0, "data": <parsed_body> }。
如果原始 body 为空data 设为 null。
"""
if not original_body or original_body.strip() == b"":
data: Any = None
else:
data = json.loads(original_body)
wrapped = {"code": 0, "data": data}
return json.dumps(wrapped, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
def _update_content_length(
headers: list[tuple[bytes, bytes] | list],
new_length: int,
) -> list[list[bytes]]:
"""替换 headers 中的 content-length 为新值。"""
new_headers: list[list[bytes]] = []
found = False
for pair in headers:
k = pair[0] if isinstance(pair[0], bytes) else pair[0].encode()
v = pair[1] if isinstance(pair[1], bytes) else pair[1].encode()
if k.lower() == b"content-length":
new_headers.append([k, str(new_length).encode()])
found = True
else:
new_headers.append([k, v])
if not found:
new_headers.append([b"content-length", str(new_length).encode()])
return new_headers
# ---------------------------------------------------------------------------
# Exception Handlers
# ---------------------------------------------------------------------------
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""HTTPException → { code: <status_code>, message: <detail> }"""
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.status_code, "message": exc.detail},
)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""未捕获异常 → { code: 500, message: "Internal Server Error" }
完整堆栈写入服务端日志。"""
logger.exception("未捕获异常: %s", exc)
return JSONResponse(
status_code=500,
content={"code": 500, "message": "Internal Server Error"},
)