Feat #4329: Add on_subscription_result hook to SchemaExtension#4330
Feat #4329: Add on_subscription_result hook to SchemaExtension#4330Ladol wants to merge 9 commits intostrawberry-graphql:mainfrom
on_subscription_result hook to SchemaExtension#4330Conversation
…hemaExtension Fixes strawberry-graphql#4329 Fixes strawberry-graphql#3680 Previously, `SchemaExtension` hooks only wrapped the initial setup phase of a GraphQL subscription, leaving extensions completely disconnected from the actual stream of yielded events. This commit introduces the `on_subscription_result` hook to the base `SchemaExtension` class and triggers it inside the `schema._subscribe` generator. This allows extensions to safely mutate streamed data before it reaches the transport layer. Additionally, the `MaskErrors` extension has been updated to use this new hook, fixing an issue where sensitive errors were leaking unmasked over WebSocket connections.
Reviewer's GuideAdds a new File-Level Changes
Possibly linked issues
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
Thanks for adding the Here's a preview of the changelog: Adds a new Previously, extensions were only triggered during the initial setup phase of a subscription, meaning transport layers (like WebSockets) bypassed them during the actual data streaming phase. This new hook solves this by executing right before each result is yielded to the client. As part of this architectural update, the built-in Here's the tweet text: |
There was a problem hiding this comment.
Hey - I've found 1 issue, and left some high level feedback:
- In
ExtensionsRunner.on_subscription_result, thegetattr(extension, "on_subscription_result", None)check is redundant becauseSchemaExtensionalready defines this method; you can simplify the loop to callextension.on_subscription_result(result)directly and rely onawait_maybeto handle sync vs async implementations.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `ExtensionsRunner.on_subscription_result`, the `getattr(extension, "on_subscription_result", None)` check is redundant because `SchemaExtension` already defines this method; you can simplify the loop to call `extension.on_subscription_result(result)` directly and rely on `await_maybe` to handle sync vs async implementations.
## Individual Comments
### Comment 1
<location path="tests/extensions/test_subscription_hook.py" line_range="12-16" />
<code_context>
+
+
+# Dummy extension that uses the new hook
+class StreamModifierExtension(SchemaExtension):
+ def on_subscription_result(self, result: ExecutionResult) -> None:
+ if result.data and "count" in result.data:
+ # Mutate the outgoing data stream
+ result.data["count"] = f"Modified: {result.data['count']}"
+
+
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding a test that covers an async `on_subscription_result` hook implementation.
Current tests only cover the synchronous implementation. Since `extensions_runner.on_subscription_result` uses `await_maybe`, please add a test with an extension whose `on_subscription_result` is `async def` and awaits a side effect before mutating the result, to confirm async hooks are correctly awaited and safely integrate with async dependencies.
Suggested implementation:
```python
import asyncio
import pytest
import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult
# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
def on_subscription_result(self, result: ExecutionResult) -> None:
if result.data and "count" in result.data:
# Mutate the outgoing data stream
result.data["count"] = f"Modified: {result.data['count']}"
class AsyncStreamModifierExtension(SchemaExtension):
side_effect_ran = False
async def _side_effect(self) -> None:
# Simulate an async dependency / side effect
await asyncio.sleep(0)
AsyncStreamModifierExtension.side_effect_ran = True
async def on_subscription_result(self, result: ExecutionResult) -> None:
# This should be awaited by extensions_runner.on_subscription_result
await self._side_effect()
if result.data and "count" in result.data:
# Mutate the outgoing data stream after the async side effect
result.data["count"] = f"Modified: {result.data['count']}"
@strawberry.type
class Query:
# Minimal Query type; field is unused but required by Strawberry
example: str = "example"
@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> int:
for i in range(3):
yield i
@pytest.mark.asyncio
async def test_async_on_subscription_result_is_awaited() -> None:
# Reset class-level flag before running the subscription
AsyncStreamModifierExtension.side_effect_ran = False
schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[AsyncStreamModifierExtension],
)
results = schema.subscribe("subscription { count }")
# Consume first result from the async iterator
first_result = await results.__anext__()
assert first_result.errors is None
assert first_result.data == {"count": "Modified: 0"}
assert AsyncStreamModifierExtension.side_effect_ran is True
```
If this file already defines `Query` / `Subscription` types or other subscription tests, you may want to:
1. Reuse existing `Query`/`Subscription` instead of defining new ones here, to keep the test suite DRY.
2. Adjust the subscription field name or expected payload in the test to match any shared schema definitions (e.g., if an existing subscription already yields a `count` field).
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| class StreamModifierExtension(SchemaExtension): | ||
| def on_subscription_result(self, result: ExecutionResult) -> None: | ||
| if result.data and "count" in result.data: | ||
| # Mutate the outgoing data stream | ||
| result.data["count"] = f"Modified: {result.data['count']}" |
There was a problem hiding this comment.
suggestion (testing): Consider adding a test that covers an async on_subscription_result hook implementation.
Current tests only cover the synchronous implementation. Since extensions_runner.on_subscription_result uses await_maybe, please add a test with an extension whose on_subscription_result is async def and awaits a side effect before mutating the result, to confirm async hooks are correctly awaited and safely integrate with async dependencies.
Suggested implementation:
import asyncio
import pytest
import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.mask_errors import MaskErrors
from strawberry.types import ExecutionResult
# Dummy extension that uses the new hook
class StreamModifierExtension(SchemaExtension):
def on_subscription_result(self, result: ExecutionResult) -> None:
if result.data and "count" in result.data:
# Mutate the outgoing data stream
result.data["count"] = f"Modified: {result.data['count']}"
class AsyncStreamModifierExtension(SchemaExtension):
side_effect_ran = False
async def _side_effect(self) -> None:
# Simulate an async dependency / side effect
await asyncio.sleep(0)
AsyncStreamModifierExtension.side_effect_ran = True
async def on_subscription_result(self, result: ExecutionResult) -> None:
# This should be awaited by extensions_runner.on_subscription_result
await self._side_effect()
if result.data and "count" in result.data:
# Mutate the outgoing data stream after the async side effect
result.data["count"] = f"Modified: {result.data['count']}"
@strawberry.type
class Query:
# Minimal Query type; field is unused but required by Strawberry
example: str = "example"
@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> int:
for i in range(3):
yield i
@pytest.mark.asyncio
async def test_async_on_subscription_result_is_awaited() -> None:
# Reset class-level flag before running the subscription
AsyncStreamModifierExtension.side_effect_ran = False
schema = strawberry.Schema(
query=Query,
subscription=Subscription,
extensions=[AsyncStreamModifierExtension],
)
results = schema.subscribe("subscription { count }")
# Consume first result from the async iterator
first_result = await results.__anext__()
assert first_result.errors is None
assert first_result.data == {"count": "Modified: 0"}
assert AsyncStreamModifierExtension.side_effect_ran is TrueIf this file already defines Query / Subscription types or other subscription tests, you may want to:
- Reuse existing
Query/Subscriptioninstead of defining new ones here, to keep the test suite DRY. - Adjust the subscription field name or expected payload in the test to match any shared schema definitions (e.g., if an existing subscription already yields a
countfield).
Greptile SummaryThis PR introduces the The architecture is sound:
Confidence Score: 4/5Safe to merge after fixing the MaskErrors parse-error regression in on_operation. The new hook and all its plumbing are correct and well-tested for subscription paths. One P1 regression in MaskErrors.on_operation would cause query and mutation parse errors to be returned unmasked to callers relying on MaskErrors for error sanitisation. strawberry/extensions/mask_errors.py lines 61-65 — the Important Files Changed
Sequence DiagramsequenceDiagram
participant Client
participant Schema
participant ExtRunner as ExtensionsRunner
participant MaskErrors
participant GQLCore as graphql-core
Client->>Schema: subscribe(query, ...)
Schema->>Schema: _create_execution_context()
Schema->>ExtRunner: create_extensions_runner(ctx, extensions)
Schema->>Schema: _subscribe(ctx, extensions_runner, ...)
activate Schema
ExtRunner->>MaskErrors: on_operation() [enter]
Schema->>Schema: _parse_and_validate_async(ctx, extensions_runner)
alt Parse or validation error
Schema-->>ExtRunner: initial_error (PreExecutionError)
Schema->>Schema: _handle_execution_result(ctx, initial_error)
ExtRunner->>MaskErrors: on_subscription_result(result) [enter]
MaskErrors->>MaskErrors: _process_result(result)
MaskErrors-->>ExtRunner: yield
Schema-->>Client: yield execution_result
Schema-->>Schema: return
else Successful parse and validation
Schema->>GQLCore: subscribe(schema, document, ...)
GQLCore-->>Schema: async iterator
loop For each streamed result
GQLCore-->>Schema: GraphQLExecutionResult
Schema->>Schema: _handle_execution_result(ctx, result)
ExtRunner->>MaskErrors: on_subscription_result(extension_result) [enter]
MaskErrors->>MaskErrors: _process_result(result)
MaskErrors-->>ExtRunner: yield
Schema-->>Client: yield extension_result
end
end
ExtRunner->>MaskErrors: on_operation() [exit]
note over ExtRunner,MaskErrors: on_operation skips masking for subscriptions
deactivate Schema
Reviews (8): Last reviewed commit: "Fix: Return on RuntimeError in MaskError..." | Re-trigger Greptile |
strawberry/extensions/runner.py
Outdated
| async def on_subscription_result(self, result: ExecutionResult) -> None: | ||
| """Run the subscription result hook across all active extensions.""" | ||
| for extension in self.extensions: | ||
| # Check if the extension implemented the new hook | ||
| hook = getattr(extension, "on_subscription_result", None) | ||
| if hook: | ||
| await await_maybe(hook(result)) |
There was a problem hiding this comment.
getattr guard is always truthy
The comment says "Check if the extension implemented the new hook", but since SchemaExtension now defines on_subscription_result with a default no-op body, every extension will always carry this attribute. The if hook: branch is therefore always taken, making the check misleading without being harmful (the base-class call simply returns None which await_maybe handles fine).
Consider simplifying to remove the dead guard:
| async def on_subscription_result(self, result: ExecutionResult) -> None: | |
| """Run the subscription result hook across all active extensions.""" | |
| for extension in self.extensions: | |
| # Check if the extension implemented the new hook | |
| hook = getattr(extension, "on_subscription_result", None) | |
| if hook: | |
| await await_maybe(hook(result)) | |
| async def on_subscription_result(self, result: ExecutionResult) -> None: | |
| """Run the subscription result hook across all active extensions.""" | |
| for extension in self.extensions: | |
| await await_maybe(extension.on_subscription_result(result)) |
|
I've pushed updates addressing the feedback from the bots automatic reviews. The bot suggested adding the new hook to the HOOK_METHODS set in base_extension.py, but when this hook was there it recommended removing it, so I'm unsure of what the best practice would be. Looking forward for a review. |
| """Called before and after the execution step.""" | ||
| yield None | ||
|
|
||
| def on_subscription_result( |
There was a problem hiding this comment.
I think we should follow the same patterns as the others here, using yield None, which would allow extensions to do something before and after
Wdyt?
There was a problem hiding this comment.
Hi.
I agree. I think it makes sense to follow the patterns as the other methods, and this change allows for more flexibility for the extensions.
I have just made a new commit and pushed the updates to implement this. Thank you for your review.
There was a problem hiding this comment.
Pull request overview
Adds a new subscription lifecycle hook so SchemaExtension implementations can inspect/mutate each streamed subscription ExecutionResult, fixing MaskErrors not masking errors yielded during subscription streaming (e.g., over WebSockets).
Changes:
- Introduces
SchemaExtension.on_subscription_result(result)(sync or async) and wires it throughSchemaExtensionsRunner. - Invokes the new hook for each yielded subscription result inside
Schema._subscribe. - Updates
MaskErrorsto mask errors for streamed subscription results and adds targeted subscription tests + release note.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
strawberry/extensions/base_extension.py |
Adds the on_subscription_result hook to the base extension API. |
strawberry/extensions/runner.py |
Adds a runner method to invoke the hook across all active extensions (awaiting async hooks). |
strawberry/schema/schema.py |
Calls on_subscription_result for initial, pre-execution, per-event, and exception-derived subscription results. |
strawberry/extensions/mask_errors.py |
Uses the new hook to mask subscription-stream errors. |
tests/extensions/test_subscription_hook.py |
Adds tests for stream mutation, async hook awaiting, and MaskErrors masking on subscriptions. |
RELEASE.md |
Documents the new hook and the MaskErrors behavior change. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
strawberry/schema/schema.py
Outdated
| execution_context, initial_error, extensions_runner | ||
| ) | ||
| await extensions_runner.on_subscription_result(execution_result) | ||
| yield execution_result |
There was a problem hiding this comment.
When _parse_and_validate_async returns an initial_error, _subscribe yields it but then continues into the execution path. For parse errors this can hit the assert execution_context.graphql_document is not None and yield an additional (unrelated) error; for validation errors it can also lead to duplicate error results from subscribe(...). Consider returning immediately after yielding the initial error (or wrapping the rest of the function in an else) so subscription setup stops cleanly on pre-execution failures.
| yield execution_result | |
| yield execution_result | |
| return |
strawberry/extensions/mask_errors.py
Outdated
| elif result: | ||
| self._process_result(result.initial_result) | ||
|
|
||
| def on_subscription_result(self, result: StrawberryExecutionResult) -> None: |
There was a problem hiding this comment.
MaskErrors.on_subscription_result narrows the parameter type to StrawberryExecutionResult, but the base SchemaExtension.on_subscription_result is typed as taking ExecutionResult. This narrower override can trigger type-checking issues (it’s not substitutable if the runner passes a base ExecutionResult). Prefer matching the base signature (ExecutionResult) here.
| def on_subscription_result(self, result: StrawberryExecutionResult) -> None: | |
| def on_subscription_result(self, result: GraphQLExecutionResult) -> None: |
| assert len(result.errors) == 1 | ||
|
|
||
| # The crucial check: MaskErrors successfully intercepted and masked it! | ||
| error_message = result.errors[0].message | ||
| assert error_message == "Unexpected error." | ||
| assert "fieldThatDoesNotExist" not in error_message |
There was a problem hiding this comment.
This test asserts len(result.errors) == 1, but earlier comment notes Strawberry may yield 1+ validation errors. Since MaskErrors masks messages without changing the number of errors, this assertion will be flaky depending on validation output. Consider asserting that result.errors is non-empty and that all error messages are masked (iterate over result.errors) rather than enforcing a single error.
| assert len(result.errors) == 1 | |
| # The crucial check: MaskErrors successfully intercepted and masked it! | |
| error_message = result.errors[0].message | |
| assert error_message == "Unexpected error." | |
| assert "fieldThatDoesNotExist" not in error_message | |
| assert result.errors | |
| # The crucial check: MaskErrors successfully intercepted and masked every error! | |
| for error in result.errors: | |
| error_message = error.message | |
| assert error_message == "Unexpected error." | |
| assert "fieldThatDoesNotExist" not in error_message |
| query = "subscription { count }" | ||
| sub_generator = await schema.subscribe(query) | ||
|
|
||
| # Consume first result from the async iterator | ||
| first_result = await sub_generator.__anext__() | ||
|
|
There was a problem hiding this comment.
The test consumes only the first item from the subscription iterator and never closes it. Other subscription tests in the repo wrap schema.subscribe(...) in contextlib.aclosing(...) to ensure the async generator is closed even when not fully exhausted. Consider using aclosing (or explicitly calling await sub_generator.aclose()) to avoid leaking an open subscription iterator.
| def on_subscription_result( | ||
| self, result: ExecutionResult | ||
| ) -> None | AwaitableOrValue[None]: | ||
| """Called exactly once for every event/result yielded by a GraphQL subscription. | ||
|
|
||
| Extensions can mutate the `result` object directly (e.g., masking errors). | ||
| """ |
There was a problem hiding this comment.
The return type for on_subscription_result is currently None | AwaitableOrValue[None], but AwaitableOrValue[None] already includes None. Consider simplifying this to AwaitableOrValue[None] to avoid redundant typing and match the pattern used by other extension methods like resolve/get_results.
bellini666
left a comment
There was a problem hiding this comment.
This looks good to me. But would like @patrick91 's opinion here as well 🙏🏼

Description
Previously,
SchemaExtensionhooks only wrapped the initial setup phase of a GraphQL subscription, leaving extensions completely disconnected from the actual stream of yielded events.This commit introduces the
on_subscription_resulthook to the baseSchemaExtensionclass and triggers it inside theschema._subscribegenerator. This allows extensions to safely mutate streamed data before it reaches the transport layer.Additionally, the
MaskErrorsextension has been updated to use this new hook, fixing an issue where sensitive errors were leaking unmasked over WebSocket connections.Types of Changes
Issues Fixed or Closed by This PR
#3680
#4329
Checklist
Summary by Sourcery
Add a subscription-result lifecycle hook to schema extensions and ensure subscription errors are properly masked.
New Features:
on_subscription_resulthook onSchemaExtensionthat runs for every yielded subscription event, allowing extensions to inspect or mutate each result.Bug Fixes:
MaskErrorsextension masks exceptions emitted from subscription streams so sensitive error details are not leaked over WebSocket connections.Enhancements:
Documentation:
Tests:
MaskErrorscorrectly masking subscription errors.