Implement reasoning for other providers

This commit is contained in:
Nicolas Mowen 2026-05-19 10:23:35 -06:00
parent d7683fd797
commit ea2e423e11
3 changed files with 73 additions and 5 deletions

View File

@ -248,6 +248,13 @@ class GeminiClient(GenAIClient):
if tool_config:
config_params["tool_config"] = tool_config
# Ask thinking-capable models (Gemini 2.5+) to include their
# reasoning trace as separate `thought` parts so we can surface
# it on the reasoning channel. Older models ignore this field.
config_params["thinking_config"] = types.ThinkingConfig(
include_thoughts=True
)
# Merge runtime_options
if isinstance(self.genai_config.runtime_options, dict):
config_params.update(self.genai_config.runtime_options)
@ -262,19 +269,24 @@ class GeminiClient(GenAIClient):
if not response or not response.candidates:
return {
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
}
candidate = response.candidates[0]
content = None
reasoning_parts: list[str] = []
tool_calls = None
# Extract content and tool calls from response
# Extract content, reasoning, and tool calls from response
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.text:
content = part.text.strip()
if getattr(part, "thought", False):
reasoning_parts.append(part.text)
else:
content = part.text.strip()
elif part.function_call:
# Handle function call
if tool_calls is None:
@ -297,6 +309,8 @@ class GeminiClient(GenAIClient):
}
)
reasoning = "".join(reasoning_parts).strip() or None
# Determine finish reason
finish_reason = "error"
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
@ -322,6 +336,7 @@ class GeminiClient(GenAIClient):
return {
"content": content,
"reasoning": reasoning,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
@ -330,6 +345,7 @@ class GeminiClient(GenAIClient):
logger.warning("Gemini API error during chat_with_tools: %s", str(e))
return {
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
}
@ -339,6 +355,7 @@ class GeminiClient(GenAIClient):
)
return {
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
}
@ -477,12 +494,19 @@ class GeminiClient(GenAIClient):
if tool_config:
config_params["tool_config"] = tool_config
# Ask thinking-capable models to include their reasoning trace
# as separate `thought` parts (Gemini 2.5+; ignored elsewhere).
config_params["thinking_config"] = types.ThinkingConfig(
include_thoughts=True
)
# Merge runtime_options
if isinstance(self.genai_config.runtime_options, dict):
config_params.update(self.genai_config.runtime_options)
# Use streaming API
content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
usage_stats: Optional[dict[str, Any]] = None
@ -519,12 +543,16 @@ class GeminiClient(GenAIClient):
]:
finish_reason = "error"
# Extract content and tool calls from chunk
# Extract content, reasoning, and tool calls from chunk
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.text:
content_parts.append(part.text)
yield ("content_delta", part.text)
if getattr(part, "thought", False):
reasoning_parts.append(part.text)
yield ("reasoning_delta", part.text)
else:
content_parts.append(part.text)
yield ("content_delta", part.text)
elif part.function_call:
# Handle function call
try:
@ -565,6 +593,7 @@ class GeminiClient(GenAIClient):
# Build final message
full_content = "".join(content_parts).strip() or None
full_reasoning = "".join(reasoning_parts).strip() or None
# Convert tool calls to list format
tool_calls_list = None
@ -593,6 +622,7 @@ class GeminiClient(GenAIClient):
"message",
{
"content": full_content,
"reasoning": full_reasoning,
"tool_calls": tool_calls_list,
"finish_reason": finish_reason,
},
@ -604,6 +634,7 @@ class GeminiClient(GenAIClient):
"message",
{
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
},
@ -616,6 +647,7 @@ class GeminiClient(GenAIClient):
"message",
{
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
},

View File

