diff --git a/docs/integrations/pydantic.md b/docs/integrations/pydantic.md index e1c1967772..7db8a0cba1 100644 --- a/docs/integrations/pydantic.md +++ b/docs/integrations/pydantic.md @@ -97,6 +97,60 @@ class UserInput: friends: strawberry.auto ``` +## Controlling omission semantics with `use_pydantic_default` + +`use_pydantic_default` on `@strawberry.experimental.pydantic.input` determines +how omitted GraphQL fields are represented. + +- `True` (default): omitted fields use the Pydantic model's default or + `default_factory`. +- `False`: omitted fields become `strawberry.UNSET`, allowing omission to be + distinguished from `null` and explicit values. + +| GraphQL input | True (default) | False | +| -------------- | ------------------------ | --------- | +| omitted | pydantic default applied | `UNSET` | +| provided value | unchanged | unchanged | +| null | `None` | `None` | + +When `False`, `UNSET` values remain on the Strawberry input and are not passed +to the Pydantic constructor, enabling patch-style updates. + +```python +from __future__ import annotations + +import pydantic +import strawberry +from strawberry import UNSET + + +class UserModel(pydantic.BaseModel): + name: str + interests: list[str] | None = pydantic.Field(default_factory=list) + + +@strawberry.experimental.pydantic.input(model=UserModel, use_pydantic_default=False) +class UpdateUserInput: + name: strawberry.auto + interests: strawberry.auto + + +@strawberry.type +class Mutation: + @strawberry.mutation + async def update_user(self, user_data: UpdateUserInput) -> str: + changes: dict[str, object] = {} + if user_data.name is not UNSET: + changes["name"] = user_data.name + if user_data.interests is not UNSET: + changes["interests"] = user_data.interests + + current = UserModel(name="Alice", interests=["games"]) + updated = current.model_copy(update=changes) + + return f"changes={changes} before={current.model_dump()} after={updated.model_dump()}" +``` + ## Interface types Interface types are similar to normal types; we can create one by using the diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index e529caa7eb..be5f659a66 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -32,6 +32,7 @@ from strawberry.types.field import StrawberryField from strawberry.types.object_type import _process_type, _wrap_dataclass from strawberry.types.type_resolver import _get_fields +from strawberry.types.unset import UNSET if TYPE_CHECKING: import builtins @@ -61,6 +62,8 @@ def _build_dataclass_creation_fields( auto_fields_set: set[str], use_pydantic_alias: bool, compat: PydanticCompat, + *, + use_pydantic_default: bool, ) -> DataclassCreationFields: field_type = ( get_type_for_field(field, is_input, compat=compat) @@ -83,12 +86,18 @@ def _build_dataclass_creation_fields( elif field.has_alias and use_pydantic_alias: graphql_name = field.alias + # for inputs with use_pydantic_default, default_factory should be used + if is_input and not use_pydantic_default: + default_factory = UNSET + else: + default_factory = get_default_factory_for_field(field, compat=compat) + strawberry_field = StrawberryField( python_name=field.name, graphql_name=graphql_name, # always unset because we use default_factory instead default=dataclasses.MISSING, - default_factory=get_default_factory_for_field(field, compat=compat), + default_factory=default_factory, type_annotation=StrawberryAnnotation.from_annotation(field_type), description=field.description, deprecation_reason=( @@ -127,6 +136,7 @@ def type( all_fields: bool = False, include_computed: bool = False, use_pydantic_alias: bool = True, + use_pydantic_default: bool = True, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: compat = PydanticCompat.from_model(model) @@ -192,6 +202,7 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: auto_fields_set, use_pydantic_alias, compat=compat, + use_pydantic_default=use_pydantic_default, ) for field_name, field in model_fields.items() if field_name in fields_set @@ -280,12 +291,15 @@ def from_pydantic_default( return ret def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel: - instance_kwargs = { - f.name: convert_strawberry_class_to_pydantic_model( - getattr(self, f.name) + # when preserving omission on inputs, drop UNSET fields + instance_kwargs = {} + for f in dataclasses.fields(self): + value = getattr(self, f.name) + if is_input and value is UNSET and not use_pydantic_default: + continue + instance_kwargs[f.name] = convert_strawberry_class_to_pydantic_model( + value ) - for f in dataclasses.fields(self) - } instance_kwargs.update(kwargs) return model(**instance_kwargs) @@ -309,12 +323,21 @@ def input( directives: Sequence[object] | None = (), all_fields: bool = False, use_pydantic_alias: bool = True, + use_pydantic_default: bool = True, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an input type from a Pydantic model. Equal to `partial(type, is_input=True)` See https://github.com/strawberry-graphql/strawberry/issues/1830. + + Parameters + ---------- + use_pydantic_default: + When False, fields omitted by the GraphQL client are represented as + :data:`strawberry.UNSET` on the generated input class, instead of + materialising the Pydantic default or default_factory. This enables + true tri-state semantics (omitted vs. null vs. value) for inputs. """ return type( model=model, @@ -326,6 +349,7 @@ def input( directives=directives, all_fields=all_fields, use_pydantic_alias=use_pydantic_alias, + use_pydantic_default=use_pydantic_default, ) diff --git a/strawberry/types/unset.py b/strawberry/types/unset.py index 2f28d65ec0..24534a26d2 100644 --- a/strawberry/types/unset.py +++ b/strawberry/types/unset.py @@ -16,6 +16,9 @@ def __new__(cls: type["UnsetType"]) -> "UnsetType": return ret return cls.__instance + def __call__(self) -> "UnsetType": + return self + def __str__(self) -> str: return "" diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 1772bc8a63..5e5b4631df 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -1331,3 +1331,85 @@ def user(self) -> User: assert not result.errors assert result.data["user"] == {"age": 20, "location": "earth"} + + +@pytest.mark.parametrize( + "use_default, provided_interests, expected_raw, expected_pydantic", + [ + # use_pydantic_default=False: omitted results in UNSET, no pydantic default + ( + False, + strawberry.UNSET, + {"name": "John", "interests": strawberry.UNSET}, + {"name": "John", "interests": []}, + ), + # use_pydantic_default=False: provided list passed through + ( + False, + ["games"], + {"name": "John", "interests": ["games"]}, + {"name": "John", "interests": ["games"]}, + ), + # use_pydantic_default=False: provided None passed as None + ( + False, + None, + {"name": "John", "interests": None}, + {"name": "John", "interests": None}, + ), + # use_pydantic_default=True: omitted, default_factory=list is applied (not UNSET) + ( + True, + strawberry.UNSET, + {"name": "John", "interests": []}, + {"name": "John", "interests": []}, + ), + # use_pydantic_default=True: provided list unchanged + ( + True, + ["games"], + {"name": "John", "interests": ["games"]}, + {"name": "John", "interests": ["games"]}, + ), + # use_pydantic_default=True: provided None unchanged + ( + True, + None, + {"name": "John", "interests": None}, + {"name": "John", "interests": None}, + ), + ], +) +def test_input_use_pydantic_default_parameterized( + use_default, + provided_interests, + expected_raw, + expected_pydantic, +): + class UserModel(BaseModel): + name: str + interests: list[str] | None = Field(default_factory=list) + + @strawberry.experimental.pydantic.input( + UserModel, + use_pydantic_default=use_default, + ) + class UpdateUserInput: + name: strawberry.auto + interests: strawberry.auto + + if provided_interests is strawberry.UNSET: + data = UpdateUserInput(name="John") + else: + data = UpdateUserInput(name="John", interests=provided_interests) + + raw = strawberry.asdict(data) + assert raw["name"] == expected_raw["name"] + + if expected_raw["interests"] is strawberry.UNSET: + assert raw["interests"] is strawberry.UNSET + else: + assert raw["interests"] == expected_raw["interests"] + + p = data.to_pydantic().model_dump() + assert p == expected_pydantic diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index b5af1ad396..c6aed4cb82 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -229,3 +229,19 @@ class Type: assert field.python_name == "field" assert field.type == Literal["field"] + + +def test_input_use_pydantic_default_false_field_types(): + class UserModel(pydantic.BaseModel): + name: str + interests: list[str] | None = pydantic.Field(default_factory=list) + + @strawberry.experimental.pydantic.input(UserModel, use_pydantic_default=False) + class UpdateUserInput: + name: strawberry.auto + interests: strawberry.auto + + definition = UpdateUserInput.__strawberry_definition__ + fields = {f.python_name: f for f in definition.fields} + + assert isinstance(fields["interests"].type, StrawberryOptional) diff --git a/tests/schema/test_pydantic.py b/tests/schema/test_pydantic.py index 6e905b2403..e5c64acbcf 100644 --- a/tests/schema/test_pydantic.py +++ b/tests/schema/test_pydantic.py @@ -67,3 +67,94 @@ class Query: assert not result.errors assert result.data["user"] == {"__typename": "User", "age_": 5} + + +@pytest.mark.parametrize( + "use_pydantic_default, expected_raw_interests, expected_pydantic", + [ + (False, "UNSET", {"name": "John", "interests": []}), + (True, [], {"name": "John", "interests": []}), + ], +) +def test_graphql_input_use_pydantic_default_integration( + use_pydantic_default, + expected_raw_interests, + expected_pydantic, +): + from pydantic import BaseModel, Field + + class UserModel(BaseModel): + name: str + interests: list[str] | None = Field(default_factory=list) + + @strawberry.experimental.pydantic.input( + UserModel, + use_pydantic_default=use_pydantic_default, + ) + class UpdateUserInput: + name: strawberry.auto + interests: strawberry.auto + + @strawberry.type + class RawResult: + name: str + interests: strawberry.scalars.JSON + + @strawberry.type + class PydanticResult: + name: str + interests: strawberry.scalars.JSON | None + + @strawberry.type + class UpdateResult: + raw: RawResult + pydantic: PydanticResult + + @strawberry.type + class Mutation: + @strawberry.mutation + def update_user(self, user_data: UpdateUserInput) -> UpdateResult: + raw_dict = strawberry.asdict(user_data) + p_dict = user_data.to_pydantic().model_dump() + + # JSON-friendly representation + raw_interests = raw_dict["interests"] + if raw_interests is strawberry.UNSET: + raw_interests = "UNSET" + + return UpdateResult( + raw=RawResult( + name=raw_dict["name"], + interests=raw_interests, + ), + pydantic=PydanticResult( + name=p_dict["name"], + interests=p_dict.get("interests"), + ), + ) + + @strawberry.type + class Query: + ok: bool + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + query = """ + mutation { + updateUser(userData: { name: "John" }) { + raw { name interests } + pydantic { name interests } + } + } + """ + + result = schema.execute_sync(query) + assert not result.errors + + raw = result.data["updateUser"]["raw"] + pyd = result.data["updateUser"]["pydantic"] + + assert raw["name"] == "John" + assert raw["interests"] == expected_raw_interests + + assert pyd == expected_pydantic