"""AI 事件 WebSocket 推送端点。 提供: - /ws/ai-cache/{site_id}?token=xxx — 缓存更新 / 失效事件 - /ws/ai-alerts/{site_id}?token=xxx — AI 告警事件(Phase 3.1) 协议: - 客户端连接(必须带 token query)→ 服务端校验 JWT → accept → 订阅 EventBus → 持续 send_json 事件 - 事件格式:{"type": "cache_updated|alert_created|...", "site_id": int, "payload": {...}} - 服务端关闭或客户端断开时清理订阅 W1-AI-CLOSURE 组 7 安全加固(P0-9): - 新增 token query 参数,JWT 解码失败 → 关闭(4401) - site_id 必须与 token 中的 site_id 一致;site_id=-1 全局订阅要求 token 角色含 super_admin - 不再允许任何用户监听任意门店的 AI 事件流 """ from __future__ import annotations import logging from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect, status from jose import JWTError from ..ai.event_bus import AIEvent, get_event_bus from ..auth.jwt import decode_access_token logger = logging.getLogger(__name__) ws_router = APIRouter() _WS_CLOSE_AUTH_FAILED = 4401 # 自定义 close code(WebSocket close codes 4000-4999 给应用使用) def _authorize_ws(token: str | None, site_id: int) -> bool: """校验 WS 连接的 JWT token 与 site_id。 返回 True 表示通过,False 表示拒绝。 site_id == -1 时要求 token roles 含 super_admin。 """ if not token: return False try: payload = decode_access_token(token) except JWTError: return False token_site_id = payload.get("site_id") roles = payload.get("roles") or [] if not isinstance(roles, list): roles = [] if site_id == -1: # 全局订阅(admin-web 全局监控):仅 super_admin return "super_admin" in roles # 站点订阅:token 的 site_id 必须等于 URL 中的 site_id # super_admin 跨站点放行 return token_site_id == site_id or "super_admin" in roles @ws_router.websocket("/ws/ai-cache/{site_id}") async def ws_ai_cache( websocket: WebSocket, site_id: int, token: str | None = Query(default=None), ) -> None: """AI 缓存事件推送。 site_id=-1 表示订阅全局(收所有门店的 cache_updated / cache_invalidated), 需 super_admin token。 """ await _serve_event_stream(websocket, site_id, token=token, endpoint="ai-cache") @ws_router.websocket("/ws/ai-alerts/{site_id}") async def ws_ai_alerts( websocket: WebSocket, site_id: int, token: str | None = Query(default=None), ) -> None: """AI 告警事件推送(Phase 3.1)。 site_id=-1 表示订阅全局告警,需 super_admin token。 事件 type: alert_created / alert_updated / budget_exceeded / circuit_opened。 """ await _serve_event_stream(websocket, site_id, token=token, endpoint="ai-alerts") async def _serve_event_stream( websocket: WebSocket, site_id: int, *, token: str | None, endpoint: str, ) -> None: """共享事件流处理逻辑(含 token 鉴权)。""" if not _authorize_ws(token, site_id): await websocket.close(code=_WS_CLOSE_AUTH_FAILED, reason="auth required") logger.warning( "WS %s 鉴权失败: site_id=%s token_present=%s", endpoint, site_id, bool(token), ) return await websocket.accept() # -1 映射为全局订阅(None) subscribe_key: int | None = None if site_id == -1 else site_id logger.info( "WS %s 连接建立: site_id=%s", endpoint, subscribe_key if subscribe_key else "ALL", ) bus = get_event_bus() queue = await bus.subscribe(subscribe_key) try: while True: event = await queue.get() if event is None: break await websocket.send_json({ "type": event.type, "site_id": event.site_id, "payload": event.payload, }) except WebSocketDisconnect: logger.info("WS %s 客户端断开: site_id=%s", endpoint, subscribe_key) except Exception: logger.exception("WS %s 异常: site_id=%s", endpoint, subscribe_key) finally: await bus.unsubscribe(subscribe_key, queue) try: await websocket.close() except Exception: pass