Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import kotlin.reflect.KClass;
import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter;
import kotlin.reflect.KType;
import kotlin.reflect.full.KClasses;
import kotlin.reflect.jvm.KCallablesJvm;
import kotlin.reflect.jvm.ReflectJvmMapping;
Expand Down Expand Up @@ -924,12 +925,36 @@ public static <T> T instantiateClass(Constructor<T> ctor, @Nullable Object... ar
Map<KParameter, Object> argParameters = CollectionUtils.newHashMap(parameters.size());
for (int i = 0 ; i < args.length ; i++) {
if (!(parameters.get(i).isOptional() && args[i] == null)) {
argParameters.put(parameters.get(i), args[i]);
Object arg = args[i];
KType type = parameters.get(i).getType();
if (!(type.isMarkedNullable() && arg == null) &&
type.getClassifier() instanceof KClass<?> kClass &&
KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(kClass)) &&
!JvmClassMappingKt.getJavaClass(kClass).isInstance(arg)) {
arg = box(kClass, arg);
}
argParameters.put(parameters.get(i), arg);
}
}
return kotlinConstructor.callBy(argParameters);
}

private static Object box(KClass<?> kClass, @Nullable Object arg) {
KFunction<?> constructor = KClasses.getPrimaryConstructor(kClass);
Assert.state(constructor != null,
"Kotlin value classes annotated with @JvmInline are expected to have a single JVM constructor");
KType type = constructor.getParameters().get(0).getType();
if (!(type.isMarkedNullable() && arg == null) &&
type.getClassifier() instanceof KClass<?> parameterClass &&
KotlinDetector.isInlineClass(JvmClassMappingKt.getJavaClass(parameterClass))) {
arg = box(parameterClass, arg);
}
if (!KCallablesJvm.isAccessible(constructor)) {
KCallablesJvm.setAccessible(constructor, true);
}
return constructor.call(arg);
}

