Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Release type: patch

Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client.

### Description
Fixes an issue where schema extensions (such as `MaskErrors`) were being bypassed when streaming data over WebSockets.

Previously, standard Queries and Mutations would pass their results through the extension pipeline, but Subscriptions would send raw `ExecutionResult` objects directly over the WebSocket. This caused internal/unmasked errors to leak to the client. This PR manually triggers `_process_result` on active extensions right before `send_next` and `send_data_message` dispatch the payload.

### Migration guide
No migration required.

### Types of Changes
- [ ] Core
- [x] Bugfix
- [ ] New feature
- [ ] Enhancement/optimization
- [ ] Documentation

### Checklist
- [x] My code follows the code style of this project.
- [ ] My change requires a change to the documentation.
- [x] I have read the CONTRIBUTING document.
- [x] I have added tests to cover my changes.
- [x] I have tested the changes and verified that they work and don't break anything (as well as I can manage).
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PongMessage,
SubscribeMessage,
)
from strawberry.subscriptions.utils import process_extensions
from strawberry.types import ExecutionResult
from strawberry.types.execution import PreExecutionError
from strawberry.types.graphql import OperationType
Expand Down Expand Up @@ -394,6 +395,9 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None:
)

async def send_next(self, execution_result: ExecutionResult) -> None:
extensions = getattr(self.handler.schema, "extensions", [])
process_extensions(execution_result, extensions)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 schema.extensions does not reflect the per-request extension runner

getattr(self.handler.schema, "extensions", []) returns the raw list that was originally passed to strawberry.Schema(extensions=[...]). However, schema.subscribe() builds its extensions_runner via schema._async_extensions, which:

  1. Instantiates class-based extensions fresh and sets execution_context on them.
  2. Also adds DirectivesExtension when schema.directives is non-empty — an extension that is completely absent from schema.extensions.

Because process_extensions bypasses the per-request runner, any extension added only inside get_extensions() (like DirectivesExtension) will never have its _process_result called, and class-based extensions receive a freshly constructed instance rather than the one already wired to the current ExecutionContext.

The structurally correct fix is to thread the SchemaExtensionsRunner that _subscribe already has through to send_next, so that _process_result is called on the same, already-configured instances that the rest of the operation lifecycle uses.


next_payload: NextMessagePayload = {"data": execution_result.data}

if execution_result.errors:
Expand Down
4 changes: 4 additions & 0 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
StartMessage,
StopMessage,
)
from strawberry.subscriptions.utils import process_extensions
from strawberry.types.execution import ExecutionResult, PreExecutionError
from strawberry.types.unset import UnsetType

Expand Down Expand Up @@ -216,6 +217,9 @@ async def cleanup(self) -> None:
async def send_data_message(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
extensions = getattr(self.schema, "extensions", [])
process_extensions(execution_result, extensions)

data_message: DataMessage = {
"type": "data",
"id": operation_id,
Expand Down
16 changes: 16 additions & 0 deletions strawberry/subscriptions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Any

from strawberry.types import ExecutionResult


def process_extensions(
execution_result: ExecutionResult, extensions: list[Any]
) -> None:
"""Run the execution result through active schema extensions."""
for ext in extensions:
extension_instance = (
ext(execution_context=None) if isinstance(ext, type) else ext
)

if hasattr(extension_instance, "_process_result"):
extension_instance._process_result(execution_result)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Class-type extensions with custom __init__ will crash at runtime

When an extension is provided as a class (rather than an instance), the code calls ext(execution_context=None). This silently fails for any extension that overrides __init__ without an execution_context parameter. A concrete example is MaskErrors itself:

class MaskErrors(SchemaExtension):
    def __init__(self, should_mask_error=..., error_message=...) -> None: ...

Calling MaskErrors(execution_context=None) raises:

TypeError: __init__() got an unexpected keyword argument 'execution_context'

The end-to-end test avoids this by always passing a pre-constructed instance (extensions=[MaskErrors(...)]), so the bug path (isinstance(ext, type) → True) is never exercised by the current tests. Any user who passes extensions=[MaskErrors] (i.e. the class, without parentheses) will get an unhandled exception on the first subscription event.

The proper fix is to skip class-type extensions entirely here and only process instance-type extensions — or, better, reuse the already-instantiated extensions from the SchemaExtensionsRunner for the current request.

66 changes: 66 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,3 +1216,69 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


@patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: Mock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]

# Error intercepted and extension called
Comment on lines +1220 to +1242
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test validates call count but not actual masking behaviour

The test patches _process_result with a no-op mock and asserts it was called once. This confirms the hook fires, but does not verify that error masking actually works end-to-end (e.g., that the error message is replaced with "Unexpected error." and that original exception details are not leaked to the client).

Consider adding a complementary integration test that uses a real MaskErrors extension and asserts the response contains the masked message rather than the raw exception text. This would catch regressions like the configuration-loss issue described in the handler comment.

Also note there is a missing blank line before the @patch.object decorator (PEP 8 E302).

Suggested change
@patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: Mock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)
next_message: NextMessage = await ws.receive_json()
assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]
# Error intercepted and extension called
@patch.object(MyExtension, "_process_result", create=True)

needs two blank lines after the previous test function body.

mock.assert_called_once()


async def test_subscription_error_masking_end_to_end(
http_client_class: type[HttpClient],
):
"""Test that the real MaskErrors extension successfully masks the payload text."""
import strawberry
from strawberry.extensions import MaskErrors
from tests.views.schema import Query, Subscription

# Create a custom schema with the real MaskErrors extension attached
custom_schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[MaskErrors(error_message="Unexpected error.")],
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "Super secret database error") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]

assert next_message["payload"]["errors"][0]["message"] == "Unexpected error."
74 changes: 74 additions & 0 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,77 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


@mock.patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: mock.MagicMock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_legacy_message(
{
"type": "start",
"id": "demo",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "demo"
assert "errors" in data_message["payload"]

# Error intercepted and extension called
mock.assert_called_once()

await ws.send_legacy_message({"type": "stop", "id": "demo"})
complete_message = await ws.receive_json()
assert complete_message["type"] == "complete"


async def test_subscription_error_masking_end_to_end(
http_client_class: type[HttpClient],
):
"""Test that the real MaskErrors extension successfully masks the legacy payload text."""
import strawberry
from strawberry.extensions import MaskErrors
from tests.views.schema import Query, Subscription

# Create a custom schema with the real MaskErrors extension attached
custom_schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[MaskErrors(error_message="Unexpected error.")],
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
) as ws:
await ws.send_legacy_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_legacy_message(
{
"type": "start",
"id": "demo",
"payload": {
"query": 'subscription { exception(message: "Super secret database error") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "demo"
assert "errors" in data_message["payload"]

assert data_message["payload"]["errors"][0]["message"] == "Unexpected error."

# Clean up the socket
await ws.send_legacy_message({"type": "stop", "id": "demo"})
await ws.receive_json()
Loading