Implement tool call history keeping

This commit is contained in:
Nicolas Mowen 2026-06-11 16:14:33 -06:00
parent efe585a920
commit 7039dc5cb4
2 changed files with 55 additions and 3 deletions

View File

@ -1153,7 +1153,7 @@ async def chat_completion(
)
for msg in body.messages:
msg_dict = {
msg_dict: Dict[str, Any] = {
"role": msg.role,
"content": msg.content,
}
@ -1161,9 +1161,18 @@ async def chat_completion(
msg_dict["tool_call_id"] = msg.tool_call_id
if msg.name:
msg_dict["name"] = msg.name
# Replayed assistant turns carry their original tool_calls so the
# rendered prefix matches the prior turn exactly (prompt caching).
if msg.tool_calls is not None:
msg_dict["tool_calls"] = msg.tool_calls
conversation.append(msg_dict)
# Everything appended from here on belongs to the assistant turn we are
# about to generate. We hand this slice back to the client so it can replay
# it verbatim on the next turn, keeping the cached prompt prefix intact.
turn_start_len = len(conversation)
tool_iterations = 0
tool_calls: List[ToolCall] = []
max_iterations = body.max_tool_iterations
@ -1180,6 +1189,20 @@ async def chat_completion(
async def stream_body_llm():
nonlocal conversation, stream_tool_calls, stream_iterations
def _emit_replay_messages(extra: Optional[List[Dict[str, Any]]] = None):
# Hand the client the exact messages appended for this assistant
# turn (assistant tool-call turns, tool results, injected image
# messages, and the final assistant message) so it can replay
# them verbatim next turn and keep the prompt cache warm.
turn_messages = conversation[turn_start_len:] + (extra or [])
return (
json.dumps({"type": "messages", "messages": turn_messages}).encode(
"utf-8"
)
+ b"\n"
)
while stream_iterations < max_iterations:
if await request.is_disconnected():
logger.debug("Client disconnected, stopping chat stream")
@ -1266,9 +1289,20 @@ async def chat_completion(
)
break
else:
# Final answer: the streaming loop never appends the
# last assistant message to `conversation`, so add it
# to the replay slice explicitly.
final_assistant = {
"role": "assistant",
"content": msg.get("content"),
}
yield _emit_replay_messages(extra=[final_assistant])
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
return
else:
# Max iterations reached: replay whatever we accumulated so the
# next turn still starts from a cache-friendly prefix.
yield _emit_replay_messages()
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
return StreamingResponse(
@ -1363,6 +1397,7 @@ async def chat_completion(
finish_reason=response.get("finish_reason", "stop"),
tool_iterations=tool_iterations,
tool_calls=tool_calls,
messages=conversation[turn_start_len:],
).model_dump(),
)
@ -1395,6 +1430,7 @@ async def chat_completion(
finish_reason="length",
tool_iterations=tool_iterations,
tool_calls=tool_calls,
messages=conversation[turn_start_len:],
).model_dump(),
)

View File

@ -1,6 +1,6 @@
"""Chat API request models."""
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
@ -11,13 +11,29 @@ class ChatMessage(BaseModel):
role: str = Field(
description="Message role: 'user', 'assistant', 'system', or 'tool'"
)
content: str = Field(description="Message content")
content: Optional[Any] = Field(
default=None,
description=(
"Message content. Usually a string, but may be a multimodal content "
"list (e.g. text + image_url) or null for assistant turns that only "
"request tool calls."
),
)
tool_call_id: Optional[str] = Field(
default=None, description="For tool messages, the ID of the tool call"
)
name: Optional[str] = Field(
default=None, description="For tool messages, the tool name"
)
tool_calls: Optional[list[dict[str, Any]]] = Field(
default=None,
description=(
"For assistant messages replayed from prior turns, the OpenAI-format "
"tool calls the model previously requested. Replaying these verbatim "
"keeps the conversation prefix byte-for-byte identical so the model "
"server's prompt cache hits on follow-up turns."
),
)
class ChatCompletionRequest(BaseModel):