Implement reasoning for other providers

This commit is contained in:
Nicolas Mowen 2026-05-19 10:23:35 -06:00
parent d7683fd797
commit ea2e423e11
3 changed files with 73 additions and 5 deletions

View File

@ -248,6 +248,13 @@ class GeminiClient(GenAIClient):
if tool_config: if tool_config:
config_params["tool_config"] = tool_config config_params["tool_config"] = tool_config
# Ask thinking-capable models (Gemini 2.5+) to include their
# reasoning trace as separate `thought` parts so we can surface
# it on the reasoning channel. Older models ignore this field.
config_params["thinking_config"] = types.ThinkingConfig(
include_thoughts=True
)
# Merge runtime_options # Merge runtime_options
if isinstance(self.genai_config.runtime_options, dict): if isinstance(self.genai_config.runtime_options, dict):
config_params.update(self.genai_config.runtime_options) config_params.update(self.genai_config.runtime_options)
@ -262,19 +269,24 @@ class GeminiClient(GenAIClient):
if not response or not response.candidates: if not response or not response.candidates:
return { return {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
candidate = response.candidates[0] candidate = response.candidates[0]
content = None content = None
reasoning_parts: list[str] = []
tool_calls = None tool_calls = None
# Extract content and tool calls from response # Extract content, reasoning, and tool calls from response
if candidate.content and candidate.content.parts: if candidate.content and candidate.content.parts:
for part in candidate.content.parts: for part in candidate.content.parts:
if part.text: if part.text:
content = part.text.strip() if getattr(part, "thought", False):
reasoning_parts.append(part.text)
else:
content = part.text.strip()
elif part.function_call: elif part.function_call:
# Handle function call # Handle function call
if tool_calls is None: if tool_calls is None:
@ -297,6 +309,8 @@ class GeminiClient(GenAIClient):
} }
) )
reasoning = "".join(reasoning_parts).strip() or None
# Determine finish reason # Determine finish reason
finish_reason = "error" finish_reason = "error"
if hasattr(candidate, "finish_reason") and candidate.finish_reason: if hasattr(candidate, "finish_reason") and candidate.finish_reason:
@ -322,6 +336,7 @@ class GeminiClient(GenAIClient):
return { return {
"content": content, "content": content,
"reasoning": reasoning,
"tool_calls": tool_calls, "tool_calls": tool_calls,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
@ -330,6 +345,7 @@ class GeminiClient(GenAIClient):
logger.warning("Gemini API error during chat_with_tools: %s", str(e)) logger.warning("Gemini API error during chat_with_tools: %s", str(e))
return { return {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
@ -339,6 +355,7 @@ class GeminiClient(GenAIClient):
) )
return { return {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
@ -477,12 +494,19 @@ class GeminiClient(GenAIClient):
if tool_config: if tool_config:
config_params["tool_config"] = tool_config config_params["tool_config"] = tool_config
# Ask thinking-capable models to include their reasoning trace
# as separate `thought` parts (Gemini 2.5+; ignored elsewhere).
config_params["thinking_config"] = types.ThinkingConfig(
include_thoughts=True
)
# Merge runtime_options # Merge runtime_options
if isinstance(self.genai_config.runtime_options, dict): if isinstance(self.genai_config.runtime_options, dict):
config_params.update(self.genai_config.runtime_options) config_params.update(self.genai_config.runtime_options)
# Use streaming API # Use streaming API
content_parts: list[str] = [] content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {} tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop" finish_reason = "stop"
usage_stats: Optional[dict[str, Any]] = None usage_stats: Optional[dict[str, Any]] = None
@ -519,12 +543,16 @@ class GeminiClient(GenAIClient):
]: ]:
finish_reason = "error" finish_reason = "error"
# Extract content and tool calls from chunk # Extract content, reasoning, and tool calls from chunk
if candidate.content and candidate.content.parts: if candidate.content and candidate.content.parts:
for part in candidate.content.parts: for part in candidate.content.parts:
if part.text: if part.text:
content_parts.append(part.text) if getattr(part, "thought", False):
yield ("content_delta", part.text) reasoning_parts.append(part.text)
yield ("reasoning_delta", part.text)
else:
content_parts.append(part.text)
yield ("content_delta", part.text)
elif part.function_call: elif part.function_call:
# Handle function call # Handle function call
try: try:
@ -565,6 +593,7 @@ class GeminiClient(GenAIClient):
# Build final message # Build final message
full_content = "".join(content_parts).strip() or None full_content = "".join(content_parts).strip() or None
full_reasoning = "".join(reasoning_parts).strip() or None
# Convert tool calls to list format # Convert tool calls to list format
tool_calls_list = None tool_calls_list = None
@ -593,6 +622,7 @@ class GeminiClient(GenAIClient):
"message", "message",
{ {
"content": full_content, "content": full_content,
"reasoning": full_reasoning,
"tool_calls": tool_calls_list, "tool_calls": tool_calls_list,
"finish_reason": finish_reason, "finish_reason": finish_reason,
}, },
@ -604,6 +634,7 @@ class GeminiClient(GenAIClient):
"message", "message",
{ {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
}, },
@ -616,6 +647,7 @@ class GeminiClient(GenAIClient):
"message", "message",
{ {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
}, },

View File

@ -337,6 +337,9 @@ class OllamaClient(GenAIClient):
response.get("done"), response.get("done"),
) )
content = message.get("content", "").strip() if message.get("content") else None content = message.get("content", "").strip() if message.get("content") else None
reasoning = (
message.get("thinking", "").strip() if message.get("thinking") else None
)
tool_calls = parse_tool_calls_from_message(message) tool_calls = parse_tool_calls_from_message(message)
finish_reason = "error" finish_reason = "error"
if response.get("done"): if response.get("done"):
@ -349,6 +352,7 @@ class OllamaClient(GenAIClient):
finish_reason = "stop" finish_reason = "stop"
return { return {
"content": content, "content": content,
"reasoning": reasoning,
"tool_calls": tool_calls, "tool_calls": tool_calls,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
@ -432,6 +436,9 @@ class OllamaClient(GenAIClient):
) )
response = await async_client.chat(**request_params) response = await async_client.chat(**request_params)
result = self._message_from_response(response) result = self._message_from_response(response)
reasoning = result.get("reasoning")
if reasoning:
yield ("reasoning_delta", reasoning)
content = result.get("content") content = result.get("content")
if content: if content:
yield ("content_delta", content) yield ("content_delta", content)
@ -450,6 +457,7 @@ class OllamaClient(GenAIClient):
headers=self._auth_headers(), headers=self._auth_headers(),
) )
content_parts: list[str] = [] content_parts: list[str] = []
reasoning_parts: list[str] = []
final_message: dict[str, Any] | None = None final_message: dict[str, Any] | None = None
final_chunk: Any = None final_chunk: Any = None
stream = await async_client.chat(**request_params) stream = await async_client.chat(**request_params)
@ -457,6 +465,10 @@ class OllamaClient(GenAIClient):
if not chunk or "message" not in chunk: if not chunk or "message" not in chunk:
continue continue
msg = chunk.get("message", {}) msg = chunk.get("message", {})
reasoning_delta = msg.get("thinking") or ""
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
yield ("reasoning_delta", reasoning_delta)
delta = msg.get("content") or "" delta = msg.get("content") or ""
if delta: if delta:
content_parts.append(delta) content_parts.append(delta)
@ -464,8 +476,10 @@ class OllamaClient(GenAIClient):
if chunk.get("done"): if chunk.get("done"):
final_chunk = chunk final_chunk = chunk
full_content = "".join(content_parts).strip() or None full_content = "".join(content_parts).strip() or None
full_reasoning = "".join(reasoning_parts).strip() or None
final_message = { final_message = {
"content": full_content, "content": full_content,
"reasoning": full_reasoning,
"tool_calls": None, "tool_calls": None,
"finish_reason": "stop", "finish_reason": "stop",
} }
@ -482,6 +496,7 @@ class OllamaClient(GenAIClient):
"message", "message",
{ {
"content": "".join(content_parts).strip() or None, "content": "".join(content_parts).strip() or None,
"reasoning": "".join(reasoning_parts).strip() or None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "stop", "finish_reason": "stop",
}, },

