diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index e19256f7e4..33de25b2a0 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -219,7 +219,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon "index", choice.index() != null ? choice.index() : 0, "finishReason", getFinishReasonJson(choice.finishReason()), "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", - "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of())); + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of()), + "reasoningContent", choice.message().reasoningContent() != null ? choice.message().reasoningContent() : ""); return buildGeneration(choice, metadata, request); }).toList(); // @formatter:on diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index a860307d54..143faec491 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -124,6 +124,7 @@ void aiResponseContainsAiMetadata() { assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP"); assertThat(chatGenerationMetadata.getContentFilters()).isEmpty(); + assertThat(generation.getOutput().getMetadata().get("reasoningContent")).isEqualTo(""); }); } @@ -144,7 +145,26 @@ void aiResponseContainsAiLogprobsMetadata() { assertThat(logprobs).isNotNull().isInstanceOf(OpenAiApi.LogProbs.class); } + @Test + void aiResponseContainsReasoningContent() { + + prepareMock(getJsonWithReasoningContent()); + + Prompt prompt = new Prompt("Reach for the sky."); + + ChatResponse response = this.openAiChatClient.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput().getMetadata().get("reasoningContent")) + .isEqualTo("Let me think step by step..."); + } + private void prepareMock(boolean includeLogprobs) { + prepareMock(getJson(includeLogprobs)); + } + + private void prepareMock(String json) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); @@ -157,7 +177,7 @@ private void prepareMock(boolean includeLogprobs) { this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions"))) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) - .andRespond(withSuccess(getJson(includeLogprobs), MediaType.APPLICATION_JSON).headers(httpHeaders)); + .andRespond(withSuccess(json, MediaType.APPLICATION_JSON).headers(httpHeaders)); } @@ -209,6 +229,31 @@ private String getJson(boolean includeLogprobs) { return String.format(getBaseJson(), ""); } + private String getJsonWithReasoningContent() { + return """ + { + "id": "chatcmpl-456", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "I surrender!", + "reasoning_content": "Let me think step by step..." + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + } + @SpringBootConfiguration static class Config {