@ -337,6 +337,9 @@ class OllamaClient(GenAIClient):
response.get("done"),
)
content = message.get("content", "").strip() if message.get("content") else None
reasoning = (
message.get("thinking", "").strip() if message.get("thinking") else None
)
tool_calls = parse_tool_calls_from_message(message)
finish_reason = "error"
if response.get("done"):
@ -349,6 +352,7 @@ class OllamaClient(GenAIClient):
finish_reason = "stop"
return {
"content": content,
"reasoning": reasoning,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
@ -432,6 +436,9 @@ class OllamaClient(GenAIClient):
)
response = await async_client.chat(**request_params)
result = self._message_from_response(response)
reasoning = result.get("reasoning")
if reasoning:
yield ("reasoning_delta", reasoning)
content = result.get("content")
if content:
yield ("content_delta", content)
@ -450,6 +457,7 @@ class OllamaClient(GenAIClient):
headers=self._auth_headers(),
)
content_parts: list[str] = []
reasoning_parts: list[str] = []
final_message: dict[str, Any] | None = None
final_chunk: Any = None
stream = await async_client.chat(**request_params)
@ -457,6 +465,10 @@ class OllamaClient(GenAIClient):
if not chunk or "message" not in chunk:
continue
msg = chunk.get("message", {})
reasoning_delta = msg.get("thinking") or ""
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
yield ("reasoning_delta", reasoning_delta)
delta = msg.get("content") or ""
if delta:
content_parts.append(delta)
@ -464,8 +476,10 @@ class OllamaClient(GenAIClient):
if chunk.get("done"):
final_chunk = chunk
full_content = "".join(content_parts).strip() or None
full_reasoning = "".join(reasoning_parts).strip() or None
final_message = {
"content": full_content,
"reasoning": full_reasoning,
"tool_calls": None,
"finish_reason": "stop",
}
@ -482,6 +496,7 @@ class OllamaClient(GenAIClient):
"message",
{
"content": "".join(content_parts).strip() or None,
"reasoning": "".join(reasoning_parts).strip() or None,
"tool_calls": None,
"finish_reason": "stop",
},

View File

@ -236,6 +236,10 @@ class OpenAIClient(GenAIClient):
choice = result.choices[0]
message = choice.message
content = message.content.strip() if message.content else None
raw_reasoning = getattr(message, "reasoning_content", None) or getattr(
message, "reasoning", None
)
reasoning = raw_reasoning.strip() if raw_reasoning else None
tool_calls = None
if message.tool_calls:
@ -270,6 +274,7 @@ class OpenAIClient(GenAIClient):
return {
"content": content,
"reasoning": reasoning,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
@ -278,6 +283,7 @@ class OpenAIClient(GenAIClient):
logger.warning("OpenAI request timed out: %s", str(e))
return {
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
}
@ -285,6 +291,7 @@ class OpenAIClient(GenAIClient):
logger.warning("OpenAI returned an error: %s", str(e))
return {
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
}
@ -335,6 +342,7 @@ class OpenAIClient(GenAIClient):
# Use streaming API
content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
usage_stats: Optional[dict[str, Any]] = None
@ -356,6 +364,15 @@ class OpenAIClient(GenAIClient):
if choice.finish_reason:
finish_reason = choice.finish_reason
# Extract reasoning deltas (reasoning_content or reasoning,
# depending on the server)
reasoning_delta = getattr(delta, "reasoning_content", None) or getattr(
delta, "reasoning", None
)
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
yield ("reasoning_delta", reasoning_delta)
# Extract content deltas
if delta.content:
content_parts.append(delta.content)
@ -384,6 +401,7 @@ class OpenAIClient(GenAIClient):
# Build final message
full_content = "".join(content_parts).strip() or None
full_reasoning = "".join(reasoning_parts).strip() or None
# Convert tool calls to list format
tool_calls_list = None
@ -412,6 +430,7 @@ class OpenAIClient(GenAIClient):
"message",
{
"content": full_content,
"reasoning": full_reasoning,
"tool_calls": tool_calls_list,
"finish_reason": finish_reason,
},
@ -423,6 +442,7 @@ class OpenAIClient(GenAIClient):
"message",
{
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
},
@ -433,6 +453,7 @@ class OpenAIClient(GenAIClient):
"message",
{
"content": None,
"reasoning": None,
"tool_calls": None,
"finish_reason": "error",
},