Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Release type: patch

Fixes an issue where schema extensions (like `MaskErrors`) were bypassed during WebSocket subscriptions. The extensions' `_process_result` hooks are now properly triggered for each yielded result in both `graphql-transport-ws` and `graphql-ws` protocols, ensuring errors are correctly formatted before being sent to the client.

### Description
Fixes an issue where schema extensions (such as `MaskErrors`) were being bypassed when streaming data over WebSockets.

Previously, standard Queries and Mutations would pass their results through the extension pipeline, but Subscriptions would send raw `ExecutionResult` objects directly over the WebSocket. This caused internal/unmasked errors to leak to the client. This PR manually triggers `_process_result` on active extensions right before `send_next` and `send_data_message` dispatch the payload.

### Migration guide
No migration required.

### Types of Changes
- [ ] Core
- [x] Bugfix
- [ ] New feature
- [ ] Enhancement/optimization
- [ ] Documentation

### Checklist
- [x] My code follows the code style of this project.
- [ ] My change requires a change to the documentation.
- [x] I have read the CONTRIBUTING document.
- [x] I have added tests to cover my changes.
- [x] I have tested the changes and verified that they work and don't break anything (as well as I can manage).
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PongMessage,
SubscribeMessage,
)
from strawberry.subscriptions.utils import build_operation_extensions
from strawberry.types import ExecutionResult
from strawberry.types.execution import PreExecutionError
from strawberry.types.graphql import OperationType
Expand Down Expand Up @@ -347,6 +348,7 @@ class Operation(Generic[Context, RootValue]):
"completed",
"handler",
"id",
"operation_extensions",
"operation_name",
"operation_type",
"query",
Expand All @@ -371,6 +373,8 @@ def __init__(
self.operation_name = operation_name
self.completed = False
self.task: asyncio.Task | None = None
schema_extensions = getattr(self.handler.schema, "extensions", [])
self.operation_extensions = build_operation_extensions(schema_extensions)

async def send_operation_message(self, message: Message) -> None:
if self.completed:
Expand All @@ -394,6 +398,10 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None:
)

async def send_next(self, execution_result: ExecutionResult) -> None:
for ext in self.operation_extensions:
if hasattr(ext, "_process_result"):
ext._process_result(execution_result)

next_payload: NextMessagePayload = {"data": execution_result.data}

if execution_result.errors:
Expand Down
12 changes: 12 additions & 0 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
StartMessage,
StopMessage,
)
from strawberry.subscriptions.utils import build_operation_extensions
from strawberry.types.execution import ExecutionResult, PreExecutionError
from strawberry.types.unset import UnsetType

Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
self.keep_alive_task: asyncio.Task | None = None
self.subscriptions: dict[str, AsyncGenerator] = {}
self.tasks: dict[str, asyncio.Task] = {}
self.operation_extensions: dict[str, list[Any]] = {}

async def handle(self) -> None:
try:
Expand Down Expand Up @@ -139,6 +141,10 @@ async def handle_start(self, message: StartMessage) -> None:
operation_id, query, operation_name, variables
)
self.tasks[operation_id] = asyncio.create_task(result_handler)
schema_extensions = getattr(self.schema, "extensions", [])
self.operation_extensions[operation_id] = build_operation_extensions(
schema_extensions
)

async def handle_stop(self, message: StopMessage) -> None:
operation_id = message["id"]
Expand Down Expand Up @@ -209,13 +215,19 @@ async def cleanup_operation(self, operation_id: str) -> None:
await self.tasks[operation_id]
del self.tasks[operation_id]

self.operation_extensions.pop(operation_id, None)

async def cleanup(self) -> None:
for operation_id in list(self.tasks.keys()):
await self.cleanup_operation(operation_id)

