Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ build-backend = "setuptools.build_meta"
authors = [{ name = "chatos@alibaba" }]
requires-python = "<4.0,>=3.10"
name = "rl-rock"
version = "1.4.7"
version = "1.4.6"
version = "1.4.7"
description = "ROCK-Reinforcement Open Construction Kit"
readme = "README.md"
dependencies = [
Expand Down
73 changes: 37 additions & 36 deletions tests/unit/sandbox/test_websocket_proxy_subprotocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for WebSocket proxy subprotocol forwarding and performance fixes."""

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch, call

Expand Down Expand Up @@ -97,11 +98,11 @@ def fake_connect(url, **kwargs):
assert "subprotocols" in connect_kwargs
assert "binary" in connect_kwargs["subprotocols"]

async def test_no_subprotocol_falls_back_to_binary_base64(self):
"""When client sends no subprotocols, connect() should fall back to ['binary', 'base64'].
async def test_no_subprotocol_falls_back_to_binary(self):
"""When client sends no subprotocols, connect() should fall back to ['binary'].

websockify rejects connections without a Sec-WebSocket-Protocol header,
so we always send the default subprotocols when the client doesn't declare any.
so we always send the default subprotocol when the client doesn't declare any.
"""
service = _make_service()
client_ws = _make_client_ws(subprotocols=[])
Expand All @@ -121,7 +122,7 @@ def fake_connect(url, **kwargs):
with patch("rock.sandbox.service.sandbox_proxy_service.websockets.connect", side_effect=fake_connect):
await SandboxProxyService.websocket_proxy(service, client_ws, "sb1", None, port=8006)

assert connect_kwargs.get("subprotocols") == ["binary", "base64"]
assert connect_kwargs.get("subprotocols") == ["binary"]

async def test_negotiated_subprotocol_passed_to_client_accept(self):
"""After upstream negotiates subprotocol, client accept() must be called with it."""
Expand All @@ -140,9 +141,7 @@ async def test_negotiated_subprotocol_passed_to_client_accept(self):
# accept 必须带上协商好的子协议
client_ws.accept.assert_called_once()
call_kwargs = client_ws.accept.call_args
subprotocol = call_kwargs.kwargs.get("subprotocol") or (
call_kwargs.args[0] if call_kwargs.args else None
)
subprotocol = call_kwargs.kwargs.get("subprotocol") or (call_kwargs.args[0] if call_kwargs.args else None)
assert subprotocol == "binary"


Expand Down Expand Up @@ -175,10 +174,12 @@ async def test_binary_message_forwarded_without_sleep(self):
"""Binary messages should be forwarded directly without any sleep."""
service = MagicMock(spec=SandboxProxyService)

source_ws = FakeStarletteWebSocket([
{"type": "websocket.receive", "bytes": b"\x00\x01\x02"},
{"type": "websocket.disconnect", "code": 1000},
])
source_ws = FakeStarletteWebSocket(
[
{"type": "websocket.receive", "bytes": b"\x00\x01\x02"},
{"type": "websocket.disconnect", "code": 1000},
]
)

target_ws = MagicMock(spec=["recv", "send"])
target_ws.send = AsyncMock()
Expand Down Expand Up @@ -208,19 +209,19 @@ async def test_all_binary_frames_forwarded_to_target(self):
"""Every binary frame from client must reach the upstream target."""
service = MagicMock(spec=SandboxProxyService)

source_ws = FakeStarletteWebSocket([
{"type": "websocket.receive", "bytes": b"vnc_frame_1"},
{"type": "websocket.receive", "bytes": b"vnc_frame_2"},
{"type": "websocket.receive", "bytes": b"vnc_frame_3"},
{"type": "websocket.disconnect", "code": 1000},
])
source_ws = FakeStarletteWebSocket(
[
{"type": "websocket.receive", "bytes": b"vnc_frame_1"},
{"type": "websocket.receive", "bytes": b"vnc_frame_2"},
{"type": "websocket.receive", "bytes": b"vnc_frame_3"},
{"type": "websocket.disconnect", "code": 1000},
]
)

target_ws = MagicMock(spec=["recv", "send"])
target_ws.send = AsyncMock()

await SandboxProxyService._forward_messages(
service, source_ws, target_ws, "client->target"
)
await SandboxProxyService._forward_messages(service, source_ws, target_ws, "client->target")

forwarded = [c.args[0] for c in target_ws.send.call_args_list]
assert forwarded == [b"vnc_frame_1", b"vnc_frame_2", b"vnc_frame_3"]
Expand All @@ -229,19 +230,19 @@ async def test_mixed_text_and_binary_frames_all_forwarded(self):
"""Interleaved text and binary frames must all be forwarded correctly."""
service = MagicMock(spec=SandboxProxyService)

source_ws = FakeStarletteWebSocket([
{"type": "websocket.receive", "text": "hello"},
{"type": "websocket.receive", "bytes": b"\x00\x01"},
{"type": "websocket.receive", "text": "world"},
{"type": "websocket.disconnect", "code": 1000},
])
source_ws = FakeStarletteWebSocket(
[
{"type": "websocket.receive", "text": "hello"},
{"type": "websocket.receive", "bytes": b"\x00\x01"},
{"type": "websocket.receive", "text": "world"},
{"type": "websocket.disconnect", "code": 1000},
]
)

target_ws = MagicMock(spec=["recv", "send"])
target_ws.send = AsyncMock()

await SandboxProxyService._forward_messages(
service, source_ws, target_ws, "client->target"
)
await SandboxProxyService._forward_messages(service, source_ws, target_ws, "client->target")

forwarded = [c.args[0] for c in target_ws.send.call_args_list]
assert forwarded == ["hello", b"\x00\x01", "world"]
Expand All @@ -250,16 +251,16 @@ async def test_single_binary_frame_not_lost(self):
"""Even a single binary frame must not be silently dropped."""
service = MagicMock(spec=SandboxProxyService)

source_ws = FakeStarletteWebSocket([
{"type": "websocket.receive", "bytes": b"rfb_handshake"},
{"type": "websocket.disconnect", "code": 1000},
])
source_ws = FakeStarletteWebSocket(
[
{"type": "websocket.receive", "bytes": b"rfb_handshake"},
{"type": "websocket.disconnect", "code": 1000},
]
)

target_ws = MagicMock(spec=["recv", "send"])
target_ws.send = AsyncMock()

await SandboxProxyService._forward_messages(
service, source_ws, target_ws, "client->target"
)
await SandboxProxyService._forward_messages(service, source_ws, target_ws, "client->target")

target_ws.send.assert_called_once_with(b"rfb_handshake")
Loading