mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-15 09:50:51 +03:00
Chat improvements (#23195)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
* Support token streaming stats * Propogate streaming token stats to chat calls * Show token stats for each image * Add settings to handle token stats and other options * i18n * Use select * Improve mobile layout and spacing
This commit is contained in:
parent
78fc472026
commit
d9c1ea908d
@ -1411,6 +1411,11 @@ When a user refers to a specific object they have seen or describe with identify
|
||||
)
|
||||
+ b"\n"
|
||||
)
|
||||
elif kind == "stats":
|
||||
yield (
|
||||
json.dumps({"type": "stats", **value}).encode("utf-8")
|
||||
+ b"\n"
|
||||
)
|
||||
elif kind == "message":
|
||||
msg = value
|
||||
if msg.get("finish_reason") == "error":
|
||||
|
||||
@ -10,6 +10,7 @@ from openai import AzureOpenAI
|
||||
|
||||
from frigate.config import GenAIProviderEnum
|
||||
from frigate.genai import GenAIClient, register_genai_provider
|
||||
from frigate.genai.openai import _stats_from_openai_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -210,6 +211,7 @@ class OpenAIClient(GenAIClient):
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
if tools:
|
||||
@ -221,10 +223,15 @@ class OpenAIClient(GenAIClient):
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
usage_stats: Optional[dict[str, Any]] = None
|
||||
|
||||
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||
|
||||
for chunk in stream:
|
||||
chunk_usage = getattr(chunk, "usage", None)
|
||||
if chunk_usage is not None:
|
||||
usage_stats = _stats_from_openai_usage(chunk_usage)
|
||||
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
@ -284,6 +291,9 @@ class OpenAIClient(GenAIClient):
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
if usage_stats is not None:
|
||||
yield ("stats", usage_stats)
|
||||
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
|
||||
@ -14,6 +14,20 @@ from frigate.genai import GenAIClient, register_genai_provider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stats_from_gemini_usage(usage: Any) -> Optional[dict[str, Any]]:
|
||||
"""Build a stats dict from a Gemini usage_metadata object."""
|
||||
prompt_tokens = getattr(usage, "prompt_token_count", None)
|
||||
completion_tokens = getattr(usage, "candidates_token_count", None)
|
||||
if prompt_tokens is None and completion_tokens is None:
|
||||
return None
|
||||
stats: dict[str, Any] = {}
|
||||
if isinstance(prompt_tokens, int):
|
||||
stats["prompt_tokens"] = prompt_tokens
|
||||
if isinstance(completion_tokens, int):
|
||||
stats["completion_tokens"] = completion_tokens
|
||||
return stats or None
|
||||
|
||||
|
||||
@register_genai_provider(GenAIProviderEnum.gemini)
|
||||
class GeminiClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using Gemini."""
|
||||
@ -471,6 +485,7 @@ class GeminiClient(GenAIClient):
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
usage_stats: Optional[dict[str, Any]] = None
|
||||
|
||||
stream = await self.provider.aio.models.generate_content_stream(
|
||||
model=self.genai_config.model,
|
||||
@ -479,6 +494,12 @@ class GeminiClient(GenAIClient):
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
chunk_usage = getattr(chunk, "usage_metadata", None)
|
||||
if chunk_usage is not None:
|
||||
maybe_stats = _stats_from_gemini_usage(chunk_usage)
|
||||
if maybe_stats is not None:
|
||||
usage_stats = maybe_stats
|
||||
|
||||
if not chunk or not chunk.candidates:
|
||||
continue
|
||||
|
||||
@ -565,6 +586,9 @@ class GeminiClient(GenAIClient):
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
if usage_stats is not None:
|
||||
yield ("stats", usage_stats)
|
||||
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
|
||||
@ -18,6 +18,52 @@ from frigate.genai.utils import parse_tool_calls_from_message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stats_from_llama_cpp_chunk(data: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""Build a stats dict from a llama.cpp streaming chunk.
|
||||
|
||||
Final-chunk `usage` carries authoritative token counts. Per-chunk
|
||||
`timings` (enabled via timings_per_token) carries the running token
|
||||
counts (prompt_n, predicted_n) and generation rate, so live updates
|
||||
work mid-stream.
|
||||
"""
|
||||
usage = data.get("usage") or {}
|
||||
timings = data.get("timings") or {}
|
||||
prompt_tokens = usage.get("prompt_tokens")
|
||||
completion_tokens = usage.get("completion_tokens")
|
||||
predicted_ms = timings.get("predicted_ms")
|
||||
tps = timings.get("predicted_per_second")
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
if not isinstance(prompt_tokens, int):
|
||||
prompt_n = timings.get("prompt_n")
|
||||
|
||||
if isinstance(prompt_n, int):
|
||||
prompt_tokens = prompt_n
|
||||
|
||||
if not isinstance(completion_tokens, int):
|
||||
predicted_n = timings.get("predicted_n")
|
||||
|
||||
if isinstance(predicted_n, int):
|
||||
completion_tokens = predicted_n
|
||||
|
||||
if not isinstance(prompt_tokens, int) and not isinstance(completion_tokens, int):
|
||||
return None
|
||||
|
||||
if isinstance(prompt_tokens, int):
|
||||
stats["prompt_tokens"] = prompt_tokens
|
||||
|
||||
if isinstance(completion_tokens, int):
|
||||
stats["completion_tokens"] = completion_tokens
|
||||
|
||||
if isinstance(predicted_ms, (int, float)) and predicted_ms > 0:
|
||||
stats["completion_duration_ms"] = float(predicted_ms)
|
||||
|
||||
if isinstance(tps, (int, float)) and tps > 0:
|
||||
stats["tokens_per_second"] = float(tps)
|
||||
|
||||
return stats or None
|
||||
|
||||
|
||||
def _parse_launch_arg(args: list[str], flag: str) -> str | None:
|
||||
"""Return the value following `flag` in a positional argv list, or None."""
|
||||
try:
|
||||
@ -462,6 +508,8 @@ class LlamaCppClient(GenAIClient):
|
||||
}
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
payload["timings_per_token"] = True
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
if openai_tool_choice is not None:
|
||||
@ -724,6 +772,9 @@ class LlamaCppClient(GenAIClient):
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
maybe_stats = _stats_from_llama_cpp_chunk(data)
|
||||
if maybe_stats is not None:
|
||||
yield ("stats", maybe_stats)
|
||||
choices = data.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
@ -18,6 +18,37 @@ from frigate.genai.utils import parse_tool_calls_from_message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_ollama_stats(response: Any) -> Optional[dict[str, Any]]:
|
||||
"""Build a stats dict from Ollama's response metadata.
|
||||
|
||||
Ollama reports eval_count/eval_duration (generation) and
|
||||
prompt_eval_count (context size). Durations are nanoseconds.
|
||||
"""
|
||||
if not response:
|
||||
return None
|
||||
if hasattr(response, "get"):
|
||||
getter = response.get
|
||||
else:
|
||||
getter = lambda key: getattr(response, key, None) # noqa: E731
|
||||
|
||||
eval_count = getter("eval_count")
|
||||
eval_duration_ns = getter("eval_duration")
|
||||
prompt_eval_count = getter("prompt_eval_count")
|
||||
if eval_count is None and prompt_eval_count is None:
|
||||
return None
|
||||
|
||||
stats: dict[str, Any] = {}
|
||||
if isinstance(prompt_eval_count, int):
|
||||
stats["prompt_tokens"] = prompt_eval_count
|
||||
if isinstance(eval_count, int):
|
||||
stats["completion_tokens"] = eval_count
|
||||
if isinstance(eval_duration_ns, int) and eval_duration_ns > 0:
|
||||
stats["completion_duration_ms"] = eval_duration_ns / 1_000_000
|
||||
if isinstance(eval_count, int) and eval_count > 0:
|
||||
stats["tokens_per_second"] = eval_count / (eval_duration_ns / 1_000_000_000)
|
||||
return stats or None
|
||||
|
||||
|
||||
def _normalize_multimodal_content(
|
||||
content: Any,
|
||||
) -> tuple[Optional[str], Optional[list[bytes]]]:
|
||||
@ -403,6 +434,9 @@ class OllamaClient(GenAIClient):
|
||||
content = result.get("content")
|
||||
if content:
|
||||
yield ("content_delta", content)
|
||||
stats = _extract_ollama_stats(response)
|
||||
if stats is not None:
|
||||
yield ("stats", stats)
|
||||
yield ("message", result)
|
||||
return
|
||||
|
||||
@ -416,6 +450,7 @@ class OllamaClient(GenAIClient):
|
||||
)
|
||||
content_parts: list[str] = []
|
||||
final_message: dict[str, Any] | None = None
|
||||
final_chunk: Any = None
|
||||
stream = await async_client.chat(**request_params)
|
||||
async for chunk in stream:
|
||||
if not chunk or "message" not in chunk:
|
||||
@ -426,6 +461,7 @@ class OllamaClient(GenAIClient):
|
||||
content_parts.append(delta)
|
||||
yield ("content_delta", delta)
|
||||
if chunk.get("done"):
|
||||
final_chunk = chunk
|
||||
full_content = "".join(content_parts).strip() or None
|
||||
final_message = {
|
||||
"content": full_content,
|
||||
@ -434,6 +470,10 @@ class OllamaClient(GenAIClient):
|
||||
}
|
||||
break
|
||||
|
||||
stats = _extract_ollama_stats(final_chunk)
|
||||
if stats is not None:
|
||||
yield ("stats", stats)
|
||||
|
||||
if final_message is not None:
|
||||
yield ("message", final_message)
|
||||
else:
|
||||
|
||||
@ -14,6 +14,22 @@ from frigate.genai import GenAIClient, register_genai_provider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stats_from_openai_usage(usage: Any) -> Optional[dict[str, Any]]:
|
||||
"""Build a stats dict from an OpenAI-compatible usage object."""
|
||||
if usage is None:
|
||||
return None
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", None)
|
||||
completion_tokens = getattr(usage, "completion_tokens", None)
|
||||
if prompt_tokens is None and completion_tokens is None:
|
||||
return None
|
||||
stats: dict[str, Any] = {}
|
||||
if isinstance(prompt_tokens, int):
|
||||
stats["prompt_tokens"] = prompt_tokens
|
||||
if isinstance(completion_tokens, int):
|
||||
stats["completion_tokens"] = completion_tokens
|
||||
return stats or None
|
||||
|
||||
|
||||
@register_genai_provider(GenAIProviderEnum.openai)
|
||||
class OpenAIClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using OpenAI."""
|
||||
@ -298,6 +314,7 @@ class OpenAIClient(GenAIClient):
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
|
||||
if tools:
|
||||
@ -318,10 +335,15 @@ class OpenAIClient(GenAIClient):
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
usage_stats: Optional[dict[str, Any]] = None
|
||||
|
||||
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||
|
||||
for chunk in stream:
|
||||
chunk_usage = getattr(chunk, "usage", None)
|
||||
if chunk_usage is not None:
|
||||
usage_stats = _stats_from_openai_usage(chunk_usage)
|
||||
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
@ -381,6 +403,9 @@ class OpenAIClient(GenAIClient):
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
if usage_stats is not None:
|
||||
yield ("stats", usage_stats)
|
||||
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
|
||||
@ -42,5 +42,23 @@
|
||||
"show_camera_status": "What is the current status of my cameras?",
|
||||
"recap": "What happened while I was away?",
|
||||
"watch_camera": "Watch the front door and let me know if anyone shows up"
|
||||
},
|
||||
"new_chat": "New chat",
|
||||
"settings": {
|
||||
"title": "Chat settings",
|
||||
"show_stats": {
|
||||
"title": "Show stats",
|
||||
"desc": "Show generation rate and context size for chat responses.",
|
||||
"while_generating": "While generating",
|
||||
"always": "Always"
|
||||
},
|
||||
"auto_scroll": {
|
||||
"title": "Auto-scroll",
|
||||
"desc": "Follow new messages as they arrive."
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"context": "{{tokens}} tokens",
|
||||
"tokens_per_second": "{{rate}} t/s"
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,6 @@ import {
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
type ChatEvent = { id: string; score?: number };
|
||||
|
||||
@ -37,10 +36,7 @@ export function ChatEventThumbnailsRow({
|
||||
const renderThumb = (event: ChatEvent, isAnchor = false) => (
|
||||
<div
|
||||
key={event.id}
|
||||
className={cn(
|
||||
"relative aspect-square size-32 shrink-0 overflow-hidden rounded-lg",
|
||||
isAnchor && "ring-2 ring-primary",
|
||||
)}
|
||||
className="relative aspect-square size-32 shrink-0 overflow-hidden rounded-lg"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
@ -71,9 +67,15 @@ export function ChatEventThumbnailsRow({
|
||||
<TooltipContent>{t("open_in_explore")}</TooltipContent>
|
||||
</Tooltip>
|
||||
{isAnchor && (
|
||||
<>
|
||||
<span
|
||||
aria-hidden="true"
|
||||
className="pointer-events-none absolute inset-0 rounded-lg ring-2 ring-inset ring-primary"
|
||||
/>
|
||||
<span className="pointer-events-none absolute left-1 top-1 rounded bg-primary px-1 text-[10px] text-primary-foreground">
|
||||
{t("anchor")}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -17,6 +17,7 @@ import {
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatAttachmentChip } from "@/components/chat/ChatAttachmentChip";
|
||||
import { parseAttachedEvent } from "@/utils/chatUtil";
|
||||
import type { ChatStats, ShowStatsMode } from "@/types/chat";
|
||||
|
||||
type MessageBubbleProps = {
|
||||
role: "user" | "assistant";
|
||||
@ -24,14 +25,29 @@ type MessageBubbleProps = {
|
||||
messageIndex?: number;
|
||||
onEditSubmit?: (messageIndex: number, newContent: string) => void;
|
||||
isComplete?: boolean;
|
||||
stats?: ChatStats;
|
||||
showStats?: ShowStatsMode;
|
||||
};
|
||||
|
||||
function formatTokens(n: number | undefined): string | null {
|
||||
if (n === undefined) return null;
|
||||
if (n >= 1000) return `${(n / 1000).toFixed(1)}k`;
|
||||
return String(n);
|
||||
}
|
||||
|
||||
function formatRate(rate: number | undefined): string | null {
|
||||
if (rate === undefined || rate <= 0) return null;
|
||||
return rate >= 10 ? rate.toFixed(0) : rate.toFixed(1);
|
||||
}
|
||||
|
||||
export function MessageBubble({
|
||||
role,
|
||||
content,
|
||||
messageIndex = 0,
|
||||
onEditSubmit,
|
||||
isComplete = true,
|
||||
stats,
|
||||
showStats = "while_generating",
|
||||
}: MessageBubbleProps) {
|
||||
const { t } = useTranslation(["views/chat", "common"]);
|
||||
const isUser = role === "user";
|
||||
@ -214,7 +230,7 @@ export function MessageBubble({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-0.5">
|
||||
<div className="flex items-center gap-1.5">
|
||||
{isUser && onEditSubmit != null && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
@ -256,6 +272,27 @@ export function MessageBubble({
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
{!isUser &&
|
||||
stats &&
|
||||
(showStats === "always" || !isComplete) &&
|
||||
(() => {
|
||||
const ctx = formatTokens(stats.promptTokens);
|
||||
const rate = formatRate(stats.tokensPerSecond);
|
||||
if (ctx === null && rate === null) return null;
|
||||
return (
|
||||
<div className="flex items-center gap-1 text-xs text-muted-foreground">
|
||||
{ctx !== null && (
|
||||
<span>{t("stats.context", { tokens: ctx })}</span>
|
||||
)}
|
||||
{ctx !== null && rate !== null && (
|
||||
<span aria-hidden="true">·</span>
|
||||
)}
|
||||
{rate !== null && (
|
||||
<span>{t("stats.tokens_per_second", { rate })}</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
108
web/src/components/chat/ChatSettings.tsx
Normal file
108
web/src/components/chat/ChatSettings.tsx
Normal file
@ -0,0 +1,108 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useState } from "react";
|
||||
import { isDesktop } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import PlatformAwareDialog from "../overlay/dialog/PlatformAwareDialog";
|
||||
import { FaCog } from "react-icons/fa";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
} from "@/components/ui/select";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import type { ShowStatsMode } from "@/types/chat";
|
||||
|
||||
type ChatSettingsProps = {
|
||||
showStats: ShowStatsMode;
|
||||
setShowStats: (mode: ShowStatsMode) => void;
|
||||
autoScroll: boolean;
|
||||
setAutoScroll: (enabled: boolean) => void;
|
||||
};
|
||||
|
||||
export default function ChatSettings({
|
||||
showStats,
|
||||
setShowStats,
|
||||
autoScroll,
|
||||
setAutoScroll,
|
||||
}: ChatSettingsProps) {
|
||||
const { t } = useTranslation(["views/chat"]);
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const trigger = (
|
||||
<Button
|
||||
className="flex items-center md:gap-2"
|
||||
aria-label={t("settings.title")}
|
||||
size="sm"
|
||||
>
|
||||
<FaCog className="text-secondary-foreground" />
|
||||
<span className="hidden md:inline">{t("settings.title")}</span>
|
||||
</Button>
|
||||
);
|
||||
|
||||
const content = (
|
||||
<div className="my-3 space-y-5 py-3 md:mt-0 md:py-0">
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-0.5">
|
||||
<div className="text-md">{t("settings.show_stats.title")}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{t("settings.show_stats.desc")}
|
||||
</div>
|
||||
</div>
|
||||
<Select
|
||||
value={showStats}
|
||||
onValueChange={(v) => setShowStats(v as ShowStatsMode)}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
{showStats === "always"
|
||||
? t("settings.show_stats.always")
|
||||
: t("settings.show_stats.while_generating")}
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem className="cursor-pointer" value="while_generating">
|
||||
{t("settings.show_stats.while_generating")}
|
||||
</SelectItem>
|
||||
<SelectItem className="cursor-pointer" value="always">
|
||||
{t("settings.show_stats.always")}
|
||||
</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
<DropdownMenuSeparator />
|
||||
<div className="flex items-center justify-between gap-3">
|
||||
<div className="space-y-0.5">
|
||||
<Label htmlFor="auto-scroll" className="text-md cursor-pointer">
|
||||
{t("settings.auto_scroll.title")}
|
||||
</Label>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{t("settings.auto_scroll.desc")}
|
||||
</div>
|
||||
</div>
|
||||
<Switch
|
||||
id="auto-scroll"
|
||||
checked={autoScroll}
|
||||
onCheckedChange={setAutoScroll}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<PlatformAwareDialog
|
||||
trigger={trigger}
|
||||
content={content}
|
||||
contentClassName={cn(
|
||||
"scrollbar-container h-auto overflow-y-auto",
|
||||
isDesktop ? "max-h-[80dvh] w-72" : "px-4",
|
||||
)}
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -54,7 +54,9 @@ export function ChatStartingState({ onSendMessage }: ChatStartingStateProps) {
|
||||
<div className="flex size-full flex-col items-center justify-center gap-6 p-8">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<h1 className="text-4xl font-bold text-foreground">{t("title")}</h1>
|
||||
<p className="text-muted-foreground">{t("subtitle")}</p>
|
||||
<p className="text-center text-muted-foreground md:text-left">
|
||||
{t("subtitle")}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex w-full max-w-2xl flex-col items-center gap-4">
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { FaArrowUpLong, FaStop } from "react-icons/fa6";
|
||||
import { LuCircleAlert } from "react-icons/lu";
|
||||
import { LuCircleAlert, LuMessageSquarePlus } from "react-icons/lu";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
|
||||
import axios from "axios";
|
||||
@ -12,7 +12,9 @@ import { ChatStartingState } from "@/components/chat/ChatStartingState";
|
||||
import { ChatAttachmentChip } from "@/components/chat/ChatAttachmentChip";
|
||||
import { ChatQuickReplies } from "@/components/chat/ChatQuickReplies";
|
||||
import { ChatPaperclipButton } from "@/components/chat/ChatPaperclipButton";
|
||||
import type { ChatMessage } from "@/types/chat";
|
||||
import ChatSettings from "@/components/chat/ChatSettings";
|
||||
import type { ChatMessage, ShowStatsMode } from "@/types/chat";
|
||||
import { usePersistence } from "@/hooks/use-persistence";
|
||||
import {
|
||||
getEventIdsFromSearchObjectsToolCalls,
|
||||
getFindSimilarObjectsFromToolCalls,
|
||||
@ -27,6 +29,14 @@ export default function ChatPage() {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [attachedEventId, setAttachedEventId] = useState<string | null>(null);
|
||||
const [showStats, setShowStats] = usePersistence<ShowStatsMode>(
|
||||
"chat-show-stats",
|
||||
"while_generating",
|
||||
);
|
||||
const [autoScroll, setAutoScroll] = usePersistence<boolean>(
|
||||
"chat-auto-scroll",
|
||||
true,
|
||||
);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
|
||||
@ -36,13 +46,14 @@ export default function ChatPage() {
|
||||
|
||||
// Auto-scroll to bottom when messages change, but only if near bottom
|
||||
useEffect(() => {
|
||||
if (!autoScroll) return;
|
||||
const el = scrollRef.current;
|
||||
if (!el) return;
|
||||
const isNearBottom = el.scrollHeight - el.scrollTop - el.clientHeight < 150;
|
||||
if (isNearBottom) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: "smooth" });
|
||||
}
|
||||
}, [messages]);
|
||||
}, [messages, autoScroll]);
|
||||
|
||||
const submitConversation = useCallback(
|
||||
async (messagesToSend: ChatMessage[]) => {
|
||||
@ -125,6 +136,16 @@ export default function ChatPage() {
|
||||
setIsLoading(false);
|
||||
}, []);
|
||||
|
||||
const startNewChat = useCallback(() => {
|
||||
abortRef.current?.abort();
|
||||
abortRef.current = null;
|
||||
setIsLoading(false);
|
||||
setMessages([]);
|
||||
setInput("");
|
||||
setAttachedEventId(null);
|
||||
setError(null);
|
||||
}, []);
|
||||
|
||||
const handleEditSubmit = useCallback(
|
||||
(messageIndex: number, newContent: string) => {
|
||||
const newList: ChatMessage[] = [
|
||||
@ -140,28 +161,44 @@ export default function ChatPage() {
|
||||
setAttachedEventId(null);
|
||||
}, []);
|
||||
|
||||
const hasStarted = messages.length > 0;
|
||||
|
||||
return (
|
||||
<div className="flex size-full justify-center p-2 md:p-4">
|
||||
<div className="flex size-full flex-col xl:w-[50%] 3xl:w-[35%]">
|
||||
{messages.length === 0 ? (
|
||||
<ChatStartingState
|
||||
onSendMessage={(message) => {
|
||||
setInput("");
|
||||
submitConversation([{ role: "user", content: message }]);
|
||||
}}
|
||||
<div className="flex size-full flex-col">
|
||||
<div className="flex shrink-0 items-center justify-end gap-2 px-2 pb-3 pt-2 md:px-4 md:pt-4">
|
||||
{hasStarted && (
|
||||
<Button
|
||||
className="flex items-center md:gap-2"
|
||||
aria-label={t("new_chat")}
|
||||
size="sm"
|
||||
onClick={startNewChat}
|
||||
>
|
||||
<LuMessageSquarePlus className="text-secondary-foreground" />
|
||||
<span className="hidden md:inline">{t("new_chat")}</span>
|
||||
</Button>
|
||||
)}
|
||||
<ChatSettings
|
||||
showStats={showStats ?? "while_generating"}
|
||||
setShowStats={setShowStats}
|
||||
autoScroll={autoScroll ?? true}
|
||||
setAutoScroll={setAutoScroll}
|
||||
/>
|
||||
) : (
|
||||
<>
|
||||
</div>
|
||||
<div
|
||||
ref={scrollRef}
|
||||
className="scrollbar-container flex min-h-0 w-full flex-1 flex-col gap-3 overflow-y-auto"
|
||||
className="scrollbar-container flex min-h-0 flex-1 flex-col overflow-y-auto"
|
||||
>
|
||||
<div className="flex flex-1 justify-center px-2 md:px-4">
|
||||
<div className="flex w-full flex-col xl:w-[50%] 3xl:w-[35%]">
|
||||
{hasStarted ? (
|
||||
<div className="flex w-full flex-1 flex-col gap-3 pb-3">
|
||||
{messages.map((msg, i) => {
|
||||
const isLastAssistant =
|
||||
i === messages.length - 1 && msg.role === "assistant";
|
||||
const isComplete =
|
||||
msg.role === "user" || !isLoading || !isLastAssistant;
|
||||
const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0;
|
||||
const hasToolCalls =
|
||||
msg.toolCalls && msg.toolCalls.length > 0;
|
||||
const hasContent = !!msg.content?.trim();
|
||||
const showProcessing =
|
||||
isLastAssistant && isLoading && !hasContent;
|
||||
@ -204,6 +241,8 @@ export default function ChatPage() {
|
||||
msg.role === "user" ? handleEditSubmit : undefined
|
||||
}
|
||||
isComplete={isComplete}
|
||||
stats={msg.stats}
|
||||
showStats={showStats}
|
||||
/>
|
||||
)}
|
||||
{msg.role === "assistant" &&
|
||||
@ -244,9 +283,20 @@ export default function ChatPage() {
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<ChatStartingState
|
||||
onSendMessage={(message) => {
|
||||
setInput("");
|
||||
submitConversation([{ role: "user", content: message }]);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{messages.length > 0 && (
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{hasStarted && (
|
||||
<div className="flex shrink-0 justify-center p-2 md:px-4 md:pb-4">
|
||||
<div className="flex w-full xl:w-[50%] 3xl:w-[35%]">
|
||||
<ChatEntry
|
||||
input={input}
|
||||
setInput={setInput}
|
||||
@ -259,9 +309,10 @@ export default function ChatPage() {
|
||||
onStop={stopGeneration}
|
||||
recentEventIds={recentEventIds}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -298,7 +349,7 @@ function ChatEntry({
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-2 flex w-full flex-col items-stretch justify-center gap-2 rounded-xl bg-secondary p-3">
|
||||
<div className="flex w-full flex-col items-stretch justify-center gap-2 rounded-xl bg-secondary p-3">
|
||||
{attachedEventId && (
|
||||
<div className="flex items-center">
|
||||
<ChatAttachmentChip
|
||||
|
||||
@ -8,9 +8,19 @@ export type ChatMessage = {
|
||||
role: "user" | "assistant";
|
||||
content: string;
|
||||
toolCalls?: ToolCall[];
|
||||
stats?: ChatStats;
|
||||
};
|
||||
|
||||
export type StartingRequest = {
|
||||
label: string;
|
||||
prompt: string;
|
||||
};
|
||||
|
||||
export type ChatStats = {
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
completionDurationMs?: number;
|
||||
tokensPerSecond?: number;
|
||||
};
|
||||
|
||||
export type ShowStatsMode = "while_generating" | "always";
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { ChatMessage, ToolCall } from "@/types/chat";
|
||||
import type { ChatMessage, ChatStats, ToolCall } from "@/types/chat";
|
||||
|
||||
export type StreamChatCallbacks = {
|
||||
/** Update the messages array (e.g. pass to setState). */
|
||||
@ -7,14 +7,27 @@ export type StreamChatCallbacks = {
|
||||
onError: (message: string) => void;
|
||||
/** Called when the stream finishes (success or error). */
|
||||
onDone: () => void;
|
||||
/** Called when the stream emits token/timing stats. The stats are also
|
||||
* attached to the last assistant message in updateMessages, so consumers
|
||||
* can usually rely on the message itself rather than wiring this up. */
|
||||
onStats?: (stats: ChatStats) => void;
|
||||
/** Message used when fetch throws and no server error is available. */
|
||||
defaultErrorMessage?: string;
|
||||
};
|
||||
|
||||
type StatsChunk = {
|
||||
type: "stats";
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
completion_duration_ms?: number;
|
||||
tokens_per_second?: number;
|
||||
};
|
||||
|
||||
type StreamChunk =
|
||||
| { type: "error"; error: string }
|
||||
| { type: "tool_calls"; tool_calls: ToolCall[] }
|
||||
| { type: "content"; delta: string };
|
||||
| { type: "content"; delta: string }
|
||||
| StatsChunk;
|
||||
|
||||
/**
|
||||
* POST to chat/completion with stream: true, parse NDJSON stream, and invoke
|
||||
@ -31,6 +44,7 @@ export async function streamChatCompletion(
|
||||
updateMessages,
|
||||
onError,
|
||||
onDone,
|
||||
onStats,
|
||||
defaultErrorMessage = "Something went wrong. Please try again.",
|
||||
} = callbacks;
|
||||
|
||||
@ -95,6 +109,23 @@ export async function streamChatCompletion(
|
||||
});
|
||||
return "continue";
|
||||
}
|
||||
if (data.type === "stats") {
|
||||
const stats: ChatStats = {
|
||||
promptTokens: data.prompt_tokens,
|
||||
completionTokens: data.completion_tokens,
|
||||
completionDurationMs: data.completion_duration_ms,
|
||||
tokensPerSecond: data.tokens_per_second,
|
||||
};
|
||||
updateMessages((prev) => {
|
||||
const next = [...prev];
|
||||
const lastMsg = next[next.length - 1];
|
||||
if (lastMsg?.role === "assistant")
|
||||
next[next.length - 1] = { ...lastMsg, stats };
|
||||
return next;
|
||||
});
|
||||
onStats?.(stats);
|
||||
return "continue";
|
||||
}
|
||||
return "continue";
|
||||
};
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user