# -*- coding: utf-8 -*- """ TraceMiddleware — ASGI 中间件 拦截 xcx_* 路由前缀(/api/xcx/)的请求,创建 TraceContext 并记录全链路 span。 非 xcx 路由直接跳过,不创建 TraceContext。 DEV_TRACE_ENABLED 关闭时跳过所有采集。 记录的 span 类型: - HTTP_IN: 请求进入(method, path, query_params, body_preview) - HTTP_OUT: 请求结束(status_code, duration, body_size) - MIDDLEWARE: ResponseWrapperMiddleware 执行耗时 - MIDDLEWARE_ERROR: 响应包装失败时记录 """ from __future__ import annotations import logging import time from datetime import datetime from urllib.parse import unquote from starlette.types import ASGIApp, Message, Receive, Scope, Send from app.trace.config import get_trace_config from app.trace.context import ( SpanType, TraceContext, TraceSpan, create_http_trace, get_current_trace, set_current_trace, trace_context_var, ) from app.trace.writer import get_trace_writer logger = logging.getLogger(__name__) # xcx 路由前缀——仅匹配此前缀的请求才采集 trace XCX_PATH_PREFIX = "/api/xcx/" def _should_trace(path: str) -> bool: """判断请求路径是否属于 xcx_* 路由前缀,需要采集 trace。""" return path.startswith(XCX_PATH_PREFIX) class TraceMiddleware: """ASGI 中间件:全链路请求追踪。 执行顺序(最外层,最先执行): TraceMiddleware → CORSMiddleware → ResponseWrapperMiddleware → 路由处理 职责: 1. 检查 DEV_TRACE_ENABLED 开关(运行时检查,支持动态切换) 2. 仅拦截 /api/xcx/ 前缀的请求 3. 创建 TraceContext 存入 contextvars 4. 记录 HTTP_IN / HTTP_OUT / MIDDLEWARE / MIDDLEWARE_ERROR span 5. 响应头写入 X-Request-ID, X-Process-Time, X-DB-Queries, X-DB-Time 6. 调用 TraceWriter 写入完整 trace """ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # 仅处理 HTTP 请求 if scope["type"] != "http": await self.app(scope, receive, send) return path = scope.get("path", "") config = get_trace_config() # 开关关闭 或 非 xcx 路由 → 直接透传 if not config.enabled or not _should_trace(path): await self.app(scope, receive, send) return # ── 创建 TraceContext ── method = scope.get("method", "UNKNOWN") query_string = scope.get("query_string", b"").decode("latin-1", errors="replace") ctx = create_http_trace(method, path) token = set_current_trace(ctx) # 记录 HTTP_IN span query_params = {} if query_string: # 简单解析 query string 为 dict for pair in query_string.split("&"): if "=" in pair: k, v = pair.split("=", 1) query_params[unquote(k)] = unquote(v) ctx.add_span(TraceSpan( span_type=SpanType.HTTP_IN, module="trace.middleware", function="TraceMiddleware.__call__", description_zh=f"接收请求 {method} {path}", description_en=f"Received request {method} {path}", params={"query": query_params}, result_summary="", duration_ms=0.0, timestamp=datetime.now().isoformat(), )) start_time = time.perf_counter() # ── 拦截响应,注入 trace 头和采集 HTTP_OUT ── status_code = 200 response_body_parts: list[bytes] = [] middleware_start = time.perf_counter() middleware_error: str | None = None async def send_wrapper(message: Message) -> None: nonlocal status_code, middleware_error if message["type"] == "http.response.start": status_code = message.get("status", 200) # 计算 trace 统计数据 elapsed_ms = (time.perf_counter() - start_time) * 1000 current_ctx = get_current_trace() db_queries = 0 db_time_ms = 0.0 if current_ctx: for s in current_ctx.spans: if s.span_type == SpanType.DB_QUERY: db_queries += 1 db_time_ms += s.duration_ms # 注入响应头 headers = list(message.get("headers", [])) headers.append([b"x-request-id", ctx.request_id.encode()]) headers.append([b"x-process-time", f"{elapsed_ms:.1f}ms".encode()]) headers.append([b"x-db-queries", str(db_queries).encode()]) headers.append([b"x-db-time", f"{db_time_ms:.1f}ms".encode()]) message = {**message, "headers": headers} await send(message) return if message["type"] == "http.response.body": body = message.get("body", b"") response_body_parts.append(body) await send(message) return await send(message) try: await self.app(scope, receive, send_wrapper) except Exception as exc: # 即使内层异常,也要记录 trace middleware_error = f"{type(exc).__name__}: {exc}" raise finally: elapsed_ms = (time.perf_counter() - start_time) * 1000 middleware_elapsed_ms = (time.perf_counter() - middleware_start) * 1000 body_size = sum(len(p) for p in response_body_parts) # 记录 MIDDLEWARE span(ResponseWrapperMiddleware 执行耗时) if middleware_error: ctx.add_span(TraceSpan( span_type=SpanType.MIDDLEWARE_ERROR, module="middleware.response_wrapper", function="ResponseWrapperMiddleware.__call__", description_zh=f"响应包装失败: {middleware_error}", description_en=f"Response wrapping failed: {middleware_error}", params={}, result_summary=middleware_error, duration_ms=middleware_elapsed_ms, timestamp=datetime.now().isoformat(), )) else: ctx.add_span(TraceSpan( span_type=SpanType.MIDDLEWARE, module="middleware.response_wrapper", function="ResponseWrapperMiddleware.__call__", description_zh="响应包装中间件执行完成", description_en="Response wrapper middleware completed", params={}, result_summary=f"body_size={body_size}", duration_ms=middleware_elapsed_ms, timestamp=datetime.now().isoformat(), extra={"body_size": body_size}, )) # 记录 HTTP_OUT span ctx.add_span(TraceSpan( span_type=SpanType.HTTP_OUT, module="trace.middleware", function="TraceMiddleware.__call__", description_zh=f"响应返回 {status_code},耗时 {elapsed_ms:.0f}ms", description_en=f"Response sent {status_code}, took {elapsed_ms:.0f}ms", params={}, result_summary=f"{status_code}, {body_size}B body", duration_ms=elapsed_ms, timestamp=datetime.now().isoformat(), extra={"status_code": status_code, "body_size": body_size}, )) # 写入 trace 日志 try: writer = get_trace_writer() await writer.write_trace(ctx) except Exception: logger.warning("Trace 日志写入失败", exc_info=True) # 恢复 contextvars trace_context_var.reset(token)