async def send_data_message(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
for ext in self.operation_extensions.get(operation_id, []):
if hasattr(ext, "_process_result"):
ext._process_result(execution_result)

data_message: DataMessage = {
"type": "data",
"id": operation_id,
Expand Down
29 changes: 29 additions & 0 deletions strawberry/subscriptions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import inspect
from typing import Any


def build_operation_extensions(extensions: list[Any]) -> list[Any]:
"""Build a fresh, isolated set of extensions for a single operation.

Only extensions that are either pre-constructed instances or
zero-argument classes (optionally accepting execution_context) are
fully supported. Class-based extensions requiring custom constructor
arguments should be passed as pre-constructed instances.

Full extension lifecycle support (on_operation, on_execute, etc.)
for subscriptions should be tracked as a separate issue.
"""
instances = []
for ext in extensions:
if isinstance(ext, type):
sig = inspect.signature(ext.__init__)
if "execution_context" in sig.parameters:
extension_instance = ext(execution_context=None)
else:
extension_instance = ext()
Comment on lines +18 to +23
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Broad except TypeError can swallow unrelated constructor failures

The outer try/except TypeError is intended only to detect extensions whose __init__ does not accept an execution_context keyword argument. However, it also silently catches any TypeError raised inside the constructor body (e.g. a type mismatch in dependency injection logic, a wrong internal call, etc.).

When that happens the code falls back to ext(), which may also raise — or worse, may succeed but silently produce a misconfigured instance. The real constructor error is never surfaced.

A tighter guard would inspect the signature before calling, rather than catching the exception after:

import inspect

sig = inspect.signature(ext.__init__)
if "execution_context" in sig.parameters:
    extension_instance = ext(execution_context=None)
else:
    extension_instance = ext()

This keeps the intent explicit and lets any other TypeError propagate normally.

Comment on lines +22 to +23
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Class-based extensions with constructor args silently drop configuration

When a class-based extension is instantiated here via ext() (the branch where execution_context is absent from its signature), any constructor parameters that the extension normally accepts are silently omitted.

A concrete case: suppose a user registers a custom extension as a class with a default-valued argument that controls behaviour:

class RateLimitExtension(SchemaExtension):
    def __init__(self, max_errors: int = 10) -> None:
        self.max_errors = max_errors

build_operation_extensions will call RateLimitExtension() and get an instance configured with max_errors=10, regardless of what the original entry in schema.extensions was. More critically, schema.extensions stores the class itself — not a configured instance — so there is no way to recover the user-intended configuration from the class alone.

The safest fix for the general case is to avoid re-instantiation entirely: run schema.get_extensions() once per operation (as the normal query/mutation path does) so that Strawberry's own instantiation logic is reused, and then call _process_result on those instances instead of creating a separate set. If per-operation instantiation must stay, a clarifying note that only zero-argument (or execution_context-only) class-based extensions are supported would prevent silent misconfiguration.


extension_instance.execution_context = None
Comment on lines +18 to +25
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Extension instantiated on every subscription event

For class-based extensions, inspect.signature(ext.__init__) is called and a fresh instance is constructed on every invocation of process_extensions — i.e., for every individual subscription result. This means:

  1. inspect.signature introspection runs on each subscription event, which is non-trivial for high-frequency streams.
  2. Any class-based extension that accumulates mutable state (e.g., a counter, a timer, a deduplication set) will silently lose that state between events because a fresh object is handed each event.

The resolved extension instances (with proper execution_context) are already available inside schema._subscribe via the extensions_runner. A more robust and efficient fix would be to call _process_result on those already-configured instances rather than re-resolving and re-constructing from schema.extensions on every event.

If the current approach is kept, consider at minimum caching the per-schema extension list once (not per event) to avoid repeated signature inspection:

# Computed once per subscription start, not once per event
extension_instances = _build_extension_instances(schema.extensions)

instances.append(extension_instance)
else:
instances.append(ext)
Comment on lines +27 to +28
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Instance-based extensions get _process_result called from two places

When an extension is registered as a pre-constructed instance (e.g. extensions=[MaskErrors(error_message="Custom")]), build_operation_extensions returns the same object that schema._async_extensions also returns. That same object is passed to SchemaExtensionsRunner inside schema.subscribe(), whose operation() context manager's post-yield code calls _process_result again (via MaskErrors.on_operation) once the subscription generator is exhausted.

The result:

  • For every subscription event except the last, only the new code path calls _process_result — correct.
  • For the last event, _process_result is called once here (before the result is sent — correct) and then a second time by on_operation's post-yield after the result has already been transmitted — harmless for MaskErrors (idempotent masking), but incorrect for any extension that performs a non-idempotent transformation or side-effect (e.g. incrementing a counter, writing to a log with the original message).

Because MaskErrors is stateless this passes today's tests, but it creates a latent correctness trap for custom extensions. The cleanest resolution is to reuse the SchemaExtensionsRunner that schema.subscribe() already holds, so there is a single, authoritative call site for _process_result.

return instances
113 changes: 113 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,3 +1216,116 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


@patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: Mock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]

# Error intercepted and extension called
Comment on lines +1220 to +1242
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test validates call count but not actual masking behaviour

The test patches _process_result with a no-op mock and asserts it was called once. This confirms the hook fires, but does not verify that error masking actually works end-to-end (e.g., that the error message is replaced with "Unexpected error." and that original exception details are not leaked to the client).

Consider adding a complementary integration test that uses a real MaskErrors extension and asserts the response contains the masked message rather than the raw exception text. This would catch regressions like the configuration-loss issue described in the handler comment.

Also note there is a missing blank line before the @patch.object decorator (PEP 8 E302).

Suggested change
@patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: Mock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)
next_message: NextMessage = await ws.receive_json()
assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]
# Error intercepted and extension called
@patch.object(MyExtension, "_process_result", create=True)

