diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..582b1d8fcf --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,53 @@ +Release type: minor + +Remove deprecated `strawberry.scalar(cls, ...)` wrapper pattern and `ScalarWrapper`, deprecated since [0.288.0](https://github.com/strawberry-graphql/strawberry/releases/tag/0.288.0). + +You can run `strawberry upgrade replace-scalar-wrappers ` to automatically replace built-in scalar wrapper imports. + +### Migration guide + +**Before (deprecated):** +```python +import strawberry +from datetime import datetime + +EpochDateTime = strawberry.scalar( + datetime, + serialize=lambda v: int(v.timestamp()), + parse_value=lambda v: datetime.fromtimestamp(v), +) + + +@strawberry.type +class Query: + created: EpochDateTime +``` + +**After:** +```python +import strawberry +from typing import NewType +from datetime import datetime +from strawberry.schema.config import StrawberryConfig + +EpochDateTime = NewType("EpochDateTime", datetime) + + +@strawberry.type +class Query: + created: EpochDateTime + + +schema = strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + EpochDateTime: strawberry.scalar( + name="EpochDateTime", + serialize=lambda v: int(v.timestamp()), + parse_value=lambda v: datetime.fromtimestamp(v), + ) + } + ), +) +``` diff --git a/docs/errors/scalar-already-registered.md b/docs/errors/scalar-already-registered.md index a7a283bf1d..830751b020 100644 --- a/docs/errors/scalar-already-registered.md +++ b/docs/errors/scalar-already-registered.md @@ -12,16 +12,11 @@ the following code will throw this error: ```python import strawberry +from typing import NewType +from strawberry.schema.config import StrawberryConfig -MyCustomScalar = strawberry.scalar( - str, - name="MyCustomScalar", -) - -MyCustomScalar2 = strawberry.scalar( - int, - name="MyCustomScalar", -) +MyCustomScalar = NewType("MyCustomScalar", str) +MyCustomScalar2 = NewType("MyCustomScalar2", int) @strawberry.type @@ -30,7 +25,19 @@ class Query: scalar_2: MyCustomScalar2 -strawberry.Schema(Query) +strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + MyCustomScalar: strawberry.scalar( + name="MyCustomScalar", serialize=str, parse_value=str + ), + MyCustomScalar2: strawberry.scalar( + name="MyCustomScalar", serialize=int, parse_value=int + ), + } + ), +) ``` This happens because different types in Strawberry (and GraphQL) cannot have the @@ -48,16 +55,11 @@ name of one of them, for example in this code we renamed the second scalar: ```python import strawberry +from typing import NewType +from strawberry.schema.config import StrawberryConfig -MyCustomScalar = strawberry.scalar( - str, - name="MyCustomScalar", -) - -MyCustomScalar2 = strawberry.scalar( - int, - name="MyCustomScalar2", -) +MyCustomScalar = NewType("MyCustomScalar", str) +MyCustomScalar2 = NewType("MyCustomScalar2", int) @strawberry.type @@ -66,5 +68,17 @@ class Query: scalar_2: MyCustomScalar2 -strawberry.Schema(Query) +strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + MyCustomScalar: strawberry.scalar( + name="MyCustomScalar", serialize=str, parse_value=str + ), + MyCustomScalar2: strawberry.scalar( + name="MyCustomScalar2", serialize=int, parse_value=int + ), + } + ), +) ``` diff --git a/docs/general/queries.md b/docs/general/queries.md index 1d6edb9dd4..d3bf9b0eeb 100644 --- a/docs/general/queries.md +++ b/docs/general/queries.md @@ -130,9 +130,11 @@ This can be useful for static typing, as custom scalars are not valid type annotations. ```python -BigInt = strawberry.scalar( - int, name="BigInt", serialize=lambda v: str(v), parse_value=lambda v: int(v) -) +import strawberry +from typing import Annotated, NewType +from strawberry.schema.config import StrawberryConfig + +BigInt = NewType("BigInt", int) @strawberry.type @@ -142,4 +144,18 @@ class Query: self, ids: Annotated[list[int], strawberry.argument(graphql_type=list[BigInt])], ) -> list[User]: ... + + +schema = strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + BigInt: strawberry.scalar( + name="BigInt", + serialize=lambda v: str(v), + parse_value=lambda v: int(v), + ) + } + ), +) ``` diff --git a/docs/guides/file-upload.md b/docs/guides/file-upload.md index 26e35af902..957ec9b76f 100644 --- a/docs/guides/file-upload.md +++ b/docs/guides/file-upload.md @@ -46,11 +46,12 @@ override based on the integrations above. For example with Starlette: import strawberry from starlette.datastructures import UploadFile from strawberry.file_uploads import UploadDefinition +from strawberry.schema.config import StrawberryConfig schema = strawberry.Schema( query=Query, mutation=Mutation, - scalar_overrides={UploadFile: UploadDefinition}, + config=StrawberryConfig(scalar_map={UploadFile: UploadDefinition}), ) ``` diff --git a/docs/integrations/pydantic.md b/docs/integrations/pydantic.md index e1c1967772..b753be8541 100644 --- a/docs/integrations/pydantic.md +++ b/docs/integrations/pydantic.md @@ -357,12 +357,13 @@ type Query { Pydantic BaseModels may define a custom type with [`__get_validators__`](https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__) -logic. You will need to add a scalar type and add the mapping to the -`scalar_overrides` argument in the Schema class. +logic. You will need to add a scalar definition and add the mapping to +`scalar_map` in `StrawberryConfig`. ```python import strawberry from pydantic import BaseModel +from strawberry.schema.config import StrawberryConfig class MyCustomType: @@ -383,14 +384,6 @@ class Example(BaseModel): class ExampleGQL: ... -MyScalarType = strawberry.scalar( - MyCustomType, - # or another function describing how to represent MyCustomType in the response - serialize=str, - parse_value=lambda v: MyCustomType(), -) - - @strawberry.type class Query: @strawberry.field() @@ -398,8 +391,18 @@ class Query: return Example(custom=MyCustomType()) -# Tells strawberry to convert MyCustomType into MyScalarType -schema = strawberry.Schema(query=Query, scalar_overrides={MyCustomType: MyScalarType}) +schema = strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + MyCustomType: strawberry.scalar( + name="MyScalarType", + serialize=str, + parse_value=lambda v: MyCustomType(), + ) + } + ), +) ``` ## Custom Conversion Logic @@ -422,6 +425,7 @@ import base64 import strawberry from pydantic import BaseModel from typing import Union, NewType +from strawberry.schema.config import StrawberryConfig class User(BaseModel): @@ -429,11 +433,7 @@ class User(BaseModel): hash: bytes -Base64 = strawberry.scalar( - NewType("Base64", bytes), - serialize=lambda v: base64.b64encode(v).decode("utf-8"), - parse_value=lambda v: base64.b64decode(v.encode("utf-8")), -) +Base64 = NewType("Base64", bytes) @strawberry.experimental.pydantic.type(model=User) @@ -449,7 +449,18 @@ class Query: return UserType.from_pydantic(User(id=123, hash=b"abcd")) -schema = strawberry.Schema(query=Query) +schema = strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + Base64: strawberry.scalar( + name="Base64", + serialize=lambda v: base64.b64encode(v).decode("utf-8"), + parse_value=lambda v: base64.b64decode(v.encode("utf-8")), + ) + } + ), +) print(schema.execute_sync("query { test { id, hash } }").data) # {"test": {"id": "123", "hash": "YWJjZA=="}} diff --git a/docs/types/schema.md b/docs/types/schema.md index a0d14e8b51..7379ad0bf3 100644 --- a/docs/types/schema.md +++ b/docs/types/schema.md @@ -108,9 +108,9 @@ schema = strawberry.Schema(Query, types=[Individual, Company]) List of [extensions](/docs/extensions) to add to your Schema. -#### `scalar_overrides: Optional[Dict[object, ScalarWrapper]] = None` +#### `scalar_overrides: Optional[Dict[object, ScalarDefinition]] = None` -Override the implementation of the built in scalars. +Override the implementation of the built-in scalars. [More information](/docs/types/scalars#overriding-built-in-scalars). --- diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 7b933555e4..c6ad97b7a7 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -44,7 +44,7 @@ ) from strawberry.types.enum import StrawberryEnumDefinition from strawberry.types.lazy_type import LazyType -from strawberry.types.scalar import ScalarDefinition, ScalarWrapper +from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion from strawberry.types.unset import UNSET from strawberry.utils.str_converters import capitalize_first, to_camel_case @@ -541,14 +541,18 @@ def _get_field_type( not isinstance(field_type, StrawberryType) and field_type in self.schema.schema_converter.scalar_registry ): - field_type = self.schema.schema_converter.scalar_registry[field_type] # type: ignore - - if isinstance(field_type, ScalarWrapper): - python_type = field_type.wrap - if hasattr(python_type, "__supertype__"): - python_type = python_type.__supertype__ - - return self._collect_scalar(field_type._scalar_definition, python_type) # type: ignore + # Store the original Python type (could be a type or NewType) + # before replacing with the ScalarDefinition + original_python_type = field_type + # For NewTypes, get the underlying type for the codegen + if hasattr(original_python_type, "__supertype__"): + python_type = original_python_type.__supertype__ + elif isinstance(original_python_type, type): + python_type = original_python_type + else: + python_type = None + field_type = self.schema.schema_converter.scalar_registry[field_type] + return self._collect_scalar(field_type, python_type) if isinstance(field_type, ScalarDefinition): return self._collect_scalar(field_type, None) diff --git a/strawberry/exceptions/invalid_union_type.py b/strawberry/exceptions/invalid_union_type.py index 5af48f4145..a844d5ec07 100644 --- a/strawberry/exceptions/invalid_union_type.py +++ b/strawberry/exceptions/invalid_union_type.py @@ -27,7 +27,6 @@ def __init__( union_definition: StrawberryUnion | None = None, ) -> None: from strawberry.types.base import StrawberryList - from strawberry.types.scalar import ScalarWrapper self.union_name = union_name self.invalid_type = invalid_type @@ -37,9 +36,7 @@ def __init__( # one is our code checking for invalid types, the other is the caller self.frame = getframeinfo(stack()[2][0]) - if isinstance(invalid_type, ScalarWrapper): - type_name = invalid_type.wrap.__name__ - elif isinstance(invalid_type, StrawberryList): + if isinstance(invalid_type, StrawberryList): type_name = "list[...]" else: try: diff --git a/strawberry/federation/scalar.py b/strawberry/federation/scalar.py index 95c3f0d459..c21f787f67 100644 --- a/strawberry/federation/scalar.py +++ b/strawberry/federation/scalar.py @@ -1,43 +1,12 @@ +import sys from collections.abc import Callable, Iterable -from typing import ( - Any, - NewType, - TypeVar, - overload, -) -from strawberry.types.scalar import ScalarWrapper, _process_scalar +from strawberry.types.scalar import ScalarDefinition, identity -_T = TypeVar("_T", bound=type | NewType) - -def identity(x: _T) -> _T: # pragma: no cover - return x - - -@overload -def scalar( - *, - name: str | None = None, - description: str | None = None, - specified_by_url: str | None = None, - serialize: Callable = identity, - parse_value: Callable | None = None, - parse_literal: Callable | None = None, - directives: Iterable[object] = (), - authenticated: bool = False, - inaccessible: bool = False, - policy: list[list[str]] | None = None, - requires_scopes: list[list[str]] | None = None, - tags: Iterable[str] | None = (), -) -> Callable[[_T], _T]: ... - - -@overload def scalar( - cls: _T, *, - name: str | None = None, + name: str, description: str | None = None, specified_by_url: str | None = None, serialize: Callable = identity, @@ -49,68 +18,53 @@ def scalar( policy: list[list[str]] | None = None, requires_scopes: list[list[str]] | None = None, tags: Iterable[str] | None = (), -) -> _T: ... +) -> ScalarDefinition: + """Creates a GraphQL custom scalar definition with federation support. - -def scalar( - cls: _T | None = None, - *, - name: str | None = None, - description: str | None = None, - specified_by_url: str | None = None, - serialize: Callable = identity, - parse_value: Callable | None = None, - parse_literal: Callable | None = None, - directives: Iterable[object] = (), - authenticated: bool = False, - inaccessible: bool = False, - policy: list[list[str]] | None = None, - requires_scopes: list[list[str]] | None = None, - tags: Iterable[str] | None = (), -) -> Any: - """Annotates a class or type as a GraphQL custom scalar. + Returns a `ScalarDefinition` for use in `StrawberryConfig.scalar_map` + or `Schema(scalar_overrides=...)`. Args: - cls: The class or type to annotate - name: The GraphQL name of the scalar - description: The description of the scalar - specified_by_url: The URL of the specification - serialize: The function to serialize the scalar - parse_value: The function to parse the value - parse_literal: The function to parse the literal - directives: The directives to apply to the scalar - authenticated: Whether to add the @authenticated directive - inaccessible: Whether to add the @inaccessible directive - policy: The list of policy names to add to the @policy directive - requires_scopes: The list of scopes to add to the @requires directive - tags: The list of tags to add to the @tag directive + name: The GraphQL name of the scalar. + description: The description of the scalar. + specified_by_url: The URL of the specification. + serialize: The function to serialize the scalar. + parse_value: The function to parse the value. + parse_literal: The function to parse the literal. + directives: The directives to apply to the scalar. + authenticated: Whether to add the @authenticated directive. + inaccessible: Whether to add the @inaccessible directive. + policy: The list of policy names to add to the @policy directive. + requires_scopes: The list of scopes to add to the @requires directive. + tags: The list of tags to add to the @tag directive. Returns: - The decorated class or type + A `ScalarDefinition`. - Example usages: + Example usage: ```python - strawberry.federation.scalar( - datetime.date, - serialize=lambda value: value.isoformat(), - parse_value=datetime.parse_date, + from typing import NewType + import strawberry + from strawberry.schema.config import StrawberryConfig + + # Define the type + Base64 = NewType("Base64", bytes) + + # Configure the scalar with federation directives + schema = strawberry.federation.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + Base64: strawberry.federation.scalar( + name="Base64", + serialize=lambda v: base64.b64encode(v).decode(), + parse_value=lambda v: base64.b64decode(v), + authenticated=True, + ) + } + ), ) - - Base64Encoded = strawberry.federation.scalar( - NewType("Base64Encoded", bytes), - serialize=base64.b64encode, - parse_value=base64.b64decode, - ) - - - @strawberry.federation.scalar( - serialize=lambda value: ",".join(value.items), - parse_value=lambda value: CustomList(value.split(",")), - ) - class CustomList: - def __init__(self, items): - self.items = items ``` """ from strawberry.federation.schema_directives import ( @@ -121,42 +75,45 @@ def __init__(self, items): Tag, ) - if parse_value is None: - parse_value = cls - - directives = list(directives) + all_directives = list(directives) if authenticated: - directives.append(Authenticated()) + all_directives.append(Authenticated()) if inaccessible: - directives.append(Inaccessible()) + all_directives.append(Inaccessible()) if policy: - directives.append(Policy(policies=policy)) + all_directives.append(Policy(policies=policy)) if requires_scopes: - directives.append(RequiresScopes(scopes=requires_scopes)) + all_directives.append(RequiresScopes(scopes=requires_scopes)) if tags: - directives.extend(Tag(name=tag) for tag in tags) - - def wrap(cls: _T) -> ScalarWrapper: - return _process_scalar( - cls, - name=name, - description=description, - specified_by_url=specified_by_url, - serialize=serialize, - parse_value=parse_value, - parse_literal=parse_literal, - directives=directives, - ) - - if cls is None: - return wrap - - return wrap(cls) + all_directives.extend(Tag(name=tag) for tag in tags) + + from strawberry.exceptions.handler import should_use_rich_exceptions + + _source_file = None + _source_line = None + + if should_use_rich_exceptions(): + frame = sys._getframe(1) + _source_file = frame.f_code.co_filename + _source_line = frame.f_lineno + + return ScalarDefinition( + name=name, + description=description, + specified_by_url=specified_by_url, + serialize=serialize, + parse_literal=parse_literal, + parse_value=parse_value, + directives=tuple(all_directives), + origin=None, + _source_file=_source_file, + _source_line=_source_line, + ) __all__ = ["scalar"] diff --git a/strawberry/federation/schema.py b/strawberry/federation/schema.py index ebcbd013a1..8b218cc2b8 100644 --- a/strawberry/federation/schema.py +++ b/strawberry/federation/schema.py @@ -22,7 +22,7 @@ get_object_definition, ) from strawberry.types.info import Info -from strawberry.types.scalar import ScalarDefinition, ScalarWrapper, scalar +from strawberry.types.scalar import ScalarDefinition, scalar from strawberry.types.union import StrawberryUnion from strawberry.utils.inspect import get_func_args @@ -56,8 +56,7 @@ def __init__( extensions: Iterable[Union[type["SchemaExtension"], "SchemaExtension"]] = (), execution_context_class: type["GraphQLExecutionContext"] | None = None, config: Optional["StrawberryConfig"] = None, - scalar_overrides: dict[object, Union[type, "ScalarWrapper", "ScalarDefinition"]] - | None = None, + scalar_overrides: dict[object, Union[type, "ScalarDefinition"]] | None = None, schema_directives: Iterable[object] = (), federation_version: Literal[ "2.0", @@ -83,9 +82,7 @@ def __init__( types = [*types, FederationAny] # Add federation scalars to scalar_overrides so they can be recognized - federation_scalar_overrides: dict[ - object, type | ScalarDefinition | ScalarWrapper - ] = { + federation_scalar_overrides: dict[object, type | ScalarDefinition] = { FederationAny: scalar( name="_Any", serialize=lambda v: v, parse_value=lambda v: v ), diff --git a/strawberry/printer/printer.py b/strawberry/printer/printer.py index c8e56ab0ef..301b17407c 100644 --- a/strawberry/printer/printer.py +++ b/strawberry/printer/printer.py @@ -37,7 +37,6 @@ has_object_definition, ) from strawberry.types.enum import StrawberryEnumDefinition -from strawberry.types.scalar import ScalarWrapper from strawberry.types.unset import UNSET from .ast_from_value import ast_from_value @@ -627,8 +626,6 @@ def print_schema(schema: BaseSchema) -> str: def _name_getter(type_: Any) -> str: if hasattr(type_, "name"): return type_.name - if isinstance(type_, ScalarWrapper): - return type_._scalar_definition.name return type_.__name__ return "\n\n".join( diff --git a/strawberry/scalars.py b/strawberry/scalars.py index 8250979723..a3a3ed1491 100644 --- a/strawberry/scalars.py +++ b/strawberry/scalars.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from collections.abc import Mapping - from strawberry.types.scalar import ScalarDefinition, ScalarWrapper + from strawberry.types.scalar import ScalarDefinition ID = NewType("ID", str) @@ -26,7 +26,7 @@ def is_scalar( annotation: Any, - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], ) -> bool: if annotation in scalar_registry: return True diff --git a/strawberry/schema/compat.py b/strawberry/schema/compat.py index 3ed4eaefd5..06b17c4b68 100644 --- a/strawberry/schema/compat.py +++ b/strawberry/schema/compat.py @@ -16,7 +16,7 @@ from collections.abc import Mapping from typing import TypeGuard - from strawberry.types.scalar import ScalarDefinition, ScalarWrapper + from strawberry.types.scalar import ScalarDefinition def is_input_type(type_: StrawberryType | type) -> TypeGuard[type]: @@ -33,7 +33,7 @@ def is_interface_type(type_: StrawberryType | type) -> TypeGuard[type]: def is_scalar( type_: StrawberryType | type, - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], ) -> TypeGuard[type]: return is_strawberry_scalar(type_, scalar_registry) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index c2f81489f1..256abd22f9 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -92,7 +92,7 @@ from strawberry.types.base import StrawberryType from strawberry.types.enum import StrawberryEnumDefinition from strawberry.types.field import StrawberryField - from strawberry.types.scalar import ScalarDefinition, ScalarWrapper + from strawberry.types.scalar import ScalarDefinition from strawberry.types.union import StrawberryUnion SubscriptionResult: TypeAlias = AsyncGenerator[ @@ -216,9 +216,7 @@ def __init__( extensions: Iterable[type[SchemaExtension] | SchemaExtension] = (), execution_context_class: type[GraphQLExecutionContext] | None = None, config: StrawberryConfig | None = None, - scalar_overrides: ( - Mapping[object, type | ScalarWrapper | ScalarDefinition] | None - ) = None, + scalar_overrides: Mapping[object, type | ScalarDefinition] | None = None, schema_directives: Iterable[object] = (), ) -> None: """Default Schema to be used in a Strawberry application. diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 19fa3266a5..5855e97383 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -69,7 +69,7 @@ from strawberry.types.field import UNRESOLVED from strawberry.types.lazy_type import LazyType from strawberry.types.private import is_private -from strawberry.types.scalar import ScalarWrapper, scalar +from strawberry.types.scalar import ScalarDefinition, scalar from strawberry.types.union import StrawberryUnion from strawberry.types.unset import UNSET from strawberry.utils.await_maybe import await_maybe @@ -93,7 +93,6 @@ from strawberry.types.enum import EnumValue from strawberry.types.field import StrawberryField from strawberry.types.info import Info - from strawberry.types.scalar import ScalarDefinition FieldType = TypeVar( @@ -190,7 +189,7 @@ def get_arguments( info: Info, kwargs: Any, config: StrawberryConfig, - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], ) -> tuple[list[Any], dict[str, Any]]: # TODO: An extension might have changed the resolver arguments, # but we need them here since we are calling it. @@ -247,7 +246,7 @@ class GraphQLCoreConverter: def __init__( self, config: StrawberryConfig, - scalar_overrides: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_overrides: Mapping[object, ScalarDefinition], scalar_map: Mapping[object, ScalarDefinition], get_fields: Callable[[StrawberryObjectDefinition], list[StrawberryField]], ) -> None: @@ -258,12 +257,10 @@ def __init__( def _get_scalar_registry( self, - scalar_overrides: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_overrides: Mapping[object, ScalarDefinition], scalar_map: Mapping[object, ScalarDefinition], - ) -> Mapping[object, ScalarWrapper | ScalarDefinition]: - scalar_registry: dict[object, ScalarWrapper | ScalarDefinition] = { - **DEFAULT_SCALAR_REGISTRY - } + ) -> Mapping[object, ScalarDefinition]: + scalar_registry: dict[object, ScalarDefinition] = {**DEFAULT_SCALAR_REGISTRY} global_id_name = "GlobalID" if self.config.relay_use_legacy_global_id else "ID" @@ -811,12 +808,7 @@ def from_scalar(self, scalar: type) -> GraphQLScalarType: scalar_definition: ScalarDefinition if scalar in self.scalar_registry: - _scalar_definition = self.scalar_registry[scalar] - # TODO: check why we need the cast and we are not trying with getattr first - if isinstance(_scalar_definition, ScalarWrapper): - scalar_definition = _scalar_definition._scalar_definition - else: - scalar_definition = _scalar_definition + scalar_definition = self.scalar_registry[scalar] else: scalar_definition = scalar._scalar_definition # type: ignore[attr-defined] diff --git a/strawberry/schema_codegen/__init__.py b/strawberry/schema_codegen/__init__.py index d5fb540ee5..3081c3954c 100644 --- a/strawberry/schema_codegen/__init__.py +++ b/strawberry/schema_codegen/__init__.py @@ -632,6 +632,8 @@ def _get_schema_definition( root_mutation_name: str | None, root_subscription_name: str | None, is_apollo_federation: bool, + scalar_infos: list[ScalarInfo], + imports: set[Import], ) -> cst.SimpleStatementLine | None: if not any([root_query_name, root_mutation_name, root_subscription_name]): return None @@ -654,6 +656,82 @@ def _get_arg(name: str, value: str) -> cst.Arg: if root_subscription_name: args.append(_get_arg("subscription", root_subscription_name)) + # Generate scalar_map config if there are custom scalars + if scalar_infos: + imports.add( + Import(module="strawberry.schema.config", imports=("StrawberryConfig",)) + ) + + identity_lambda = cst.Lambda( + body=cst.Name("v"), + params=cst.Parameters( + params=[cst.Param(cst.Name("v"))], + ), + ) + + scalar_map_elements = [] + for scalar_info in scalar_infos: + scalar_args: list[cst.Arg] = [ + _get_argument("name", scalar_info.name), + ] + if scalar_info.description: + scalar_args.append( + _get_argument("description", scalar_info.description) + ) + if scalar_info.specified_by_url: + scalar_args.append( + _get_argument("specified_by_url", scalar_info.specified_by_url) + ) + scalar_args.extend( + [ + cst.Arg( + keyword=cst.Name("serialize"), + value=identity_lambda, + equal=cst.AssignEqual( + cst.SimpleWhitespace(""), cst.SimpleWhitespace("") + ), + ), + cst.Arg( + keyword=cst.Name("parse_value"), + value=identity_lambda, + equal=cst.AssignEqual( + cst.SimpleWhitespace(""), cst.SimpleWhitespace("") + ), + ), + ] + ) + + scalar_map_elements.append( + cst.DictElement( + key=cst.Name(scalar_info.name), + value=cst.Call( + func=cst.Attribute( + value=cst.Name("strawberry"), + attr=cst.Name("scalar"), + ), + args=scalar_args, + ), + ) + ) + + config_arg = cst.Arg( + keyword=cst.Name("config"), + value=cst.Call( + func=cst.Name("StrawberryConfig"), + args=[ + cst.Arg( + keyword=cst.Name("scalar_map"), + value=cst.Dict(elements=scalar_map_elements), + equal=cst.AssignEqual( + cst.SimpleWhitespace(""), cst.SimpleWhitespace("") + ), + ), + ], + ), + equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), + ) + args.append(config_arg) + # Federation 2 is now always enabled for federation schemas if is_apollo_federation: schema_call = cst.Call( @@ -692,6 +770,13 @@ class Definition: name: str +@dataclasses.dataclass(frozen=True) +class ScalarInfo: + name: str + description: str | None + specified_by_url: str | None + + def _get_union_definition(definition: UnionTypeDefinitionNode) -> Definition: name = definition.name.value @@ -732,26 +817,26 @@ def _get_union_definition(definition: UnionTypeDefinitionNode) -> Definition: def _get_scalar_definition( definition: ScalarTypeDefinitionNode, imports: set[Import] -) -> Definition | None: +) -> tuple[Definition | None, ScalarInfo | None]: name = definition.name.value if name == "Date": imports.add(Import(module="datetime", imports=("date",))) - return None + return None, None if name == "Time": imports.add(Import(module="datetime", imports=("time",))) - return None + return None, None if name == "DateTime": imports.add(Import(module="datetime", imports=("datetime",))) - return None + return None, None if name == "Decimal": imports.add(Import(module="decimal", imports=("Decimal",))) - return None + return None, None if name == "UUID": imports.add(Import(module="uuid", imports=("UUID",))) - return None + return None, None if name == "JSON": - return None + return None, None description = definition.description.value if definition.description else None @@ -767,62 +852,36 @@ def _get_scalar_definition( imports.add(Import(module="typing", imports=("NewType",))) - identity_lambda = cst.Lambda( - body=cst.Name("v"), - params=cst.Parameters( - params=[cst.Param(cst.Name("v"))], - ), - ) - - additional_args: list[cst.Arg | None] = [ - _get_argument("description", description) if description else None, - _get_argument("specified_by_url", specified_by_url) - if specified_by_url - else None, - cst.Arg( - keyword=cst.Name("serialize"), - value=identity_lambda, - equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), - ), - cst.Arg( - keyword=cst.Name("parse_value"), - value=identity_lambda, - equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), - ), - ] - + # Generate just the NewType definition statement_definition = cst.SimpleStatementLine( body=[ cst.Assign( targets=[cst.AssignTarget(cst.Name(name))], value=cst.Call( - func=cst.Attribute( - value=cst.Name("strawberry"), - attr=cst.Name("scalar"), - ), + func=cst.Name("NewType"), args=[ - cst.Arg( - cst.Call( - func=cst.Name("NewType"), - args=[ - cst.Arg(cst.SimpleString(f'"{name}"')), - cst.Arg(cst.Name("object")), - ], - ) - ), - *filter(None, additional_args), + cst.Arg(cst.SimpleString(f'"{name}"')), + cst.Arg(cst.Name("object")), ], ), ) ] ) - return Definition(statement_definition, [], name=definition.name.value) + + scalar_info = ScalarInfo( + name=name, + description=description, + specified_by_url=specified_by_url, + ) + + return Definition(statement_definition, [], name=definition.name.value), scalar_info def codegen(schema: str) -> str: document = parse(schema) definitions: dict[str, Definition] = {} + scalar_infos: list[ScalarInfo] = [] root_query_name: str | None = None root_mutation_name: str | None = None @@ -876,7 +935,11 @@ def codegen(schema: str) -> str: definition = _get_union_definition(graphql_definition) elif isinstance(graphql_definition, ScalarTypeDefinitionNode): - definition = _get_scalar_definition(graphql_definition, imports) + definition, scalar_info = _get_scalar_definition( + graphql_definition, imports + ) + if scalar_info is not None: + scalar_infos.append(scalar_info) elif isinstance(graphql_definition, SchemaExtensionNode): is_apollo_federation = any( @@ -905,6 +968,8 @@ def codegen(schema: str) -> str: root_mutation_name=root_mutation_name, root_subscription_name=root_subscription_name, is_apollo_federation=is_apollo_federation, + scalar_infos=scalar_infos, + imports=imports, ) if schema_definition: diff --git a/strawberry/types/arguments.py b/strawberry/types/arguments.py index c6a5f38da1..bed0c9f2bc 100644 --- a/strawberry/types/arguments.py +++ b/strawberry/types/arguments.py @@ -33,7 +33,7 @@ from strawberry.schema.config import StrawberryConfig from strawberry.types.base import StrawberryType - from strawberry.types.scalar import ScalarDefinition, ScalarWrapper + from strawberry.types.scalar import ScalarDefinition DEPRECATED_NAMES: dict[str, str] = { @@ -152,7 +152,7 @@ def is_maybe(self) -> bool: def _is_leaf_type( type_: StrawberryType | type, - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], skip_classes: tuple[type, ...] = (), ) -> bool: if type_ in skip_classes: @@ -172,7 +172,7 @@ def _is_leaf_type( def _is_optional_leaf_type( type_: StrawberryType | type, - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], skip_classes: tuple[type, ...] = (), ) -> bool: if type_ in skip_classes: @@ -187,7 +187,7 @@ def _is_optional_leaf_type( def convert_argument( value: object, type_: StrawberryType | type, - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], config: StrawberryConfig, ) -> object: from strawberry.relay.types import GlobalID @@ -283,7 +283,7 @@ def convert_argument( def convert_arguments( value: dict[str, Any], arguments: list[StrawberryArgument], - scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], + scalar_registry: Mapping[object, ScalarDefinition], config: StrawberryConfig, ) -> dict[str, Any]: """Converts a nested dictionary to a dictionary of actual types. diff --git a/strawberry/types/scalar.py b/strawberry/types/scalar.py index f242fc3b14..2d999af6d2 100644 --- a/strawberry/types/scalar.py +++ b/strawberry/types/scalar.py @@ -5,15 +5,9 @@ from typing import ( TYPE_CHECKING, Any, - NewType, - Optional, - TypeVar, - overload, ) -from strawberry.exceptions import InvalidUnionTypeError from strawberry.types.base import StrawberryType -from strawberry.utils.str_converters import to_camel_case if TYPE_CHECKING: from collections.abc import Callable, Iterable, Mapping @@ -21,10 +15,7 @@ from graphql import GraphQLScalarType -_T = TypeVar("_T", bound=type | NewType) - - -def identity(x: _T) -> _T: +def identity(x: Any) -> Any: return x @@ -57,70 +48,7 @@ def is_graphql_generic(self) -> bool: return False -class ScalarWrapper: - _scalar_definition: ScalarDefinition - - def __init__(self, wrap: Callable[[Any], Any]) -> None: - self.wrap = wrap - - def __call__(self, *args: str, **kwargs: Any) -> Any: - return self.wrap(*args, **kwargs) - - def __or__(self, other: StrawberryType | type) -> StrawberryType: - if other is None: - # Return the correct notation when using `StrawberryUnion | None`. - return Optional[self] # noqa: UP045 - - # Raise an error in any other case. - # There is Work in progress to deal with more merging cases, see: - # https://github.com/strawberry-graphql/strawberry/pull/1455 - raise InvalidUnionTypeError(str(other), self.wrap) - - -def _process_scalar( - cls: _T, - *, - name: str | None = None, - description: str | None = None, - specified_by_url: str | None = None, - serialize: Callable | None = None, - parse_value: Callable | None = None, - parse_literal: Callable | None = None, - directives: Iterable[object] = (), -) -> ScalarWrapper: - from strawberry.exceptions.handler import should_use_rich_exceptions - - name = name or to_camel_case(cls.__name__) # type: ignore[union-attr] - - _source_file = None - _source_line = None - - if should_use_rich_exceptions(): - frame = sys._getframe(3) - - _source_file = frame.f_code.co_filename - _source_line = frame.f_lineno - - wrapper = ScalarWrapper(cls) - wrapper._scalar_definition = ScalarDefinition( - name=name, - description=description, - specified_by_url=specified_by_url, - serialize=serialize, - parse_literal=parse_literal, - parse_value=parse_value, - directives=directives, - origin=cls, # type: ignore[arg-type] - _source_file=_source_file, - _source_line=_source_line, - ) - - return wrapper - - -@overload def scalar( - cls: None = None, *, name: str, description: str | None = None, @@ -129,69 +57,13 @@ def scalar( parse_value: Callable | None = None, parse_literal: Callable | None = None, directives: Iterable[object] = (), -) -> ScalarDefinition: ... - - -@overload -def scalar( - cls: None = None, - *, - name: None = None, - description: str | None = None, - specified_by_url: str | None = None, - serialize: Callable = identity, - parse_value: Callable | None = None, - parse_literal: Callable | None = None, - directives: Iterable[object] = (), -) -> Callable[[_T], _T]: ... - - -@overload -def scalar( - cls: _T, - *, - name: str | None = None, - description: str | None = None, - specified_by_url: str | None = None, - serialize: Callable = identity, - parse_value: Callable | None = None, - parse_literal: Callable | None = None, - directives: Iterable[object] = (), -) -> _T: ... - - -# TODO: We are tricking pyright into thinking that we are returning the given type -# here or else it won't let us use any custom scalar to annotate attributes in -# dataclasses/types. This should be properly solved when implementing StrawberryScalar -def scalar( - cls: _T | None = None, - *, - name: str | None = None, - description: str | None = None, - specified_by_url: str | None = None, - serialize: Callable = identity, - parse_value: Callable | None = None, - parse_literal: Callable | None = None, - directives: Iterable[object] = (), -) -> Any: - """Annotates a class or type as a GraphQL custom scalar. - - This function can be used in three ways: +) -> ScalarDefinition: + """Creates a GraphQL custom scalar definition. - 1. With a `name` but no `cls`: Returns a `ScalarDefinition` for use in - `StrawberryConfig.scalar_map`. This is the recommended approach as it - provides proper type checking support. - - 2. As a decorator (no `cls`): Returns a decorator function. When the `cls` - argument is provided inline, this is deprecated in favor of using - `scalar_map`. - - 3. With a `cls` argument (deprecated): Wraps the class/type directly. - This approach is deprecated because it causes type checker issues. - Use `scalar_map` in `StrawberryConfig` instead. + Returns a `ScalarDefinition` for use in `StrawberryConfig.scalar_map` + or `Schema(scalar_overrides=...)`. Args: - cls: The class or type to annotate (deprecated, use scalar_map instead). name: The GraphQL name of the scalar. description: The description of the scalar. specified_by_url: The URL of the specification. @@ -201,12 +73,9 @@ def scalar( directives: The directives to apply to the scalar. Returns: - A `ScalarDefinition` when called with `name` only, a decorator function - when called without arguments, or the wrapped type when called with `cls`. - - Example usages: + A `ScalarDefinition`. - Recommended approach using scalar_map: + Example usage: ```python from typing import NewType @@ -230,16 +99,6 @@ def scalar( ), ) ``` - - Legacy approach (deprecated): - - ```python - Base64Encoded = strawberry.scalar( - NewType("Base64Encoded", bytes), - serialize=base64.b64encode, - parse_value=base64.b64decode, - ) - ``` """ from strawberry.exceptions.handler import should_use_rich_exceptions @@ -251,48 +110,18 @@ def scalar( _source_file = frame.f_code.co_filename _source_line = frame.f_lineno - if cls is None and name is not None: - return ScalarDefinition( - name=name, - description=description, - specified_by_url=specified_by_url, - serialize=serialize, - parse_literal=parse_literal, - parse_value=parse_value, - directives=directives, - origin=None, - _source_file=_source_file, - _source_line=_source_line, - ) - - if parse_value is None: - parse_value = cls - - def wrap(cls: _T) -> ScalarWrapper: - import warnings - - warnings.warn( - "Passing a class to strawberry.scalar() is deprecated. " - "Use StrawberryConfig.scalar_map instead for better type checking support. " - "See: https://strawberry.rocks/docs/types/scalars", - DeprecationWarning, - stacklevel=3, - ) - return _process_scalar( - cls, - name=name, - description=description, - specified_by_url=specified_by_url, - serialize=serialize, - parse_value=parse_value, - parse_literal=parse_literal, - directives=directives, - ) - - if cls is None: - return wrap - - return wrap(cls) + return ScalarDefinition( + name=name, + description=description, + specified_by_url=specified_by_url, + serialize=serialize, + parse_literal=parse_literal, + parse_value=parse_value, + directives=directives, + origin=None, + _source_file=_source_file, + _source_line=_source_line, + ) -__all__ = ["ScalarDefinition", "ScalarWrapper", "scalar"] +__all__ = ["ScalarDefinition", "identity", "scalar"] diff --git a/tests/cli/test_locate_definition.py b/tests/cli/test_locate_definition.py index c8f6d8e223..fc1fcceca9 100644 --- a/tests/cli/test_locate_definition.py +++ b/tests/cli/test_locate_definition.py @@ -23,7 +23,7 @@ def test_find_model_name(cli_app: Typer, cli_runner: CliRunner): assert result.exit_code == 0 assert _simplify_path(result.stdout.strip()) == snapshot( - "fixtures/sample_package/sample_module.py:38:7" + "fixtures/sample_package/sample_module.py:35:7" ) @@ -33,7 +33,7 @@ def test_find_model_field(cli_app: Typer, cli_runner: CliRunner): assert result.exit_code == 0 assert _simplify_path(result.stdout.strip()) == snapshot( - "fixtures/sample_package/sample_module.py:39:5" + "fixtures/sample_package/sample_module.py:36:5" ) diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py index 2893766e79..a21b24c14e 100644 --- a/tests/codegen/conftest.py +++ b/tests/codegen/conftest.py @@ -14,11 +14,12 @@ import pytest import strawberry +from strawberry.schema.config import StrawberryConfig if TYPE_CHECKING: from .lazy_type import LaziestType -JSON = strawberry.scalar(NewType("JSON", str)) +JSON = NewType("JSON", str) @strawberry.enum @@ -171,4 +172,9 @@ def add_blog_posts(self, input: AddBlogPostsInput) -> AddBlogPostsOutput: @pytest.fixture def schema() -> strawberry.Schema: - return strawberry.Schema(query=Query, mutation=Mutation, types=[BlogPost, Image]) + return strawberry.Schema( + query=Query, + mutation=Mutation, + types=[BlogPost, Image], + config=StrawberryConfig(scalar_map={JSON: strawberry.scalar(name="JSON")}), + ) diff --git a/tests/federation/printer/test_authenticated.py b/tests/federation/printer/test_authenticated.py index e91b1adefa..cbc9c11e10 100644 --- a/tests/federation/printer/test_authenticated.py +++ b/tests/federation/printer/test_authenticated.py @@ -1,8 +1,9 @@ import textwrap from enum import Enum -from typing import Annotated +from typing import Annotated, NewType import strawberry +from strawberry.schema.config import StrawberryConfig def test_field_authenticated_printed_correctly(): @@ -54,15 +55,22 @@ def top_products( def test_field_authenticated_printed_correctly_on_scalar(): - @strawberry.federation.scalar(authenticated=True) - class SomeScalar(str): - __slots__ = () + SomeScalar = NewType("SomeScalar", str) @strawberry.federation.type class Query: hello: SomeScalar - schema = strawberry.federation.Schema(query=Query) + schema = strawberry.federation.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + SomeScalar: strawberry.federation.scalar( + name="SomeScalar", authenticated=True + ) + } + ), + ) expected = """ schema @link(url: "https://specs.apollo.dev/federation/v2.11", import: ["@authenticated"]) { diff --git a/tests/federation/printer/test_inaccessible.py b/tests/federation/printer/test_inaccessible.py index 408b62b47c..f16a881ca7 100644 --- a/tests/federation/printer/test_inaccessible.py +++ b/tests/federation/printer/test_inaccessible.py @@ -1,8 +1,9 @@ import textwrap from enum import Enum -from typing import Annotated +from typing import Annotated, NewType import strawberry +from strawberry.schema.config import StrawberryConfig def test_field_inaccessible_printed_correctly(): @@ -129,7 +130,7 @@ def hello(self) -> str: # pragma: no cover def test_inaccessible_on_scalar(): - SomeScalar = strawberry.federation.scalar(str, name="SomeScalar", inaccessible=True) + SomeScalar = NewType("SomeScalar", str) @strawberry.type class Query: @@ -137,6 +138,13 @@ class Query: schema = strawberry.federation.Schema( query=Query, + config=StrawberryConfig( + scalar_map={ + SomeScalar: strawberry.federation.scalar( + name="SomeScalar", inaccessible=True + ) + } + ), ) expected = """ diff --git a/tests/federation/printer/test_link.py b/tests/federation/printer/test_link.py index 6151a133be..94fcf434bf 100644 --- a/tests/federation/printer/test_link.py +++ b/tests/federation/printer/test_link.py @@ -1,7 +1,9 @@ import textwrap +from typing import NewType import strawberry from strawberry.federation.schema_directives import Link +from strawberry.schema.config import StrawberryConfig from tests.conftest import skip_if_gql_32 @@ -298,10 +300,7 @@ class Query: def test_adds_link_directive_automatically_from_scalar(): - # TODO: Federation scalar - @strawberry.scalar - class X: - pass + X = NewType("X", str) @strawberry.federation.type(keys=["id"]) class User: @@ -312,7 +311,10 @@ class User: class Query: user: User - schema = strawberry.federation.Schema(query=Query) + schema = strawberry.federation.Schema( + query=Query, + config=StrawberryConfig(scalar_map={X: strawberry.scalar(name="X")}), + ) expected = """ schema @link(url: "https://specs.apollo.dev/federation/v2.11", import: ["@key"]) { diff --git a/tests/federation/printer/test_policy.py b/tests/federation/printer/test_policy.py index a66a6bbd14..f22eeb4453 100644 --- a/tests/federation/printer/test_policy.py +++ b/tests/federation/printer/test_policy.py @@ -1,8 +1,9 @@ import textwrap from enum import Enum -from typing import Annotated +from typing import Annotated, NewType import strawberry +from strawberry.schema.config import StrawberryConfig def test_field_policy_printed_correctly(): @@ -60,17 +61,23 @@ def top_products( def test_field_policy_printed_correctly_on_scalar(): - @strawberry.federation.scalar( - policy=[["client", "poweruser"], ["admin"], ["productowner"]] - ) - class SomeScalar(str): - __slots__ = () + SomeScalar = NewType("SomeScalar", str) @strawberry.federation.type class Query: hello: SomeScalar - schema = strawberry.federation.Schema(query=Query) + schema = strawberry.federation.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + SomeScalar: strawberry.federation.scalar( + name="SomeScalar", + policy=[["client", "poweruser"], ["admin"], ["productowner"]], + ) + } + ), + ) expected = """ schema @link(url: "https://specs.apollo.dev/federation/v2.11", import: ["@policy"]) { diff --git a/tests/federation/printer/test_requires_scopes.py b/tests/federation/printer/test_requires_scopes.py index 4aba5695ef..3e30fa8f72 100644 --- a/tests/federation/printer/test_requires_scopes.py +++ b/tests/federation/printer/test_requires_scopes.py @@ -1,8 +1,9 @@ import textwrap from enum import Enum -from typing import Annotated +from typing import Annotated, NewType import strawberry +from strawberry.schema.config import StrawberryConfig def test_field_requires_scopes_printed_correctly(): @@ -60,17 +61,27 @@ def top_products( def test_field_requires_scopes_printed_correctly_on_scalar(): - @strawberry.federation.scalar( - requires_scopes=[["client", "poweruser"], ["admin"], ["productowner"]] - ) - class SomeScalar(str): - __slots__ = () + SomeScalar = NewType("SomeScalar", str) @strawberry.federation.type class Query: hello: SomeScalar - schema = strawberry.federation.Schema(query=Query) + schema = strawberry.federation.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + SomeScalar: strawberry.federation.scalar( + name="SomeScalar", + requires_scopes=[ + ["client", "poweruser"], + ["admin"], + ["productowner"], + ], + ) + } + ), + ) expected = """ schema @link(url: "https://specs.apollo.dev/federation/v2.11", import: ["@requiresScopes"]) { diff --git a/tests/federation/printer/test_tag.py b/tests/federation/printer/test_tag.py index a749bd0110..223fdafc51 100644 --- a/tests/federation/printer/test_tag.py +++ b/tests/federation/printer/test_tag.py @@ -1,8 +1,9 @@ import textwrap from enum import Enum -from typing import Annotated +from typing import Annotated, NewType import strawberry +from strawberry.schema.config import StrawberryConfig def test_field_tag_printed_correctly(): @@ -56,15 +57,22 @@ def top_products( def test_field_tag_printed_correctly_on_scalar(): - @strawberry.federation.scalar(tags=["myTag", "anotherTag"]) - class SomeScalar(str): - __slots__ = () + SomeScalar = NewType("SomeScalar", str) @strawberry.federation.type class Query: hello: SomeScalar - schema = strawberry.federation.Schema(query=Query) + schema = strawberry.federation.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + SomeScalar: strawberry.federation.scalar( + name="SomeScalar", tags=["myTag", "anotherTag"] + ) + } + ), + ) expected = """ schema @link(url: "https://specs.apollo.dev/federation/v2.11", import: ["@tag"]) { diff --git a/tests/fields/test_arguments.py b/tests/fields/test_arguments.py index 45e027248a..11487ce1ca 100644 --- a/tests/fields/test_arguments.py +++ b/tests/fields/test_arguments.py @@ -356,7 +356,9 @@ def name( def test_annotated_argument_with_graphql_type_override(): - BigInt = strawberry.scalar(int, name="BigInt", serialize=str, parse_value=int) + from typing import NewType + + BigInt = NewType("BigInt", int) @strawberry.type class Query: diff --git a/tests/fixtures/sample_package/sample_module.py b/tests/fixtures/sample_package/sample_module.py index 0de153ab0b..3ee47d7558 100644 --- a/tests/fixtures/sample_package/sample_module.py +++ b/tests/fixtures/sample_package/sample_module.py @@ -2,12 +2,9 @@ from typing import Annotated, NewType import strawberry +from strawberry.schema.config import StrawberryConfig -ExampleScalar = strawberry.scalar( - NewType("ExampleScalar", object), - serialize=lambda v: v, - parse_value=lambda v: v, -) +ExampleScalar = NewType("ExampleScalar", object) @strawberry.type @@ -52,7 +49,18 @@ def user(self) -> User: def create_schema(): - return strawberry.Schema(query=Query) + return strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + ExampleScalar: strawberry.scalar( + name="ExampleScalar", + serialize=lambda v: v, + parse_value=lambda v: v, + ) + } + ), + ) schema = create_schema() diff --git a/tests/objects/generics/test_names.py b/tests/objects/generics/test_names.py index 5840d7a53f..66fe323ad8 100644 --- a/tests/objects/generics/test_names.py +++ b/tests/objects/generics/test_names.py @@ -16,7 +16,7 @@ Enum = StrawberryEnumDefinition(None, name="Enum", values=[], description=None) # type: ignore -CustomInt = strawberry.scalar(NewType("CustomInt", int)) +CustomInt = NewType("CustomInt", int) @strawberry.type diff --git a/tests/schema/test_custom_scalar.py b/tests/schema/test_custom_scalar.py index 848fb01754..69d30752d0 100644 --- a/tests/schema/test_custom_scalar.py +++ b/tests/schema/test_custom_scalar.py @@ -2,20 +2,17 @@ from typing import NewType import strawberry +from strawberry.schema.config import StrawberryConfig -Base64Encoded = strawberry.scalar( - NewType("Base64Encoded", bytes), - serialize=base64.b64encode, - parse_value=base64.b64decode, -) +# Define the types +Base64Encoded = NewType("Base64Encoded", bytes) -@strawberry.scalar(serialize=lambda x: 42, parse_value=lambda x: Always42()) class Always42: pass -MyStr = strawberry.scalar(NewType("MyStr", str)) +MyStr = NewType("MyStr", str) def test_custom_scalar_serialization(): @@ -25,7 +22,18 @@ class Query: def custom_scalar_field(self) -> Base64Encoded: return Base64Encoded(b"decoded value") - schema = strawberry.Schema(Query) + schema = strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + Base64Encoded: strawberry.scalar( + name="Base64Encoded", + serialize=base64.b64encode, + parse_value=base64.b64decode, + ) + } + ), + ) result = schema.execute_sync("{ customScalarField }") @@ -40,7 +48,18 @@ class Query: def decode_base64(self, encoded: Base64Encoded) -> str: return bytes(encoded).decode("ascii") - schema = strawberry.Schema(Query) + schema = strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + Base64Encoded: strawberry.scalar( + name="Base64Encoded", + serialize=base64.b64encode, + parse_value=base64.b64decode, + ) + } + ), + ) encoded = Base64Encoded(base64.b64encode(b"decoded")) query = """query decode($encoded: Base64Encoded!) { @@ -59,7 +78,18 @@ class Query: def answer(self) -> Always42: return Always42() - schema = strawberry.Schema(Query) + schema = strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + Always42: strawberry.scalar( + name="Always42", + serialize=lambda x: 42, + parse_value=lambda x: Always42(), + ) + } + ), + ) result = schema.execute_sync("{ answer }") @@ -74,7 +104,16 @@ class Query: def my_str(self, arg: MyStr) -> MyStr: return MyStr(str(arg) + "Suffix") - schema = strawberry.Schema(Query) + schema = strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + MyStr: strawberry.scalar( + name="MyStr", + ) + } + ), + ) result = schema.execute_sync('{ myStr(arg: "value") }') diff --git a/tests/schema/test_directives.py b/tests/schema/test_directives.py index 3bc99782f7..c46797edda 100644 --- a/tests/schema/test_directives.py +++ b/tests/schema/test_directives.py @@ -636,7 +636,11 @@ def uppercase(value: DirectiveValue[str], input: DirectiveInput): def test_directives_with_scalar(): - DirectiveInput = strawberry.scalar(str, name="DirectiveInput") + from typing import NewType + + from strawberry.schema.config import StrawberryConfig + + DirectiveInput = NewType("DirectiveInput", str) @strawberry.type class Query: @@ -650,7 +654,13 @@ def cake(self) -> str: def uppercase(value: DirectiveValue[str], input: DirectiveInput): return value.upper() - schema = strawberry.Schema(query=Query, directives=[uppercase]) + schema = strawberry.Schema( + query=Query, + directives=[uppercase], + config=StrawberryConfig( + scalar_map={DirectiveInput: strawberry.scalar(name="DirectiveInput")} + ), + ) expected_schema = ''' """Make string uppercase""" diff --git a/tests/schema/test_name_converter.py b/tests/schema/test_name_converter.py index 7b4f4783d4..82c9df05dd 100644 --- a/tests/schema/test_name_converter.py +++ b/tests/schema/test_name_converter.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Annotated, Generic, TypeVar +from typing import Annotated, Generic, NewType, TypeVar import strawberry from strawberry.directive import StrawberryDirective @@ -66,7 +66,7 @@ def from_enum_value( T = TypeVar("T") -MyScalar = strawberry.scalar(str, name="SensitiveConfiguration") +SensitiveConfiguration = NewType("SensitiveConfiguration", str) @strawberry.enum @@ -136,8 +136,13 @@ def print(self, enum: MyEnum) -> str: schema = strawberry.Schema( query=Query, - types=[MyScalar, Node], - config=StrawberryConfig(name_converter=AppendsNameConverter("X")), + types=[SensitiveConfiguration, Node], + config=StrawberryConfig( + name_converter=AppendsNameConverter("X"), + scalar_map={ + SensitiveConfiguration: strawberry.scalar(name="SensitiveConfiguration") + }, + ), ) diff --git a/tests/schema/test_scalars.py b/tests/schema/test_scalars.py index 9208134ba4..0df359d904 100644 --- a/tests/schema/test_scalars.py +++ b/tests/schema/test_scalars.py @@ -306,11 +306,7 @@ def base64_decode(data: Base64) -> str: def test_override_built_in_scalars(): - EpochDateTime = strawberry.scalar( - datetime, - serialize=lambda value: int(value.timestamp()), - parse_value=lambda value: datetime.fromtimestamp(int(value), timezone.utc), - ) + from strawberry.schema.config import StrawberryConfig @strawberry.type class Query: @@ -324,9 +320,17 @@ def isoformat(self, input_datetime: datetime) -> str: schema = strawberry.Schema( Query, - scalar_overrides={ - datetime: EpochDateTime, - }, + config=StrawberryConfig( + scalar_map={ + datetime: strawberry.scalar( + name="DateTime", + serialize=lambda value: int(value.timestamp()), + parse_value=lambda value: datetime.fromtimestamp( + int(value), timezone.utc + ), + ) + } + ), ) result = schema.execute_sync( @@ -344,12 +348,7 @@ def isoformat(self, input_datetime: datetime) -> str: def test_override_unknown_scalars(): - Duration = strawberry.scalar( - timedelta, - name="Duration", - serialize=timedelta.total_seconds, - parse_value=lambda s: timedelta(seconds=s), - ) + from strawberry.schema.config import StrawberryConfig @strawberry.type class Query: @@ -357,7 +356,18 @@ class Query: def duration(self, value: timedelta) -> timedelta: return value - schema = strawberry.Schema(Query, scalar_overrides={timedelta: Duration}) + schema = strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + timedelta: strawberry.scalar( + name="Duration", + serialize=timedelta.total_seconds, + parse_value=lambda s: timedelta(seconds=s), + ) + } + ), + ) result = schema.execute_sync("{ duration(value: 10) }") @@ -401,22 +411,27 @@ def decimal(value: Decimal) -> Decimal: match="Scalar `MyCustomScalar` has already been registered", ) def test_duplicate_scalars_raises_exception(): - MyCustomScalar = strawberry.scalar( - str, - name="MyCustomScalar", - ) + from typing import NewType - MyCustomScalar2 = strawberry.scalar( - int, - name="MyCustomScalar", - ) + from strawberry.schema.config import StrawberryConfig + + MyStr = NewType("MyStr", str) + MyInt = NewType("MyInt", int) @strawberry.type class Query: - scalar_1: MyCustomScalar - scalar_2: MyCustomScalar2 + scalar_1: MyStr + scalar_2: MyInt - strawberry.Schema(Query) + strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + MyStr: strawberry.scalar(name="MyCustomScalar"), + MyInt: strawberry.scalar(name="MyCustomScalar"), + } + ), + ) @pytest.mark.raises_strawberry_exception( @@ -424,22 +439,27 @@ class Query: match="Scalar `MyCustomScalar` has already been registered", ) def test_duplicate_scalars_raises_exception_using_alias(): - MyCustomScalar = scalar( - str, - name="MyCustomScalar", - ) + from typing import NewType - MyCustomScalar2 = scalar( - int, - name="MyCustomScalar", - ) + from strawberry.schema.config import StrawberryConfig + + MyStr = NewType("MyStr", str) + MyInt = NewType("MyInt", int) @strawberry.type class Query: - scalar_1: MyCustomScalar - scalar_2: MyCustomScalar2 + scalar_1: MyStr + scalar_2: MyInt - strawberry.Schema(Query) + strawberry.Schema( + Query, + config=StrawberryConfig( + scalar_map={ + MyStr: scalar(name="MyCustomScalar"), + MyInt: scalar(name="MyCustomScalar"), + } + ), + ) def test_optional_scalar_with_or_operator(): @@ -583,25 +603,16 @@ def email(self) -> Email: def test_scalar_map_combined_with_scalar_overrides(): """Test that scalar_map and scalar_overrides work together.""" - import warnings from typing import NewType from strawberry.schema.config import StrawberryConfig MyInt = NewType("MyInt", int) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - MyFloat = strawberry.scalar( - float, - name="MyFloat", - serialize=lambda v: round(v, 2), - ) - @strawberry.type class Query: my_int: MyInt - my_float: MyFloat + my_float: float schema = strawberry.Schema( query=Query, @@ -610,12 +621,13 @@ class Query: MyInt: strawberry.scalar( name="MyInt", serialize=lambda v: v * 2, - ) + ), + float: strawberry.scalar( + name="MyFloat", + serialize=lambda v: round(v, 2), + ), } ), - scalar_overrides={ - float: MyFloat, - }, ) result = schema.execute_sync( diff --git a/tests/schema/test_union.py b/tests/schema/test_union.py index ce1d692538..b2e1509dd2 100644 --- a/tests/schema/test_union.py +++ b/tests/schema/test_union.py @@ -816,13 +816,15 @@ class Query: InvalidUnionTypeError, match="Type `Always42` cannot be used in a GraphQL Union" ) def test_raises_on_union_of_custom_scalar(): + from typing import NewType + + from strawberry.schema.config import StrawberryConfig + @strawberry.type class ICanBeInUnion: foo: str - @strawberry.scalar(serialize=lambda x: 42, parse_value=lambda x: Always42()) - class Always42: - pass + Always42 = NewType("Always42", int) @strawberry.type class Query: @@ -830,7 +832,18 @@ class Query: Always42 | ICanBeInUnion, strawberry.union(name="ExampleUnion") ] - strawberry.Schema(query=Query) + strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + Always42: strawberry.scalar( + name="Always42", + serialize=lambda x: 42, + parse_value=lambda x: x, + ) + } + ), + ) def test_union_of_unions(): diff --git a/tests/schema_codegen/test_scalar.py b/tests/schema_codegen/test_scalar.py index 2005f99e35..d1ee92426c 100644 --- a/tests/schema_codegen/test_scalar.py +++ b/tests/schema_codegen/test_scalar.py @@ -6,14 +6,25 @@ def test_scalar(): schema = """ scalar LocalDate @specifiedBy(url: "https://scalars.graphql.org/andimarek/local-date.html") + + type Query { + date: LocalDate + } """ expected = textwrap.dedent( """ import strawberry + from strawberry.schema.config import StrawberryConfig from typing import NewType - LocalDate = strawberry.scalar(NewType("LocalDate", object), specified_by_url="https://scalars.graphql.org/andimarek/local-date.html", serialize=lambda v: v, parse_value=lambda v: v) + LocalDate = NewType("LocalDate", object) + + @strawberry.type + class Query: + date: LocalDate | None + + schema = strawberry.Schema(query=Query, config=StrawberryConfig(scalar_map={LocalDate: strawberry.scalar(name="LocalDate", specified_by_url="https://scalars.graphql.org/andimarek/local-date.html", serialize=lambda v: v, parse_value=lambda v: v)})) """ ).strip() @@ -24,14 +35,25 @@ def test_scalar_with_description(): schema = """ "A date without a time-zone in the ISO-8601 calendar system, such as 2007-12-03." scalar LocalDate + + type Query { + date: LocalDate + } """ expected = textwrap.dedent( """ import strawberry + from strawberry.schema.config import StrawberryConfig from typing import NewType - LocalDate = strawberry.scalar(NewType("LocalDate", object), description="A date without a time-zone in the ISO-8601 calendar system, such as 2007-12-03.", serialize=lambda v: v, parse_value=lambda v: v) + LocalDate = NewType("LocalDate", object) + + @strawberry.type + class Query: + date: LocalDate | None + + schema = strawberry.Schema(query=Query, config=StrawberryConfig(scalar_map={LocalDate: strawberry.scalar(name="LocalDate", description="A date without a time-zone in the ISO-8601 calendar system, such as 2007-12-03.", serialize=lambda v: v, parse_value=lambda v: v)})) """ ).strip() diff --git a/tests/test_printer/test_schema_directives.py b/tests/test_printer/test_schema_directives.py index 4a4c2cc6ae..e95524b3b5 100644 --- a/tests/test_printer/test_schema_directives.py +++ b/tests/test_printer/test_schema_directives.py @@ -360,7 +360,9 @@ class Query: def test_prints_with_scalar(): - SensitiveConfiguration = strawberry.scalar(str, name="SensitiveConfiguration") + from typing import NewType + + SensitiveConfiguration = NewType("SensitiveConfiguration", str) @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) class Sensitive: @@ -380,7 +382,15 @@ class Query: scalar SensitiveConfiguration """ - schema = strawberry.Schema(query=Query) + schema = strawberry.Schema( + query=Query, + types=[SensitiveConfiguration], + config=StrawberryConfig( + scalar_map={ + SensitiveConfiguration: strawberry.scalar(name="SensitiveConfiguration") + } + ), + ) assert print_schema(schema) == textwrap.dedent(expected_output).strip() @@ -442,13 +452,13 @@ class Query: def test_print_directive_on_scalar(): + from typing import NewType + @strawberry.schema_directive(locations=[Location.SCALAR]) class Sensitive: reason: str - SensitiveString = strawberry.scalar( - str, name="SensitiveString", directives=[Sensitive(reason="example")] - ) + SensitiveString = NewType("SensitiveString", str) @strawberry.type class Query: @@ -464,7 +474,16 @@ class Query: scalar SensitiveString @sensitive(reason: "example") """ - schema = strawberry.Schema(query=Query) + schema = strawberry.Schema( + query=Query, + config=StrawberryConfig( + scalar_map={ + SensitiveString: strawberry.scalar( + name="SensitiveString", directives=[Sensitive(reason="example")] + ) + } + ), + ) assert print_schema(schema) == textwrap.dedent(expected_output).strip() diff --git a/tests/typecheckers/test_scalars.py b/tests/typecheckers/test_scalars.py index 7bb5c8521b..de8fa28da3 100644 --- a/tests/typecheckers/test_scalars.py +++ b/tests/typecheckers/test_scalars.py @@ -149,10 +149,10 @@ def test(): CODE_SCHEMA_OVERRIDES = """ import strawberry -from datetime import datetime, timezone +from datetime import datetime EpochDateTime = strawberry.scalar( - datetime, + name="EpochDateTime", ) @strawberry.type @@ -175,7 +175,7 @@ def test_schema_overrides(): [ Result( type="information", - message='Type of "EpochDateTime" is "type[datetime]"', + message='Type of "EpochDateTime" is "ScalarDefinition"', line=16, column=13, ) @@ -185,7 +185,7 @@ def test_schema_overrides(): [ Result( type="note", - message='Revealed type is "def (year: typing.SupportsIndex, month: typing.SupportsIndex, day: typing.SupportsIndex, hour: typing.SupportsIndex =, minute: typing.SupportsIndex =, second: typing.SupportsIndex =, microsecond: typing.SupportsIndex =, tzinfo: datetime.tzinfo | None =, *, fold: builtins.int =) -> datetime.datetime"', + message='Revealed type is "strawberry.types.scalar.ScalarDefinition"', line=17, column=13, ) @@ -195,7 +195,7 @@ def test_schema_overrides(): [ Result( type="information", - message="Revealed type: ``", + message="Revealed type: `ScalarDefinition`", line=17, column=13, ), diff --git a/tests/utils/test_locate_definition.py b/tests/utils/test_locate_definition.py index 18f197c8f3..262630b5ff 100644 --- a/tests/utils/test_locate_definition.py +++ b/tests/utils/test_locate_definition.py @@ -24,7 +24,7 @@ def test_find_model_name() -> None: result = locate_definition(schema, "User") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:38:7" + "fixtures/sample_package/sample_module.py:35:7" ) @@ -35,7 +35,7 @@ def test_find_model_name_enum() -> None: result = locate_definition(schema, "Role") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:32:7" + "fixtures/sample_package/sample_module.py:29:7" ) @@ -46,7 +46,7 @@ def test_find_model_name_scalar() -> None: result = locate_definition(schema, "ExampleScalar") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:7:13" + "fixtures/sample_package/sample_module.py:57:21" ) @@ -57,7 +57,7 @@ def test_find_model_field() -> None: result = locate_definition(schema, "User.name") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:39:5" + "fixtures/sample_package/sample_module.py:36:5" ) @@ -68,7 +68,7 @@ def test_find_model_field_scalar() -> None: result = locate_definition(schema, "User.example_scalar") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:42:5" + "fixtures/sample_package/sample_module.py:39:5" ) @@ -79,7 +79,7 @@ def test_find_model_field_with_resolver() -> None: result = locate_definition(schema, "Query.user") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:50:5" + "fixtures/sample_package/sample_module.py:47:5" ) @@ -108,7 +108,7 @@ def test_find_union() -> None: result = locate_definition(schema, "UnionExample") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:23:16" + "fixtures/sample_package/sample_module.py:20:16" ) @@ -119,5 +119,5 @@ def test_find_inline_union() -> None: result = locate_definition(schema, "InlineUnion") assert _simplify_path(result) == snapshot( - "fixtures/sample_package/sample_module.py:44:19" + "fixtures/sample_package/sample_module.py:41:19" )