Skip to content

Commit f61a99f

Browse files
committed
feat: Add support for the think flag at the ChatMode level.
Signed-off-by: Sun Yuhan <sunyuhan1998@users.noreply.github.com>
1 parent abfc9e1 commit f61a99f

File tree

9 files changed

+66
-3
lines changed

9 files changed

+66
-3
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
507507
.stream(stream)
508508
.messages(ollamaMessages)
509509
.options(requestOptions)
510-
.think(requestOptions.getThink());
510+
.think(requestOptions.isThink());
511511

512512
if (requestOptions.getFormat() != null) {
513513
requestBuilder.format(requestOptions.getFormat());

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
409409
.format(fromOptions.getFormat())
410410
.keepAlive(fromOptions.getKeepAlive())
411411
.truncate(fromOptions.getTruncate())
412-
.think(fromOptions.getThink())
412+
.think(fromOptions.isThink())
413413
.useNUMA(fromOptions.getUseNUMA())
414414
.numCtx(fromOptions.getNumCtx())
415415
.numBatch(fromOptions.getNumBatch())
@@ -837,7 +837,8 @@ public void setTruncate(Boolean truncate) {
837837
this.truncate = truncate;
838838
}
839839

840-
public Boolean getThink() {
840+
@Override
841+
public Boolean isThink() {
841842
return this.think;
842843
}
843844

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ public interface ChatOptions extends ModelOptions {
8383
@Nullable
8484
Double getTopP();
8585

86+
/**
87+
* Returns the think flag to use for the chat.
88+
* @return the think flag to use for the chat
89+
*/
90+
@Nullable
91+
default Boolean isThink() {
92+
return false;
93+
}
94+
8695
/**
8796
* Returns a copy of this {@link ChatOptions}.
8897
* @return a copy of this {@link ChatOptions}
@@ -158,6 +167,13 @@ interface Builder {
158167
*/
159168
Builder topP(Double topP);
160169

170+
/**
171+
* Builds with the think to use for the chat.
172+
* @param think Whether to enable thinking mode
173+
* @return the builder.
174+
*/
175+
Builder think(Boolean think);
176+
161177
/**
162178
* Build the {@link ChatOptions}.
163179
* @return the Chat options.

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ public class DefaultChatOptions implements ChatOptions {
4141

4242
private Double topP;
4343

44+
private Boolean think;
45+
4446
@Override
4547
public String getModel() {
4648
return this.model;
@@ -113,6 +115,15 @@ public void setTopP(Double topP) {
113115
this.topP = topP;
114116
}
115117

118+
@Override
119+
public Boolean isThink() {
120+
return this.think;
121+
}
122+
123+
public void setThink(Boolean think) {
124+
this.think = think;
125+
}
126+
116127
@Override
117128
@SuppressWarnings("unchecked")
118129
public <T extends ChatOptions> T copy() {
@@ -125,6 +136,7 @@ public <T extends ChatOptions> T copy() {
125136
copy.setTemperature(this.getTemperature());
126137
copy.setTopK(this.getTopK());
127138
copy.setTopP(this.getTopP());
139+
copy.setThink(this.isThink());
128140
return (T) copy;
129141
}
130142

spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ public DefaultChatOptionsBuilder topP(Double topP) {
7373
return this;
7474
}
7575

76+
public DefaultChatOptionsBuilder think(Boolean think) {
77+
this.options.setThink(think);
78+
return this;
79+
}
80+
7681
public ChatOptions build() {
7782
return this.options.copy();
7883
}

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
7070
@Nullable
7171
private Double topP;
7272

73+
@Nullable
74+
private Boolean think;
75+
7376
@Override
7477
public List<ToolCallback> getToolCallbacks() {
7578
return List.copyOf(this.toolCallbacks);
@@ -198,6 +201,16 @@ public void setTopP(@Nullable Double topP) {
198201
this.topP = topP;
199202
}
200203

204+
@Override
205+
@Nullable
206+
public Boolean isThink() {
207+
return this.think;
208+
}
209+
210+
public void setThink(@Nullable Boolean think) {
211+
this.think = think;
212+
}
213+
201214
@Override
202215
@SuppressWarnings("unchecked")
203216
public <T extends ChatOptions> T copy() {
@@ -325,6 +338,12 @@ public ToolCallingChatOptions.Builder topP(@Nullable Double topP) {
325338
return this;
326339
}
327340

341+
@Override
342+
public ToolCallingChatOptions.Builder think(Boolean think) {
343+
this.options.setThink(think);
344+
return this;
345+
}
346+
328347
@Override
329348
public ToolCallingChatOptions build() {
330349
return this.options;

spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ interface Builder extends ChatOptions.Builder {
219219
@Override
220220
Builder topP(@Nullable Double topP);
221221

222+
@Override
223+
Builder think(@Nullable Boolean think);
224+
222225
@Override
223226
ToolCallingChatOptions build();
224227

spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ void shouldBuildWithAllOptions() {
5353
.topP(1.0)
5454
.topK(40)
5555
.stopSequences(List.of("stop1", "stop2"))
56+
.think(true)
5657
.build();
5758

5859
assertThat(options.getModel()).isEqualTo("gpt-4");
5960
assertThat(options.getMaxTokens()).isEqualTo(100);
6061
assertThat(options.getTemperature()).isEqualTo(0.7);
6162
assertThat(options.getTopP()).isEqualTo(1.0);
6263
assertThat(options.getTopK()).isEqualTo(40);
64+
assertThat(options.isThink()).isEqualTo(true);
6365
assertThat(options.getStopSequences()).containsExactly("stop1", "stop2");
6466
}
6567

@@ -82,6 +84,7 @@ void shouldCopyOptions() {
8284
.temperature(0.7)
8385
.topP(1.0)
8486
.topK(40)
87+
.think(true)
8588
.stopSequences(List.of("stop1", "stop2"))
8689
.build();
8790

@@ -107,6 +110,7 @@ void shouldUpcastToChatOptions() {
107110
.temperature(0.7)
108111
.topP(1.0)
109112
.topK(40)
113+
.think(true)
110114
.stopSequences(List.of("stop1", "stop2"))
111115
.toolNames(Set.of("function1", "function2"))
112116
.toolCallbacks(List.of(callback))
@@ -121,6 +125,7 @@ void shouldUpcastToChatOptions() {
121125
assertThat(chatOptions.getTemperature()).isEqualTo(0.7);
122126
assertThat(chatOptions.getTopP()).isEqualTo(1.0);
123127
assertThat(chatOptions.getTopK()).isEqualTo(40);
128+
assertThat(chatOptions.isThink()).isEqualTo(true);
124129
assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2");
125130
}
126131

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void builderShouldCreateOptionsWithAllProperties() {
188188
.stopSequences(List.of("stop"))
189189
.topK(3)
190190
.topP(0.9)
191+
.think(true)
191192
.build();
192193

193194
assertThat(options).satisfies(o -> {
@@ -203,6 +204,7 @@ void builderShouldCreateOptionsWithAllProperties() {
203204
assertThat(o.getStopSequences()).containsExactly("stop");
204205
assertThat(o.getTopK()).isEqualTo(3);
205206
assertThat(o.getTopP()).isEqualTo(0.9);
207+
assertThat(o.isThink()).isEqualTo(true);
206208
});
207209
}
208210

0 commit comments

Comments
 (0)