Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Release type: patch

Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client.

### Description
Fixes an issue where schema extensions (such as `MaskErrors`) were being bypassed when streaming data over WebSockets.

Previously, standard Queries and Mutations would pass their results through the extension pipeline, but Subscriptions would send raw `ExecutionResult` objects directly over the WebSocket. This caused internal/unmasked errors to leak to the client. This PR manually triggers `_process_result` on active extensions right before `send_next` and `send_data_message` dispatch the payload.

### Migration guide
No migration required.

### Types of Changes
- [ ] Core
- [x] Bugfix
- [ ] New feature
- [ ] Enhancement/optimization
- [ ] Documentation

### Checklist
- [x] My code follows the code style of this project.
- [ ] My change requires a change to the documentation.
- [x] I have read the CONTRIBUTING document.
- [x] I have added tests to cover my changes.
- [x] I have tested the changes and verified that they work and don't break anything (as well as I can manage).
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PongMessage,
SubscribeMessage,
)
from strawberry.subscriptions.utils import process_extensions
from strawberry.types import ExecutionResult
from strawberry.types.execution import PreExecutionError
from strawberry.types.graphql import OperationType
Expand Down Expand Up @@ -394,6 +395,9 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None:
)

async def send_next(self, execution_result: ExecutionResult) -> None:
extensions = getattr(self.handler.schema, "extensions", [])
process_extensions(execution_result, extensions)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 schema.extensions does not reflect the per-request extension runner

getattr(self.handler.schema, "extensions", []) returns the raw list that was originally passed to strawberry.Schema(extensions=[...]). However, schema.subscribe() builds its extensions_runner via schema._async_extensions, which:

  1. Instantiates class-based extensions fresh and sets execution_context on them.
  2. Also adds DirectivesExtension when schema.directives is non-empty — an extension that is completely absent from schema.extensions.

Because process_extensions bypasses the per-request runner, any extension added only inside get_extensions() (like DirectivesExtension) will never have its _process_result called, and class-based extensions receive a freshly constructed instance rather than the one already wired to the current ExecutionContext.

The structurally correct fix is to thread the SchemaExtensionsRunner that _subscribe already has through to send_next, so that _process_result is called on the same, already-configured instances that the rest of the operation lifecycle uses.


next_payload: NextMessagePayload = {"data": execution_result.data}

if execution_result.errors:
Expand Down
4 changes: 4 additions & 0 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
StartMessage,
StopMessage,
)
from strawberry.subscriptions.utils import process_extensions
from strawberry.types.execution import ExecutionResult, PreExecutionError
from strawberry.types.unset import UnsetType

Expand Down Expand Up @@ -216,6 +217,9 @@ async def cleanup(self) -> None:
async def send_data_message(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
extensions = getattr(self.schema, "extensions", [])
process_extensions(execution_result, extensions)

data_message: DataMessage = {
"type": "data",
"id": operation_id,
Expand Down
26 changes: 26 additions & 0 deletions strawberry/subscriptions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import inspect
from typing import Any

from strawberry.types import ExecutionResult


def process_extensions(
execution_result: ExecutionResult, extensions: list[Any]
) -> None:
"""Run the execution result through active schema extensions."""
for ext in extensions:
if isinstance(ext, type):
# Inspect the constructor to see if it requires execution_context
sig = inspect.signature(ext.__init__)
if "execution_context" in sig.parameters:
extension_instance = ext(execution_context=None)
else:
extension_instance = ext()
Comment on lines +18 to +23
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Broad except TypeError can swallow unrelated constructor failures

The outer try/except TypeError is intended only to detect extensions whose __init__ does not accept an execution_context keyword argument. However, it also silently catches any TypeError raised inside the constructor body (e.g. a type mismatch in dependency injection logic, a wrong internal call, etc.).

When that happens the code falls back to ext(), which may also raise — or worse, may succeed but silently produce a misconfigured instance. The real constructor error is never surfaced.

A tighter guard would inspect the signature before calling, rather than catching the exception after:

import inspect

sig = inspect.signature(ext.__init__)
if "execution_context" in sig.parameters:
    extension_instance = ext(execution_context=None)
else:
    extension_instance = ext()

This keeps the intent explicit and lets any other TypeError propagate normally.

Comment on lines +22 to +23
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Class-based extensions with constructor args silently drop configuration

When a class-based extension is instantiated here via ext() (the branch where execution_context is absent from its signature), any constructor parameters that the extension normally accepts are silently omitted.

A concrete case: suppose a user registers a custom extension as a class with a default-valued argument that controls behaviour:

class RateLimitExtension(SchemaExtension):
    def __init__(self, max_errors: int = 10) -> None:
        self.max_errors = max_errors

build_operation_extensions will call RateLimitExtension() and get an instance configured with max_errors=10, regardless of what the original entry in schema.extensions was. More critically, schema.extensions stores the class itself — not a configured instance — so there is no way to recover the user-intended configuration from the class alone.

The safest fix for the general case is to avoid re-instantiation entirely: run schema.get_extensions() once per operation (as the normal query/mutation path does) so that Strawberry's own instantiation logic is reused, and then call _process_result on those instances instead of creating a separate set. If per-operation instantiation must stay, a clarifying note that only zero-argument (or execution_context-only) class-based extensions are supported would prevent silent misconfiguration.


# Explicitly set this ONLY for newly constructed instances
extension_instance.execution_context = None
Comment on lines +18 to +25
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Extension instantiated on every subscription event

For class-based extensions, inspect.signature(ext.__init__) is called and a fresh instance is constructed on every invocation of process_extensions — i.e., for every individual subscription result. This means:

  1. inspect.signature introspection runs on each subscription event, which is non-trivial for high-frequency streams.
  2. Any class-based extension that accumulates mutable state (e.g., a counter, a timer, a deduplication set) will silently lose that state between events because a fresh object is handed each event.

The resolved extension instances (with proper execution_context) are already available inside schema._subscribe via the extensions_runner. A more robust and efficient fix would be to call _process_result on those already-configured instances rather than re-resolving and re-constructing from schema.extensions on every event.

If the current approach is kept, consider at minimum caching the per-schema extension list once (not per event) to avoid repeated signature inspection:

# Computed once per subscription start, not once per event
extension_instances = _build_extension_instances(schema.extensions)

else:
extension_instance = ext

if hasattr(extension_instance, "_process_result"):
extension_instance._process_result(execution_result)
113 changes: 113 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,3 +1216,116 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


@patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: Mock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]

# Error intercepted and extension called
Comment on lines +1220 to +1242
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test validates call count but not actual masking behaviour

The test patches _process_result with a no-op mock and asserts it was called once. This confirms the hook fires, but does not verify that error masking actually works end-to-end (e.g., that the error message is replaced with "Unexpected error." and that original exception details are not leaked to the client).

Consider adding a complementary integration test that uses a real MaskErrors extension and asserts the response contains the masked message rather than the raw exception text. This would catch regressions like the configuration-loss issue described in the handler comment.

Also note there is a missing blank line before the @patch.object decorator (PEP 8 E302).

Suggested change
@patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: Mock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)
next_message: NextMessage = await ws.receive_json()
assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]
# Error intercepted and extension called
@patch.object(MyExtension, "_process_result", create=True)

