diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..3113c6b6c9 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,25 @@ +Release type: patch + +Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client. + +### Description +Fixes an issue where schema extensions (such as `MaskErrors`) were being bypassed when streaming data over WebSockets. + +Previously, standard Queries and Mutations would pass their results through the extension pipeline, but Subscriptions would send raw `ExecutionResult` objects directly over the WebSocket. This caused internal/unmasked errors to leak to the client. This PR manually triggers `_process_result` on active extensions right before `send_next` and `send_data_message` dispatch the payload. + +### Migration guide +No migration required. + +### Types of Changes +- [ ] Core +- [x] Bugfix +- [ ] New feature +- [ ] Enhancement/optimization +- [ ] Documentation + +### Checklist +- [x] My code follows the code style of this project. +- [ ] My change requires a change to the documentation. +- [x] I have read the CONTRIBUTING document. +- [x] I have added tests to cover my changes. +- [x] I have tested the changes and verified that they work and don't break anything (as well as I can manage). diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 402577a5b9..88b378bb7b 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -29,6 +29,7 @@ PongMessage, SubscribeMessage, ) +from strawberry.subscriptions.utils import build_operation_extensions from strawberry.types import ExecutionResult from strawberry.types.execution import PreExecutionError from strawberry.types.graphql import OperationType @@ -347,6 +348,7 @@ class Operation(Generic[Context, RootValue]): "completed", "handler", "id", + "operation_extensions", "operation_name", "operation_type", "query", @@ -371,6 +373,8 @@ def __init__( self.operation_name = operation_name self.completed = False self.task: asyncio.Task | None = None + schema_extensions = getattr(self.handler.schema, "extensions", []) + self.operation_extensions = build_operation_extensions(schema_extensions) async def send_operation_message(self, message: Message) -> None: if self.completed: @@ -394,6 +398,10 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None: ) async def send_next(self, execution_result: ExecutionResult) -> None: + for ext in self.operation_extensions: + if hasattr(ext, "_process_result"): + ext._process_result(execution_result) + next_payload: NextMessagePayload = {"data": execution_result.data} if execution_result.errors: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 21979ff23d..ee52b9f286 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -24,6 +24,7 @@ StartMessage, StopMessage, ) +from strawberry.subscriptions.utils import build_operation_extensions from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.types.unset import UnsetType @@ -55,6 +56,7 @@ def __init__( self.keep_alive_task: asyncio.Task | None = None self.subscriptions: dict[str, AsyncGenerator] = {} self.tasks: dict[str, asyncio.Task] = {} + self.operation_extensions: dict[str, list[Any]] = {} async def handle(self) -> None: try: @@ -139,6 +141,10 @@ async def handle_start(self, message: StartMessage) -> None: operation_id, query, operation_name, variables ) self.tasks[operation_id] = asyncio.create_task(result_handler) + schema_extensions = getattr(self.schema, "extensions", []) + self.operation_extensions[operation_id] = build_operation_extensions( + schema_extensions + ) async def handle_stop(self, message: StopMessage) -> None: operation_id = message["id"] @@ -209,6 +215,8 @@ async def cleanup_operation(self, operation_id: str) -> None: await self.tasks[operation_id] del self.tasks[operation_id] + self.operation_extensions.pop(operation_id, None) + async def cleanup(self) -> None: for operation_id in list(self.tasks.keys()): await self.cleanup_operation(operation_id) @@ -216,6 +224,10 @@ async def cleanup(self) -> None: async def send_data_message( self, execution_result: ExecutionResult, operation_id: str ) -> None: + for ext in self.operation_extensions.get(operation_id, []): + if hasattr(ext, "_process_result"): + ext._process_result(execution_result) + data_message: DataMessage = { "type": "data", "id": operation_id, diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py new file mode 100644 index 0000000000..81429a6711 --- /dev/null +++ b/strawberry/subscriptions/utils.py @@ -0,0 +1,29 @@ +import inspect +from typing import Any + + +def build_operation_extensions(extensions: list[Any]) -> list[Any]: + """Build a fresh, isolated set of extensions for a single operation. + + Only extensions that are either pre-constructed instances or + zero-argument classes (optionally accepting execution_context) are + fully supported. Class-based extensions requiring custom constructor + arguments should be passed as pre-constructed instances. + + Full extension lifecycle support (on_operation, on_execute, etc.) + for subscriptions should be tracked as a separate issue. + """ + instances = [] + for ext in extensions: + if isinstance(ext, type): + sig = inspect.signature(ext.__init__) + if "execution_context" in sig.parameters: + extension_instance = ext(execution_context=None) + else: + extension_instance = ext() + + extension_instance.execution_context = None + instances.append(extension_instance) + else: + instances.append(ext) + return instances diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index d5a212d4d2..413021e5c6 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1216,3 +1216,116 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( assert not process_errors.called assert Subscription.active_infinity_subscriptions == 0 + + +@patch.object(MyExtension, "_process_result", create=True) +async def test_subscription_errors_trigger_extension_process_result( + mock: Mock, ws: WebSocketClient +): + """Test that schema extensions are called to process results when a subscription yields an error.""" + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "TEST EXC") }', + }, + } + ) + + next_message: NextMessage = await ws.receive_json() + + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + assert "errors" in next_message["payload"] + + # Error intercepted and extension called + mock.assert_called_once() + + +async def test_subscription_error_masking_end_to_end( + http_client_class: type[HttpClient], +): + """Test that the real MaskErrors extension successfully masks the payload text.""" + import strawberry + from strawberry.extensions import MaskErrors + from tests.views.schema import Query, Subscription + + # Create a custom schema with the real MaskErrors extension attached + custom_schema = strawberry.Schema( + query=Query, + subscription=Subscription, + extensions=[MaskErrors(error_message="Unexpected error.")], + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.send_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "Super secret database error") }', + }, + } + ) + + next_message: NextMessage = await ws.receive_json() + + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + assert "errors" in next_message["payload"] + + assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." + + +async def test_subscription_masking_with_class_extension( + http_client_class: type[HttpClient], +): + """Test that passing an extension as a class (not instance) successfully masks errors.""" + import strawberry + from strawberry.extensions import SchemaExtension + from tests.views.schema import Query, Subscription + + class CustomClassExtension(SchemaExtension): + def __init__(self, execution_context=None): + self.execution_context = execution_context # type: ignore + + def _process_result(self, execution_result): + if execution_result.errors: + execution_result.errors[0].message = "Unexpected error." + + # Create a custom schema passing the CLASS, not an instance + custom_schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[CustomClassExtension] + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.send_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_message( + { + "id": "sub_class", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "Secret database error") }', + }, + } + ) + + next_message: NextMessage = await ws.receive_json() + + assert next_message["type"] == "next" + assert next_message["id"] == "sub_class" + assert "errors" in next_message["payload"] + + assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index b41a127d3b..19754d2c47 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -865,3 +865,128 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( assert not process_errors.called assert Subscription.active_infinity_subscriptions == 0 + + +@mock.patch.object(MyExtension, "_process_result", create=True) +async def test_subscription_errors_trigger_extension_process_result( + mock: mock.MagicMock, ws: WebSocketClient +): + """Test that schema extensions are called to process results when a subscription yields an error.""" + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { exception(message: "TEST EXC") }', + }, + } + ) + + data_message: DataMessage = await ws.receive_json() + + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert "errors" in data_message["payload"] + + # Error intercepted and extension called + mock.assert_called_once() + + await ws.send_legacy_message({"type": "stop", "id": "demo"}) + complete_message = await ws.receive_json() + assert complete_message["type"] == "complete" + + +async def test_subscription_error_masking_end_to_end( + http_client_class: type[HttpClient], +): + """Test that the real MaskErrors extension successfully masks the legacy payload text.""" + import strawberry + from strawberry.extensions import MaskErrors + from tests.views.schema import Query, Subscription + + # Create a custom schema with the real MaskErrors extension attached + custom_schema = strawberry.Schema( + query=Query, + subscription=Subscription, + extensions=[MaskErrors(error_message="Unexpected error.")], + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.send_legacy_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { exception(message: "Super secret database error") }', + }, + } + ) + + data_message: DataMessage = await ws.receive_json() + + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert "errors" in data_message["payload"] + + assert data_message["payload"]["errors"][0]["message"] == "Unexpected error." + + # Clean up the socket + await ws.send_legacy_message({"type": "stop", "id": "demo"}) + await ws.receive_json() + + +async def test_subscription_masking_with_class_extension( + http_client_class: type[HttpClient], +): + """Test that passing an extension as a class (not instance) successfully masks the legacy payload text.""" + import strawberry + from strawberry.extensions import SchemaExtension + from tests.views.schema import Query, Subscription + + class CustomClassExtension(SchemaExtension): + def __init__(self, execution_context=None): + self.execution_context = execution_context # type: ignore + + def _process_result(self, execution_result): + if execution_result.errors: + execution_result.errors[0].message = "Unexpected error." + + # Create a custom schema passing the CLASS, not an instance + custom_schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[CustomClassExtension] + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.send_legacy_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_legacy_message( + { + "type": "start", + "id": "sub_class", + "payload": { + "query": 'subscription { exception(message: "Secret database error") }', + }, + } + ) + + data_message: DataMessage = await ws.receive_json() + + assert data_message["type"] == "data" + assert data_message["id"] == "sub_class" + assert "errors" in data_message["payload"] + + assert data_message["payload"]["errors"][0]["message"] == "Unexpected error." + + # Clean up the socket + await ws.send_legacy_message({"type": "stop", "id": "sub_class"}) + await ws.receive_json()