-
-
Notifications
You must be signed in to change notification settings - Fork 627
Feat #4329: Add on_subscription_result hook to SchemaExtension
#4330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
c77c272
d511b9d
c2ffc67
d61feef
0faacb2
173eb92
1756b1f
afe84b6
29754da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| Release type: minor | ||
|
|
||
| Adds a new `on_subscription_result` hook to `SchemaExtension` that allows extensions to interact with and mutate the stream of events yielded by GraphQL subscriptions. | ||
|
|
||
| Previously, extensions were only triggered during the initial setup phase of a subscription, meaning transport layers (like WebSockets) bypassed them during the actual data streaming phase. This new hook solves this by executing right before each result is yielded to the client. | ||
|
|
||
| As part of this architectural update, the built-in `MaskErrors` extension has been updated to use this new hook, ensuring that sensitive exceptions are now correctly masked during WebSocket subscriptions. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
| if TYPE_CHECKING: | ||
| from graphql import GraphQLResolveInfo | ||
|
|
||
| from strawberry.types import ExecutionContext | ||
| from strawberry.types import ExecutionContext, ExecutionResult | ||
|
|
||
|
|
||
| class LifecycleStep(Enum): | ||
|
|
@@ -51,6 +51,14 @@ def on_execute( # type: ignore | |
| """Called before and after the execution step.""" | ||
| yield None | ||
|
|
||
| def on_subscription_result( | ||
| self, result: ExecutionResult | ||
| ) -> None | AwaitableOrValue[None]: | ||
| """Called exactly once for every event/result yielded by a GraphQL subscription. | ||
|
|
||
| Extensions can mutate the `result` object directly (e.g., masking errors). | ||
| """ | ||
|
||
|
|
||
| def resolve( | ||
| self, | ||
| _next: Callable, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -59,3 +59,7 @@ def on_operation(self) -> Iterator[None]: | |||||
| self._process_result(result) | ||||||
| elif result: | ||||||
| self._process_result(result.initial_result) | ||||||
|
|
||||||
| def on_subscription_result(self, result: StrawberryExecutionResult) -> None: | ||||||
|
||||||
| def on_subscription_result(self, result: StrawberryExecutionResult) -> None: | |
| def on_subscription_result(self, result: GraphQLExecutionResult) -> None: |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -833,9 +833,11 @@ async def _subscribe( | |||||||
| initial_error.extensions = ( | ||||||||
| await extensions_runner.get_extensions_results(execution_context) | ||||||||
| ) | ||||||||
| yield await self._handle_execution_result( | ||||||||
| execution_result = await self._handle_execution_result( | ||||||||
| execution_context, initial_error, extensions_runner | ||||||||
| ) | ||||||||
| await extensions_runner.on_subscription_result(execution_result) | ||||||||
| yield execution_result | ||||||||
|
||||||||
| yield execution_result | |
| yield execution_result | |
| return |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,157 @@ | ||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||
| from collections.abc import AsyncGenerator | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import strawberry | ||||||||||||||||||||||||||||
| from strawberry.extensions import SchemaExtension | ||||||||||||||||||||||||||||
| from strawberry.extensions.mask_errors import MaskErrors | ||||||||||||||||||||||||||||
| from strawberry.types import ExecutionResult | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Dummy extension that uses the new hook | ||||||||||||||||||||||||||||
| class StreamModifierExtension(SchemaExtension): | ||||||||||||||||||||||||||||
| def on_subscription_result(self, result: ExecutionResult) -> None: | ||||||||||||||||||||||||||||
| if result.data and "count" in result.data: | ||||||||||||||||||||||||||||
| # Mutate the outgoing data stream | ||||||||||||||||||||||||||||
| result.data["count"] = f"Modified: {result.data['count']}" | ||||||||||||||||||||||||||||
|
Comment on lines
+14
to
+18
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding a test that covers an async Current tests only cover the synchronous implementation. Since Suggested implementation: import asyncio
import pytest
import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult
# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
def on_subscription_result(self, result: ExecutionResult) -> None:
if result.data and "count" in result.data:
# Mutate the outgoing data stream
result.data["count"] = f"Modified: {result.data['count']}"
class AsyncStreamModifierExtension(SchemaExtension):
side_effect_ran = False
async def _side_effect(self) -> None:
# Simulate an async dependency / side effect
await asyncio.sleep(0)
AsyncStreamModifierExtension.side_effect_ran = True
async def on_subscription_result(self, result: ExecutionResult) -> None:
# This should be awaited by extensions_runner.on_subscription_result
await self._side_effect()
if result.data and "count" in result.data:
# Mutate the outgoing data stream after the async side effect
result.data["count"] = f"Modified: {result.data['count']}"
@strawberry.type
class Query:
# Minimal Query type; field is unused but required by Strawberry
example: str = "example"
@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> int:
for i in range(3):
yield i
@pytest.mark.asyncio
async def test_async_on_subscription_result_is_awaited() -> None:
# Reset class-level flag before running the subscription
AsyncStreamModifierExtension.side_effect_ran = False
schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[AsyncStreamModifierExtension],
)
results = schema.subscribe("subscription { count }")
# Consume first result from the async iterator
first_result = await results.__anext__()
assert first_result.errors is None
assert first_result.data == {"count": "Modified: 0"}
assert AsyncStreamModifierExtension.side_effect_ran is TrueIf this file already defines
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Create a basic schema with a subscription | ||||||||||||||||||||||||||||
| @strawberry.type | ||||||||||||||||||||||||||||
| class Query: | ||||||||||||||||||||||||||||
| hello: str = "world" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @strawberry.type | ||||||||||||||||||||||||||||
| class Subscription: | ||||||||||||||||||||||||||||
| @strawberry.subscription | ||||||||||||||||||||||||||||
| async def count(self) -> AsyncGenerator[int, None]: | ||||||||||||||||||||||||||||
| yield 1 | ||||||||||||||||||||||||||||
| yield 2 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @strawberry.subscription | ||||||||||||||||||||||||||||
| async def dangerous_stream(self) -> AsyncGenerator[int, None]: | ||||||||||||||||||||||||||||
| yield 1 | ||||||||||||||||||||||||||||
| raise ValueError("Secret database credentials leaked!") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Tests | ||||||||||||||||||||||||||||
| @pytest.mark.asyncio | ||||||||||||||||||||||||||||
| async def test_extension_modifies_subscription_stream(): | ||||||||||||||||||||||||||||
| schema = strawberry.Schema( | ||||||||||||||||||||||||||||
| query=Query, subscription=Subscription, extensions=[StreamModifierExtension] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| query = "subscription { count }" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| sub_generator = await schema.subscribe(query) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Get all yielded results | ||||||||||||||||||||||||||||
| results = [result async for result in sub_generator] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert len(results) == 2 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Assert that the extension successfully intercepted and modified the stream | ||||||||||||||||||||||||||||
| assert results[0].data["count"] == "Modified: 1" | ||||||||||||||||||||||||||||
| assert results[1].data["count"] == "Modified: 2" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pytest.mark.asyncio | ||||||||||||||||||||||||||||
| async def test_mask_errors_scrubs_subscription_exceptions(): | ||||||||||||||||||||||||||||
| # Initialize schema with the MaskErrors extension | ||||||||||||||||||||||||||||
| schema = strawberry.Schema( | ||||||||||||||||||||||||||||
| query=Query, subscription=Subscription, extensions=[MaskErrors()] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| query = "subscription { dangerousStream }" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| sub_generator = await schema.subscribe(query) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Get all yielded results | ||||||||||||||||||||||||||||
| results = [result async for result in sub_generator] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # We expect 2 results: the successful yield, and the error | ||||||||||||||||||||||||||||
| assert len(results) == 2 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Assert the first yield worked normally | ||||||||||||||||||||||||||||
| assert results[0].data["dangerousStream"] == 1 | ||||||||||||||||||||||||||||
| assert not results[0].errors | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Assert the error was caught and MASKED | ||||||||||||||||||||||||||||
| assert results[1].data is None | ||||||||||||||||||||||||||||
| assert len(results[1].errors) == 1 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # The crucial check: The raw exception message MUST NOT be exposed | ||||||||||||||||||||||||||||
| error_message = results[1].errors[0].message | ||||||||||||||||||||||||||||
| assert error_message == "Unexpected error." | ||||||||||||||||||||||||||||
| assert "Secret database credentials" not in error_message | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class AsyncStreamModifierExtension(SchemaExtension): | ||||||||||||||||||||||||||||
| side_effect_ran = False | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| async def _side_effect(self) -> None: | ||||||||||||||||||||||||||||
| # Simulate an async dependency / side effect | ||||||||||||||||||||||||||||
| await asyncio.sleep(0) | ||||||||||||||||||||||||||||
| AsyncStreamModifierExtension.side_effect_ran = True | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| async def on_subscription_result(self, result: ExecutionResult) -> None: | ||||||||||||||||||||||||||||
| # This should be awaited by extensions_runner.on_subscription_result | ||||||||||||||||||||||||||||
| await self._side_effect() | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if result.data and "count" in result.data: | ||||||||||||||||||||||||||||
| # Mutate the outgoing data stream after the async side effect | ||||||||||||||||||||||||||||
| result.data["count"] = f"Modified: {result.data['count']}" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pytest.mark.asyncio | ||||||||||||||||||||||||||||
| async def test_async_on_subscription_result_is_awaited() -> None: | ||||||||||||||||||||||||||||
| # Reset class-level flag before running the subscription | ||||||||||||||||||||||||||||
| AsyncStreamModifierExtension.side_effect_ran = False | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| schema = strawberry.Schema( | ||||||||||||||||||||||||||||
| query=Query, | ||||||||||||||||||||||||||||
| subscription=Subscription, | ||||||||||||||||||||||||||||
| extensions=[AsyncStreamModifierExtension()], | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| query = "subscription { count }" | ||||||||||||||||||||||||||||
| sub_generator = await schema.subscribe(query) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Consume first result from the async iterator | ||||||||||||||||||||||||||||
| first_result = await sub_generator.__anext__() | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
Comment on lines
+125
to
+130
|
||||||||||||||||||||||||||||
| assert first_result.errors is None | ||||||||||||||||||||||||||||
| assert first_result.data == {"count": "Modified: 1"} | ||||||||||||||||||||||||||||
| assert AsyncStreamModifierExtension.side_effect_ran is True | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pytest.mark.asyncio | ||||||||||||||||||||||||||||
| async def test_mask_errors_scrubs_pre_execution_errors(): | ||||||||||||||||||||||||||||
| # Initialize schema with MaskErrors | ||||||||||||||||||||||||||||
| schema = strawberry.Schema( | ||||||||||||||||||||||||||||
| query=Query, subscription=Subscription, extensions=[MaskErrors()] | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Querying a field that doesn't exist triggers Validation errors BEFORE execution | ||||||||||||||||||||||||||||
| query = "subscription { fieldThatDoesNotExist }" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Run the subscription | ||||||||||||||||||||||||||||
| sub_generator = await schema.subscribe(query) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Exhaust the generator | ||||||||||||||||||||||||||||
| results = [result async for result in sub_generator] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Strawberry yields 1 or more errors depending on the validation layers it hits. | ||||||||||||||||||||||||||||
| # We must ensure that EVERY result yielded was successfully intercepted and masked. | ||||||||||||||||||||||||||||
| assert len(results) > 0 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for result in results: | ||||||||||||||||||||||||||||
| assert result.data is None | ||||||||||||||||||||||||||||
| assert len(result.errors) == 1 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # The crucial check: MaskErrors successfully intercepted and masked it! | ||||||||||||||||||||||||||||
| error_message = result.errors[0].message | ||||||||||||||||||||||||||||
| assert error_message == "Unexpected error." | ||||||||||||||||||||||||||||
| assert "fieldThatDoesNotExist" not in error_message | ||||||||||||||||||||||||||||
|
Comment on lines
+157
to
+162
|
||||||||||||||||||||||||||||
| assert len(result.errors) == 1 | |
| # The crucial check: MaskErrors successfully intercepted and masked it! | |
| error_message = result.errors[0].message | |
| assert error_message == "Unexpected error." | |
| assert "fieldThatDoesNotExist" not in error_message | |
| assert result.errors | |
| # The crucial check: MaskErrors successfully intercepted and masked every error! | |
| for error in result.errors: | |
| error_message = error.message | |
| assert error_message == "Unexpected error." | |
| assert "fieldThatDoesNotExist" not in error_message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should follow the same patterns as the others here, using
yield None, which would allow extensions to do something before and afterWdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi.
I agree. I think it makes sense to follow the patterns as the other methods, and this change allows for more flexibility for the extensions.
I have just made a new commit and pushed the updates to implement this. Thank you for your review.