diff --git a/spring-beans/src/main/java/org/springframework/beans/BeanUtils.java b/spring-beans/src/main/java/org/springframework/beans/BeanUtils.java index 152c6feaf3f9..f856a68b8e0d 100644 --- a/spring-beans/src/main/java/org/springframework/beans/BeanUtils.java +++ b/spring-beans/src/main/java/org/springframework/beans/BeanUtils.java @@ -37,6 +37,7 @@ import kotlin.reflect.KClass; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; +import kotlin.reflect.KType; import kotlin.reflect.full.KClasses; import kotlin.reflect.jvm.KCallablesJvm; import kotlin.reflect.jvm.ReflectJvmMapping; @@ -924,12 +925,36 @@ public static T instantiateClass(Constructor ctor, @Nullable Object... ar Map argParameters = CollectionUtils.newHashMap(parameters.size()); for (int i = 0 ; i < args.length ; i++) { if (!(parameters.get(i).isOptional() && args[i] == null)) { - argParameters.put(parameters.get(i), args[i]); + Object arg = args[i]; + KType type = parameters.get(i).getType(); + if (!(type.isMarkedNullable() && arg == null) && + type.getClassifier() instanceof KClass kClass && + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass)) && + !JvmClassMappingKt.getJavaClass(kClass).isInstance(arg)) { + arg = box(kClass, arg); + } + argParameters.put(parameters.get(i), arg); } } return kotlinConstructor.callBy(argParameters); } + private static Object box(KClass kClass, @Nullable Object arg) { + KFunction constructor = KClasses.getPrimaryConstructor(kClass); + Assert.state(constructor != null, + "Kotlin value classes annotated with @JvmInline are expected to have a single JVM constructor"); + KType type = constructor.getParameters().get(0).getType(); + if (!(type.isMarkedNullable() && arg == null) && + type.getClassifier() instanceof KClass parameterClass && + KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass))) { + arg = box(parameterClass, arg); + } + if (!KCallablesJvm.isAccessible(constructor)) { + KCallablesJvm.setAccessible(constructor, true); + } + return constructor.call(arg); + } + public static boolean hasDefaultConstructorMarker(Constructor ctor) { int parameterCount = ctor.getParameterCount(); return parameterCount > 0 && ctor.getParameters()[parameterCount -1].getType() == DefaultConstructorMarker.class; diff --git a/spring-beans/src/main/java/org/springframework/beans/TypeConverterDelegate.java b/spring-beans/src/main/java/org/springframework/beans/TypeConverterDelegate.java index 8455682b949f..93556b85824f 100644 --- a/spring-beans/src/main/java/org/springframework/beans/TypeConverterDelegate.java +++ b/spring-beans/src/main/java/org/springframework/beans/TypeConverterDelegate.java @@ -31,6 +31,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.CollectionFactory; +import org.springframework.core.KotlinDetector; import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.ConversionService; import org.springframework.core.convert.TypeDescriptor; @@ -238,6 +239,12 @@ else if (convertedValue instanceof Number num && Number.class.isAssignableFrom(r } } + if (!ClassUtils.isAssignableValue(requiredType, convertedValue)) { + if (convertedValue != null && KotlinDetector.isInlineClass(requiredType)) { + convertedValue = convertToInlineClass(propertyName, oldValue, convertedValue, requiredType); + } + } + if (!ClassUtils.isAssignableValue(requiredType, convertedValue)) { if (conversionAttemptEx != null) { // Original exception from former ConversionService call above... @@ -283,6 +290,18 @@ else if (conversionService != null && typeDescriptor != null) { return (T) convertedValue; } + private Object convertToInlineClass(@Nullable String propertyName, @Nullable Object oldValue, + Object newValue, Class requiredType) { + + Constructor constructor = BeanUtils.findPrimaryConstructor(requiredType); + if (constructor == null || constructor.getParameterCount() != 1) { + return newValue; + } + Object constructorArgument = convertIfNecessary( + propertyName, oldValue, newValue, constructor.getParameterTypes()[0]); + return BeanUtils.instantiateClass(constructor, constructorArgument); + } + private Object attemptToConvertStringToEnum(Class requiredType, String trimmedValue, Object currentConvertedValue) { Object convertedValue = currentConvertedValue; diff --git a/spring-context/src/main/java/org/springframework/validation/DataBinder.java b/spring-context/src/main/java/org/springframework/validation/DataBinder.java index bd804bfe6479..dae6a7bc0471 100644 --- a/spring-context/src/main/java/org/springframework/validation/DataBinder.java +++ b/spring-context/src/main/java/org/springframework/validation/DataBinder.java @@ -910,8 +910,8 @@ public void construct(ValueResolver valueResolver) { else { // A single data class constructor -> resolve constructor arguments from request parameters. @Nullable String[] paramNames = BeanUtils.getParameterNames(ctor); - Class[] paramTypes = ctor.getParameterTypes(); - @Nullable Object[] args = new Object[paramTypes.length]; + Class[] paramTypes = new Class[paramNames.length]; + @Nullable Object[] args = new Object[paramNames.length]; Set failedParamNames = new HashSet<>(4); for (int i = 0; i < paramNames.length; i++) { @@ -925,7 +925,8 @@ public void construct(ValueResolver valueResolver) { } String paramPath = nestedPath + lookupName; - Class paramType = paramTypes[i]; + Class paramType = param.getParameterType(); + paramTypes[i] = paramType; ResolvableType resolvableType = ResolvableType.forMethodParameter(param); Object value = valueResolver.resolveValue(paramPath, paramType); @@ -1008,7 +1009,8 @@ else if (paramType.isArray()) { */ protected boolean shouldConstructArgument(MethodParameter param) { Class type = param.nestedIfOptional().getNestedParameterType(); - return !BeanUtils.isSimpleValueType(type) && !type.getPackageName().startsWith("java."); + return !BeanUtils.isSimpleValueType(type) && !KotlinDetector.isInlineClass(type) && + !type.getPackageName().startsWith("java."); } private boolean hasValuesFor(String paramPath, ValueResolver resolver) { diff --git a/spring-context/src/test/kotlin/org/springframework/validation/DataBinderKotlinValueClassTests.kt b/spring-context/src/test/kotlin/org/springframework/validation/DataBinderKotlinValueClassTests.kt new file mode 100644 index 000000000000..94d6b7ba449b --- /dev/null +++ b/spring-context/src/test/kotlin/org/springframework/validation/DataBinderKotlinValueClassTests.kt @@ -0,0 +1,114 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.validation + +import org.assertj.core.api.Assertions +import org.junit.jupiter.api.Test +import org.springframework.core.ResolvableType +import org.springframework.format.support.DefaultFormattingConversionService +import java.util.UUID + +/** + * Tests for [DataBinder] constructor binding with Kotlin value classes. + */ +class DataBinderKotlinValueClassTests { + + @Test + fun constructDataClassWithStringValueClass() { + val binder = createDataBinder(StringValueClassRecord::class.java) + binder.construct(TestValueResolver(mapOf("title" to "hello"))) + + Assertions.assertThat(getTarget(binder).title).isEqualTo(Title("hello")) + } + + @Test + fun constructDataClassWithLongValueClass() { + val binder = createDataBinder(LongValueClassRecord::class.java) + binder.construct(TestValueResolver(mapOf("userId" to "1"))) + + Assertions.assertThat(getTarget(binder).userId).isEqualTo(UserId(1)) + } + + @Test + fun constructDataClassWithUuidValueClass() { + val uuid = UUID.randomUUID() + val binder = createDataBinder(UuidValueClassRecord::class.java) + binder.construct(TestValueResolver(mapOf("orderId" to uuid.toString()))) + + Assertions.assertThat(getTarget(binder).orderId).isEqualTo(OrderId(uuid)) + } + + @Test + fun constructDataClassWithNullablePrimitiveValueClass() { + val binder = createDataBinder(NullablePrimitiveValueClassRecord::class.java) + binder.construct(TestValueResolver(mapOf("userId" to "1"))) + + Assertions.assertThat(getTarget(binder).userId).isEqualTo(NullableUserId(1)) + } + + @Test + fun constructDataClassWithNestedValueClass() { + val binder = createDataBinder(NestedValueClassRecord::class.java) + binder.construct(TestValueResolver(mapOf("title" to "hello"))) + + Assertions.assertThat(getTarget(binder).title).isEqualTo(NestedTitle(Title("hello"))) + } + + private fun createDataBinder(targetType: Class<*>): DataBinder { + val binder = DataBinder(null) + binder.targetType = ResolvableType.forClass(targetType) + binder.conversionService = DefaultFormattingConversionService() + return binder + } + + private inline fun getTarget(dataBinder: DataBinder): T { + Assertions.assertThat(dataBinder.bindingResult.allErrors).isEmpty() + return dataBinder.target as T + } + + private data class StringValueClassRecord(val title: Title) + + @JvmInline + value class Title(val value: String) + + private data class LongValueClassRecord(val userId: UserId) + + @JvmInline + value class UserId(val value: Long) + + private data class UuidValueClassRecord(val orderId: OrderId) + + @JvmInline + value class OrderId(val value: UUID) + + private data class NullablePrimitiveValueClassRecord(val userId: NullableUserId?) + + @JvmInline + value class NullableUserId(val value: Long) + + private data class NestedValueClassRecord(val title: NestedTitle) + + @JvmInline + value class NestedTitle(val value: Title) + + private class TestValueResolver(private val values: Map) : DataBinder.ValueResolver { + + override fun resolveValue(name: String, type: Class<*>): Any? = this.values[name] + + override fun getNames(): Set = this.values.keys + } +} diff --git a/spring-web/src/test/kotlin/org/springframework/web/method/annotation/ModelAttributeMethodProcessorKotlinTests.kt b/spring-web/src/test/kotlin/org/springframework/web/method/annotation/ModelAttributeMethodProcessorKotlinTests.kt index a369ceda3abb..4160e1aaab73 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/method/annotation/ModelAttributeMethodProcessorKotlinTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/method/annotation/ModelAttributeMethodProcessorKotlinTests.kt @@ -27,6 +27,7 @@ import org.mockito.Mockito.mock import org.springframework.core.MethodParameter import org.springframework.core.ResolvableType import org.springframework.core.annotation.SynthesizingMethodParameter +import org.springframework.format.support.DefaultFormattingConversionService import org.springframework.web.bind.MethodArgumentNotValidException import org.springframework.web.bind.support.WebDataBinderFactory import org.springframework.web.bind.support.WebRequestDataBinder @@ -47,6 +48,8 @@ class ModelAttributeMethodProcessorKotlinTests { private lateinit var param: MethodParameter + private lateinit var valueClassParam: MethodParameter + @BeforeEach fun setup() { @@ -54,6 +57,9 @@ class ModelAttributeMethodProcessorKotlinTests { processor = ModelAttributeMethodProcessor(false) val method = ModelAttributeHandler::class.java.getDeclaredMethod("test", Param::class.java) param = SynthesizingMethodParameter(method, 0) + + val valueClassMethod = ModelAttributeHandler::class.java.getDeclaredMethod("valueClass", ValueClassParam::class.java) + valueClassParam = SynthesizingMethodParameter(valueClassMethod, 0) } @Test @@ -87,11 +93,36 @@ class ModelAttributeMethodProcessorKotlinTests { .hasMessageContaining("parameter a") } + @Test + fun resolveArgumentWithValueClass() { + val mockRequest = MockHttpServletRequest().apply { addParameter("id", "1") } + val requestWithParam = ServletWebRequest(mockRequest) + val factory = mock() + given(factory.createBinder(any(), any(), eq("valueClassParam"), any())) + .willAnswer { + val binder = WebRequestDataBinder(it.getArgument(1)) + binder.setTargetType(ResolvableType.forMethodParameter(this.valueClassParam)) + binder.conversionService = DefaultFormattingConversionService() + binder + } + + assertThat(processor.resolveArgument(this.valueClassParam, container, requestWithParam, factory)) + .isEqualTo(ValueClassParam(ValueClass(1))) + } + private data class Param(val a: String) + private data class ValueClassParam(val id: ValueClass) + + @JvmInline + value class ValueClass(val value: Long) + private class ModelAttributeHandler { @Suppress("UNUSED_PARAMETER") fun test(param: Param) { } + + @Suppress("UNUSED_PARAMETER") + fun valueClass(valueClassParam: ValueClassParam) { } } }