Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Await cancelled subscription tasks during WebSocket shutdown so their `finally` blocks run before shared state (DB pools, event loop) is torn down.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.connection_timed_out = False
self.operations: dict[str, Operation[Context, RootValue]] = {}
self.completed_tasks: list[asyncio.Task] = []
self.cancelled_tasks: list[asyncio.Task] = []

async def handle(self) -> None:
self.on_request_accepted()
Expand Down Expand Up @@ -94,7 +95,9 @@ async def shutdown(self) -> None:

for operation_id in list(self.operations.keys()):
await self.cleanup_operation(operation_id)

await self.reap_completed_tasks()
await self.reap_cancelled_tasks()

def on_request_accepted(self) -> None:
# handle_request should call this once it has sent the
Expand Down Expand Up @@ -329,6 +332,7 @@ async def cleanup_operation(self, operation_id: str) -> None:
operation = self.operations.pop(operation_id)
assert operation.task
operation.task.cancel()
self.cancelled_tasks.append(operation.task)
# do not await the task here, lest we block the main
# websocket handler Task.

Expand All @@ -339,6 +343,13 @@ async def reap_completed_tasks(self) -> None:
with suppress(BaseException):
await task

async def reap_cancelled_tasks(self) -> None:
"""Await tasks that have been cancelled."""
tasks, self.cancelled_tasks = self.cancelled_tasks, []
for task in tasks:
with suppress(BaseException):
await task


class Operation(Generic[Context, RootValue]):
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""
Expand Down
51 changes: 51 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,3 +1216,54 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0


async def test_shutdown_awaits_cancelled_subscription_tasks(
http_client: HttpClient,
):
with contextlib.suppress(ImportError):
from tests.http.clients.channels import ChannelsHttpClient

if isinstance(http_client, ChannelsHttpClient):
pytest.skip("Can't patch on_init for this client")

handler = None
cleanup_done_at_shutdown_end = None

def on_init(_handler):
nonlocal handler
if handler:
return
handler = _handler
original_shutdown = _handler.shutdown

async def tracked_shutdown():
nonlocal cleanup_done_at_shutdown_end
await original_shutdown()
cleanup_done_at_shutdown_end = Subscription.active_infinity_subscriptions

_handler.shutdown = tracked_shutdown

with patch.object(DebuggableGraphQLTransportWSHandler, "on_init", on_init):
async with http_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
await ws.send_message({"type": "connection_init"})
ack: ConnectionAckMessage = await ws.receive_json()
assert ack == {"type": "connection_ack"}

await ws.send_message(
{
"id": "sub1",
"type": "subscribe",
"payload": {"query": 'subscription { infinity(message: "Hi") }'},
}
)
await ws.receive(timeout=2)
assert Subscription.active_infinity_subscriptions == 1

await ws.close()

await asyncio.sleep(0.5)
assert handler is not None
assert cleanup_done_at_shutdown_end == 0
Loading