public static boolean hasDefaultConstructorMarker(Constructor<?> ctor) {
int parameterCount = ctor.getParameterCount();
return parameterCount > 0 && ctor.getParameters()[parameterCount -1].getType() == DefaultConstructorMarker.class;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.jspecify.annotations.Nullable;

import org.springframework.core.CollectionFactory;
import org.springframework.core.KotlinDetector;
import org.springframework.core.convert.ConversionFailedException;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.TypeDescriptor;
Expand Down Expand Up @@ -238,6 +239,12 @@ else if (convertedValue instanceof Number num && Number.class.isAssignableFrom(r
}
}

if (!ClassUtils.isAssignableValue(requiredType, convertedValue)) {
if (convertedValue != null && KotlinDetector.isInlineClass(requiredType)) {
convertedValue = convertToInlineClass(propertyName, oldValue, convertedValue, requiredType);
}
}

if (!ClassUtils.isAssignableValue(requiredType, convertedValue)) {
if (conversionAttemptEx != null) {
// Original exception from former ConversionService call above...
Expand Down Expand Up @@ -283,6 +290,18 @@ else if (conversionService != null && typeDescriptor != null) {
return (T) convertedValue;
}

private <T> Object convertToInlineClass(@Nullable String propertyName, @Nullable Object oldValue,
Object newValue, Class<T> requiredType) {

Constructor<T> constructor = BeanUtils.findPrimaryConstructor(requiredType);
if (constructor == null || constructor.getParameterCount() != 1) {
return newValue;
}
Object constructorArgument = convertIfNecessary(
propertyName, oldValue, newValue, constructor.getParameterTypes()[0]);
return BeanUtils.instantiateClass(constructor, constructorArgument);
}

private Object attemptToConvertStringToEnum(Class<?> requiredType, String trimmedValue, Object currentConvertedValue) {
Object convertedValue = currentConvertedValue;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,8 @@ public void construct(ValueResolver valueResolver) {
else {
// A single data class constructor -> resolve constructor arguments from request parameters.
@Nullable String[] paramNames = BeanUtils.getParameterNames(ctor);
Class<?>[] paramTypes = ctor.getParameterTypes();
@Nullable Object[] args = new Object[paramTypes.length];
Class<?>[] paramTypes = new Class<?>[paramNames.length];
@Nullable Object[] args = new Object[paramNames.length];
Set<String> failedParamNames = new HashSet<>(4);

for (int i = 0; i < paramNames.length; i++) {
Expand All @@ -925,7 +925,8 @@ public void construct(ValueResolver valueResolver) {
}

String paramPath = nestedPath + lookupName;
Class<?> paramType = paramTypes[i];
Class<?> paramType = param.getParameterType();
paramTypes[i] = paramType;
ResolvableType resolvableType = ResolvableType.forMethodParameter(param);

Object value = valueResolver.resolveValue(paramPath, paramType);
Expand Down Expand Up @@ -1008,7 +1009,8 @@ else if (paramType.isArray()) {
*/
protected boolean shouldConstructArgument(MethodParameter param) {
Class<?> type = param.nestedIfOptional().getNestedParameterType();
return !BeanUtils.isSimpleValueType(type) && !type.getPackageName().startsWith("java.");
return !BeanUtils.isSimpleValueType(type) && !KotlinDetector.isInlineClass(type) &&
!type.getPackageName().startsWith("java.");
}

private boolean hasValuesFor(String paramPath, ValueResolver resolver) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2002-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.validation

import org.assertj.core.api.Assertions
import org.junit.jupiter.api.Test
import org.springframework.core.ResolvableType
import org.springframework.format.support.DefaultFormattingConversionService
import java.util.UUID

/**
* Tests for [DataBinder] constructor binding with Kotlin value classes.
*/
class DataBinderKotlinValueClassTests {

@Test
fun constructDataClassWithStringValueClass() {
val binder = createDataBinder(StringValueClassRecord::class.java)
binder.construct(TestValueResolver(mapOf("title" to "hello")))

Assertions.assertThat(getTarget<StringValueClassRecord>(binder).title).isEqualTo(Title("hello"))
}

@Test
fun constructDataClassWithLongValueClass() {
val binder = createDataBinder(LongValueClassRecord::class.java)
binder.construct(TestValueResolver(mapOf("userId" to "1")))

Assertions.assertThat(getTarget<LongValueClassRecord>(binder).userId).isEqualTo(UserId(1))
}

@Test
fun constructDataClassWithUuidValueClass() {
val uuid = UUID.randomUUID()
val binder = createDataBinder(UuidValueClassRecord::class.java)
binder.construct(TestValueResolver(mapOf("orderId" to uuid.toString())))

Assertions.assertThat(getTarget<UuidValueClassRecord>(binder).orderId).isEqualTo(OrderId(uuid))
}

@Test
fun constructDataClassWithNullablePrimitiveValueClass() {
val binder = createDataBinder(NullablePrimitiveValueClassRecord::class.java)
binder.construct(TestValueResolver(mapOf("userId" to "1")))

Assertions.assertThat(getTarget<NullablePrimitiveValueClassRecord>(binder).userId).isEqualTo(NullableUserId(1))
}

@Test
fun constructDataClassWithNestedValueClass() {
val binder = createDataBinder(NestedValueClassRecord::class.java)
binder.construct(TestValueResolver(mapOf("title" to "hello")))

Assertions.assertThat(getTarget<NestedValueClassRecord>(binder).title).isEqualTo(NestedTitle(Title("hello")))
}

private fun createDataBinder(targetType: Class<*>): DataBinder {
val binder = DataBinder(null)
binder.targetType = ResolvableType.forClass(targetType)
binder.conversionService = DefaultFormattingConversionService()
return binder
}

private inline fun <reified T> getTarget(dataBinder: DataBinder): T {
Assertions.assertThat(dataBinder.bindingResult.allErrors).isEmpty()
return dataBinder.target as T
}

private data class StringValueClassRecord(val title: Title)

@JvmInline
value class Title(val value: String)

private data class LongValueClassRecord(val userId: UserId)

@JvmInline
value class UserId(val value: Long)

private data class UuidValueClassRecord(val orderId: OrderId)

@JvmInline
value class OrderId(val value: UUID)

private data class NullablePrimitiveValueClassRecord(val userId: NullableUserId?)

@JvmInline
value class NullableUserId(val value: Long)

private data class NestedValueClassRecord(val title: NestedTitle)

@JvmInline
value class NestedTitle(val value: Title)

private class TestValueResolver(private val values: Map<String, Any>) : DataBinder.ValueResolver {

override fun resolveValue(name: String, type: Class<*>): Any? = this.values[name]

override fun getNames(): Set<String> = this.values.keys
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.mockito.Mockito.mock
import org.springframework.core.MethodParameter
import org.springframework.core.ResolvableType
import org.springframework.core.annotation.SynthesizingMethodParameter
import org.springframework.format.support.DefaultFormattingConversionService
import org.springframework.web.bind.MethodArgumentNotValidException
import org.springframework.web.bind.support.WebDataBinderFactory
import org.springframework.web.bind.support.WebRequestDataBinder
Expand All @@ -47,13 +48,18 @@ class ModelAttributeMethodProcessorKotlinTests {

private lateinit var param: MethodParameter

private lateinit var valueClassParam: MethodParameter


@BeforeEach
fun setup() {
container = ModelAndViewContainer()
processor = ModelAttributeMethodProcessor(false)
val method = ModelAttributeHandler::class.java.getDeclaredMethod("test", Param::class.java)
param = SynthesizingMethodParameter(method, 0)

val valueClassMethod = ModelAttributeHandler::class.java.getDeclaredMethod("valueClass", ValueClassParam::class.java)
valueClassParam = SynthesizingMethodParameter(valueClassMethod, 0)
}

@Test
Expand Down Expand Up @@ -87,11 +93,36 @@ class ModelAttributeMethodProcessorKotlinTests {
.hasMessageContaining("parameter a")
}

@Test
fun resolveArgumentWithValueClass() {
val mockRequest = MockHttpServletRequest().apply { addParameter("id", "1") }
val requestWithParam = ServletWebRequest(mockRequest)
val factory = mock<WebDataBinderFactory>()
given(factory.createBinder(any(), any(), eq("valueClassParam"), any()))
.willAnswer {
val binder = WebRequestDataBinder(it.getArgument(1))
binder.setTargetType(ResolvableType.forMethodParameter(this.valueClassParam))
binder.conversionService = DefaultFormattingConversionService()
binder
}

assertThat(processor.resolveArgument(this.valueClassParam, container, requestWithParam, factory))
.isEqualTo(ValueClassParam(ValueClass(1)))
}

private data class Param(val a: String)

private data class ValueClassParam(val id: ValueClass)

@JvmInline
value class ValueClass(val value: Long)

private class ModelAttributeHandler {
@Suppress("UNUSED_PARAMETER")
fun test(param: Param) { }

@Suppress("UNUSED_PARAMETER")
fun valueClass(valueClassParam: ValueClassParam) { }
}

}