View File

@ -236,6 +236,10 @@ class OpenAIClient(GenAIClient):
choice = result.choices[0] choice = result.choices[0]
message = choice.message message = choice.message
content = message.content.strip() if message.content else None content = message.content.strip() if message.content else None
raw_reasoning = getattr(message, "reasoning_content", None) or getattr(
message, "reasoning", None
)
reasoning = raw_reasoning.strip() if raw_reasoning else None
tool_calls = None tool_calls = None
if message.tool_calls: if message.tool_calls:
@ -270,6 +274,7 @@ class OpenAIClient(GenAIClient):
return { return {
"content": content, "content": content,
"reasoning": reasoning,
"tool_calls": tool_calls, "tool_calls": tool_calls,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
@ -278,6 +283,7 @@ class OpenAIClient(GenAIClient):
logger.warning("OpenAI request timed out: %s", str(e)) logger.warning("OpenAI request timed out: %s", str(e))
return { return {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
@ -285,6 +291,7 @@ class OpenAIClient(GenAIClient):
logger.warning("OpenAI returned an error: %s", str(e)) logger.warning("OpenAI returned an error: %s", str(e))
return { return {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
@ -335,6 +342,7 @@ class OpenAIClient(GenAIClient):
# Use streaming API # Use streaming API
content_parts: list[str] = [] content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {} tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop" finish_reason = "stop"
usage_stats: Optional[dict[str, Any]] = None usage_stats: Optional[dict[str, Any]] = None
@ -356,6 +364,15 @@ class OpenAIClient(GenAIClient):
if choice.finish_reason: if choice.finish_reason:
finish_reason = choice.finish_reason finish_reason = choice.finish_reason
# Extract reasoning deltas (reasoning_content or reasoning,
# depending on the server)
reasoning_delta = getattr(delta, "reasoning_content", None) or getattr(
delta, "reasoning", None
)
if reasoning_delta:
reasoning_parts.append(reasoning_delta)
yield ("reasoning_delta", reasoning_delta)
# Extract content deltas # Extract content deltas
if delta.content: if delta.content:
content_parts.append(delta.content) content_parts.append(delta.content)
@ -384,6 +401,7 @@ class OpenAIClient(GenAIClient):
# Build final message # Build final message
full_content = "".join(content_parts).strip() or None full_content = "".join(content_parts).strip() or None
full_reasoning = "".join(reasoning_parts).strip() or None
# Convert tool calls to list format # Convert tool calls to list format
tool_calls_list = None tool_calls_list = None
@ -412,6 +430,7 @@ class OpenAIClient(GenAIClient):
"message", "message",
{ {
"content": full_content, "content": full_content,
"reasoning": full_reasoning,
"tool_calls": tool_calls_list, "tool_calls": tool_calls_list,
"finish_reason": finish_reason, "finish_reason": finish_reason,
}, },
@ -423,6 +442,7 @@ class OpenAIClient(GenAIClient):
"message", "message",
{ {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
}, },
@ -433,6 +453,7 @@ class OpenAIClient(GenAIClient):
"message", "message",
{ {
"content": None, "content": None,
"reasoning": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
}, },