diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..28039bb6eb --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: minor + +Adds a new `on_subscription_result` hook to `SchemaExtension` that allows extensions to interact with and mutate the stream of events yielded by GraphQL subscriptions. + +Previously, extensions were only triggered during the initial setup phase of a subscription, meaning transport layers (like WebSockets) bypassed them during the actual data streaming phase. This new hook solves this by executing right before each result is yielded to the client. + +As part of this architectural update, the built-in `MaskErrors` extension has been updated to use this new hook, ensuring that sensitive exceptions are now correctly masked during WebSocket subscriptions. diff --git a/strawberry/extensions/base_extension.py b/strawberry/extensions/base_extension.py index edf3535d55..443a25883a 100644 --- a/strawberry/extensions/base_extension.py +++ b/strawberry/extensions/base_extension.py @@ -4,6 +4,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any +from strawberry.types import ExecutionResult from strawberry.utils.await_maybe import AsyncIteratorOrIterator, AwaitableOrValue if TYPE_CHECKING: @@ -51,6 +52,15 @@ def on_execute( # type: ignore """Called before and after the execution step.""" yield None + def on_subscription_result( # type: ignore + self, result: ExecutionResult + ) -> AsyncIteratorOrIterator[None]: # pragma: no cover + """Called before and after each event/result yielded by a GraphQL subscription. + + Extensions can mutate the `result` object directly (e.g., masking errors). + """ + yield None + def resolve( self, _next: Callable, @@ -70,13 +80,17 @@ def _implements_resolve(cls) -> bool: return cls.resolve is not SchemaExtension.resolve -Hook = Callable[[SchemaExtension], AsyncIteratorOrIterator[None]] +Hook = ( + Callable[[SchemaExtension], AsyncIteratorOrIterator[None]] + | Callable[[SchemaExtension, ExecutionResult], AsyncIteratorOrIterator[None]] +) HOOK_METHODS: set[str] = { SchemaExtension.on_operation.__name__, SchemaExtension.on_validate.__name__, SchemaExtension.on_parse.__name__, SchemaExtension.on_execute.__name__, + SchemaExtension.on_subscription_result.__name__, } __all__ = ["HOOK_METHODS", "Hook", "LifecycleStep", "SchemaExtension"] diff --git a/strawberry/extensions/context.py b/strawberry/extensions/context.py index a8da3c0a0d..409eac4eec 100644 --- a/strawberry/extensions/context.py +++ b/strawberry/extensions/context.py @@ -17,6 +17,7 @@ from types import TracebackType from strawberry.extensions.base_extension import Hook + from strawberry.types import ExecutionResult from strawberry.utils.await_maybe import AwaitableOrValue @@ -155,3 +156,56 @@ class ParsingContextManager(ExtensionContextManagerBase): class ExecutingContextManager(ExtensionContextManagerBase): HOOK_NAME = SchemaExtension.on_execute.__name__ + + +class SubscriptionResultContextManager(ExtensionContextManagerBase): + HOOK_NAME = SchemaExtension.on_subscription_result.__name__ + + def __init__( + self, extensions: list[SchemaExtension], result: ExecutionResult + ) -> None: + self.result = result + super().__init__(extensions) + + def from_callable( # type: ignore[override] + self, + extension: SchemaExtension, + func: Callable[..., AwaitableOrValue[Any]], + ) -> WrappedHook: + if iscoroutinefunction(func): + + @contextlib.asynccontextmanager + async def iterator(result: ExecutionResult) -> AsyncIterator[None]: + await func(extension, result) + yield + + return WrappedHook(extension=extension, hook=iterator, is_async=True) + + @contextlib.contextmanager + def iterator(result: ExecutionResult) -> Iterator[None]: + func(extension, result) + yield + + return WrappedHook(extension=extension, hook=iterator, is_async=False) + + def __enter__(self) -> None: + self.exit_stack = contextlib.ExitStack() + self.exit_stack.__enter__() + + for hook in self.hooks: + if hook.is_async: + raise RuntimeError( + f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} " + "failed to complete synchronously." + ) + self.exit_stack.enter_context(hook.hook(self.result)) # type: ignore + + async def __aenter__(self) -> None: + self.async_exit_stack = contextlib.AsyncExitStack() + await self.async_exit_stack.__aenter__() + + for hook in self.hooks: + if hook.is_async: + await self.async_exit_stack.enter_async_context(hook.hook(self.result)) # type: ignore + else: + self.async_exit_stack.enter_context(hook.hook(self.result)) # type: ignore diff --git a/strawberry/extensions/mask_errors.py b/strawberry/extensions/mask_errors.py index 481f515979..1dbe82417a 100644 --- a/strawberry/extensions/mask_errors.py +++ b/strawberry/extensions/mask_errors.py @@ -6,6 +6,7 @@ from strawberry.extensions.base_extension import SchemaExtension from strawberry.types.execution import ExecutionResult as StrawberryExecutionResult +from strawberry.types.graphql import OperationType def default_should_mask_error(_: GraphQLError) -> bool: @@ -53,9 +54,26 @@ def _process_result(self, result: Any) -> None: def on_operation(self) -> Iterator[None]: yield + # Subscriptions are handled event-by-event in on_subscription_result + try: + if self.execution_context.operation_type == OperationType.SUBSCRIPTION: + return + except RuntimeError: + # If the query fails to parse early on, operation_type throws a RuntimeError. + # We must PASS here to ensure standard Queries/Mutations get masked, + # even if it means a malformed Subscription gets double-masked. + pass + result = self.execution_context.result if isinstance(result, (GraphQLExecutionResult, StrawberryExecutionResult)): self._process_result(result) elif result: self._process_result(result.initial_result) + + def on_subscription_result( + self, result: StrawberryExecutionResult + ) -> Iterator[None]: + """Mask errors on streaming subscription results.""" + self._process_result(result) + yield None diff --git a/strawberry/extensions/runner.py b/strawberry/extensions/runner.py index fbe3b40cf9..963ed6b165 100644 --- a/strawberry/extensions/runner.py +++ b/strawberry/extensions/runner.py @@ -7,12 +7,13 @@ ExecutingContextManager, OperationContextManager, ParsingContextManager, + SubscriptionResultContextManager, ValidationContextManager, ) from strawberry.utils.await_maybe import await_maybe if TYPE_CHECKING: - from strawberry.types import ExecutionContext + from strawberry.types import ExecutionContext, ExecutionResult from . import SchemaExtension @@ -40,6 +41,11 @@ def parsing(self) -> ParsingContextManager: def executing(self) -> ExecutingContextManager: return ExecutingContextManager(self.extensions) + def on_subscription_result( + self, result: ExecutionResult + ) -> SubscriptionResultContextManager: + return SubscriptionResultContextManager(self.extensions, result) + def get_extensions_results_sync(self) -> dict[str, Any]: data: dict[str, Any] = {} for extension in self.extensions: diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 3cdfea1538..2a65d993c7 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -833,9 +833,12 @@ async def _subscribe( initial_error.extensions = ( await extensions_runner.get_extensions_results(execution_context) ) - yield await self._handle_execution_result( + execution_result = await self._handle_execution_result( execution_context, initial_error, extensions_runner ) + async with extensions_runner.on_subscription_result(execution_result): + yield execution_result + return # do not fall through to subscribe() after a pre-execution error try: async with extensions_runner.executing(): assert execution_context.graphql_document is not None @@ -866,39 +869,54 @@ async def _subscribe( # Handle pre-execution errors. if isinstance(aiter_or_result, OriginalExecutionResult): - yield await self._handle_execution_result( + execution_result = await self._handle_execution_result( execution_context, PreExecutionError(data=None, errors=aiter_or_result.errors), extensions_runner, ) + async with extensions_runner.on_subscription_result( + execution_result + ): + yield execution_result else: try: async with aclosing(aiter_or_result): async for result in aiter_or_result: - yield await self._handle_execution_result( + extension_result = await self._handle_execution_result( execution_context, result, extensions_runner, ) + + async with extensions_runner.on_subscription_result( + extension_result + ): + yield extension_result # graphql-core doesn't handle exceptions raised while executing. except Exception as exc: # noqa: BLE001 - yield await self._handle_execution_result( + execution_result = await self._handle_execution_result( execution_context, OriginalExecutionResult( data=None, errors=[_coerce_error(exc)] ), extensions_runner, ) + async with extensions_runner.on_subscription_result( + execution_result + ): + yield execution_result # catch exceptions raised in `on_execute` hook. except Exception as exc: # noqa: BLE001 origin_result = OriginalExecutionResult( data=None, errors=[_coerce_error(exc)] ) - yield await self._handle_execution_result( + execution_result = await self._handle_execution_result( execution_context, origin_result, extensions_runner, ) + async with extensions_runner.on_subscription_result(execution_result): + yield execution_result async def subscribe( self, diff --git a/tests/extensions/test_subscription_hook.py b/tests/extensions/test_subscription_hook.py new file mode 100644 index 0000000000..85f60973b2 --- /dev/null +++ b/tests/extensions/test_subscription_hook.py @@ -0,0 +1,191 @@ +import asyncio +from collections.abc import AsyncGenerator, AsyncIterator, Iterator +from typing import Any + +import pytest + +import strawberry +from strawberry.extensions import SchemaExtension +from strawberry.extensions.mask_errors import MaskErrors +from strawberry.types import ExecutionResult + + +# Dummy extension that uses the new hook +class StreamModifierExtension(SchemaExtension): + def on_subscription_result(self, result: ExecutionResult) -> Iterator[None]: + if result.data and "count" in result.data: + # Mutate the outgoing data stream + result.data["count"] = f"Modified: {result.data['count']}" + yield None + + +# Create a basic schema with a subscription +@strawberry.type +class Query: + hello: str = "world" + + +@strawberry.type +class Subscription: + @strawberry.subscription + async def count(self) -> AsyncGenerator[int, None]: + yield 1 + yield 2 + + @strawberry.subscription + async def dangerous_stream(self) -> AsyncGenerator[int, None]: + yield 1 + raise ValueError("Secret database credentials leaked!") + + +# Tests +@pytest.mark.asyncio +async def test_extension_modifies_subscription_stream(): + schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[StreamModifierExtension] + ) + + query = "subscription { count }" + + sub_generator = await schema.subscribe(query) + + # Get all yielded results + results = [result async for result in sub_generator] + + assert len(results) == 2 + + # Assert that the extension successfully intercepted and modified the stream + assert results[0].data["count"] == "Modified: 1" + assert results[1].data["count"] == "Modified: 2" + + +@pytest.mark.asyncio +async def test_mask_errors_scrubs_subscription_exceptions(): + # Initialize schema with the MaskErrors extension + schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[MaskErrors()] + ) + + query = "subscription { dangerousStream }" + + sub_generator = await schema.subscribe(query) + + # Get all yielded results + results = [result async for result in sub_generator] + + # We expect 2 results: the successful yield, and the error + assert len(results) == 2 + + # Assert the first yield worked normally + assert results[0].data["dangerousStream"] == 1 + assert not results[0].errors + + # Assert the error was caught and MASKED + assert results[1].data is None + assert len(results[1].errors) == 1 + + # The crucial check: The raw exception message MUST NOT be exposed + error_message = results[1].errors[0].message + assert error_message == "Unexpected error." + assert "Secret database credentials" not in error_message + + +class AsyncStreamModifierExtension(SchemaExtension): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.side_effect_ran = False # Attached to the instance + + async def _side_effect(self) -> None: + # Simulate an async dependency / side effect + await asyncio.sleep(0) + self.side_effect_ran = True + + async def on_subscription_result( + self, result: ExecutionResult + ) -> AsyncIterator[None]: + # This should be awaited by extensions_runner.on_subscription_result + await self._side_effect() + + if result.data and "count" in result.data: + # Mutate the outgoing data stream after the async side effect + result.data["count"] = f"Modified: {result.data['count']}" + yield None + + +@pytest.mark.asyncio +async def test_async_on_subscription_result_is_awaited() -> None: + extension = AsyncStreamModifierExtension() + + schema = strawberry.Schema( + query=Query, + subscription=Subscription, + extensions=[extension], + ) + + query = "subscription { count }" + sub_generator = await schema.subscribe(query) + + # Consume first result from the async iterator + first_result = await sub_generator.__anext__() + + assert first_result.errors is None + assert first_result.data == {"count": "Modified: 1"} + assert extension.side_effect_ran is True + + +@pytest.mark.asyncio +async def test_mask_errors_scrubs_pre_execution_errors(): + # Initialize schema with MaskErrors + schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[MaskErrors()] + ) + + # Querying a field that doesn't exist triggers Validation errors BEFORE execution + query = "subscription { fieldThatDoesNotExist }" + + # Run the subscription + sub_generator = await schema.subscribe(query) + + # Exhaust the generator + results = [result async for result in sub_generator] + + # Pre-execution errors immediately yield exactly 1 result containing the error + assert len(results) == 1 + + for result in results: + assert result.data is None + assert len(result.errors) == 1 + + # The crucial check: MaskErrors successfully intercepted and masked it! + error_message = result.errors[0].message + assert error_message == "Unexpected error." + assert "fieldThatDoesNotExist" not in error_message + + +@pytest.mark.asyncio +async def test_mask_errors_scrubs_subscription_parse_errors(): + # Initialize schema with MaskErrors + schema = strawberry.Schema( + query=Query, subscription=Subscription, extensions=[MaskErrors()] + ) + + # A syntactically invalid query (missing closing brace) triggers a Parse error BEFORE execution + query = "subscription { count" + + # Run the subscription + sub_generator = await schema.subscribe(query) + + # Exhaust the generator + results = [result async for result in sub_generator] + + # Parse errors immediately yield exactly 1 result containing the error + assert len(results) == 1 + + for result in results: + assert result.data is None + assert len(result.errors) == 1 + + # The crucial check: MaskErrors successfully intercepted and masked it! + error_message = result.errors[0].message + assert error_message == "Unexpected error." + assert "Syntax Error" not in error_message