-
-
Notifications
You must be signed in to change notification settings - Fork 627
Fix #3680: MaskErrors does not mask errors for subscriptions
#4301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
b4078b8
2e7af1e
ffcb552
4906370
d5e18b1
c225927
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| Release type: patch | ||
|
|
||
| Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client. | ||
|
|
||
| ### Description | ||
| Fixes an issue where schema extensions (such as `MaskErrors`) were being bypassed when streaming data over WebSockets. | ||
|
|
||
| Previously, standard Queries and Mutations would pass their results through the extension pipeline, but Subscriptions would send raw `ExecutionResult` objects directly over the WebSocket. This caused internal/unmasked errors to leak to the client. This PR manually triggers `_process_result` on active extensions right before `send_next` and `send_data_message` dispatch the payload. | ||
|
|
||
| ### Migration guide | ||
| No migration required. | ||
|
|
||
| ### Types of Changes | ||
| - [ ] Core | ||
| - [x] Bugfix | ||
| - [ ] New feature | ||
| - [ ] Enhancement/optimization | ||
| - [ ] Documentation | ||
|
|
||
| ### Checklist | ||
| - [x] My code follows the code style of this project. | ||
| - [ ] My change requires a change to the documentation. | ||
| - [x] I have read the CONTRIBUTING document. | ||
| - [x] I have added tests to cover my changes. | ||
| - [x] I have tested the changes and verified that they work and don't break anything (as well as I can manage). |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import inspect | ||
| from typing import Any | ||
|
|
||
| from strawberry.types import ExecutionResult | ||
|
|
||
|
|
||
| def process_extensions( | ||
| execution_result: ExecutionResult, extensions: list[Any] | ||
| ) -> None: | ||
| """Run the execution result through active schema extensions.""" | ||
| for ext in extensions: | ||
| if isinstance(ext, type): | ||
| # Inspect the constructor to see if it requires execution_context | ||
| sig = inspect.signature(ext.__init__) | ||
| if "execution_context" in sig.parameters: | ||
| extension_instance = ext(execution_context=None) | ||
| else: | ||
| extension_instance = ext() | ||
|
Comment on lines
+18
to
+23
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The outer When that happens the code falls back to A tighter guard would inspect the signature before calling, rather than catching the exception after: import inspect
sig = inspect.signature(ext.__init__)
if "execution_context" in sig.parameters:
extension_instance = ext(execution_context=None)
else:
extension_instance = ext()This keeps the intent explicit and lets any other
Comment on lines
+22
to
+23
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When a class-based extension is instantiated here via A concrete case: suppose a user registers a custom extension as a class with a default-valued argument that controls behaviour: class RateLimitExtension(SchemaExtension):
def __init__(self, max_errors: int = 10) -> None:
self.max_errors = max_errors
The safest fix for the general case is to avoid re-instantiation entirely: run |
||
|
|
||
| # Explicitly set this ONLY for newly constructed instances | ||
| extension_instance.execution_context = None | ||
|
Comment on lines
+18
to
+25
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For class-based extensions,
The resolved extension instances (with proper If the current approach is kept, consider at minimum caching the per-schema extension list once (not per event) to avoid repeated signature inspection: # Computed once per subscription start, not once per event
extension_instances = _build_extension_instances(schema.extensions) |
||
| else: | ||
| extension_instance = ext | ||
|
|
||
| if hasattr(extension_instance, "_process_result"): | ||
| extension_instance._process_result(execution_result) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1216,3 +1216,116 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( | |||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert not process_errors.called | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert Subscription.active_infinity_subscriptions == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| @patch.object(MyExtension, "_process_result", create=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| async def test_subscription_errors_trigger_extension_process_result( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| mock: Mock, ws: WebSocketClient | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that schema extensions are called to process results when a subscription yields an error.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.send_message( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "id": "sub1", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "subscribe", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "payload": { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "query": 'subscription { exception(message: "TEST EXC") }', | ||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| next_message: NextMessage = await ws.receive_json() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["type"] == "next" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["id"] == "sub1" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "errors" in next_message["payload"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Error intercepted and extension called | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1220
to
+1242
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test validates call count but not actual masking behaviour The test patches Consider adding a complementary integration test that uses a real Also note there is a missing blank line before the
Suggested change
needs two blank lines after the previous test function body. |
||||||||||||||||||||||||||||||||||||||||||||||||||
| mock.assert_called_once() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| async def test_subscription_error_masking_end_to_end( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| http_client_class: type[HttpClient], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that the real MaskErrors extension successfully masks the payload text.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import strawberry | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from strawberry.extensions import MaskErrors | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from tests.views.schema import Query, Subscription | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create a custom schema with the real MaskErrors extension attached | ||||||||||||||||||||||||||||||||||||||||||||||||||
| custom_schema = strawberry.Schema( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| query=Query, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| subscription=Subscription, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| extensions=[MaskErrors(error_message="Unexpected error.")], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| test_client = http_client_class(custom_schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| async with test_client.ws_connect( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) as ws: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.send_message({"type": "connection_init"}) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.receive_json() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.send_message( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "id": "sub1", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "subscribe", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "payload": { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "query": 'subscription { exception(message: "Super secret database error") }', | ||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| next_message: NextMessage = await ws.receive_json() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["type"] == "next" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["id"] == "sub1" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "errors" in next_message["payload"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| async def test_subscription_masking_with_class_extension( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| http_client_class: type[HttpClient], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that passing an extension as a class (not instance) successfully masks errors.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import strawberry | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from strawberry.extensions import SchemaExtension | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from tests.views.schema import Query, Subscription | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class CustomClassExtension(SchemaExtension): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, execution_context=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.execution_context = execution_context # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def _process_result(self, execution_result): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if execution_result.errors: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| execution_result.errors[0].message = "Unexpected error." | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create a custom schema passing the CLASS, not an instance | ||||||||||||||||||||||||||||||||||||||||||||||||||
| custom_schema = strawberry.Schema( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| query=Query, subscription=Subscription, extensions=[CustomClassExtension] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| test_client = http_client_class(custom_schema) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| async with test_client.ws_connect( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) as ws: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.send_message({"type": "connection_init"}) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.receive_json() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| await ws.send_message( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "id": "sub_class", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "subscribe", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "payload": { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "query": 'subscription { exception(message: "Secret database error") }', | ||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| next_message: NextMessage = await ws.receive_json() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["type"] == "next" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["id"] == "sub_class" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "errors" in next_message["payload"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| assert next_message["payload"]["errors"][0]["message"] == "Unexpected error." | ||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
schema.extensionsdoes not reflect the per-request extension runnergetattr(self.handler.schema, "extensions", [])returns the raw list that was originally passed tostrawberry.Schema(extensions=[...]). However,schema.subscribe()builds itsextensions_runnerviaschema._async_extensions, which:execution_contexton them.DirectivesExtensionwhenschema.directivesis non-empty — an extension that is completely absent fromschema.extensions.Because
process_extensionsbypasses the per-request runner, any extension added only insideget_extensions()(likeDirectivesExtension) will never have its_process_resultcalled, and class-based extensions receive a freshly constructed instance rather than the one already wired to the currentExecutionContext.The structurally correct fix is to thread the
SchemaExtensionsRunnerthat_subscribealready has through tosend_next, so that_process_resultis called on the same, already-configured instances that the rest of the operation lifecycle uses.