needs two blank lines after the previous test function body.

mock.assert_called_once()


async def test_subscription_error_masking_end_to_end(
http_client_class: type[HttpClient],
):
"""Test that the real MaskErrors extension successfully masks the payload text."""
import strawberry
from strawberry.extensions import MaskErrors
from tests.views.schema import Query, Subscription

# Create a custom schema with the real MaskErrors extension attached
custom_schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[MaskErrors(error_message="Unexpected error.")],
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "Super secret database error") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]

assert next_message["payload"]["errors"][0]["message"] == "Unexpected error."


async def test_subscription_masking_with_class_extension(
http_client_class: type[HttpClient],
):
"""Test that passing an extension as a class (not instance) successfully masks errors."""
import strawberry
from strawberry.extensions import SchemaExtension
from tests.views.schema import Query, Subscription

class CustomClassExtension(SchemaExtension):
def __init__(self, execution_context=None):
self.execution_context = execution_context # type: ignore

def _process_result(self, execution_result):
if execution_result.errors:
execution_result.errors[0].message = "Unexpected error."

# Create a custom schema passing the CLASS, not an instance
custom_schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[CustomClassExtension]
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_message(
{
"id": "sub_class",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "Secret database error") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub_class"
assert "errors" in next_message["payload"]

assert next_message["payload"]["errors"][0]["message"] == "Unexpected error."
125 changes: 125 additions & 0 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,128 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


@mock.patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: mock.MagicMock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_legacy_message(
{
"type": "start",
"id": "demo",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "demo"
assert "errors" in data_message["payload"]

# Error intercepted and extension called
mock.assert_called_once()

await ws.send_legacy_message({"type": "stop", "id": "demo"})
complete_message = await ws.receive_json()
assert complete_message["type"] == "complete"


async def test_subscription_error_masking_end_to_end(
http_client_class: type[HttpClient],
):
"""Test that the real MaskErrors extension successfully masks the legacy payload text."""
import strawberry
from strawberry.extensions import MaskErrors
from tests.views.schema import Query, Subscription

# Create a custom schema with the real MaskErrors extension attached
custom_schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[MaskErrors(error_message="Unexpected error.")],
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
) as ws:
await ws.send_legacy_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_legacy_message(
{
"type": "start",
"id": "demo",
"payload": {
"query": 'subscription { exception(message: "Super secret database error") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "demo"
assert "errors" in data_message["payload"]

assert data_message["payload"]["errors"][0]["message"] == "Unexpected error."

# Clean up the socket
await ws.send_legacy_message({"type": "stop", "id": "demo"})
await ws.receive_json()


async def test_subscription_masking_with_class_extension(
http_client_class: type[HttpClient],
):
"""Test that passing an extension as a class (not instance) successfully masks the legacy payload text."""
import strawberry
from strawberry.extensions import SchemaExtension
from tests.views.schema import Query, Subscription

class CustomClassExtension(SchemaExtension):
def __init__(self, execution_context=None):
self.execution_context = execution_context # type: ignore

def _process_result(self, execution_result):
if execution_result.errors:
execution_result.errors[0].message = "Unexpected error."

# Create a custom schema passing the CLASS, not an instance
custom_schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[CustomClassExtension]
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
) as ws:
await ws.send_legacy_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_legacy_message(
{
"type": "start",
"id": "sub_class",
"payload": {
"query": 'subscription { exception(message: "Secret database error") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "sub_class"
assert "errors" in data_message["payload"]

assert data_message["payload"]["errors"][0]["message"] == "Unexpected error."

# Clean up the socket
await ws.send_legacy_message({"type": "stop", "id": "sub_class"})
await ws.receive_json()
Loading