Skip to content
Open
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: minor

Adds a new `on_subscription_result` hook to `SchemaExtension` that allows extensions to interact with and mutate the stream of events yielded by GraphQL subscriptions.

Previously, extensions were only triggered during the initial setup phase of a subscription, meaning transport layers (like WebSockets) bypassed them during the actual data streaming phase. This new hook solves this by executing right before each result is yielded to the client.

As part of this architectural update, the built-in `MaskErrors` extension has been updated to use this new hook, ensuring that sensitive exceptions are now correctly masked during WebSocket subscriptions.
10 changes: 9 additions & 1 deletion strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from graphql import GraphQLResolveInfo

from strawberry.types import ExecutionContext
from strawberry.types import ExecutionContext, ExecutionResult


class LifecycleStep(Enum):
Expand Down Expand Up @@ -51,6 +51,14 @@ def on_execute( # type: ignore
"""Called before and after the execution step."""
yield None

def on_subscription_result(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should follow the same patterns as the others here, using yield None, which would allow extensions to do something before and after

Wdyt?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi.
I agree. I think it makes sense to follow the patterns as the other methods, and this change allows for more flexibility for the extensions.
I have just made a new commit and pushed the updates to implement this. Thank you for your review.

self, result: ExecutionResult
) -> None | AwaitableOrValue[None]:
"""Called exactly once for every event/result yielded by a GraphQL subscription.

Extensions can mutate the `result` object directly (e.g., masking errors).
"""
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type for on_subscription_result is currently None | AwaitableOrValue[None], but AwaitableOrValue[None] already includes None. Consider simplifying this to AwaitableOrValue[None] to avoid redundant typing and match the pattern used by other extension methods like resolve/get_results.

Copilot uses AI. Check for mistakes.

def resolve(
self,
_next: Callable,
Expand Down
4 changes: 4 additions & 0 deletions strawberry/extensions/mask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ def on_operation(self) -> Iterator[None]:
self._process_result(result)
elif result:
self._process_result(result.initial_result)

def on_subscription_result(self, result: StrawberryExecutionResult) -> None:
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaskErrors.on_subscription_result narrows the parameter type to StrawberryExecutionResult, but the base SchemaExtension.on_subscription_result is typed as taking ExecutionResult. This narrower override can trigger type-checking issues (it’s not substitutable if the runner passes a base ExecutionResult). Prefer matching the base signature (ExecutionResult) here.

Suggested change
def on_subscription_result(self, result: StrawberryExecutionResult) -> None:
def on_subscription_result(self, result: GraphQLExecutionResult) -> None:

Copilot uses AI. Check for mistakes.
"""Mask errors on streaming subscription results."""
self._process_result(result)
7 changes: 6 additions & 1 deletion strawberry/extensions/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from strawberry.utils.await_maybe import await_maybe

if TYPE_CHECKING:
from strawberry.types import ExecutionContext
from strawberry.types import ExecutionContext, ExecutionResult

from . import SchemaExtension

Expand Down Expand Up @@ -40,6 +40,11 @@ def parsing(self) -> ParsingContextManager:
def executing(self) -> ExecutingContextManager:
return ExecutingContextManager(self.extensions)

async def on_subscription_result(self, result: ExecutionResult) -> None:
"""Run the subscription result hook across all active extensions."""
for extension in self.extensions:
await await_maybe(extension.on_subscription_result(result))

def get_extensions_results_sync(self) -> dict[str, Any]:
data: dict[str, Any] = {}
for extension in self.extensions:
Expand Down
25 changes: 20 additions & 5 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,9 +833,11 @@ async def _subscribe(
initial_error.extensions = (
await extensions_runner.get_extensions_results(execution_context)
)
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context, initial_error, extensions_runner
)
await extensions_runner.on_subscription_result(execution_result)
yield execution_result
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When _parse_and_validate_async returns an initial_error, _subscribe yields it but then continues into the execution path. For parse errors this can hit the assert execution_context.graphql_document is not None and yield an additional (unrelated) error; for validation errors it can also lead to duplicate error results from subscribe(...). Consider returning immediately after yielding the initial error (or wrapping the rest of the function in an else) so subscription setup stops cleanly on pre-execution failures.

Suggested change
yield execution_result
yield execution_result
return

Copilot uses AI. Check for mistakes.
try:
async with extensions_runner.executing():
assert execution_context.graphql_document is not None
Expand Down Expand Up @@ -866,39 +868,52 @@ async def _subscribe(

# Handle pre-execution errors.
if isinstance(aiter_or_result, OriginalExecutionResult):
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
PreExecutionError(data=None, errors=aiter_or_result.errors),
extensions_runner,
)
await extensions_runner.on_subscription_result(execution_result)
yield execution_result
else:
try:
async with aclosing(aiter_or_result):
async for result in aiter_or_result:
yield await self._handle_execution_result(
extension_result = await self._handle_execution_result(
execution_context,
result,
extensions_runner,
)

await extensions_runner.on_subscription_result(
extension_result
)

yield extension_result

# graphql-core doesn't handle exceptions raised while executing.
except Exception as exc: # noqa: BLE001
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
OriginalExecutionResult(
data=None, errors=[_coerce_error(exc)]
),
extensions_runner,
)
await extensions_runner.on_subscription_result(execution_result)
yield execution_result
# catch exceptions raised in `on_execute` hook.
except Exception as exc: # noqa: BLE001
origin_result = OriginalExecutionResult(
data=None, errors=[_coerce_error(exc)]
)
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
origin_result,
extensions_runner,
)
await extensions_runner.on_subscription_result(execution_result)
yield execution_result

async def subscribe(
self,
Expand Down
157 changes: 157 additions & 0 deletions tests/extensions/test_subscription_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import asyncio
from collections.abc import AsyncGenerator

import pytest

import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult


# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
def on_subscription_result(self, result: ExecutionResult) -> None:
if result.data and "count" in result.data:
# Mutate the outgoing data stream
result.data["count"] = f"Modified: {result.data['count']}"
Comment on lines +14 to +18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test that covers an async on_subscription_result hook implementation.

Current tests only cover the synchronous implementation. Since extensions_runner.on_subscription_result uses await_maybe, please add a test with an extension whose on_subscription_result is async def and awaits a side effect before mutating the result, to confirm async hooks are correctly awaited and safely integrate with async dependencies.

Suggested implementation:

import asyncio

import pytest

import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult


# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
    def on_subscription_result(self, result: ExecutionResult) -> None:
        if result.data and "count" in result.data:
            # Mutate the outgoing data stream
            result.data["count"] = f"Modified: {result.data['count']}"


class AsyncStreamModifierExtension(SchemaExtension):
    side_effect_ran = False

    async def _side_effect(self) -> None:
        # Simulate an async dependency / side effect
        await asyncio.sleep(0)
        AsyncStreamModifierExtension.side_effect_ran = True

    async def on_subscription_result(self, result: ExecutionResult) -> None:
        # This should be awaited by extensions_runner.on_subscription_result
        await self._side_effect()

        if result.data and "count" in result.data:
            # Mutate the outgoing data stream after the async side effect
            result.data["count"] = f"Modified: {result.data['count']}"


@strawberry.type
class Query:
    # Minimal Query type; field is unused but required by Strawberry
    example: str = "example"


@strawberry.type
class Subscription:
    @strawberry.subscription
    async def count(self) -> int:
        for i in range(3):
            yield i


@pytest.mark.asyncio
async def test_async_on_subscription_result_is_awaited() -> None:
    # Reset class-level flag before running the subscription
    AsyncStreamModifierExtension.side_effect_ran = False

    schema = strawberry.Schema(
        query=Query,
        subscription=Subscription,
        extensions=[AsyncStreamModifierExtension],
    )

    results = schema.subscribe("subscription { count }")

    # Consume first result from the async iterator
    first_result = await results.__anext__()

    assert first_result.errors is None
    assert first_result.data == {"count": "Modified: 0"}
    assert AsyncStreamModifierExtension.side_effect_ran is True

If this file already defines Query / Subscription types or other subscription tests, you may want to:

  1. Reuse existing Query/Subscription instead of defining new ones here, to keep the test suite DRY.
  2. Adjust the subscription field name or expected payload in the test to match any shared schema definitions (e.g., if an existing subscription already yields a count field).



# Create a basic schema with a subscription
@strawberry.type
class Query:
hello: str = "world"


@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> AsyncGenerator[int, None]:
yield 1
yield 2

@strawberry.subscription
async def dangerous_stream(self) -> AsyncGenerator[int, None]:
yield 1
raise ValueError("Secret database credentials leaked!")


# Tests
@pytest.mark.asyncio
async def test_extension_modifies_subscription_stream():
schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[StreamModifierExtension]
)

query = "subscription { count }"

sub_generator = await schema.subscribe(query)

# Get all yielded results
results = [result async for result in sub_generator]

assert len(results) == 2

# Assert that the extension successfully intercepted and modified the stream
assert results[0].data["count"] == "Modified: 1"
assert results[1].data["count"] == "Modified: 2"


@pytest.mark.asyncio
async def test_mask_errors_scrubs_subscription_exceptions():
# Initialize schema with the MaskErrors extension
schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[MaskErrors()]
)

