diff --git a/.gitignore b/.gitignore index e0280c51b9..b07f42fc8c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ /env/ MANIFEST coverage.* +venv/ .coverage .cache/ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5f34b00194..b60ebe75b5 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -664,7 +664,35 @@ def run_child_validation(self, data): self.child.initial_data = data return super().run_child_validation(data) """ - return self.child.run_validation(data) + if not hasattr(self.child, 'instance'): + return self.child.run_validation(data) + + if not ( + hasattr(self, '_list_serializer_instance_map') and + isinstance(data, Mapping) + ): + return self.child.run_validation(data) + + lookup_field = getattr(getattr(self.child, 'Meta', None), 'lookup_field', None) + data_pk = data.get(lookup_field) + if data_pk is None: + data_pk = data.get('id') + if data_pk is None: + data_pk = data.get('pk') + + if data_pk is None: + return self.child.run_validation(data) + + child_instance = self._list_serializer_instance_map.get(str(data_pk)) + if child_instance is None: + return self.child.run_validation(data) + + original_instance = self.child.instance + try: + self.child.instance = child_instance + return self.child.run_validation(data) + finally: + self.child.instance = original_instance def to_internal_value(self, data): """ @@ -674,12 +702,16 @@ def to_internal_value(self, data): data = html.parse_html_list(data, default=[]) if not isinstance(data, list): - message = self.error_messages['not_a_list'].format( - input_type=type(data).__name__ - ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] - }, code='not_a_list') + api_settings.NON_FIELD_ERRORS_KEY: [ + ErrorDetail( + self.error_messages['not_a_list'].format( + input_type=type(data).__name__ + ), + code='not_a_list' + ) + ] + }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] @@ -702,19 +734,49 @@ def to_internal_value(self, data): ret = [] errors = [] - for item in data: - try: - validated = self.run_child_validation(item) - except ValidationError as exc: - errors.append(exc.detail) - else: - ret.append(validated) - errors.append({}) + # Build a primary key lookup for instance matching in many=True updates. + instance_map = None + if self.instance is not None: + if isinstance(self.instance, Mapping): + instance_map = {str(k): v for k, v in self.instance.items()} + elif isinstance(self.instance, (list, tuple, models.query.QuerySet)): + instance_map = {} + lookup_field = getattr(getattr(self.child, 'Meta', None), 'lookup_field', None) + + for obj in self.instance: + if lookup_field is not None: + pk = getattr(obj, lookup_field, None) + else: + pk = getattr(obj, 'pk', None) + if pk is None: + pk = getattr(obj, 'id', None) - if any(errors): - raise ValidationError(errors) + if pk is not None: + key = str(pk) + # If duplicate keys are present, keep the last value, + # matching standard mapping assignment behavior. + instance_map[key] = obj - return ret + if instance_map is not None: + self._list_serializer_instance_map = instance_map + + try: + for item in data: + try: + validated = self.run_child_validation(item) + except ValidationError as exc: + errors.append(exc.detail) + else: + ret.append(validated) + errors.append({}) + + if any(errors): + raise ValidationError(errors) + + return ret + finally: + if hasattr(self, '_list_serializer_instance_map'): + delattr(self, '_list_serializer_instance_map') def to_representation(self, data): """ @@ -749,6 +811,13 @@ def save(self, **kwargs): """ Save and return a list of object instances. """ + assert hasattr(self, '_errors'), ( + 'You must call `.is_valid()` before calling `.save()`.' + ) + assert not self.errors, ( + 'You cannot call `.save()` on a serializer with invalid data.' + ) + # Guard against incorrect use of `serializer.save(commit=False)` assert 'commit' not in kwargs, ( "'commit' is not a valid keyword argument to the 'save()' method. " @@ -756,9 +825,13 @@ def save(self, **kwargs): "inspect 'serializer.validated_data' instead. " "You can also pass additional keyword arguments to 'save()' if you " "need to set extra attributes on the saved model instance. " - "For example: 'serializer.save(owner=request.user)'.'" + "For example: 'serializer.save(owner=request.user)'." + ) + assert not hasattr(self, '_data'), ( + "You cannot call `.save()` after accessing `serializer.data`." + "If you need to access data before committing to the database then " + "inspect 'serializer.validated_data' instead. " ) - validated_data = [ {**attrs, **kwargs} for attrs in self.validated_data ] diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index f690559a8a..1066cd49bf 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -203,6 +203,155 @@ def update(self, instance, validated_data): assert updated_instances == expected_output +class TestListSerializerInstanceMatching: + def test_matching_with_id(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField() + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + instance = [ + BasicObject(id=1), + BasicObject(id=2), + ] + input_data = [ + {'id': 1}, + {'id': 2}, + ] + + serializer = TestSerializer(instance, data=input_data, many=True) + assert serializer.is_valid() + assert seen_instances == instance + + def test_matching_with_pk(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + pk = serializers.IntegerField() + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + instance = [ + BasicObject(pk=1), + BasicObject(pk=2), + ] + input_data = [ + {'pk': 1}, + {'pk': 2}, + ] + + serializer = TestSerializer(instance, data=input_data, many=True) + assert serializer.is_valid() + assert seen_instances == instance + + def test_matching_with_id_against_object_with_pk_only(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField() + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + instance = [BasicObject(pk=1)] + input_data = [{'id': 1}] + + serializer = TestSerializer(instance, data=input_data, many=True) + assert serializer.is_valid() + assert seen_instances == instance + + def test_mapping_instance_matching(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField() + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + obj1 = BasicObject(id=1) + obj2 = BasicObject(id=2) + instance = { + '1': obj1, + '2': obj2, + } + input_data = [ + {'id': 1}, + {'id': 2}, + ] + + serializer = TestSerializer(instance, data=input_data, many=True) + assert serializer.is_valid() + assert seen_instances == [obj1, obj2] + + def test_unsupported_instance_type_preserves_original_behavior(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField() + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + serializer = TestSerializer(instance=123, data=[{'id': 1}], many=True) + assert serializer.is_valid() + assert seen_instances == [123] + + def test_missing_lookup_field_in_data_does_not_assign_instance(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField(required=False) + + class Meta: + lookup_field = 'uuid' + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + class TestListSerializer(serializers.ListSerializer): + child = TestSerializer() + + serializer = TestListSerializer( + instance=[BasicObject(id=1, uuid='uuid-1')], + data=[{'id': 1}], + ) + assert serializer.is_valid() + assert seen_instances == [None] + + def test_matching_with_configurable_lookup_field(self): + seen_instances = [] + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField(required=False) + uuid = serializers.CharField() + + class Meta: + lookup_field = 'uuid' + + def validate(self, attrs): + seen_instances.append(self.instance) + return attrs + + obj1 = BasicObject(id=1, uuid='uuid-1') + obj2 = BasicObject(id=2, uuid='uuid-2') + input_data = [{'id': 1, 'uuid': 'uuid-2'}] + + serializer = TestSerializer([obj1, obj2], data=input_data, many=True) + assert serializer.is_valid() + assert seen_instances == [obj2] + + class TestNestedListSerializer: """ Tests for using a ListSerializer as a field. @@ -883,3 +1032,32 @@ def test(self): queryset = NullableOneToOneSource.objects.all() serializer = self.serializer(queryset, many=True) assert serializer.data + + +def test_many_true_instance_level_validation_uses_matched_instance(): + class Obj: + def __init__(self, id, valid): + self.id = id + self.valid = valid + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField() + status = serializers.CharField() + + def validate_status(self, value): + if self.instance is None: + raise serializers.ValidationError("Instance not matched") + if not self.instance.valid: + raise serializers.ValidationError("Invalid instance") + return value + + objs = [Obj(1, True), Obj(2, False)] + serializer = TestSerializer( + instance=objs, + data=[{"id": 1, "status": "ok"}, {"id": 2, "status": "fail"}], + many=True, + partial=True, + ) + + assert not serializer.is_valid() + assert serializer.errors == [{}, {'status': ['Invalid instance']}]