From b4078b8415b624f2ac511d6a1bb341ffa489330e Mon Sep 17 00:00:00 2001 From: Adolfo Monteiro Date: Mon, 9 Mar 2026 18:50:44 +0000 Subject: [PATCH 1/6] Fix #3680: `MaskErrors` does not mask errors for subscriptions Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client. --- RELEASE.md | 25 +++++++ .../graphql_transport_ws/handlers.py | 4 + .../protocols/graphql_ws/handlers.py | 4 + strawberry/subscriptions/utils.py | 14 ++++ tests/websockets/test_graphql_transport_ws.py | 66 +++++++++++++++++ tests/websockets/test_graphql_ws.py | 74 +++++++++++++++++++ 6 files changed, 187 insertions(+) create mode 100644 RELEASE.md create mode 100644 strawberry/subscriptions/utils.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..3113c6b6c9 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,25 @@ +Release type: patch + +Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client. + +### Description +Fixes an issue where schema extensions (such as `MaskErrors`) were being bypassed when streaming data over WebSockets. + +Previously, standard Queries and Mutations would pass their results through the extension pipeline, but Subscriptions would send raw `ExecutionResult` objects directly over the WebSocket. This caused internal/unmasked errors to leak to the client. This PR manually triggers `_process_result` on active extensions right before `send_next` and `send_data_message` dispatch the payload. + +### Migration guide +No migration required. + +### Types of Changes +- [ ] Core +- [x] Bugfix +- [ ] New feature +- [ ] Enhancement/optimization +- [ ] Documentation + +### Checklist +- [x] My code follows the code style of this project. +- [ ] My change requires a change to the documentation. +- [x] I have read the CONTRIBUTING document. +- [x] I have added tests to cover my changes. +- [x] I have tested the changes and verified that they work and don't break anything (as well as I can manage). diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 402577a5b9..04403b8dc7 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -29,6 +29,7 @@ PongMessage, SubscribeMessage, ) +from strawberry.subscriptions.utils import process_extensions from strawberry.types import ExecutionResult from strawberry.types.execution import PreExecutionError from strawberry.types.graphql import OperationType @@ -394,6 +395,9 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None: ) async def send_next(self, execution_result: ExecutionResult) -> None: + extensions = getattr(self.handler.schema, "extensions", []) + process_extensions(execution_result, extensions) + next_payload: NextMessagePayload = {"data": execution_result.data} if execution_result.errors: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 21979ff23d..1330c96d44 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -24,6 +24,7 @@ StartMessage, StopMessage, ) +from strawberry.subscriptions.utils import process_extensions from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.types.unset import UnsetType @@ -216,6 +217,9 @@ async def cleanup(self) -> None: async def send_data_message( self, execution_result: ExecutionResult, operation_id: str ) -> None: + extensions = getattr(self.schema, "extensions", []) + process_extensions(execution_result, extensions) + data_message: DataMessage = { "type": "data", "id": operation_id, diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py new file mode 100644 index 0000000000..c52db15a81 --- /dev/null +++ b/strawberry/subscriptions/utils.py @@ -0,0 +1,14 @@ +from typing import Any + +from strawberry.types import ExecutionResult + + +def process_extensions( + execution_result: ExecutionResult, extensions: list[Any] +) -> None: + """Run the execution result through active schema extensions.""" + for ext in extensions: + extension_instance = ext() if isinstance(ext, type) else ext + + if hasattr(extension_instance, "_process_result"): + extension_instance._process_result(execution_result) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index d5a212d4d2..7783012da7 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1216,3 +1216,69 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( assert not process_errors.called assert Subscription.active_infinity_subscriptions == 0 + + +@patch.object(MyExtension, "_process_result", create=True) +async def test_subscription_errors_trigger_extension_process_result( + mock: Mock, ws: WebSocketClient +): + """Test that schema extensions are called to process results when a subscription yields an error.""" + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "TEST EXC") }', + }, + } + ) + + next_message: NextMessage = await ws.receive_json() + + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + assert "errors" in next_message["payload"] + + # Error intercepted and extension called + mock.assert_called_once() + + +async def test_subscription_error_masking_end_to_end( + http_client_class: type[HttpClient], +): + """Test that the real MaskErrors extension successfully masks the payload text.""" + import strawberry + from strawberry.extensions import MaskErrors + from tests.views.schema import Query, Subscription + + # Create a custom schema with the real MaskErrors extension attached + custom_schema = strawberry.Schema( + query=Query, + subscription=Subscription, + extensions=[MaskErrors(error_message="Unexpected error.")], + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.send_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "Super secret database error") }', + }, + } + ) + + next_message: NextMessage = await ws.receive_json() + + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + assert "errors" in next_message["payload"] + + assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index b41a127d3b..0814fc0e9f 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -865,3 +865,77 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( assert not process_errors.called assert Subscription.active_infinity_subscriptions == 0 + + +@mock.patch.object(MyExtension, "_process_result", create=True) +async def test_subscription_errors_trigger_extension_process_result( + mock: mock.MagicMock, ws: WebSocketClient +): + """Test that schema extensions are called to process results when a subscription yields an error.""" + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { exception(message: "TEST EXC") }', + }, + } + ) + + data_message: DataMessage = await ws.receive_json() + + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert "errors" in data_message["payload"] + + # Error intercepted and extension called + mock.assert_called_once() + + await ws.send_legacy_message({"type": "stop", "id": "demo"}) + complete_message = await ws.receive_json() + assert complete_message["type"] == "complete" + + +async def test_subscription_error_masking_end_to_end( + http_client_class: type[HttpClient], +): + """Test that the real MaskErrors extension successfully masks the legacy payload text.""" + import strawberry + from strawberry.extensions import MaskErrors + from tests.views.schema import Query, Subscription + + # Create a custom schema with the real MaskErrors extension attached + custom_schema = strawberry.Schema( + query=Query, + subscription=Subscription, + extensions=[MaskErrors(error_message="Unexpected error.")], + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.send_legacy_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { exception(message: "Super secret database error") }', + }, + } + ) + + data_message: DataMessage = await ws.receive_json() + + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert "errors" in data_message["payload"] + + assert data_message["payload"]["errors"][0]["message"] == "Unexpected error." + + # Clean up the socket + await ws.send_legacy_message({"type": "stop", "id": "demo"}) + await ws.receive_json() From 2e7af1e24651db1b30feb45b8e8518b522c5d330 Mon Sep 17 00:00:00 2001 From: Adolfo Monteiro Date: Mon, 23 Mar 2026 21:57:54 +0000 Subject: [PATCH 2/6] fix: pass execution_context to subscription extension constructors --- strawberry/subscriptions/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py index c52db15a81..cbd65c90c9 100644 --- a/strawberry/subscriptions/utils.py +++ b/strawberry/subscriptions/utils.py @@ -8,7 +8,7 @@ def process_extensions( ) -> None: """Run the execution result through active schema extensions.""" for ext in extensions: - extension_instance = ext() if isinstance(ext, type) else ext + extension_instance = ext(execution_context=None) if isinstance(ext, type) else ext if hasattr(extension_instance, "_process_result"): extension_instance._process_result(execution_result) From ffcb5529fec5b1cc2642d780ef7d47bdb54b5f62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:58:15 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry/subscriptions/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py index cbd65c90c9..1ab8ef07c0 100644 --- a/strawberry/subscriptions/utils.py +++ b/strawberry/subscriptions/utils.py @@ -8,7 +8,9 @@ def process_extensions( ) -> None: """Run the execution result through active schema extensions.""" for ext in extensions: - extension_instance = ext(execution_context=None) if isinstance(ext, type) else ext + extension_instance = ( + ext(execution_context=None) if isinstance(ext, type) else ext + ) if hasattr(extension_instance, "_process_result"): extension_instance._process_result(execution_result) From 49063704f6e35aa36b3b0958186ffc8e43dce877 Mon Sep 17 00:00:00 2001 From: Adolfo Monteiro Date: Mon, 23 Mar 2026 23:15:18 +0000 Subject: [PATCH 4/6] fix: support both extension class styles in subscription process_extensions --- strawberry/subscriptions/utils.py | 15 ++++-- tests/websockets/test_graphql_transport_ws.py | 47 +++++++++++++++++ tests/websockets/test_graphql_ws.py | 51 +++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py index 1ab8ef07c0..9d0f33c1b5 100644 --- a/strawberry/subscriptions/utils.py +++ b/strawberry/subscriptions/utils.py @@ -8,9 +8,18 @@ def process_extensions( ) -> None: """Run the execution result through active schema extensions.""" for ext in extensions: - extension_instance = ( - ext(execution_context=None) if isinstance(ext, type) else ext - ) + if isinstance(ext, type): + try: + # Try passing the context for extensions like ApolloTracing + extension_instance = ext(execution_context=None) + except TypeError: + # Fallback for extensions like MaskErrors that don't want it + extension_instance = ext() + + # Explicitly set this ONLY for newly constructed instances + extension_instance.execution_context = None + else: + extension_instance = ext if hasattr(extension_instance, "_process_result"): extension_instance._process_result(execution_result) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 7783012da7..413021e5c6 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1282,3 +1282,50 @@ async def test_subscription_error_masking_end_to_end( assert "errors" in next_message["payload"] assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." + + +async def test_subscription_masking_with_class_extension( + http_client_class: type[HttpClient], +): + """Test that passing an extension as a class (not instance) successfully masks errors.""" + import strawberry + from strawberry.extensions import SchemaExtension + from tests.views.schema import Query, Subscription + + class CustomClassExtension(SchemaExtension): + def __init__(self, execution_context=None): + self.execution_context = execution_context # type: ignore + + def _process_result(self, execution_result): + if execution_result.errors: + execution_result.errors[0].message = "Unexpected error." + + # Create a custom schema passing the CLASS, not an instance + custom_schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[CustomClassExtension] + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.send_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_message( + { + "id": "sub_class", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "Secret database error") }', + }, + } + ) + + next_message: NextMessage = await ws.receive_json() + + assert next_message["type"] == "next" + assert next_message["id"] == "sub_class" + assert "errors" in next_message["payload"] + + assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 0814fc0e9f..19754d2c47 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -939,3 +939,54 @@ async def test_subscription_error_masking_end_to_end( # Clean up the socket await ws.send_legacy_message({"type": "stop", "id": "demo"}) await ws.receive_json() + + +async def test_subscription_masking_with_class_extension( + http_client_class: type[HttpClient], +): + """Test that passing an extension as a class (not instance) successfully masks the legacy payload text.""" + import strawberry + from strawberry.extensions import SchemaExtension + from tests.views.schema import Query, Subscription + + class CustomClassExtension(SchemaExtension): + def __init__(self, execution_context=None): + self.execution_context = execution_context # type: ignore + + def _process_result(self, execution_result): + if execution_result.errors: + execution_result.errors[0].message = "Unexpected error." + + # Create a custom schema passing the CLASS, not an instance + custom_schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[CustomClassExtension] + ) + test_client = http_client_class(custom_schema) + + async with test_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.send_legacy_message({"type": "connection_init"}) + await ws.receive_json() + + await ws.send_legacy_message( + { + "type": "start", + "id": "sub_class", + "payload": { + "query": 'subscription { exception(message: "Secret database error") }', + }, + } + ) + + data_message: DataMessage = await ws.receive_json() + + assert data_message["type"] == "data" + assert data_message["id"] == "sub_class" + assert "errors" in data_message["payload"] + + assert data_message["payload"]["errors"][0]["message"] == "Unexpected error." + + # Clean up the socket + await ws.send_legacy_message({"type": "stop", "id": "sub_class"}) + await ws.receive_json() From d5e18b1ac2ad9a184f6a85324e7affd5364cd6bf Mon Sep 17 00:00:00 2001 From: Adolfo Monteiro Date: Tue, 24 Mar 2026 12:54:27 +0000 Subject: [PATCH 5/6] fix: inspect signature before calling --- strawberry/subscriptions/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py index 9d0f33c1b5..85e63df6e9 100644 --- a/strawberry/subscriptions/utils.py +++ b/strawberry/subscriptions/utils.py @@ -1,3 +1,4 @@ +import inspect from typing import Any from strawberry.types import ExecutionResult @@ -9,11 +10,11 @@ def process_extensions( """Run the execution result through active schema extensions.""" for ext in extensions: if isinstance(ext, type): - try: - # Try passing the context for extensions like ApolloTracing + # Inspect the constructor to see if it requires execution_context + sig = inspect.signature(ext.__init__) + if "execution_context" in sig.parameters: extension_instance = ext(execution_context=None) - except TypeError: - # Fallback for extensions like MaskErrors that don't want it + else: extension_instance = ext() # Explicitly set this ONLY for newly constructed instances From c225927c065c8209d1021dcd4b1590ef9f1f475d Mon Sep 17 00:00:00 2001 From: Adolfo Monteiro Date: Tue, 24 Mar 2026 18:27:47 +0000 Subject: [PATCH 6/6] refactor: scope extension instantiation to individual subscription operations Extensions are now built once per operation and stored on the operation itself. --- .../graphql_transport_ws/handlers.py | 10 +++++--- .../protocols/graphql_ws/handlers.py | 14 ++++++++--- strawberry/subscriptions/utils.py | 25 +++++++++++-------- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 04403b8dc7..88b378bb7b 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -29,7 +29,7 @@ PongMessage, SubscribeMessage, ) -from strawberry.subscriptions.utils import process_extensions +from strawberry.subscriptions.utils import build_operation_extensions from strawberry.types import ExecutionResult from strawberry.types.execution import PreExecutionError from strawberry.types.graphql import OperationType @@ -348,6 +348,7 @@ class Operation(Generic[Context, RootValue]): "completed", "handler", "id", + "operation_extensions", "operation_name", "operation_type", "query", @@ -372,6 +373,8 @@ def __init__( self.operation_name = operation_name self.completed = False self.task: asyncio.Task | None = None + schema_extensions = getattr(self.handler.schema, "extensions", []) + self.operation_extensions = build_operation_extensions(schema_extensions) async def send_operation_message(self, message: Message) -> None: if self.completed: @@ -395,8 +398,9 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None: ) async def send_next(self, execution_result: ExecutionResult) -> None: - extensions = getattr(self.handler.schema, "extensions", []) - process_extensions(execution_result, extensions) + for ext in self.operation_extensions: + if hasattr(ext, "_process_result"): + ext._process_result(execution_result) next_payload: NextMessagePayload = {"data": execution_result.data} diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 1330c96d44..ee52b9f286 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -24,7 +24,7 @@ StartMessage, StopMessage, ) -from strawberry.subscriptions.utils import process_extensions +from strawberry.subscriptions.utils import build_operation_extensions from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.types.unset import UnsetType @@ -56,6 +56,7 @@ def __init__( self.keep_alive_task: asyncio.Task | None = None self.subscriptions: dict[str, AsyncGenerator] = {} self.tasks: dict[str, asyncio.Task] = {} + self.operation_extensions: dict[str, list[Any]] = {} async def handle(self) -> None: try: @@ -140,6 +141,10 @@ async def handle_start(self, message: StartMessage) -> None: operation_id, query, operation_name, variables ) self.tasks[operation_id] = asyncio.create_task(result_handler) + schema_extensions = getattr(self.schema, "extensions", []) + self.operation_extensions[operation_id] = build_operation_extensions( + schema_extensions + ) async def handle_stop(self, message: StopMessage) -> None: operation_id = message["id"] @@ -210,6 +215,8 @@ async def cleanup_operation(self, operation_id: str) -> None: await self.tasks[operation_id] del self.tasks[operation_id] + self.operation_extensions.pop(operation_id, None) + async def cleanup(self) -> None: for operation_id in list(self.tasks.keys()): await self.cleanup_operation(operation_id) @@ -217,8 +224,9 @@ async def cleanup(self) -> None: async def send_data_message( self, execution_result: ExecutionResult, operation_id: str ) -> None: - extensions = getattr(self.schema, "extensions", []) - process_extensions(execution_result, extensions) + for ext in self.operation_extensions.get(operation_id, []): + if hasattr(ext, "_process_result"): + ext._process_result(execution_result) data_message: DataMessage = { "type": "data", diff --git a/strawberry/subscriptions/utils.py b/strawberry/subscriptions/utils.py index 85e63df6e9..81429a6711 100644 --- a/strawberry/subscriptions/utils.py +++ b/strawberry/subscriptions/utils.py @@ -1,26 +1,29 @@ import inspect from typing import Any -from strawberry.types import ExecutionResult +def build_operation_extensions(extensions: list[Any]) -> list[Any]: + """Build a fresh, isolated set of extensions for a single operation. -def process_extensions( - execution_result: ExecutionResult, extensions: list[Any] -) -> None: - """Run the execution result through active schema extensions.""" + Only extensions that are either pre-constructed instances or + zero-argument classes (optionally accepting execution_context) are + fully supported. Class-based extensions requiring custom constructor + arguments should be passed as pre-constructed instances. + + Full extension lifecycle support (on_operation, on_execute, etc.) + for subscriptions should be tracked as a separate issue. + """ + instances = [] for ext in extensions: if isinstance(ext, type): - # Inspect the constructor to see if it requires execution_context sig = inspect.signature(ext.__init__) if "execution_context" in sig.parameters: extension_instance = ext(execution_context=None) else: extension_instance = ext() - # Explicitly set this ONLY for newly constructed instances extension_instance.execution_context = None + instances.append(extension_instance) else: - extension_instance = ext - - if hasattr(extension_instance, "_process_result"): - extension_instance._process_result(execution_result) + instances.append(ext) + return instances