mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-07-03 18:41:14 +03:00
Implement tool call history keeping
This commit is contained in:
parent
efe585a920
commit
7039dc5cb4
@ -1153,7 +1153,7 @@ async def chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for msg in body.messages:
|
for msg in body.messages:
|
||||||
msg_dict = {
|
msg_dict: Dict[str, Any] = {
|
||||||
"role": msg.role,
|
"role": msg.role,
|
||||||
"content": msg.content,
|
"content": msg.content,
|
||||||
}
|
}
|
||||||
@ -1161,9 +1161,18 @@ async def chat_completion(
|
|||||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||||
if msg.name:
|
if msg.name:
|
||||||
msg_dict["name"] = msg.name
|
msg_dict["name"] = msg.name
|
||||||
|
# Replayed assistant turns carry their original tool_calls so the
|
||||||
|
# rendered prefix matches the prior turn exactly (prompt caching).
|
||||||
|
if msg.tool_calls is not None:
|
||||||
|
msg_dict["tool_calls"] = msg.tool_calls
|
||||||
|
|
||||||
conversation.append(msg_dict)
|
conversation.append(msg_dict)
|
||||||
|
|
||||||
|
# Everything appended from here on belongs to the assistant turn we are
|
||||||
|
# about to generate. We hand this slice back to the client so it can replay
|
||||||
|
# it verbatim on the next turn, keeping the cached prompt prefix intact.
|
||||||
|
turn_start_len = len(conversation)
|
||||||
|
|
||||||
tool_iterations = 0
|
tool_iterations = 0
|
||||||
tool_calls: List[ToolCall] = []
|
tool_calls: List[ToolCall] = []
|
||||||
max_iterations = body.max_tool_iterations
|
max_iterations = body.max_tool_iterations
|
||||||
@ -1180,6 +1189,20 @@ async def chat_completion(
|
|||||||
|
|
||||||
async def stream_body_llm():
|
async def stream_body_llm():
|
||||||
nonlocal conversation, stream_tool_calls, stream_iterations
|
nonlocal conversation, stream_tool_calls, stream_iterations
|
||||||
|
|
||||||
|
def _emit_replay_messages(extra: Optional[List[Dict[str, Any]]] = None):
|
||||||
|
# Hand the client the exact messages appended for this assistant
|
||||||
|
# turn (assistant tool-call turns, tool results, injected image
|
||||||
|
# messages, and the final assistant message) so it can replay
|
||||||
|
# them verbatim next turn and keep the prompt cache warm.
|
||||||
|
turn_messages = conversation[turn_start_len:] + (extra or [])
|
||||||
|
return (
|
||||||
|
json.dumps({"type": "messages", "messages": turn_messages}).encode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
+ b"\n"
|
||||||
|
)
|
||||||
|
|
||||||
while stream_iterations < max_iterations:
|
while stream_iterations < max_iterations:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
logger.debug("Client disconnected, stopping chat stream")
|
logger.debug("Client disconnected, stopping chat stream")
|
||||||
@ -1266,9 +1289,20 @@ async def chat_completion(
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Final answer: the streaming loop never appends the
|
||||||
|
# last assistant message to `conversation`, so add it
|
||||||
|
# to the replay slice explicitly.
|
||||||
|
final_assistant = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": msg.get("content"),
|
||||||
|
}
|
||||||
|
yield _emit_replay_messages(extra=[final_assistant])
|
||||||
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
|
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
# Max iterations reached: replay whatever we accumulated so the
|
||||||
|
# next turn still starts from a cache-friendly prefix.
|
||||||
|
yield _emit_replay_messages()
|
||||||
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
|
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@ -1363,6 +1397,7 @@ async def chat_completion(
|
|||||||
finish_reason=response.get("finish_reason", "stop"),
|
finish_reason=response.get("finish_reason", "stop"),
|
||||||
tool_iterations=tool_iterations,
|
tool_iterations=tool_iterations,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
|
messages=conversation[turn_start_len:],
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1395,6 +1430,7 @@ async def chat_completion(
|
|||||||
finish_reason="length",
|
finish_reason="length",
|
||||||
tool_iterations=tool_iterations,
|
tool_iterations=tool_iterations,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
|
messages=conversation[turn_start_len:],
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Chat API request models."""
|
"""Chat API request models."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -11,13 +11,29 @@ class ChatMessage(BaseModel):
|
|||||||
role: str = Field(
|
role: str = Field(
|
||||||
description="Message role: 'user', 'assistant', 'system', or 'tool'"
|
description="Message role: 'user', 'assistant', 'system', or 'tool'"
|
||||||
)
|
)
|
||||||
content: str = Field(description="Message content")
|
content: Optional[Any] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Message content. Usually a string, but may be a multimodal content "
|
||||||
|
"list (e.g. text + image_url) or null for assistant turns that only "
|
||||||
|
"request tool calls."
|
||||||
|
),
|
||||||
|
)
|
||||||
tool_call_id: Optional[str] = Field(
|
tool_call_id: Optional[str] = Field(
|
||||||
default=None, description="For tool messages, the ID of the tool call"
|
default=None, description="For tool messages, the ID of the tool call"
|
||||||
)
|
)
|
||||||
name: Optional[str] = Field(
|
name: Optional[str] = Field(
|
||||||
default=None, description="For tool messages, the tool name"
|
default=None, description="For tool messages, the tool name"
|
||||||
)
|
)
|
||||||
|
tool_calls: Optional[list[dict[str, Any]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"For assistant messages replayed from prior turns, the OpenAI-format "
|
||||||
|
"tool calls the model previously requested. Replaying these verbatim "
|
||||||
|
"keeps the conversation prefix byte-for-byte identical so the model "
|
||||||
|
"server's prompt cache hits on follow-up turns."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user