needs two blank lines after the previous test function body.

mock.assert_called_once()


async def test_subscription_error_masking_end_to_end(
http_client_class: type[HttpClient],
):
"""Test that the real MaskErrors extension successfully masks the payload text."""
import strawberry
from strawberry.extensions import MaskErrors
from tests.views.schema import Query, Subscription

# Create a custom schema with the real MaskErrors extension attached
custom_schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[MaskErrors(error_message="Unexpected error.")],
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "Super secret database error") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub1"
assert "errors" in next_message["payload"]

assert next_message["payload"]["errors"][0]["message"] == "Unexpected error."


async def test_subscription_masking_with_class_extension(
http_client_class: type[HttpClient],
):
"""Test that passing an extension as a class (not instance) successfully masks errors."""
import strawberry
from strawberry.extensions import SchemaExtension
from tests.views.schema import Query, Subscription

class CustomClassExtension(SchemaExtension):
def __init__(self, execution_context=None):
self.execution_context = execution_context # type: ignore

def _process_result(self, execution_result):
if execution_result.errors:
execution_result.errors[0].message = "Unexpected error."

# Create a custom schema passing the CLASS, not an instance
custom_schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[CustomClassExtension]
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_message(
{
"id": "sub_class",
"type": "subscribe",
"payload": {
"query": 'subscription { exception(message: "Secret database error") }',
},
}
)

next_message: NextMessage = await ws.receive_json()

assert next_message["type"] == "next"
assert next_message["id"] == "sub_class"
assert "errors" in next_message["payload"]

assert next_message["payload"]["errors"][0]["message"] == "Unexpected error."
125 changes: 125 additions & 0 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,128 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


@mock.patch.object(MyExtension, "_process_result", create=True)
async def test_subscription_errors_trigger_extension_process_result(
mock: mock.MagicMock, ws: WebSocketClient
):
"""Test that schema extensions are called to process results when a subscription yields an error."""
await ws.send_legacy_message(
{
"type": "start",
"id": "demo",
"payload": {
"query": 'subscription { exception(message: "TEST EXC") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "demo"
assert "errors" in data_message["payload"]

# Error intercepted and extension called
mock.assert_called_once()

await ws.send_legacy_message({"type": "stop", "id": "demo"})
complete_message = await ws.receive_json()
assert complete_message["type"] == "complete"


async def test_subscription_error_masking_end_to_end(
http_client_class: type[HttpClient],
):
"""Test that the real MaskErrors extension successfully masks the legacy payload text."""
import strawberry
from strawberry.extensions import MaskErrors
from tests.views.schema import Query, Subscription

# Create a custom schema with the real MaskErrors extension attached
custom_schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[MaskErrors(error_message="Unexpected error.")],
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
) as ws:
await ws.send_legacy_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_legacy_message(
{
"type": "start",
"id": "demo",
"payload": {
"query": 'subscription { exception(message: "Super secret database error") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "demo"
assert "errors" in data_message["payload"]

assert data_message["payload"]["errors"][0]["message"] == "Unexpected error."

# Clean up the socket
await ws.send_legacy_message({"type": "stop", "id": "demo"})
await ws.receive_json()


async def test_subscription_masking_with_class_extension(
http_client_class: type[HttpClient],
):
"""Test that passing an extension as a class (not instance) successfully masks the legacy payload text."""
import strawberry
from strawberry.extensions import SchemaExtension
from tests.views.schema import Query, Subscription

class CustomClassExtension(SchemaExtension):
def __init__(self, execution_context=None):
self.execution_context = execution_context # type: ignore

def _process_result(self, execution_result):
if execution_result.errors:
execution_result.errors[0].message = "Unexpected error."

# Create a custom schema passing the CLASS, not an instance
custom_schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[CustomClassExtension]
)
test_client = http_client_class(custom_schema)

async with test_client.ws_connect(
"/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
) as ws:
await ws.send_legacy_message({"type": "connection_init"})
await ws.receive_json()

await ws.send_legacy_message(
{
"type": "start",
"id": "sub_class",
"payload": {
"query": 'subscription { exception(message: "Secret database error") }',
},
}
)

data_message: DataMessage = await ws.receive_json()

assert data_message["type"] == "data"
assert data_message["id"] == "sub_class"
assert "errors" in data_message["payload"]

assert data_message["payload"]["errors"][0]["message"] == "Unexpected error."

# Clean up the socket
await ws.send_legacy_message({"type": "stop", "id": "sub_class"})
await ws.receive_json()
Loading