mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-06-21 03:41:55 +03:00
Implement tool call history keeping
This commit is contained in:
parent
efe585a920
commit
7039dc5cb4
@ -1153,7 +1153,7 @@ async def chat_completion(
|
||||
)
|
||||
|
||||
for msg in body.messages:
|
||||
msg_dict = {
|
||||
msg_dict: Dict[str, Any] = {
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
}
|
||||
@ -1161,9 +1161,18 @@ async def chat_completion(
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
if msg.name:
|
||||
msg_dict["name"] = msg.name
|
||||
# Replayed assistant turns carry their original tool_calls so the
|
||||
# rendered prefix matches the prior turn exactly (prompt caching).
|
||||
if msg.tool_calls is not None:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
|
||||
conversation.append(msg_dict)
|
||||
|
||||
# Everything appended from here on belongs to the assistant turn we are
|
||||
# about to generate. We hand this slice back to the client so it can replay
|
||||
# it verbatim on the next turn, keeping the cached prompt prefix intact.
|
||||
turn_start_len = len(conversation)
|
||||
|
||||
tool_iterations = 0
|
||||
tool_calls: List[ToolCall] = []
|
||||
max_iterations = body.max_tool_iterations
|
||||
@ -1180,6 +1189,20 @@ async def chat_completion(
|
||||
|
||||
async def stream_body_llm():
|
||||
nonlocal conversation, stream_tool_calls, stream_iterations
|
||||
|
||||
def _emit_replay_messages(extra: Optional[List[Dict[str, Any]]] = None):
|
||||
# Hand the client the exact messages appended for this assistant
|
||||
# turn (assistant tool-call turns, tool results, injected image
|
||||
# messages, and the final assistant message) so it can replay
|
||||
# them verbatim next turn and keep the prompt cache warm.
|
||||
turn_messages = conversation[turn_start_len:] + (extra or [])
|
||||
return (
|
||||
json.dumps({"type": "messages", "messages": turn_messages}).encode(
|
||||
"utf-8"
|
||||
)
|
||||
+ b"\n"
|
||||
)
|
||||
|
||||
while stream_iterations < max_iterations:
|
||||
if await request.is_disconnected():
|
||||
logger.debug("Client disconnected, stopping chat stream")
|
||||
@ -1266,9 +1289,20 @@ async def chat_completion(
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Final answer: the streaming loop never appends the
|
||||
# last assistant message to `conversation`, so add it
|
||||
# to the replay slice explicitly.
|
||||
final_assistant = {
|
||||
"role": "assistant",
|
||||
"content": msg.get("content"),
|
||||
}
|
||||
yield _emit_replay_messages(extra=[final_assistant])
|
||||
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
|
||||
return
|
||||
else:
|
||||
# Max iterations reached: replay whatever we accumulated so the
|
||||
# next turn still starts from a cache-friendly prefix.
|
||||
yield _emit_replay_messages()
|
||||
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
|
||||
|
||||
return StreamingResponse(
|
||||
@ -1363,6 +1397,7 @@ async def chat_completion(
|
||||
finish_reason=response.get("finish_reason", "stop"),
|
||||
tool_iterations=tool_iterations,
|
||||
tool_calls=tool_calls,
|
||||
messages=conversation[turn_start_len:],
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
@ -1395,6 +1430,7 @@ async def chat_completion(
|
||||
finish_reason="length",
|
||||
tool_iterations=tool_iterations,
|
||||
tool_calls=tool_calls,
|
||||
messages=conversation[turn_start_len:],
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Chat API request models."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -11,13 +11,29 @@ class ChatMessage(BaseModel):
|
||||
role: str = Field(
|
||||
description="Message role: 'user', 'assistant', 'system', or 'tool'"
|
||||
)
|
||||
content: str = Field(description="Message content")
|
||||
content: Optional[Any] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Message content. Usually a string, but may be a multimodal content "
|
||||
"list (e.g. text + image_url) or null for assistant turns that only "
|
||||
"request tool calls."
|
||||
),
|
||||
)
|
||||
tool_call_id: Optional[str] = Field(
|
||||
default=None, description="For tool messages, the ID of the tool call"
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None, description="For tool messages, the tool name"
|
||||
)
|
||||
tool_calls: Optional[list[dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"For assistant messages replayed from prior turns, the OpenAI-format "
|
||||
"tool calls the model previously requested. Replaying these verbatim "
|
||||
"keeps the conversation prefix byte-for-byte identical so the model "
|
||||
"server's prompt cache hits on follow-up turns."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user