Include system message to not break cache

This commit is contained in:
Nicolas Mowen 2026-06-11 17:07:55 -06:00
parent 7d66d063aa
commit d616230a04
5 changed files with 62 additions and 57 deletions

View File

@ -1138,19 +1138,23 @@ async def chat_completion(
)
conversation = []
system_prompt = build_chat_system_prompt(
config=config,
allowed_cameras=allowed_cameras,
semantic_search_enabled=semantic_search_enabled,
attribute_classifications=attribute_classifications,
)
conversation.append(
{
"role": "system",
"content": system_prompt,
}
)
# Build the system message only when the client hasn't already pinned one.
# The first turn has no system message; we generate it (with the current
# timestamp) and return the whole chain so the client persists it. Later
# turns send it back verbatim, freezing the timestamp so the prompt prefix
# stays byte-identical and the model server's prompt cache keeps hitting.
if not body.messages or body.messages[0].role != "system":
conversation.append(
{
"role": "system",
"content": build_chat_system_prompt(
config=config,
allowed_cameras=allowed_cameras,
semantic_search_enabled=semantic_search_enabled,
attribute_classifications=attribute_classifications,
),
}
)
for msg in body.messages:
msg_dict = {
@ -1166,9 +1170,6 @@ async def chat_completion(
conversation.append(msg_dict)
# Messages appended past this point form this turn's replay record.
turn_start_len = len(conversation)
tool_iterations = 0
tool_calls: list[ToolCall] = []
max_iterations = body.max_tool_iterations
@ -1185,12 +1186,12 @@ async def chat_completion(
async def stream_body_llm():
nonlocal conversation, stream_iterations
def _emit_replay_messages(extra: Optional[list[dict[str, Any]]] = None):
turn_messages = conversation[turn_start_len:] + (extra or [])
def _emit_chain(extra: Optional[list[dict[str, Any]]] = None):
# Return the full conversation (including the system message) so
# the client persists and replays it verbatim next turn.
chain = conversation + (extra or [])
return (
json.dumps({"type": "messages", "messages": turn_messages}).encode(
"utf-8"
)
json.dumps({"type": "messages", "messages": chain}).encode("utf-8")
+ b"\n"
)
@ -1266,14 +1267,14 @@ async def chat_completion(
)
conversation.extend(tool_results)
conversation.extend(extra_msgs)
# Running turn slice: lets the client render tool
# Emit the running chain so the client can render tool
# calls live and replay them verbatim next turn.
yield _emit_replay_messages()
yield _emit_chain()
break
else:
# Streaming never appends the final assistant message
# to the conversation, so add it to the replay slice.
yield _emit_replay_messages(
# to the conversation, so add it to the chain.
yield _emit_chain(
extra=[
{
"role": "assistant",
@ -1284,7 +1285,7 @@ async def chat_completion(
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
return
else:
yield _emit_replay_messages()
yield _emit_chain()
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
return StreamingResponse(
@ -1331,12 +1332,12 @@ async def chat_completion(
if body.stream:
final_reasoning = response.get("reasoning")
turn_messages = conversation[turn_start_len:]
chain = list(conversation)
async def stream_body() -> Any:
yield (
json.dumps(
{"type": "messages", "messages": turn_messages}
{"type": "messages", "messages": chain}
).encode("utf-8")
+ b"\n"
)
@ -1375,7 +1376,7 @@ async def chat_completion(
finish_reason=response.get("finish_reason", "stop"),
tool_iterations=tool_iterations,
tool_calls=tool_calls,
messages=conversation[turn_start_len:],
messages=list(conversation),
).model_dump(),
)
@ -1408,7 +1409,7 @@ async def chat_completion(
finish_reason="length",
tool_iterations=tool_iterations,
tool_calls=tool_calls,
messages=conversation[turn_start_len:],
messages=list(conversation),
).model_dump(),
)

View File

@ -59,9 +59,9 @@ class ChatCompletionResponse(BaseModel):
messages: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"The exact conversation messages appended for this assistant turn "
"(assistant tool-call turns, tool results, and the final assistant "
"message). Replay these verbatim as conversation history on the next "
"request to keep the model server's prompt cache prefix intact."
"The full conversation chain, including the system message. Persist "
"and replay this verbatim on the next request so the prompt prefix "
"stays byte-identical and the model server's prompt cache keeps "
"hitting."
),
)

View File

@ -30,7 +30,7 @@ import {
type StreamingTurn = {
content: string;
reasoning: string;
turn: ChatMessage[];
chain: ChatMessage[];
stats?: ChatStats;
};
@ -99,7 +99,7 @@ export default function ChatPage() {
setError(null);
setMessages(messagesToSend);
setStreaming({ content: "", reasoning: "", turn: [] });
setStreaming({ content: "", reasoning: "", chain: [] });
setIsLoading(true);
const baseURL = axios.defaults.baseURL ?? "";
@ -112,7 +112,7 @@ export default function ChatPage() {
const controller = new AbortController();
abortRef.current = controller;
let turn: ChatMessage[] = [];
let chain: ChatMessage[] = [];
let stats: ChatStats | undefined;
let reasoning = "";
let hadError = false;
@ -130,9 +130,9 @@ export default function ChatPage() {
s ? { ...s, reasoning: s.reasoning + delta } : s,
);
},
onTurnMessages: (turnMessages) => {
turn = turnMessages;
setStreaming((s) => (s ? { ...s, turn: turnMessages } : s));
onChain: (fullChain) => {
chain = fullChain;
setStreaming((s) => (s ? { ...s, chain: fullChain } : s));
},
onStats: (s) => {
stats = s;
@ -146,14 +146,15 @@ export default function ChatPage() {
abortRef.current = null;
setIsLoading(false);
setStreaming(null);
const lastMsg = turn[turn.length - 1];
const lastMsg = chain[chain.length - 1];
if (!hadError && lastMsg?.role === "assistant") {
const committed = turn.map((m, i) =>
i === turn.length - 1
? { ...m, reasoning: reasoning || undefined, stats }
: m,
setMessages(
chain.map((m, i) =>
i === chain.length - 1
? { ...m, reasoning: reasoning || undefined, stats }
: m,
),
);
setMessages((prev) => [...prev, ...committed]);
}
},
defaultErrorMessage: t("error"),
@ -228,15 +229,17 @@ export default function ChatPage() {
const hasStarted = messages.length > 0 || streaming != null;
// The conversation plus any in-flight turn, rendered as one flat list.
const renderList = streaming ? [...messages, ...streaming.turn] : messages;
// While streaming, the backend's in-flight chain is the source of truth;
// otherwise the committed conversation is.
const renderList =
streaming && streaming.chain.length ? streaming.chain : messages;
const responses = toolResponsesById(renderList);
const streamingTail = streaming?.turn[streaming.turn.length - 1];
const renderTail = renderList[renderList.length - 1];
const finalShown =
streamingTail?.role === "assistant" && hasText(streamingTail.content);
renderTail?.role === "assistant" && hasText(renderTail.content);
const renderMessage = (msg: ChatMessage, i: number) => {
if (msg.role === "tool") return null;
if (msg.role === "system" || msg.role === "tool") return null;
if (msg.role === "user") {
if (!hasText(msg.content)) return null;

View File

@ -10,7 +10,7 @@ export type WireToolCall = {
};
export type ChatMessage = {
role: "user" | "assistant" | "tool";
role: "system" | "user" | "assistant" | "tool";
content: unknown;
tool_call_id?: string;
name?: string;

View File

@ -5,9 +5,10 @@ export type StreamChatCallbacks = {
onContentDelta: (delta: string) => void;
/** Streamed delta of the assistant's reasoning trace. */
onReasoningDelta: (delta: string) => void;
/** The exact wire messages appended for this turn so far (tool-call turns,
* tool results, and on the final emission the final assistant message). */
onTurnMessages: (messages: ChatMessage[]) => void;
/** The full conversation chain so far (system message, history, this turn's
* tool-call turns, tool results, and on the final emission the final
* assistant message). */
onChain: (chain: ChatMessage[]) => void;
/** Token/timing stats for the turn. */
onStats: (stats: ChatStats) => void;
/** Called when the stream sends an error or fetch fails. */
@ -52,7 +53,7 @@ export async function streamChatCompletion(
const {
onContentDelta,
onReasoningDelta,
onTurnMessages,
onChain,
onStats,
onError,
onDone,
@ -99,7 +100,7 @@ export async function streamChatCompletion(
return "break";
}
if (data.type === "messages") {
onTurnMessages(data.messages ?? []);
onChain(data.messages ?? []);
return "continue";
}
if (data.type === "content" && data.delta !== undefined) {