Skip to content
Open
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: minor

Adds a new `on_subscription_result` hook to `SchemaExtension` that allows extensions to interact with and mutate the stream of events yielded by GraphQL subscriptions.

Previously, extensions were only triggered during the initial setup phase of a subscription, meaning transport layers (like WebSockets) bypassed them during the actual data streaming phase. This new hook solves this by executing right before each result is yielded to the client.

As part of this architectural update, the built-in `MaskErrors` extension has been updated to use this new hook, ensuring that sensitive exceptions are now correctly masked during WebSocket subscriptions.
16 changes: 15 additions & 1 deletion strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any

from strawberry.types import ExecutionResult
from strawberry.utils.await_maybe import AsyncIteratorOrIterator, AwaitableOrValue

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,6 +52,15 @@ def on_execute( # type: ignore
"""Called before and after the execution step."""
yield None

def on_subscription_result( # type: ignore
self, result: ExecutionResult
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
"""Called before and after each event/result yielded by a GraphQL subscription.

Extensions can mutate the `result` object directly (e.g., masking errors).
"""
yield None

def resolve(
self,
_next: Callable,
Expand All @@ -70,13 +80,17 @@ def _implements_resolve(cls) -> bool:
return cls.resolve is not SchemaExtension.resolve


Hook = Callable[[SchemaExtension], AsyncIteratorOrIterator[None]]
Hook = (
Callable[[SchemaExtension], AsyncIteratorOrIterator[None]]
| Callable[[SchemaExtension, ExecutionResult], AsyncIteratorOrIterator[None]]
)

HOOK_METHODS: set[str] = {
SchemaExtension.on_operation.__name__,
SchemaExtension.on_validate.__name__,
SchemaExtension.on_parse.__name__,
SchemaExtension.on_execute.__name__,
SchemaExtension.on_subscription_result.__name__,
}

__all__ = ["HOOK_METHODS", "Hook", "LifecycleStep", "SchemaExtension"]
54 changes: 54 additions & 0 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from types import TracebackType

from strawberry.extensions.base_extension import Hook
from strawberry.types import ExecutionResult
from strawberry.utils.await_maybe import AwaitableOrValue


Expand Down Expand Up @@ -155,3 +156,56 @@ class ParsingContextManager(ExtensionContextManagerBase):

class ExecutingContextManager(ExtensionContextManagerBase):
HOOK_NAME = SchemaExtension.on_execute.__name__


class SubscriptionResultContextManager(ExtensionContextManagerBase):
HOOK_NAME = SchemaExtension.on_subscription_result.__name__

def __init__(
self, extensions: list[SchemaExtension], result: ExecutionResult
) -> None:
self.result = result
super().__init__(extensions)

def from_callable( # type: ignore[override]
self,
extension: SchemaExtension,
func: Callable[..., AwaitableOrValue[Any]],
) -> WrappedHook:
if iscoroutinefunction(func):

@contextlib.asynccontextmanager
async def iterator(result: ExecutionResult) -> AsyncIterator[None]:
await func(extension, result)
yield

return WrappedHook(extension=extension, hook=iterator, is_async=True)

@contextlib.contextmanager
def iterator(result: ExecutionResult) -> Iterator[None]:
func(extension, result)
yield

return WrappedHook(extension=extension, hook=iterator, is_async=False)

def __enter__(self) -> None:
self.exit_stack = contextlib.ExitStack()
self.exit_stack.__enter__()

for hook in self.hooks:
if hook.is_async:
raise RuntimeError(
f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} "
"failed to complete synchronously."
)
self.exit_stack.enter_context(hook.hook(self.result)) # type: ignore

async def __aenter__(self) -> None:
self.async_exit_stack = contextlib.AsyncExitStack()
await self.async_exit_stack.__aenter__()

for hook in self.hooks:
if hook.is_async:
await self.async_exit_stack.enter_async_context(hook.hook(self.result)) # type: ignore
else:
self.async_exit_stack.enter_context(hook.hook(self.result)) # type: ignore
18 changes: 18 additions & 0 deletions strawberry/extensions/mask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from strawberry.extensions.base_extension import SchemaExtension
from strawberry.types.execution import ExecutionResult as StrawberryExecutionResult
from strawberry.types.graphql import OperationType


def default_should_mask_error(_: GraphQLError) -> bool:
Expand Down Expand Up @@ -53,9 +54,26 @@ def _process_result(self, result: Any) -> None:
def on_operation(self) -> Iterator[None]:
yield

# Subscriptions are handled event-by-event in on_subscription_result
try:
if self.execution_context.operation_type == OperationType.SUBSCRIPTION:
return
except RuntimeError:
# If the query fails to parse early on, operation_type throws a RuntimeError.
# We must PASS here to ensure standard Queries/Mutations get masked,
# even if it means a malformed Subscription gets double-masked.
pass

result = self.execution_context.result

if isinstance(result, (GraphQLExecutionResult, StrawberryExecutionResult)):
self._process_result(result)
elif result:
self._process_result(result.initial_result)

def on_subscription_result(
self, result: StrawberryExecutionResult
) -> Iterator[None]:
"""Mask errors on streaming subscription results."""
self._process_result(result)
yield None
8 changes: 7 additions & 1 deletion strawberry/extensions/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
ExecutingContextManager,
OperationContextManager,
ParsingContextManager,
SubscriptionResultContextManager,
ValidationContextManager,
)
from strawberry.utils.await_maybe import await_maybe

if TYPE_CHECKING:
from strawberry.types import ExecutionContext
from strawberry.types import ExecutionContext, ExecutionResult

from . import SchemaExtension

Expand Down Expand Up @@ -40,6 +41,11 @@ def parsing(self) -> ParsingContextManager:
def executing(self) -> ExecutingContextManager:
return ExecutingContextManager(self.extensions)

def on_subscription_result(
self, result: ExecutionResult
) -> SubscriptionResultContextManager:
return SubscriptionResultContextManager(self.extensions, result)

def get_extensions_results_sync(self) -> dict[str, Any]:
data: dict[str, Any] = {}
for extension in self.extensions:
Expand Down
28 changes: 23 additions & 5 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,9 +833,12 @@ async def _subscribe(
initial_error.extensions = (
await extensions_runner.get_extensions_results(execution_context)
)
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context, initial_error, extensions_runner
)
async with extensions_runner.on_subscription_result(execution_result):
yield execution_result
return # do not fall through to subscribe() after a pre-execution error
try:
async with extensions_runner.executing():
assert execution_context.graphql_document is not None
Expand Down Expand Up @@ -866,39 +869,54 @@ async def _subscribe(

# Handle pre-execution errors.
if isinstance(aiter_or_result, OriginalExecutionResult):
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
PreExecutionError(data=None, errors=aiter_or_result.errors),
extensions_runner,
)
async with extensions_runner.on_subscription_result(
execution_result
):
yield execution_result
else:
try:
async with aclosing(aiter_or_result):
async for result in aiter_or_result:
yield await self._handle_execution_result(
extension_result = await self._handle_execution_result(
execution_context,
result,
extensions_runner,
)

async with extensions_runner.on_subscription_result(
extension_result
):
yield extension_result
# graphql-core doesn't handle exceptions raised while executing.
except Exception as exc: # noqa: BLE001
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
OriginalExecutionResult(
data=None, errors=[_coerce_error(exc)]
),
extensions_runner,
)
async with extensions_runner.on_subscription_result(
execution_result
):
yield execution_result
# catch exceptions raised in `on_execute` hook.
except Exception as exc: # noqa: BLE001
origin_result = OriginalExecutionResult(
data=None, errors=[_coerce_error(exc)]
)
yield await self._handle_execution_result(
execution_result = await self._handle_execution_result(
execution_context,
origin_result,
extensions_runner,
)
async with extensions_runner.on_subscription_result(execution_result):
yield execution_result

async def subscribe(
self,
Expand Down
Loading
Loading