diff --git a/frigate/api/chat.py b/frigate/api/chat.py index 7b450bac7..83d2f503c 100644 --- a/frigate/api/chat.py +++ b/frigate/api/chat.py @@ -1411,6 +1411,11 @@ When a user refers to a specific object they have seen or describe with identify ) + b"\n" ) + elif kind == "stats": + yield ( + json.dumps({"type": "stats", **value}).encode("utf-8") + + b"\n" + ) elif kind == "message": msg = value if msg.get("finish_reason") == "error": diff --git a/frigate/genai/azure-openai.py b/frigate/genai/azure-openai.py index 66d7d1568..04a2b8d55 100644 --- a/frigate/genai/azure-openai.py +++ b/frigate/genai/azure-openai.py @@ -10,6 +10,7 @@ from openai import AzureOpenAI from frigate.config import GenAIProviderEnum from frigate.genai import GenAIClient, register_genai_provider +from frigate.genai.openai import _stats_from_openai_usage logger = logging.getLogger(__name__) @@ -210,6 +211,7 @@ class OpenAIClient(GenAIClient): "messages": messages, "timeout": self.timeout, "stream": True, + "stream_options": {"include_usage": True}, } if tools: @@ -221,10 +223,15 @@ class OpenAIClient(GenAIClient): content_parts: list[str] = [] tool_calls_by_index: dict[int, dict[str, Any]] = {} finish_reason = "stop" + usage_stats: Optional[dict[str, Any]] = None stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload] for chunk in stream: + chunk_usage = getattr(chunk, "usage", None) + if chunk_usage is not None: + usage_stats = _stats_from_openai_usage(chunk_usage) + if not chunk or not chunk.choices: continue @@ -284,6 +291,9 @@ class OpenAIClient(GenAIClient): ) finish_reason = "tool_calls" + if usage_stats is not None: + yield ("stats", usage_stats) + yield ( "message", { diff --git a/frigate/genai/gemini.py b/frigate/genai/gemini.py index eec22a991..c1046428e 100644 --- a/frigate/genai/gemini.py +++ b/frigate/genai/gemini.py @@ -14,6 +14,20 @@ from frigate.genai import GenAIClient, register_genai_provider logger = logging.getLogger(__name__) +def _stats_from_gemini_usage(usage: Any) -> Optional[dict[str, Any]]: + """Build a stats dict from a Gemini usage_metadata object.""" + prompt_tokens = getattr(usage, "prompt_token_count", None) + completion_tokens = getattr(usage, "candidates_token_count", None) + if prompt_tokens is None and completion_tokens is None: + return None + stats: dict[str, Any] = {} + if isinstance(prompt_tokens, int): + stats["prompt_tokens"] = prompt_tokens + if isinstance(completion_tokens, int): + stats["completion_tokens"] = completion_tokens + return stats or None + + @register_genai_provider(GenAIProviderEnum.gemini) class GeminiClient(GenAIClient): """Generative AI client for Frigate using Gemini.""" @@ -471,6 +485,7 @@ class GeminiClient(GenAIClient): content_parts: list[str] = [] tool_calls_by_index: dict[int, dict[str, Any]] = {} finish_reason = "stop" + usage_stats: Optional[dict[str, Any]] = None stream = await self.provider.aio.models.generate_content_stream( model=self.genai_config.model, @@ -479,6 +494,12 @@ class GeminiClient(GenAIClient): ) async for chunk in stream: + chunk_usage = getattr(chunk, "usage_metadata", None) + if chunk_usage is not None: + maybe_stats = _stats_from_gemini_usage(chunk_usage) + if maybe_stats is not None: + usage_stats = maybe_stats + if not chunk or not chunk.candidates: continue @@ -565,6 +586,9 @@ class GeminiClient(GenAIClient): ) finish_reason = "tool_calls" + if usage_stats is not None: + yield ("stats", usage_stats) + yield ( "message", { diff --git a/frigate/genai/llama_cpp.py b/frigate/genai/llama_cpp.py index 24accdc02..c935207bf 100644 --- a/frigate/genai/llama_cpp.py +++ b/frigate/genai/llama_cpp.py @@ -18,6 +18,52 @@ from frigate.genai.utils import parse_tool_calls_from_message logger = logging.getLogger(__name__) +def _stats_from_llama_cpp_chunk(data: dict[str, Any]) -> Optional[dict[str, Any]]: + """Build a stats dict from a llama.cpp streaming chunk. + + Final-chunk `usage` carries authoritative token counts. Per-chunk + `timings` (enabled via timings_per_token) carries the running token + counts (prompt_n, predicted_n) and generation rate, so live updates + work mid-stream. + """ + usage = data.get("usage") or {} + timings = data.get("timings") or {} + prompt_tokens = usage.get("prompt_tokens") + completion_tokens = usage.get("completion_tokens") + predicted_ms = timings.get("predicted_ms") + tps = timings.get("predicted_per_second") + stats: dict[str, Any] = {} + + if not isinstance(prompt_tokens, int): + prompt_n = timings.get("prompt_n") + + if isinstance(prompt_n, int): + prompt_tokens = prompt_n + + if not isinstance(completion_tokens, int): + predicted_n = timings.get("predicted_n") + + if isinstance(predicted_n, int): + completion_tokens = predicted_n + + if not isinstance(prompt_tokens, int) and not isinstance(completion_tokens, int): + return None + + if isinstance(prompt_tokens, int): + stats["prompt_tokens"] = prompt_tokens + + if isinstance(completion_tokens, int): + stats["completion_tokens"] = completion_tokens + + if isinstance(predicted_ms, (int, float)) and predicted_ms > 0: + stats["completion_duration_ms"] = float(predicted_ms) + + if isinstance(tps, (int, float)) and tps > 0: + stats["tokens_per_second"] = float(tps) + + return stats or None + + def _parse_launch_arg(args: list[str], flag: str) -> str | None: """Return the value following `flag` in a positional argv list, or None.""" try: @@ -462,6 +508,8 @@ class LlamaCppClient(GenAIClient): } if stream: payload["stream"] = True + payload["stream_options"] = {"include_usage": True} + payload["timings_per_token"] = True if tools: payload["tools"] = tools if openai_tool_choice is not None: @@ -724,6 +772,9 @@ class LlamaCppClient(GenAIClient): data = json.loads(data_str) except json.JSONDecodeError: continue + maybe_stats = _stats_from_llama_cpp_chunk(data) + if maybe_stats is not None: + yield ("stats", maybe_stats) choices = data.get("choices") or [] if not choices: continue diff --git a/frigate/genai/ollama.py b/frigate/genai/ollama.py index bc3445961..fe286f64d 100644 --- a/frigate/genai/ollama.py +++ b/frigate/genai/ollama.py @@ -18,6 +18,37 @@ from frigate.genai.utils import parse_tool_calls_from_message logger = logging.getLogger(__name__) +def _extract_ollama_stats(response: Any) -> Optional[dict[str, Any]]: + """Build a stats dict from Ollama's response metadata. + + Ollama reports eval_count/eval_duration (generation) and + prompt_eval_count (context size). Durations are nanoseconds. + """ + if not response: + return None + if hasattr(response, "get"): + getter = response.get + else: + getter = lambda key: getattr(response, key, None) # noqa: E731 + + eval_count = getter("eval_count") + eval_duration_ns = getter("eval_duration") + prompt_eval_count = getter("prompt_eval_count") + if eval_count is None and prompt_eval_count is None: + return None + + stats: dict[str, Any] = {} + if isinstance(prompt_eval_count, int): + stats["prompt_tokens"] = prompt_eval_count + if isinstance(eval_count, int): + stats["completion_tokens"] = eval_count + if isinstance(eval_duration_ns, int) and eval_duration_ns > 0: + stats["completion_duration_ms"] = eval_duration_ns / 1_000_000 + if isinstance(eval_count, int) and eval_count > 0: + stats["tokens_per_second"] = eval_count / (eval_duration_ns / 1_000_000_000) + return stats or None + + def _normalize_multimodal_content( content: Any, ) -> tuple[Optional[str], Optional[list[bytes]]]: @@ -403,6 +434,9 @@ class OllamaClient(GenAIClient): content = result.get("content") if content: yield ("content_delta", content) + stats = _extract_ollama_stats(response) + if stats is not None: + yield ("stats", stats) yield ("message", result) return @@ -416,6 +450,7 @@ class OllamaClient(GenAIClient): ) content_parts: list[str] = [] final_message: dict[str, Any] | None = None + final_chunk: Any = None stream = await async_client.chat(**request_params) async for chunk in stream: if not chunk or "message" not in chunk: @@ -426,6 +461,7 @@ class OllamaClient(GenAIClient): content_parts.append(delta) yield ("content_delta", delta) if chunk.get("done"): + final_chunk = chunk full_content = "".join(content_parts).strip() or None final_message = { "content": full_content, @@ -434,6 +470,10 @@ class OllamaClient(GenAIClient): } break + stats = _extract_ollama_stats(final_chunk) + if stats is not None: + yield ("stats", stats) + if final_message is not None: yield ("message", final_message) else: diff --git a/frigate/genai/openai.py b/frigate/genai/openai.py index 432641332..09e0cf538 100644 --- a/frigate/genai/openai.py +++ b/frigate/genai/openai.py @@ -14,6 +14,22 @@ from frigate.genai import GenAIClient, register_genai_provider logger = logging.getLogger(__name__) +def _stats_from_openai_usage(usage: Any) -> Optional[dict[str, Any]]: + """Build a stats dict from an OpenAI-compatible usage object.""" + if usage is None: + return None + prompt_tokens = getattr(usage, "prompt_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) + if prompt_tokens is None and completion_tokens is None: + return None + stats: dict[str, Any] = {} + if isinstance(prompt_tokens, int): + stats["prompt_tokens"] = prompt_tokens + if isinstance(completion_tokens, int): + stats["completion_tokens"] = completion_tokens + return stats or None + + @register_genai_provider(GenAIProviderEnum.openai) class OpenAIClient(GenAIClient): """Generative AI client for Frigate using OpenAI.""" @@ -298,6 +314,7 @@ class OpenAIClient(GenAIClient): "messages": messages, "timeout": self.timeout, "stream": True, + "stream_options": {"include_usage": True}, } if tools: @@ -318,10 +335,15 @@ class OpenAIClient(GenAIClient): content_parts: list[str] = [] tool_calls_by_index: dict[int, dict[str, Any]] = {} finish_reason = "stop" + usage_stats: Optional[dict[str, Any]] = None stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload] for chunk in stream: + chunk_usage = getattr(chunk, "usage", None) + if chunk_usage is not None: + usage_stats = _stats_from_openai_usage(chunk_usage) + if not chunk or not chunk.choices: continue @@ -381,6 +403,9 @@ class OpenAIClient(GenAIClient): ) finish_reason = "tool_calls" + if usage_stats is not None: + yield ("stats", usage_stats) + yield ( "message", { diff --git a/web/public/locales/en/views/chat.json b/web/public/locales/en/views/chat.json index 6d78dc71f..bc320c204 100644 --- a/web/public/locales/en/views/chat.json +++ b/web/public/locales/en/views/chat.json @@ -42,5 +42,23 @@ "show_camera_status": "What is the current status of my cameras?", "recap": "What happened while I was away?", "watch_camera": "Watch the front door and let me know if anyone shows up" + }, + "new_chat": "New chat", + "settings": { + "title": "Chat settings", + "show_stats": { + "title": "Show stats", + "desc": "Show generation rate and context size for chat responses.", + "while_generating": "While generating", + "always": "Always" + }, + "auto_scroll": { + "title": "Auto-scroll", + "desc": "Follow new messages as they arrive." + } + }, + "stats": { + "context": "{{tokens}} tokens", + "tokens_per_second": "{{rate}} t/s" } } diff --git a/web/src/components/chat/ChatEventThumbnailsRow.tsx b/web/src/components/chat/ChatEventThumbnailsRow.tsx index a12153e89..94eca3f2d 100644 --- a/web/src/components/chat/ChatEventThumbnailsRow.tsx +++ b/web/src/components/chat/ChatEventThumbnailsRow.tsx @@ -6,7 +6,6 @@ import { TooltipContent, TooltipTrigger, } from "@/components/ui/tooltip"; -import { cn } from "@/lib/utils"; type ChatEvent = { id: string; score?: number }; @@ -37,10 +36,7 @@ export function ChatEventThumbnailsRow({ const renderThumb = (event: ChatEvent, isAnchor = false) => (
); diff --git a/web/src/components/chat/ChatMessage.tsx b/web/src/components/chat/ChatMessage.tsx index 9a91d7035..0a5c02763 100644 --- a/web/src/components/chat/ChatMessage.tsx +++ b/web/src/components/chat/ChatMessage.tsx @@ -17,6 +17,7 @@ import { import { cn } from "@/lib/utils"; import { ChatAttachmentChip } from "@/components/chat/ChatAttachmentChip"; import { parseAttachedEvent } from "@/utils/chatUtil"; +import type { ChatStats, ShowStatsMode } from "@/types/chat"; type MessageBubbleProps = { role: "user" | "assistant"; @@ -24,14 +25,29 @@ type MessageBubbleProps = { messageIndex?: number; onEditSubmit?: (messageIndex: number, newContent: string) => void; isComplete?: boolean; + stats?: ChatStats; + showStats?: ShowStatsMode; }; +function formatTokens(n: number | undefined): string | null { + if (n === undefined) return null; + if (n >= 1000) return `${(n / 1000).toFixed(1)}k`; + return String(n); +} + +function formatRate(rate: number | undefined): string | null { + if (rate === undefined || rate <= 0) return null; + return rate >= 10 ? rate.toFixed(0) : rate.toFixed(1); +} + export function MessageBubble({ role, content, messageIndex = 0, onEditSubmit, isComplete = true, + stats, + showStats = "while_generating", }: MessageBubbleProps) { const { t } = useTranslation(["views/chat", "common"]); const isUser = role === "user"; @@ -214,7 +230,7 @@ export function MessageBubble({ )} -
+
{isUser && onEditSubmit != null && ( @@ -256,6 +272,27 @@ export function MessageBubble({ )} + {!isUser && + stats && + (showStats === "always" || !isComplete) && + (() => { + const ctx = formatTokens(stats.promptTokens); + const rate = formatRate(stats.tokensPerSecond); + if (ctx === null && rate === null) return null; + return ( +
+ {ctx !== null && ( + {t("stats.context", { tokens: ctx })} + )} + {ctx !== null && rate !== null && ( + + )} + {rate !== null && ( + {t("stats.tokens_per_second", { rate })} + )} +
+ ); + })()}
); diff --git a/web/src/components/chat/ChatSettings.tsx b/web/src/components/chat/ChatSettings.tsx new file mode 100644 index 000000000..0f68ef22d --- /dev/null +++ b/web/src/components/chat/ChatSettings.tsx @@ -0,0 +1,108 @@ +import { Button } from "@/components/ui/button"; +import { useState } from "react"; +import { isDesktop } from "react-device-detect"; +import { cn } from "@/lib/utils"; +import PlatformAwareDialog from "../overlay/dialog/PlatformAwareDialog"; +import { FaCog } from "react-icons/fa"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, +} from "@/components/ui/select"; +import { Switch } from "@/components/ui/switch"; +import { Label } from "@/components/ui/label"; +import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu"; +import { useTranslation } from "react-i18next"; +import type { ShowStatsMode } from "@/types/chat"; + +type ChatSettingsProps = { + showStats: ShowStatsMode; + setShowStats: (mode: ShowStatsMode) => void; + autoScroll: boolean; + setAutoScroll: (enabled: boolean) => void; +}; + +export default function ChatSettings({ + showStats, + setShowStats, + autoScroll, + setAutoScroll, +}: ChatSettingsProps) { + const { t } = useTranslation(["views/chat"]); + const [open, setOpen] = useState(false); + + const trigger = ( + + ); + + const content = ( +
+
+
+
{t("settings.show_stats.title")}
+
+ {t("settings.show_stats.desc")} +
+
+ +
+ +
+
+ +
+ {t("settings.auto_scroll.desc")} +
+
+ +
+
+ ); + + return ( + + ); +} diff --git a/web/src/components/chat/ChatStartingState.tsx b/web/src/components/chat/ChatStartingState.tsx index e6b611bf9..a0a3a044c 100644 --- a/web/src/components/chat/ChatStartingState.tsx +++ b/web/src/components/chat/ChatStartingState.tsx @@ -54,7 +54,9 @@ export function ChatStartingState({ onSendMessage }: ChatStartingStateProps) {

{t("title")}

-

{t("subtitle")}

+

+ {t("subtitle")} +

diff --git a/web/src/pages/Chat.tsx b/web/src/pages/Chat.tsx index 474aa6d21..970fa3d36 100644 --- a/web/src/pages/Chat.tsx +++ b/web/src/pages/Chat.tsx @@ -1,7 +1,7 @@ import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { FaArrowUpLong, FaStop } from "react-icons/fa6"; -import { LuCircleAlert } from "react-icons/lu"; +import { LuCircleAlert, LuMessageSquarePlus } from "react-icons/lu"; import { useTranslation } from "react-i18next"; import { useState, useCallback, useRef, useEffect, useMemo } from "react"; import axios from "axios"; @@ -12,7 +12,9 @@ import { ChatStartingState } from "@/components/chat/ChatStartingState"; import { ChatAttachmentChip } from "@/components/chat/ChatAttachmentChip"; import { ChatQuickReplies } from "@/components/chat/ChatQuickReplies"; import { ChatPaperclipButton } from "@/components/chat/ChatPaperclipButton"; -import type { ChatMessage } from "@/types/chat"; +import ChatSettings from "@/components/chat/ChatSettings"; +import type { ChatMessage, ShowStatsMode } from "@/types/chat"; +import { usePersistence } from "@/hooks/use-persistence"; import { getEventIdsFromSearchObjectsToolCalls, getFindSimilarObjectsFromToolCalls, @@ -27,6 +29,14 @@ export default function ChatPage() { const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); const [attachedEventId, setAttachedEventId] = useState(null); + const [showStats, setShowStats] = usePersistence( + "chat-show-stats", + "while_generating", + ); + const [autoScroll, setAutoScroll] = usePersistence( + "chat-auto-scroll", + true, + ); const scrollRef = useRef(null); const abortRef = useRef(null); @@ -36,13 +46,14 @@ export default function ChatPage() { // Auto-scroll to bottom when messages change, but only if near bottom useEffect(() => { + if (!autoScroll) return; const el = scrollRef.current; if (!el) return; const isNearBottom = el.scrollHeight - el.scrollTop - el.clientHeight < 150; if (isNearBottom) { el.scrollTo({ top: el.scrollHeight, behavior: "smooth" }); } - }, [messages]); + }, [messages, autoScroll]); const submitConversation = useCallback( async (messagesToSend: ChatMessage[]) => { @@ -125,6 +136,16 @@ export default function ChatPage() { setIsLoading(false); }, []); + const startNewChat = useCallback(() => { + abortRef.current?.abort(); + abortRef.current = null; + setIsLoading(false); + setMessages([]); + setInput(""); + setAttachedEventId(null); + setError(null); + }, []); + const handleEditSubmit = useCallback( (messageIndex: number, newContent: string) => { const newList: ChatMessage[] = [ @@ -140,127 +161,157 @@ export default function ChatPage() { setAttachedEventId(null); }, []); + const hasStarted = messages.length > 0; + return ( -
-
- {messages.length === 0 ? ( - { - setInput(""); - submitConversation([{ role: "user", content: message }]); - }} - /> - ) : ( - <> -
- {messages.map((msg, i) => { - const isLastAssistant = - i === messages.length - 1 && msg.role === "assistant"; - const isComplete = - msg.role === "user" || !isLoading || !isLastAssistant; - const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0; - const hasContent = !!msg.content?.trim(); - const showProcessing = - isLastAssistant && isLoading && !hasContent; +
+
+ {hasStarted && ( + + )} + +
+
+
+
+ {hasStarted ? ( +
+ {messages.map((msg, i) => { + const isLastAssistant = + i === messages.length - 1 && msg.role === "assistant"; + const isComplete = + msg.role === "user" || !isLoading || !isLastAssistant; + const hasToolCalls = + msg.toolCalls && msg.toolCalls.length > 0; + const hasContent = !!msg.content?.trim(); + const showProcessing = + isLastAssistant && isLoading && !hasContent; - // Hide empty placeholder only when there are no tool calls yet - if ( - isLastAssistant && - isLoading && - !hasContent && - !hasToolCalls - ) - return ( -
- - - -
- ); - - return ( -
- {msg.role === "assistant" && hasToolCalls && ( - - )} - {showProcessing ? ( -
- - - + // Hide empty placeholder only when there are no tool calls yet + if ( + isLastAssistant && + isLoading && + !hasContent && + !hasToolCalls + ) + return ( +
+ + +
- ) : ( - - )} - {msg.role === "assistant" && - isComplete && - (() => { - const similar = getFindSimilarObjectsFromToolCalls( - msg.toolCalls, - ); - if (similar) { + ); + + return ( +
+ {msg.role === "assistant" && hasToolCalls && ( + + )} + {showProcessing ? ( +
+ + + +
+ ) : ( + + )} + {msg.role === "assistant" && + isComplete && + (() => { + const similar = getFindSimilarObjectsFromToolCalls( + msg.toolCalls, + ); + if (similar) { + return ( + + ); + } + const events = getEventIdsFromSearchObjectsToolCalls( + msg.toolCalls, + ); return ( ); - } - const events = getEventIdsFromSearchObjectsToolCalls( - msg.toolCalls, - ); - return ( - - ); - })()} -
- ); - })} - {error && ( -

- - {error} -

- )} -
- - )} - {messages.length > 0 && ( - - )} + })()} +
+ ); + })} + {error && ( +

+ + {error} +

+ )} +
+ ) : ( + { + setInput(""); + submitConversation([{ role: "user", content: message }]); + }} + /> + )} +
+
+ {hasStarted && ( +
+
+ +
+
+ )}
); } @@ -298,7 +349,7 @@ function ChatEntry({ }; return ( -
+
{attachedEventId && (
void; /** Called when the stream finishes (success or error). */ onDone: () => void; + /** Called when the stream emits token/timing stats. The stats are also + * attached to the last assistant message in updateMessages, so consumers + * can usually rely on the message itself rather than wiring this up. */ + onStats?: (stats: ChatStats) => void; /** Message used when fetch throws and no server error is available. */ defaultErrorMessage?: string; }; +type StatsChunk = { + type: "stats"; + prompt_tokens?: number; + completion_tokens?: number; + completion_duration_ms?: number; + tokens_per_second?: number; +}; + type StreamChunk = | { type: "error"; error: string } | { type: "tool_calls"; tool_calls: ToolCall[] } - | { type: "content"; delta: string }; + | { type: "content"; delta: string } + | StatsChunk; /** * POST to chat/completion with stream: true, parse NDJSON stream, and invoke @@ -31,6 +44,7 @@ export async function streamChatCompletion( updateMessages, onError, onDone, + onStats, defaultErrorMessage = "Something went wrong. Please try again.", } = callbacks; @@ -95,6 +109,23 @@ export async function streamChatCompletion( }); return "continue"; } + if (data.type === "stats") { + const stats: ChatStats = { + promptTokens: data.prompt_tokens, + completionTokens: data.completion_tokens, + completionDurationMs: data.completion_duration_ms, + tokensPerSecond: data.tokens_per_second, + }; + updateMessages((prev) => { + const next = [...prev]; + const lastMsg = next[next.length - 1]; + if (lastMsg?.role === "assistant") + next[next.length - 1] = { ...lastMsg, stats }; + return next; + }); + onStats?.(stats); + return "continue"; + } return "continue"; };