query = "subscription { dangerousStream }"

sub_generator = await schema.subscribe(query)

# Get all yielded results
results = [result async for result in sub_generator]

# We expect 2 results: the successful yield, and the error
assert len(results) == 2

# Assert the first yield worked normally
assert results[0].data["dangerousStream"] == 1
assert not results[0].errors

# Assert the error was caught and MASKED
assert results[1].data is None
assert len(results[1].errors) == 1

# The crucial check: The raw exception message MUST NOT be exposed
error_message = results[1].errors[0].message
assert error_message == "Unexpected error."
assert "Secret database credentials" not in error_message


class AsyncStreamModifierExtension(SchemaExtension):
side_effect_ran = False

async def _side_effect(self) -> None:
# Simulate an async dependency / side effect
await asyncio.sleep(0)
AsyncStreamModifierExtension.side_effect_ran = True

async def on_subscription_result(self, result: ExecutionResult) -> None:
# This should be awaited by extensions_runner.on_subscription_result
await self._side_effect()

if result.data and "count" in result.data:
# Mutate the outgoing data stream after the async side effect
result.data["count"] = f"Modified: {result.data['count']}"


@pytest.mark.asyncio
async def test_async_on_subscription_result_is_awaited() -> None:
# Reset class-level flag before running the subscription
AsyncStreamModifierExtension.side_effect_ran = False

schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[AsyncStreamModifierExtension()],
)

query = "subscription { count }"
sub_generator = await schema.subscribe(query)

# Consume first result from the async iterator
first_result = await sub_generator.__anext__()

Comment on lines +125 to +130
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test consumes only the first item from the subscription iterator and never closes it. Other subscription tests in the repo wrap schema.subscribe(...) in contextlib.aclosing(...) to ensure the async generator is closed even when not fully exhausted. Consider using aclosing (or explicitly calling await sub_generator.aclose()) to avoid leaking an open subscription iterator.

Copilot uses AI. Check for mistakes.
assert first_result.errors is None
assert first_result.data == {"count": "Modified: 1"}
assert AsyncStreamModifierExtension.side_effect_ran is True


@pytest.mark.asyncio
async def test_mask_errors_scrubs_pre_execution_errors():
# Initialize schema with MaskErrors
schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[MaskErrors()]
)

# Querying a field that doesn't exist triggers Validation errors BEFORE execution
query = "subscription { fieldThatDoesNotExist }"

# Run the subscription
sub_generator = await schema.subscribe(query)

# Exhaust the generator
results = [result async for result in sub_generator]

# Strawberry yields 1 or more errors depending on the validation layers it hits.
# We must ensure that EVERY result yielded was successfully intercepted and masked.
assert len(results) > 0

for result in results:
assert result.data is None
assert len(result.errors) == 1

# The crucial check: MaskErrors successfully intercepted and masked it!
error_message = result.errors[0].message
assert error_message == "Unexpected error."
assert "fieldThatDoesNotExist" not in error_message
Comment on lines +157 to +162
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test asserts len(result.errors) == 1, but earlier comment notes Strawberry may yield 1+ validation errors. Since MaskErrors masks messages without changing the number of errors, this assertion will be flaky depending on validation output. Consider asserting that result.errors is non-empty and that all error messages are masked (iterate over result.errors) rather than enforcing a single error.

Suggested change
assert len(result.errors) == 1
# The crucial check: MaskErrors successfully intercepted and masked it!
error_message = result.errors[0].message
assert error_message == "Unexpected error."
assert "fieldThatDoesNotExist" not in error_message
assert result.errors
# The crucial check: MaskErrors successfully intercepted and masked every error!
for error in result.errors:
error_message = error.message
assert error_message == "Unexpected error."
assert "fieldThatDoesNotExist" not in error_message

Copilot uses AI. Check for mistakes.
Loading