Skip to content
Open
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: minor

Adds a new `on_subscription_result` hook to `SchemaExtension` that allows extensions to interact with and mutate the stream of events yielded by GraphQL subscriptions.

Previously, extensions were only triggered during the initial setup phase of a subscription, meaning transport layers (like WebSockets) bypassed them during the actual data streaming phase. This new hook solves this by executing right before each result is yielded to the client.

As part of this architectural update, the built-in `MaskErrors` extension has been updated to use this new hook, ensuring that sensitive exceptions are now correctly masked during WebSocket subscriptions.
10 changes: 9 additions & 1 deletion strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from graphql import GraphQLResolveInfo

from strawberry.types import ExecutionContext
from strawberry.types import ExecutionContext, ExecutionResult


class LifecycleStep(Enum):
Expand Down Expand Up @@ -51,6 +51,14 @@ def on_execute( # type: ignore
"""Called before and after the execution step."""
yield None

def on_subscription_result(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should follow the same patterns as the others here, using yield None, which would allow extensions to do something before and after

Wdyt?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi.
I agree. I think it makes sense to follow the patterns as the other methods, and this change allows for more flexibility for the extensions.
I have just made a new commit and pushed the updates to implement this. Thank you for your review.

self, result: ExecutionResult
) -> None | AwaitableOrValue[None]:
"""Called exactly once for every event/result yielded by a GraphQL subscription.

Extensions can mutate the `result` object directly (e.g., masking errors).
"""
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type for on_subscription_result is currently None | AwaitableOrValue[None], but AwaitableOrValue[None] already includes None. Consider simplifying this to AwaitableOrValue[None] to avoid redundant typing and match the pattern used by other extension methods like resolve/get_results.

Copilot uses AI. Check for mistakes.

def resolve(
self,
_next: Callable,
Expand Down
4 changes: 4 additions & 0 deletions strawberry/extensions/mask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ def on_operation(self) -> Iterator[None]:
self._process_result(result)
elif result:
self._process_result(result.initial_result)

def on_subscription_result(self, result: StrawberryExecutionResult) -> None:
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaskErrors.on_subscription_result narrows the parameter type to StrawberryExecutionResult, but the base SchemaExtension.on_subscription_result is typed as taking ExecutionResult. This narrower override can trigger type-checking issues (it’s not substitutable if the runner passes a base ExecutionResult). Prefer matching the base signature (ExecutionResult) here.

Suggested change
def on_subscription_result(self, result: StrawberryExecutionResult) -> None:
def on_subscription_result(self, result: GraphQLExecutionResult) -> None:

Copilot uses AI. Check for mistakes.
"""Mask errors on streaming subscription results."""
self._process_result(result)
10 changes: 9 additions & 1 deletion strawberry/extensions/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from strawberry.utils.await_maybe import await_maybe

if TYPE_CHECKING:
from strawberry.types import ExecutionContext
from strawberry.types import ExecutionContext, ExecutionResult

from . import SchemaExtension

Expand Down Expand Up @@ -40,6 +40,14 @@ def parsing(self) -> ParsingContextManager:
def executing(self) -> ExecutingContextManager:
return ExecutingContextManager(self.extensions)

async def on_subscription_result(self, result: ExecutionResult) -> None:
"""Run the subscription result hook across all active extensions."""
for extension in self.extensions:
# Check if the extension implemented the new hook
hook = getattr(extension, "on_subscription_result", None)
if hook:
await await_maybe(hook(result))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 getattr guard is always truthy

The comment says "Check if the extension implemented the new hook", but since SchemaExtension now defines on_subscription_result with a default no-op body, every extension will always carry this attribute. The if hook: branch is therefore always taken, making the check misleading without being harmful (the base-class call simply returns None which await_maybe handles fine).

Consider simplifying to remove the dead guard:

Suggested change
async def on_subscription_result(self, result: ExecutionResult) -> None:
"""Run the subscription result hook across all active extensions."""
for extension in self.extensions:
# Check if the extension implemented the new hook
hook = getattr(extension, "on_subscription_result", None)
if hook:
await await_maybe(hook(result))
async def on_subscription_result(self, result: ExecutionResult) -> None:
"""Run the subscription result hook across all active extensions."""
for extension in self.extensions:
await await_maybe(extension.on_subscription_result(result))


def get_extensions_results_sync(self) -> dict[str, Any]:
data: dict[str, Any] = {}
for extension in self.extensions:
Expand Down
13 changes: 11 additions & 2 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,20 +875,29 @@ async def _subscribe(
try:
async with aclosing(aiter_or_result):
async for result in aiter_or_result:
yield await self._handle_execution_result(
extension_result = await self._handle_execution_result(
execution_context,
result,
extensions_runner,
)

await extensions_runner.on_subscription_result(
extension_result
)

yield extension_result

# graphql-core doesn't handle exceptions raised while executing.
except Exception as exc: # noqa: BLE001
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
OriginalExecutionResult(
data=None, errors=[_coerce_error(exc)]
),
extensions_runner,
)
await extensions_runner.on_subscription_result(execution_result)
yield execution_result
# catch exceptions raised in `on_execute` hook.
except Exception as exc: # noqa: BLE001
origin_result = OriginalExecutionResult(
Expand Down
87 changes: 87 additions & 0 deletions tests/extensions/test_subscription_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from collections.abc import AsyncGenerator

import pytest

import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult


# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
def on_subscription_result(self, result: ExecutionResult) -> None:
if result.data and "count" in result.data:
# Mutate the outgoing data stream
result.data["count"] = f"Modified: {result.data['count']}"
Comment on lines +14 to +18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test that covers an async on_subscription_result hook implementation.

Current tests only cover the synchronous implementation. Since extensions_runner.on_subscription_result uses await_maybe, please add a test with an extension whose on_subscription_result is async def and awaits a side effect before mutating the result, to confirm async hooks are correctly awaited and safely integrate with async dependencies.

Suggested implementation:

import asyncio

import pytest

import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult


# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
    def on_subscription_result(self, result: ExecutionResult) -> None:
        if result.data and "count" in result.data:
            # Mutate the outgoing data stream
            result.data["count"] = f"Modified: {result.data['count']}"


class AsyncStreamModifierExtension(SchemaExtension):
    side_effect_ran = False

    async def _side_effect(self) -> None:
        # Simulate an async dependency / side effect
        await asyncio.sleep(0)
        AsyncStreamModifierExtension.side_effect_ran = True

    async def on_subscription_result(self, result: ExecutionResult) -> None:
        # This should be awaited by extensions_runner.on_subscription_result
        await self._side_effect()

        if result.data and "count" in result.data:
            # Mutate the outgoing data stream after the async side effect
            result.data["count"] = f"Modified: {result.data['count']}"


@strawberry.type
class Query:
    # Minimal Query type; field is unused but required by Strawberry
    example: str = "example"


@strawberry.type
class Subscription:
    @strawberry.subscription
    async def count(self) -> int:
        for i in range(3):
            yield i


@pytest.mark.asyncio
async def test_async_on_subscription_result_is_awaited() -> None:
    # Reset class-level flag before running the subscription
    AsyncStreamModifierExtension.side_effect_ran = False

    schema = strawberry.Schema(
        query=Query,
        subscription=Subscription,
        extensions=[AsyncStreamModifierExtension],
    )

    results = schema.subscribe("subscription { count }")

    # Consume first result from the async iterator
    first_result = await results.__anext__()

    assert first_result.errors is None
    assert first_result.data == {"count": "Modified: 0"}
    assert AsyncStreamModifierExtension.side_effect_ran is True

If this file already defines Query / Subscription types or other subscription tests, you may want to:

  1. Reuse existing Query/Subscription instead of defining new ones here, to keep the test suite DRY.
  2. Adjust the subscription field name or expected payload in the test to match any shared schema definitions (e.g., if an existing subscription already yields a count field).



# Create a basic schema with a subscription
@strawberry.type
class Query:
hello: str = "world"


@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> AsyncGenerator[int, None]:
yield 1
yield 2

@strawberry.subscription
async def dangerous_stream(self) -> AsyncGenerator[int, None]:
yield 1
raise ValueError("Secret database credentials leaked!")


# Tests
@pytest.mark.asyncio
async def test_extension_modifies_subscription_stream():
schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[StreamModifierExtension]
)

query = "subscription { count }"

sub_generator = await schema.subscribe(query)

# Get all yielded results
results = [result async for result in sub_generator]

assert len(results) == 2

# Assert that the extension successfully intercepted and modified the stream
assert results[0].data["count"] == "Modified: 1"
assert results[1].data["count"] == "Modified: 2"


@pytest.mark.asyncio
async def test_mask_errors_scrubs_subscription_exceptions():
# Initialize schema with the MaskErrors extension
schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[MaskErrors()]
)

query = "subscription { dangerousStream }"

sub_generator = await schema.subscribe(query)

# Get all yielded results
results = [result async for result in sub_generator]

# We expect 2 results: the successful yield, and the error
assert len(results) == 2

# Assert the first yield worked normally
assert results[0].data["dangerousStream"] == 1
assert not results[0].errors

# Assert the error was caught and MASKED
assert results[1].data is None
assert len(results[1].errors) == 1

# The crucial check: The raw exception message MUST NOT be exposed
error_message = results[1].errors[0].message
assert error_message == "Unexpected error."
assert "Secret database credentials" not in error_message
Loading