diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..2820214c74 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: patch + +Operations over `graphql-transport-ws` now create the Context and perform validation on +the worker `Task`, thus not blocking the websocket from accepting messages. diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 1275ecf304..94a8ca0903 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -10,12 +10,13 @@ AsyncGenerator, AsyncIterator, Callable, + Coroutine, Dict, List, Optional, + Union, ) -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, GraphQLSyntaxError, parse from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -29,6 +30,7 @@ SubscribeMessage, SubscribeMessagePayload, ) +from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType from strawberry.types.unset import UNSET from strawberry.utils.debug import pretty_print_graphql_operation @@ -41,7 +43,6 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( GraphQLTransportMessage, ) - from strawberry.types import ExecutionResult class BaseGraphQLTransportWSHandler(ABC): @@ -107,7 +108,7 @@ def on_request_accepted(self) -> None: async def handle_connection_init_timeout(self) -> None: task = asyncio.current_task() - assert task + assert task is not None # for typecheckers try: delay = self.connection_init_wait_timeout.total_seconds() await asyncio.sleep(delay=delay) @@ -239,92 +240,100 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: message.payload.variables, ) - context = await self.get_context() - if isinstance(context, dict): - context["connection_params"] = self.connection_params - root_value = await self.get_root_value() - - # Get an AsyncGenerator yielding the results - if operation_type == OperationType.SUBSCRIPTION: - result_source = await self.schema.subscribe( - query=message.payload.query, - variable_values=message.payload.variables, - operation_name=message.payload.operationName, - context_value=context, - root_value=root_value, - ) - else: - # create AsyncGenerator returning a single result - async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( # type: ignore + # The method to start the operation. Will be called on worker + # thread and so may contain long running async calls without + # blocking the main websocket handler. + async def start_operation() -> Union[AsyncGenerator[Any, None], Any]: + # there is some type mismatch here which we need to gloss over with + # the use of Any. + # subscribe() returns + # Union[AsyncIterator[graphql.ExecutionResult], graphql.ExecutionResult]: + # whereas execute() returns strawberry.types.ExecutionResult. + # These execution result types are similar, but not the same. + + context = await self.get_context() + if isinstance(context, dict): + context["connection_params"] = self.connection_params + root_value = await self.get_root_value() + + if operation_type == OperationType.SUBSCRIPTION: + return await self.schema.subscribe( + query=message.payload.query, + variable_values=message.payload.variables, + operation_name=message.payload.operationName, + context_value=context, + root_value=root_value, + ) + else: + # single results behave similarly to subscriptions, + # return either a ExecutionResult or an AsyncGenerator + result = await self.schema.execute( query=message.payload.query, variable_values=message.payload.variables, context_value=context, root_value=root_value, operation_name=message.payload.operationName, ) + # Note: result may be SubscriptionExecutionResult or ExecutionResult + # now, but we don't support the former properly yet, hence the "ignore" below. - result_source = get_result_source() + # Both validation and execution errors are handled the same way. + if isinstance(result, ExecutionResult) and result.errors: + return result - operation = Operation(self, message.id, operation_type) + # create AsyncGenerator returning a single result + async def single_result() -> AsyncIterator[ExecutionResult]: + yield result # type: ignore - # Handle initial validation errors - if isinstance(result_source, GraphQLExecutionResult): - assert operation_type == OperationType.SUBSCRIPTION - assert result_source.errors - payload = [err.formatted for err in result_source.errors] - await self.send_message(ErrorMessage(id=message.id, payload=payload)) - self.schema.process_errors(result_source.errors) - return + return single_result() # Create task to handle this subscription, reserve the operation ID - operation.task = asyncio.create_task( - self.operation_task(result_source, operation) - ) + operation = Operation(self, message.id, operation_type, start_operation) + operation.task = asyncio.create_task(self.operation_task(operation)) self.operations[message.id] = operation - async def operation_task( - self, result_source: AsyncGenerator, operation: Operation - ) -> None: - """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" - # TODO: Handle errors in this method using self.handle_task_exception() + async def operation_task(self, operation: Operation) -> None: + """The operation task's top level method. + + Cleans-up and de-registers the operation once it is done. + """ + task = asyncio.current_task() + assert task is not None # for type checkers try: - await self.handle_async_results(result_source, operation) - except BaseException: # pragma: no cover - # cleanup in case of something really unexpected - # wait for generator to be closed to ensure that any existing - # 'finally' statement is called - with suppress(RuntimeError): - await result_source.aclose() - if operation.id in self.operations: - del self.operations[operation.id] + await self.handle_operation(operation) + except asyncio.CancelledError: raise - else: - await operation.send_message(CompleteMessage(id=operation.id)) + except Exception as error: + # Log any unhandled exceptions in the operation task + await self.handle_task_exception(error) finally: - # add this task to a list to be reaped later - task = asyncio.current_task() - assert task is not None + # Clenaup. Remove the operation from the list of active operations + if operation.id in self.operations: + del self.operations[operation.id] + # TODO: Stop collecting background tasks, not necessary. + # Add this task to a list to be reaped later self.completed_tasks.append(task) - async def handle_async_results( + async def handle_operation( self, - result_source: AsyncGenerator, operation: Operation, ) -> None: try: - async for result in result_source: - if ( - result.errors - and operation.operation_type != OperationType.SUBSCRIPTION - ): - error_payload = [err.formatted for err in result.errors] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - # don't need to call schema.process_errors() here because - # it was already done by schema.execute() - return - else: + result_source = await operation.start_operation() + # result_source is an ExcutionResult-like object or an AsyncGenerator + # Handle validation errors. Cannot check type directly. + if hasattr(result_source, "errors"): + assert result_source.errors + payload = [err.formatted for err in result_source.errors] + await operation.send_message( + ErrorMessage(id=operation.id, payload=payload) + ) + if operation.operation_type == OperationType.SUBSCRIPTION: + self.schema.process_errors(result_source.errors) + return + + try: + async for result in result_source: next_payload = {"data": result.data} if result.errors: self.schema.process_errors(result.errors) @@ -333,6 +342,11 @@ async def handle_async_results( ] next_message = NextMessage(id=operation.id, payload=next_payload) await operation.send_message(next_message) + await operation.send_message(CompleteMessage(id=operation.id)) + finally: + # Close the AsyncGenerator in case of errors or cancellation + await result_source.aclose() + except Exception as error: # GraphQLErrors are handled by graphql-core and included in the # ExecutionResult @@ -378,23 +392,35 @@ async def reap_completed_tasks(self) -> None: class Operation: """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" - __slots__ = ["handler", "id", "operation_type", "completed", "task"] + __slots__ = [ + "handler", + "id", + "operation_type", + "start_operation", + "completed", + "task", + ] def __init__( self, handler: BaseGraphQLTransportWSHandler, id: str, operation_type: OperationType, + start_operation: Callable[ + [], Coroutine[Any, Any, Union[Any, AsyncGenerator[Any, None]]] + ], ) -> None: self.handler = handler self.id = id self.operation_type = operation_type + self.start_operation = start_operation self.completed = False self.task: Optional[asyncio.Task] = None async def send_message(self, message: GraphQLTransportMessage) -> None: + # defensive check, should never happen if self.completed: - return + return # pragma: no cover if isinstance(message, (CompleteMessage, ErrorMessage)): self.completed = True # de-register the operation _before_ sending the final message diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..92479f8469 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .base import ( JSON, DebuggableGraphQLTransportWSMixin, @@ -50,7 +50,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return get_context(context) + return await get_context(context) async def get_root_value(self, request: web.Request) -> Query: await super().get_root_value(request) # for coverage diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..967f6adcf8 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -18,7 +18,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .base import ( JSON, DebuggableGraphQLTransportWSMixin, @@ -56,7 +56,7 @@ async def get_context( ) -> object: context = await super().get_context(request, response) - return get_context(context) + return await get_context(context) async def process_result( self, request: Request, result: ExecutionResult diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index da981403bb..34d143b01e 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -20,7 +20,7 @@ from strawberry.http.typevars import Context, RootValue from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context, get_context_async from .base import ( JSON, HttpClient, @@ -77,7 +77,7 @@ async def get_context(self, *args: str, **kwargs: Any) -> object: context["connectionInitTimeoutTask"] = getattr( self._handler, "connection_init_timeout_task", None ) - for key, val in get_context({}).items(): + for key, val in (await get_context_async({})).items(): context[key] = val return context @@ -95,7 +95,7 @@ async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue] async def get_context(self, request: ChannelsConsumer, response: Any) -> Context: context = await super().get_context(request, response) - return get_context(context) + return await get_context_async(context) async def process_result( self, request: ChannelsConsumer, result: Any diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 1a8148c136..f634bf88e1 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .asgi import AsgiWebSocketClient from .base import ( JSON, @@ -50,7 +50,7 @@ async def fastapi_get_context( ws: WebSocket = None, # type: ignore custom_value: str = Depends(custom_context_dependency), ) -> Dict[str, object]: - return get_context( + return await get_context( { "request": request or ws, "background_tasks": background_tasks, diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index ccf9999f7f..645eb425c3 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -17,7 +17,7 @@ from strawberry.types import ExecutionResult from tests.views.schema import Query, schema -from ..context import get_context +from ..context import get_context_async as get_context from .base import ( JSON, DebuggableGraphQLTransportWSMixin, @@ -30,12 +30,8 @@ ) -def custom_context_dependency() -> str: - return "Hi!" - - async def litestar_get_context(request: Request = None): - return get_context({"request": request}) + return await get_context({"request": request}) async def get_root_value(request: Request = None): diff --git a/tests/http/context.py b/tests/http/context.py index 99985b2434..c1ce5dbecf 100644 --- a/tests/http/context.py +++ b/tests/http/context.py @@ -2,6 +2,21 @@ def get_context(context: object) -> Dict[str, object]: + return get_context_inner(context) + + +# a patchable method for unittests +def get_context_inner(context: object) -> Dict[str, object]: assert isinstance(context, dict) + return {**context, "custom_value": "a value from context"} + +# async version for async frameworks +async def get_context_async(context: object) -> Dict[str, object]: + return await get_context_async_inner(context) + + +# a patchable method for unittests +async def get_context_async_inner(context: object) -> Dict[str, object]: + assert isinstance(context, dict) return {**context, "custom_value": "a value from context"} diff --git a/tests/views/schema.py b/tests/views/schema.py index b0c14bfd76..d8a4e89107 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -20,6 +20,19 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b return False +class ConditionalFailPermission(BasePermission): + @property + def message(self): + return f"failed after sleep {self.sleep}" + + async def has_permission(self, source, info, **kwargs: Any) -> bool: + self.sleep = kwargs.get("sleep", None) + self.fail = kwargs.get("fail", True) + if self.sleep is not None: + await asyncio.sleep(kwargs["sleep"]) + return not self.fail + + class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} @@ -64,7 +77,7 @@ class DebugInfo: @strawberry.type class Query: @strawberry.field - def greetings(self) -> str: + def greetings(self) -> str: # pragma: no cover return "hello" @strawberry.field @@ -78,7 +91,13 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str @strawberry.field(permission_classes=[AlwaysFailPermission]) def always_fail(self) -> Optional[str]: - return "Hey" + return "Hey" # pragma: no cover + + @strawberry.field(permission_classes=[ConditionalFailPermission]) + def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> str: + return "Hey" # pragma: no cover @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: @@ -89,7 +108,7 @@ async def exception(self, message: str) -> str: raise ValueError(message) @strawberry.field - def teapot(self, info: strawberry.Info[Any, None]) -> str: + def teapot(self, info: strawberry.Info[Any, None]) -> str: # pragma: no cover info.context["response"].status_code = 418 return "🫖" @@ -123,7 +142,7 @@ def set_header(self, info: strawberry.Info, name: str) -> str: @strawberry.type class Mutation: @strawberry.mutation - def echo(self, string_to_echo: str) -> str: + def echo(self, string_to_echo: str) -> str: # pragma: no cover return string_to_echo @strawberry.mutation @@ -143,7 +162,7 @@ def read_folder(self, folder: FolderInput) -> List[str]: return list(map(_read_file, folder.files)) @strawberry.mutation - def match_text(self, text_file: Upload, pattern: str) -> str: + def match_text(self, text_file: Upload, pattern: str) -> str: # pragma: no cover text = text_file.read().decode() return pattern if pattern in text else "" @@ -180,7 +199,7 @@ async def exception(self, message: str) -> AsyncGenerator[str, None]: raise ValueError(message) # Without this yield, the method is not recognised as an async generator - yield "Hi" + yield "Hi" # pragma: no cover @strawberry.subscription async def flavors(self) -> AsyncGenerator[Flavor, None]: @@ -262,6 +281,12 @@ async def long_finalizer( finally: await asyncio.sleep(delay) + @strawberry.subscription(permission_classes=[ConditionalFailPermission]) + async def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> AsyncGenerator[str, None]: + yield "Hey" # pragma: no cover + class Schema(strawberry.Schema): def process_errors( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index f3fd4b74b8..2b583110e1 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -6,18 +6,16 @@ import time from datetime import timedelta from typing import TYPE_CHECKING, Any, AsyncGenerator, Type -from unittest.mock import Mock, patch - -try: - from unittest.mock import AsyncMock -except ImportError: - AsyncMock = None +from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio from pytest_mock import MockerFixture from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -32,8 +30,23 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from tests.views.schema import Schema +from ..http.clients.base import WebSocketClient + +try: + from ..http.clients.fastapi import FastAPIHttpClient +except ImportError: # pragma: no cover + FastAPIHttpClient = None +try: + from ..http.clients.starlite import StarliteHttpClient +except ImportError: # pragma: no cover + StarliteHttpClient = None +try: + from ..http.clients.litestar import LitestarHttpClient +except ImportError: # pragma: no cover + LitestarHttpClient = None + if TYPE_CHECKING: - from ..http.clients.base import HttpClient, WebSocketClient + from ..http.clients.base import HttpClient @pytest_asyncio.fixture @@ -407,6 +420,28 @@ async def test_subscription_field_errors(ws: WebSocketClient): process_errors.assert_called_once() +async def test_query_field_errors(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { notASubscriptionField }", + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") is None + assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}] + assert ( + response["payload"][0]["message"] + == "Cannot query field 'notASubscriptionField' on type 'Query'." + ) + + async def test_subscription_cancellation(ws: WebSocketClient): await ws.send_json( SubscribeMessage( @@ -891,9 +926,6 @@ async def test_error_handler_for_timeout(http_client: HttpClient): if isinstance(http_client, ChannelsHttpClient): pytest.skip("Can't patch on_init for this client") - if not AsyncMock: - pytest.skip("Don't have AsyncMock") - ws = ws_raw handler = None errorhandler = AsyncMock() @@ -962,3 +994,207 @@ async def test_subscription_errors_continue(ws: WebSocketClient): response = await ws.receive_json() assert response["type"] == CompleteMessage.type assert response["id"] == "sub1" + + +async def test_validation_query(ws: WebSocketClient): + """ + Test validation for query + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_validation_subscription(ws: WebSocketClient): + """ + Test validation for subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_long_validation_concurrent_query(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + single-result-operation + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first query is stuck in validation + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub2", payload={"data": {"conditionalFail": "Hey"}} + ).as_dict() + ) + + +async def test_long_validation_concurrent_subscription(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck in validation + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub2", payload={"data": {"conditionalFail": "Hey"}} + ).as_dict() + ) + + +async def test_long_custom_context( + ws: WebSocketClient, http_client_class: Type[HttpClient] +): + """ + Test that the websocket is not blocked evaluating the context + """ + if http_client_class in (FastAPIHttpClient, StarliteHttpClient, LitestarHttpClient): + pytest.skip("Client evaluates the context only once per connection") + + counter = 0 + + async def slow_get_context(ctxt): + nonlocal counter + old = counter + counter += 1 + if old == 0: + await asyncio.sleep(0.1) + ctxt["custom_value"] = "slow" + else: + ctxt["custom_value"] = "fast" + return ctxt + + with patch("tests.http.context.get_context_async_inner", slow_get_context): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload(query="query { valueFromContext }"), + ).as_dict() + ) + + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { valueFromContext }", + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck getting context + response = await ws.receive_json() + assert ( + response + == NextMessage( + id="sub2", payload={"data": {"valueFromContext": "fast"}} + ).as_dict() + ) + + response = await ws.receive_json() + if response == CompleteMessage(id="sub2").as_dict(): + response = await ws.receive_json() # ignore the complete message + assert ( + response + == NextMessage( + id="sub1", payload={"data": {"valueFromContext": "slow"}} + ).as_dict() + ) + + +async def test_task_error_handler(ws: WebSocketClient): + """ + Test that error handling works + """ + # can't use a simple Event here, because the handler may run + # on a different thread + wakeup = False + + # a replacement method which causes an error in th eTask + async def op(*args: Any, **kwargs: Any): + nonlocal wakeup + wakeup = True + raise ZeroDivisionError("test") + + with patch.object(BaseGraphQLTransportWSHandler, "task_logger") as logger: + with patch.object(BaseGraphQLTransportWSHandler, "handle_operation", op): + # send any old subscription request. It will raise an error + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0) }" + ), + ).as_dict() + ) + + # wait for the error to be logged. Must use timed loop and not event. + while not wakeup: # noqa: ASYNC110 + await asyncio.sleep(0.01) + # and another little bit, for the thread to finish + await asyncio.sleep(0.01) + assert logger.exception.called