# -*- coding: utf-8 -*- """ 全局响应包装中间件 + 异常处理器。 ResponseWrapperMiddleware(ASGI 中间件): 对 JSON 成功响应(2xx + application/json)自动包装为 { "code": 0, "data": <原始body> }。 跳过条件:text/event-stream(SSE)、非 application/json、非 2xx 状态码。 ExceptionHandler 函数(http_exception_handler / unhandled_exception_handler): 统一格式化错误响应为 { "code": , "message": }。 """ from __future__ import annotations import json import logging from typing import Any from fastapi import Request from fastapi.responses import JSONResponse from starlette.exceptions import HTTPException from starlette.types import ASGIApp, Message, Receive, Scope, Send logger = logging.getLogger(__name__) class ResponseWrapperMiddleware: """ASGI 中间件:全局响应包装。 拦截 http.response.start / http.response.body,对满足条件的响应 包装为 { "code": 0, "data": <原始body> }。 跳过条件(透传原始响应): 1. content-type 为 text/event-stream(SSE 端点) 2. content-type 不包含 application/json(文件下载等) 3. HTTP 状态码非 2xx(错误响应已由 ExceptionHandler 格式化) """ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # 仅处理 HTTP 请求,WebSocket / lifespan 等直接透传 if scope["type"] != "http": await self.app(scope, receive, send) return # 用于在 send 回调间共享状态 should_wrap = False status_code = 0 # 缓存 start message,包装时需要修改 content-length cached_start: Message | None = None # 收集所有 body 分片(more_body 场景) body_parts: list[bytes] = [] async def send_wrapper(message: Message) -> None: nonlocal should_wrap, status_code, cached_start if message["type"] == "http.response.start": status_code = message.get("status", 200) headers = dict( (k.lower(), v) for k, v in ( (k if isinstance(k, bytes) else k.encode(), v if isinstance(v, bytes) else v.encode()) for k, v in message.get("headers", []) ) ) content_type = headers.get(b"content-type", b"").decode("latin-1").lower() # 判断是否需要包装 is_2xx = 200 <= status_code < 300 is_json = "application/json" in content_type is_sse = "text/event-stream" in content_type if is_2xx and is_json and not is_sse: should_wrap = True # 缓存 start message,等 body 完整后再发送(需要更新 content-length) cached_start = message else: # 不包装,直接透传 should_wrap = False await send(message) return if message["type"] == "http.response.body": if not should_wrap: # 不包装,直接透传 await send(message) return # 收集 body 分片 body_parts.append(message.get("body", b"")) more_body = message.get("more_body", False) if not more_body: # body 完整,执行包装 original_body = b"".join(body_parts) try: wrapped = _wrap_success_body(original_body) except Exception: # 包装失败(如 JSON 解析错误),透传原始响应 logger.debug( "响应包装失败,透传原始响应", exc_info=True, ) if cached_start is not None: await send(cached_start) await send({ "type": "http.response.body", "body": original_body, }) return # 更新 content-length 并发送 if cached_start is not None: new_headers = _update_content_length( cached_start.get("headers", []), len(wrapped), ) await send({ "type": "http.response.start", "status": cached_start.get("status", 200), "headers": new_headers, }) await send({ "type": "http.response.body", "body": wrapped, }) # 如果 more_body=True,继续收集,不发送 return try: await self.app(scope, receive, send_wrapper) except Exception: # 中间件自身异常不应阻塞请求,但这里是 app 内部异常, # 正常情况下由 ExceptionHandler 处理,此处仅做兜底日志 logger.exception("ResponseWrapperMiddleware 捕获到未处理异常") raise def _wrap_success_body(original_body: bytes) -> bytes: """将原始 JSON body 包装为 { "code": 0, "data": }。 如果原始 body 为空,data 设为 null。 """ if not original_body or original_body.strip() == b"": data: Any = None else: data = json.loads(original_body) wrapped = {"code": 0, "data": data} return json.dumps(wrapped, ensure_ascii=False, separators=(",", ":")).encode("utf-8") def _update_content_length( headers: list[tuple[bytes, bytes] | list], new_length: int, ) -> list[list[bytes]]: """替换 headers 中的 content-length 为新值。""" new_headers: list[list[bytes]] = [] found = False for pair in headers: k = pair[0] if isinstance(pair[0], bytes) else pair[0].encode() v = pair[1] if isinstance(pair[1], bytes) else pair[1].encode() if k.lower() == b"content-length": new_headers.append([k, str(new_length).encode()]) found = True else: new_headers.append([k, v]) if not found: new_headers.append([b"content-length", str(new_length).encode()]) return new_headers # --------------------------------------------------------------------------- # Exception Handlers # --------------------------------------------------------------------------- async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """HTTPException → { code: , message: }""" return JSONResponse( status_code=exc.status_code, content={"code": exc.status_code, "message": exc.detail}, ) async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: """未捕获异常 → { code: 500, message: "Internal Server Error" } 完整堆栈写入服务端日志。""" logger.exception("未捕获异常: %s", exc) return JSONResponse( status_code=500, content={"code": 500, "message": "Internal Server Error"}, )