# -*- coding: utf-8 -*- """WebSocket 日志推送端点测试 测试 /ws/logs/{execution_id} 端点的连接、日志回放、实时推送和断开行为。 利用 TaskExecutor 已有的 subscribe/broadcast 机制进行验证。 """ from __future__ import annotations import asyncio import pytest from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect from app.main import app from app.services.task_executor import task_executor @pytest.fixture(autouse=True) def _cleanup_executor(): """每个测试前后清理 TaskExecutor 内部状态。""" yield # 清理所有残留的缓冲区和订阅者 for eid in list(task_executor._log_buffers.keys()): task_executor.cleanup(eid) task_executor._subscribers.clear() task_executor._log_buffers.clear() class TestWebSocketConnection: """WebSocket 连接/断开基本行为""" def test_connect_and_disconnect(self): """客户端能成功建立和关闭 WebSocket 连接。""" client = TestClient(app) with client.websocket_connect("/ws/logs/test-exec-001") as ws: # 连接成功,直接关闭 pass # __exit__ 会关闭连接 def test_connect_registers_subscriber(self): """连接后 TaskExecutor 应注册订阅者。""" client = TestClient(app) # 预先初始化缓冲区(模拟有任务在运行) task_executor._log_buffers["test-exec-002"] = [] with client.websocket_connect("/ws/logs/test-exec-002"): # 连接期间应有订阅者 assert "test-exec-002" in task_executor._subscribers assert len(task_executor._subscribers["test-exec-002"]) >= 1 class TestLogReplay: """历史日志回放""" def test_replay_existing_logs(self): """连接时应先收到内存缓冲区中已有的日志行。""" eid = "test-exec-replay" # 预填充日志缓冲区 task_executor._log_buffers[eid] = [ "[stdout] 第一行", "[stdout] 第二行", "[stderr] 警告信息", ] client = TestClient(app) with client.websocket_connect(f"/ws/logs/{eid}") as ws: # 应按顺序收到 3 条历史日志 msg1 = ws.receive_text() msg2 = ws.receive_text() msg3 = ws.receive_text() assert msg1 == "[stdout] 第一行" assert msg2 == "[stdout] 第二行" assert msg3 == "[stderr] 警告信息" def test_no_logs_no_replay(self): """没有历史日志时不应收到回放消息。""" eid = "test-exec-empty" task_executor._log_buffers[eid] = [] client = TestClient(app) with client.websocket_connect(f"/ws/logs/{eid}") as ws: # 发送结束信号让连接正常关闭 task_executor._broadcast_end(eid) # 不应有回放消息,直接收到的是后续的结束信号处理 class TestBroadcastReceive: """实时日志广播接收""" def test_receive_broadcast_messages(self): """连接后应能收到 TaskExecutor 广播的实时日志。""" eid = "test-exec-broadcast" task_executor._log_buffers[eid] = [] client = TestClient(app) with client.websocket_connect(f"/ws/logs/{eid}") as ws: # 模拟 TaskExecutor 广播日志 task_executor._broadcast(eid, "[stdout] 实时日志行1") task_executor._broadcast(eid, "[stderr] 实时错误行") msg1 = ws.receive_text() msg2 = ws.receive_text() assert msg1 == "[stdout] 实时日志行1" assert msg2 == "[stderr] 实时错误行" def test_end_signal_closes_connection(self): """收到 None 结束信号后 WebSocket 应正常关闭。""" eid = "test-exec-end" task_executor._log_buffers[eid] = [] client = TestClient(app) with client.websocket_connect(f"/ws/logs/{eid}") as ws: # 广播一条日志后发送结束信号 task_executor._broadcast(eid, "[stdout] 最后一行") task_executor._broadcast_end(eid) msg = ws.receive_text() assert msg == "[stdout] 最后一行" def test_replay_then_broadcast(self): """先回放历史日志,再接收实时广播。""" eid = "test-exec-mixed" task_executor._log_buffers[eid] = ["[stdout] 历史行"] client = TestClient(app) with client.websocket_connect(f"/ws/logs/{eid}") as ws: # 先收到历史回放 replay = ws.receive_text() assert replay == "[stdout] 历史行" # 再收到实时广播 task_executor._broadcast(eid, "[stdout] 新行") task_executor._broadcast_end(eid) live = ws.receive_text() assert live == "[stdout] 新行" class TestLogBroadcasterUnit: """直接测试 TaskExecutor 的 subscribe/unsubscribe/broadcast 方法 (作为 LogBroadcaster 功能的单元测试)。 """ def test_subscribe_creates_queue(self): eid = "unit-sub" q = task_executor.subscribe(eid) assert isinstance(q, asyncio.Queue) assert eid in task_executor._subscribers task_executor.unsubscribe(eid, q) def test_unsubscribe_removes_queue(self): eid = "unit-unsub" q = task_executor.subscribe(eid) task_executor.unsubscribe(eid, q) # 最后一个订阅者移除后,key 也应被清理 assert eid not in task_executor._subscribers def test_broadcast_delivers_to_all_subscribers(self): eid = "unit-multi" q1 = task_executor.subscribe(eid) q2 = task_executor.subscribe(eid) task_executor._broadcast(eid, "测试消息") assert q1.get_nowait() == "测试消息" assert q2.get_nowait() == "测试消息" task_executor.unsubscribe(eid, q1) task_executor.unsubscribe(eid, q2) def test_broadcast_end_sends_none(self): eid = "unit-end" q = task_executor.subscribe(eid) task_executor._broadcast_end(eid) assert q.get_nowait() is None task_executor.unsubscribe(eid, q) def test_broadcast_no_subscribers_is_safe(self): """没有订阅者时广播不应报错。""" task_executor._broadcast("nonexistent", "无人接收") task_executor._broadcast_end("nonexistent")