diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..4b26b9ac05 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +Await cancelled subscription tasks during WebSocket shutdown so their `finally` blocks run before shared state (DB pools, event loop) is torn down. diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 402577a5b9..1f5dc327a1 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -67,6 +67,7 @@ def __init__( self.connection_timed_out = False self.operations: dict[str, Operation[Context, RootValue]] = {} self.completed_tasks: list[asyncio.Task] = [] + self.cancelled_tasks: list[asyncio.Task] = [] async def handle(self) -> None: self.on_request_accepted() @@ -94,7 +95,9 @@ async def shutdown(self) -> None: for operation_id in list(self.operations.keys()): await self.cleanup_operation(operation_id) + await self.reap_completed_tasks() + await self.reap_cancelled_tasks() def on_request_accepted(self) -> None: # handle_request should call this once it has sent the @@ -329,6 +332,7 @@ async def cleanup_operation(self, operation_id: str) -> None: operation = self.operations.pop(operation_id) assert operation.task operation.task.cancel() + self.cancelled_tasks.append(operation.task) # do not await the task here, lest we block the main # websocket handler Task. @@ -339,6 +343,13 @@ async def reap_completed_tasks(self) -> None: with suppress(BaseException): await task + async def reap_cancelled_tasks(self) -> None: + """Await tasks that have been cancelled.""" + tasks, self.cancelled_tasks = self.cancelled_tasks, [] + for task in tasks: + with suppress(BaseException): + await task + class Operation(Generic[Context, RootValue]): """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index d5a212d4d2..ed19d1658c 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -1216,3 +1216,54 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( assert not process_errors.called assert Subscription.active_infinity_subscriptions == 0 + + +async def test_shutdown_awaits_cancelled_subscription_tasks( + http_client: HttpClient, +): + with contextlib.suppress(ImportError): + from tests.http.clients.channels import ChannelsHttpClient + + if isinstance(http_client, ChannelsHttpClient): + pytest.skip("Can't patch on_init for this client") + + handler = None + cleanup_done_at_shutdown_end = None + + def on_init(_handler): + nonlocal handler + if handler: + return + handler = _handler + original_shutdown = _handler.shutdown + + async def tracked_shutdown(): + nonlocal cleanup_done_at_shutdown_end + await original_shutdown() + cleanup_done_at_shutdown_end = Subscription.active_infinity_subscriptions + + _handler.shutdown = tracked_shutdown + + with patch.object(DebuggableGraphQLTransportWSHandler, "on_init", on_init): + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.send_message({"type": "connection_init"}) + ack: ConnectionAckMessage = await ws.receive_json() + assert ack == {"type": "connection_ack"} + + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { infinity(message: "Hi") }'}, + } + ) + await ws.receive(timeout=2) + assert Subscription.active_infinity_subscriptions == 1 + + await ws.close() + + await asyncio.sleep(0.5) + assert handler is not None + assert cleanup_done_at_shutdown_end == 0