Chat improvements (#22823)

* Add score fusion helpers for find_similar_objects chat tool

* Add candidate query builder for find_similar_objects chat tool

* register find_similar_objects chat tool definition

* implement _execute_find_similar_objects chat tool dispatcher

* Dispatch find_similar_objects in chat tool executor

* Teach chat system prompt when to use find_similar_objects

* Add i18n strings for find_similar_objects chat tool

* Add frontend extractor for find_similar_objects tool response

* Render anchor badge and similarity scores in chat results

* formatting

* filter similarity results in python, not sqlite-vec

* extract pure chat helpers to chat_util module

* Teach chat system prompt about attached_event marker

* Add parseAttachedEvent and prependAttachment helpers

* Add i18n strings for chat event attachments

* Add ChatAttachmentChip component

* Make chat thumbnails attach to composer on click

* Render attachment chip in user chat bubbles

* Add ChatQuickReplies pill row component

* Add ChatPaperclipButton with event picker popover

* Wire event attachments into chat composer and messages

* add ability to stop streaming

* tweak cursor to appear at the end of the same line of the streaming response

* use abort signal

* add tooltip

* display label and camera on attachment chip
This commit is contained in:
Josh Hawkins 2026-04-09 15:31:37 -05:00 committed by GitHub
parent 556d5d8c9d
commit 98c2fe00c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1318 additions and 109 deletions

View File

@ -3,9 +3,11 @@
import base64 import base64
import json import json
import logging import logging
import operator
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Generator, List, Optional from functools import reduce
from typing import Any, Dict, List, Optional
import cv2 import cv2
from fastapi import APIRouter, Body, Depends, Request from fastapi import APIRouter, Body, Depends, Request
@ -17,6 +19,14 @@ from frigate.api.auth import (
get_allowed_cameras_for_filter, get_allowed_cameras_for_filter,
require_camera_access, require_camera_access,
) )
from frigate.api.chat_util import (
chunk_content,
distance_to_score,
format_events_with_local_time,
fuse_scores,
hydrate_event,
parse_iso_to_timestamp,
)
from frigate.api.defs.query.events_query_parameters import EventsQueryParams from frigate.api.defs.query.events_query_parameters import EventsQueryParams
from frigate.api.defs.request.chat_body import ChatCompletionRequest from frigate.api.defs.request.chat_body import ChatCompletionRequest
from frigate.api.defs.response.chat_response import ( from frigate.api.defs.response.chat_response import (
@ -32,55 +42,13 @@ from frigate.jobs.vlm_watch import (
start_vlm_watch_job, start_vlm_watch_job,
stop_vlm_watch_job, stop_vlm_watch_job,
) )
from frigate.models import Event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(tags=[Tags.chat]) router = APIRouter(tags=[Tags.chat])
def _chunk_content(content: str, chunk_size: int = 80) -> Generator[str, None, None]:
"""Yield content in word-aware chunks for streaming."""
if not content:
return
words = content.split(" ")
current: List[str] = []
current_len = 0
for w in words:
current.append(w)
current_len += len(w) + 1
if current_len >= chunk_size:
yield " ".join(current) + " "
current = []
current_len = 0
if current:
yield " ".join(current)
def _format_events_with_local_time(
events_list: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Add human-readable local start/end times to each event for the LLM."""
result = []
for evt in events_list:
if not isinstance(evt, dict):
result.append(evt)
continue
copy_evt = dict(evt)
try:
start_ts = evt.get("start_time")
end_ts = evt.get("end_time")
if start_ts is not None:
dt_start = datetime.fromtimestamp(start_ts)
copy_evt["start_time_local"] = dt_start.strftime("%Y-%m-%d %I:%M:%S %p")
if end_ts is not None:
dt_end = datetime.fromtimestamp(end_ts)
copy_evt["end_time_local"] = dt_end.strftime("%Y-%m-%d %I:%M:%S %p")
except (TypeError, ValueError, OSError):
pass
result.append(copy_evt)
return result
class ToolExecuteRequest(BaseModel): class ToolExecuteRequest(BaseModel):
"""Request model for tool execution.""" """Request model for tool execution."""
@ -158,6 +126,76 @@ def get_tool_definitions() -> List[Dict[str, Any]]:
"required": [], "required": [],
}, },
}, },
{
"type": "function",
"function": {
"name": "find_similar_objects",
"description": (
"Find tracked objects that are visually and semantically similar "
"to a specific past event. Use this when the user references a "
"particular object they have seen and wants to find other "
"sightings of the same or similar one ('that green car', 'the "
"person in the red jacket', 'the package that was delivered'). "
"Prefer this over search_objects whenever the user's intent is "
"'find more like this specific one.' Use search_objects first "
"only if you need to locate the anchor event. Requires semantic "
"search to be enabled."
),
"parameters": {
"type": "object",
"properties": {
"event_id": {
"type": "string",
"description": "The id of the anchor event to find similar objects to.",
},
"after": {
"type": "string",
"description": "Start time in ISO 8601 format (e.g., '2024-01-01T00:00:00Z').",
},
"before": {
"type": "string",
"description": "End time in ISO 8601 format (e.g., '2024-01-01T23:59:59Z').",
},
"cameras": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of cameras to restrict to. Defaults to all.",
},
"labels": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of labels to restrict to. Defaults to the anchor event's label.",
},
"sub_labels": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of sub_labels (names) to restrict to.",
},
"zones": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of zones. An event matches if any of its zones overlap.",
},
"similarity_mode": {
"type": "string",
"enum": ["visual", "semantic", "fused"],
"description": "Which similarity signal(s) to use. 'fused' (default) combines visual and semantic.",
"default": "fused",
},
"min_score": {
"type": "number",
"description": "Drop matches with a similarity score below this threshold (0.0-1.0).",
},
"limit": {
"type": "integer",
"description": "Maximum number of matches to return (default: 10).",
"default": 10,
},
},
"required": ["event_id"],
},
},
},
{ {
"type": "function", "type": "function",
"function": { "function": {
@ -434,6 +472,166 @@ async def _execute_search_objects(
) )
async def _execute_find_similar_objects(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
) -> Dict[str, Any]:
"""Execute the find_similar_objects tool.
Returns a plain dict (not JSONResponse) so the chat loop can embed it
directly in tool-result messages.
"""
# 1. Semantic search enabled?
config = request.app.frigate_config
if not getattr(config.semantic_search, "enabled", False):
return {
"error": "semantic_search_disabled",
"message": (
"Semantic search must be enabled to find similar objects. "
"Enable it in the Frigate config under semantic_search."
),
}
context = request.app.embeddings
if context is None:
return {
"error": "semantic_search_disabled",
"message": "Embeddings context is not available.",
}
# 2. Anchor lookup.
event_id = arguments.get("event_id")
if not event_id:
return {"error": "missing_event_id", "message": "event_id is required."}
try:
anchor = Event.get(Event.id == event_id)
except Event.DoesNotExist:
return {
"error": "anchor_not_found",
"message": f"Could not find event {event_id}.",
}
# 3. Parse params.
after = parse_iso_to_timestamp(arguments.get("after"))
before = parse_iso_to_timestamp(arguments.get("before"))
cameras = arguments.get("cameras")
if cameras:
# Respect RBAC: intersect with the user's allowed cameras.
cameras = [c for c in cameras if c in allowed_cameras]
else:
cameras = list(allowed_cameras) if allowed_cameras else None
labels = arguments.get("labels") or [anchor.label]
sub_labels = arguments.get("sub_labels")
zones = arguments.get("zones")
similarity_mode = arguments.get("similarity_mode", "fused")
if similarity_mode not in ("visual", "semantic", "fused"):
similarity_mode = "fused"
min_score = arguments.get("min_score")
limit = int(arguments.get("limit", 10))
limit = max(1, min(limit, 50))
# 4. Run similarity searches. We deliberately do NOT pass event_ids into
# the vec queries — the IN filter on sqlite-vec is broken in the installed
# version (see frigate/embeddings/__init__.py). Mirror the pattern used by
# frigate/api/event.py events_search: fetch top-k globally, then intersect
# with the structured filters via Peewee.
visual_distances: Dict[str, float] = {}
description_distances: Dict[str, float] = {}
try:
if similarity_mode in ("visual", "fused"):
rows = context.search_thumbnail(anchor)
visual_distances = {row[0]: row[1] for row in rows}
if similarity_mode in ("semantic", "fused"):
query_text = (
(anchor.data or {}).get("description")
or anchor.sub_label
or anchor.label
)
rows = context.search_description(query_text)
description_distances = {row[0]: row[1] for row in rows}
except Exception:
logger.exception("Similarity search failed")
return {
"error": "similarity_search_failed",
"message": "Failed to run similarity search.",
}
vec_ids = set(visual_distances) | set(description_distances)
vec_ids.discard(anchor.id)
# vec layer returns up to k=100 per modality; flag when we hit that ceiling
# so the LLM can mention there may be more matches beyond what we saw.
candidate_truncated = (
len(visual_distances) >= 100 or len(description_distances) >= 100
)
if not vec_ids:
return {
"anchor": hydrate_event(anchor),
"results": [],
"similarity_mode": similarity_mode,
"candidate_truncated": candidate_truncated,
}
# 5. Apply structured filters, intersected with vec hits.
clauses = [Event.id.in_(list(vec_ids))]
if after is not None:
clauses.append(Event.start_time >= after)
if before is not None:
clauses.append(Event.start_time <= before)
if cameras:
clauses.append(Event.camera.in_(cameras))
if labels:
clauses.append(Event.label.in_(labels))
if sub_labels:
clauses.append(Event.sub_label.in_(sub_labels))
if zones:
# Mirror the pattern used by frigate/api/event.py for JSON-array zone match.
zone_clauses = [Event.zones.cast("text") % f'*"{zone}"*' for zone in zones]
clauses.append(reduce(operator.or_, zone_clauses))
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
# 6. Fuse and rank.
scored: List[tuple[str, float]] = []
for eid in eligible:
v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats)
if eid in visual_distances
else None
)
d_score = (
distance_to_score(description_distances[eid], context.desc_stats)
if eid in description_distances
else None
)
fused = fuse_scores(v_score, d_score)
if fused is None:
continue
if min_score is not None and fused < min_score:
continue
scored.append((eid, fused))
scored.sort(key=lambda pair: pair[1], reverse=True)
scored = scored[:limit]
results = [hydrate_event(eligible[eid], score=score) for eid, score in scored]
return {
"anchor": hydrate_event(anchor),
"results": results,
"similarity_mode": similarity_mode,
"candidate_truncated": candidate_truncated,
}
@router.post( @router.post(
"/chat/execute", "/chat/execute",
dependencies=[Depends(allow_any_authenticated())], dependencies=[Depends(allow_any_authenticated())],
@ -459,6 +657,13 @@ async def execute_tool(
if tool_name == "search_objects": if tool_name == "search_objects":
return await _execute_search_objects(arguments, allowed_cameras) return await _execute_search_objects(arguments, allowed_cameras)
if tool_name == "find_similar_objects":
result = await _execute_find_similar_objects(
request, arguments, allowed_cameras
)
status_code = 200 if "error" not in result else 400
return JSONResponse(content=result, status_code=status_code)
if tool_name == "set_camera_state": if tool_name == "set_camera_state":
result = await _execute_set_camera_state(request, arguments) result = await _execute_set_camera_state(request, arguments)
return JSONResponse( return JSONResponse(
@ -642,6 +847,8 @@ async def _execute_tool_internal(
except (json.JSONDecodeError, AttributeError) as e: except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"Failed to extract tool result: {e}") logger.warning(f"Failed to extract tool result: {e}")
return {"error": "Failed to parse tool result"} return {"error": "Failed to parse tool result"}
elif tool_name == "find_similar_objects":
return await _execute_find_similar_objects(request, arguments, allowed_cameras)
elif tool_name == "set_camera_state": elif tool_name == "set_camera_state":
return await _execute_set_camera_state(request, arguments) return await _execute_set_camera_state(request, arguments)
elif tool_name == "get_live_context": elif tool_name == "get_live_context":
@ -664,8 +871,9 @@ async def _execute_tool_internal(
return _execute_get_recap(arguments, allowed_cameras) return _execute_get_recap(arguments, allowed_cameras)
else: else:
logger.error( logger.error(
"Tool call failed: unknown tool %r. Expected one of: search_objects, get_live_context, " "Tool call failed: unknown tool %r. Expected one of: search_objects, find_similar_objects, "
"start_camera_watch, stop_camera_watch, get_profile_status, get_recap. Arguments received: %s", "get_live_context, start_camera_watch, stop_camera_watch, get_profile_status, get_recap. "
"Arguments received: %s",
tool_name, tool_name,
json.dumps(arguments), json.dumps(arguments),
) )
@ -927,7 +1135,7 @@ async def _execute_pending_tools(
json.dumps(tool_args), json.dumps(tool_args),
) )
if tool_name == "search_objects" and isinstance(tool_result, list): if tool_name == "search_objects" and isinstance(tool_result, list):
tool_result = _format_events_with_local_time(tool_result) tool_result = format_events_with_local_time(tool_result)
_keys = { _keys = {
"id", "id",
"camera", "camera",
@ -1080,7 +1288,9 @@ Do not start your response with phrases like "I will check...", "Let me see...",
Always present times to the user in the server's local timezone. When tool results include start_time_local and end_time_local, use those exact strings when listing or describing detection times—do not convert or invent timestamps. Do not use UTC or ISO format with Z for the user-facing answer unless the tool result only provides Unix timestamps without local time fields. Always present times to the user in the server's local timezone. When tool results include start_time_local and end_time_local, use those exact strings when listing or describing detection times—do not convert or invent timestamps. Do not use UTC or ISO format with Z for the user-facing answer unless the tool result only provides Unix timestamps without local time fields.
When users ask about "today", "yesterday", "this week", etc., use the current date above as reference. When users ask about "today", "yesterday", "this week", etc., use the current date above as reference.
When searching for objects or events, use ISO 8601 format for dates (e.g., {current_date_str}T00:00:00Z for the start of today). When searching for objects or events, use ISO 8601 format for dates (e.g., {current_date_str}T00:00:00Z for the start of today).
Always be accurate with time calculations based on the current date provided.{cameras_section}""" Always be accurate with time calculations based on the current date provided.
When a user refers to a specific object they have seen or describe with identifying details ("that green car", "the person in the red jacket", "a package left today"), prefer the find_similar_objects tool over search_objects. Use search_objects first only to locate the anchor event, then pass its id to find_similar_objects. For generic queries like "show me all cars today", keep using search_objects. If a user message begins with [attached_event:<id>], treat that event id as the anchor for any similarity or "tell me more" request in the same message and call find_similar_objects with that id.{cameras_section}"""
conversation.append( conversation.append(
{ {
@ -1118,6 +1328,9 @@ Always be accurate with time calculations based on the current date provided.{ca
async def stream_body_llm(): async def stream_body_llm():
nonlocal conversation, stream_tool_calls, stream_iterations nonlocal conversation, stream_tool_calls, stream_iterations
while stream_iterations < max_iterations: while stream_iterations < max_iterations:
if await request.is_disconnected():
logger.debug("Client disconnected, stopping chat stream")
return
logger.debug( logger.debug(
f"Streaming LLM (iteration {stream_iterations + 1}/{max_iterations}) " f"Streaming LLM (iteration {stream_iterations + 1}/{max_iterations}) "
f"with {len(conversation)} message(s)" f"with {len(conversation)} message(s)"
@ -1127,6 +1340,9 @@ Always be accurate with time calculations based on the current date provided.{ca
tools=tools if tools else None, tools=tools if tools else None,
tool_choice="auto", tool_choice="auto",
): ):
if await request.is_disconnected():
logger.debug("Client disconnected, stopping chat stream")
return
kind, value = event kind, value = event
if kind == "content_delta": if kind == "content_delta":
yield ( yield (
@ -1156,6 +1372,11 @@ Always be accurate with time calculations based on the current date provided.{ca
msg.get("content"), pending msg.get("content"), pending
) )
) )
if await request.is_disconnected():
logger.debug(
"Client disconnected before tool execution"
)
return
( (
executed_calls, executed_calls,
tool_results, tool_results,
@ -1240,7 +1461,7 @@ Always be accurate with time calculations based on the current date provided.{ca
+ b"\n" + b"\n"
) )
# Stream content in word-sized chunks for smooth UX # Stream content in word-sized chunks for smooth UX
for part in _chunk_content(final_content): for part in chunk_content(final_content):
yield ( yield (
json.dumps({"type": "content", "delta": part}).encode( json.dumps({"type": "content", "delta": part}).encode(
"utf-8" "utf-8"

135
frigate/api/chat_util.py Normal file
View File

@ -0,0 +1,135 @@
"""Pure, stateless helpers used by the chat tool dispatchers.
These were extracted from frigate/api/chat.py to keep that module focused on
route handlers, tool dispatchers, and streaming loop internals. Nothing in
this file touches the FastAPI request, the embeddings context, or the chat
loop state all inputs and outputs are plain data.
"""
import logging
import math
import time
from datetime import datetime
from typing import Any, Dict, Generator, List, Optional
from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
logger = logging.getLogger(__name__)
# Similarity fusion weights for find_similar_objects.
# Visual dominates because the feature's primary use case is "same specific object."
# If these change, update the test in test_chat_find_similar_objects.py.
VISUAL_WEIGHT = 0.65
DESCRIPTION_WEIGHT = 0.35
def chunk_content(content: str, chunk_size: int = 80) -> Generator[str, None, None]:
"""Yield content in word-aware chunks for streaming."""
if not content:
return
words = content.split(" ")
current: List[str] = []
current_len = 0
for w in words:
current.append(w)
current_len += len(w) + 1
if current_len >= chunk_size:
yield " ".join(current) + " "
current = []
current_len = 0
if current:
yield " ".join(current)
def format_events_with_local_time(
events_list: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Add human-readable local start/end times to each event for the LLM."""
result = []
for evt in events_list:
if not isinstance(evt, dict):
result.append(evt)
continue
copy_evt = dict(evt)
try:
start_ts = evt.get("start_time")
end_ts = evt.get("end_time")
if start_ts is not None:
dt_start = datetime.fromtimestamp(start_ts)
copy_evt["start_time_local"] = dt_start.strftime("%Y-%m-%d %I:%M:%S %p")
if end_ts is not None:
dt_end = datetime.fromtimestamp(end_ts)
copy_evt["end_time_local"] = dt_end.strftime("%Y-%m-%d %I:%M:%S %p")
except (TypeError, ValueError, OSError):
pass
result.append(copy_evt)
return result
def distance_to_score(distance: float, stats: ZScoreNormalization) -> float:
"""Convert a cosine distance to a [0, 1] similarity score.
Uses the existing ZScoreNormalization stats maintained by EmbeddingsContext
to normalize across deployments, then a bounded sigmoid. Lower distance ->
higher score. If stats are uninitialized (stddev == 0), returns a neutral
0.5 so the fallback ordering by raw distance still dominates.
"""
if stats.stddev == 0:
return 0.5
z = (distance - stats.mean) / stats.stddev
# Sigmoid on -z so that small distance (good) -> high score.
return 1.0 / (1.0 + math.exp(z))
def fuse_scores(
visual_score: Optional[float],
description_score: Optional[float],
) -> Optional[float]:
"""Weighted fusion of visual and description similarity scores.
If one side is missing (e.g., no description embedding for this event),
the other side's score is returned alone with no penalty. If both are
missing, returns None and the caller should drop the event.
"""
if visual_score is None and description_score is None:
return None
if visual_score is None:
return description_score
if description_score is None:
return visual_score
return VISUAL_WEIGHT * visual_score + DESCRIPTION_WEIGHT * description_score
def parse_iso_to_timestamp(value: Optional[str]) -> Optional[float]:
"""Parse an ISO-8601 string as server-local time -> unix timestamp.
Mirrors the parsing _execute_search_objects uses so both tools accept the
same format from the LLM.
"""
if value is None:
return None
try:
s = value.replace("Z", "").strip()[:19]
dt = datetime.strptime(s, "%Y-%m-%dT%H:%M:%S")
return time.mktime(dt.timetuple())
except (ValueError, AttributeError, TypeError):
logger.warning("Invalid timestamp format: %s", value)
return None
def hydrate_event(event: Event, score: Optional[float] = None) -> Dict[str, Any]:
"""Convert an Event row into the dict shape returned by find_similar_objects."""
data: Dict[str, Any] = {
"id": event.id,
"camera": event.camera,
"label": event.label,
"sub_label": event.sub_label,
"start_time": event.start_time,
"end_time": event.end_time,
"zones": event.zones,
}
if score is not None:
data["score"] = score
return data

View File

@ -0,0 +1,303 @@
"""Tests for the find_similar_objects chat tool."""
import asyncio
import os
import tempfile
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock
from playhouse.sqlite_ext import SqliteExtDatabase
from frigate.api.chat import (
_execute_find_similar_objects,
get_tool_definitions,
)
from frigate.api.chat_util import (
DESCRIPTION_WEIGHT,
VISUAL_WEIGHT,
distance_to_score,
fuse_scores,
)
from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)
class TestDistanceToScore(unittest.TestCase):
def test_lower_distance_gives_higher_score(self):
stats = ZScoreNormalization()
# Seed the stats with a small distribution so stddev > 0.
stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
close_score = distance_to_score(0.1, stats)
far_score = distance_to_score(0.5, stats)
self.assertGreater(close_score, far_score)
self.assertGreaterEqual(close_score, 0.0)
self.assertLessEqual(close_score, 1.0)
self.assertGreaterEqual(far_score, 0.0)
self.assertLessEqual(far_score, 1.0)
def test_uninitialized_stats_returns_neutral_score(self):
stats = ZScoreNormalization() # n == 0, stddev == 0
self.assertEqual(distance_to_score(0.3, stats), 0.5)
class TestFuseScores(unittest.TestCase):
def test_weights_sum_to_one(self):
self.assertAlmostEqual(VISUAL_WEIGHT + DESCRIPTION_WEIGHT, 1.0)
def test_fuses_both_sides(self):
fused = fuse_scores(visual_score=0.8, description_score=0.4)
expected = VISUAL_WEIGHT * 0.8 + DESCRIPTION_WEIGHT * 0.4
self.assertAlmostEqual(fused, expected)
def test_missing_description_uses_visual_only(self):
fused = fuse_scores(visual_score=0.7, description_score=None)
self.assertAlmostEqual(fused, 0.7)
def test_missing_visual_uses_description_only(self):
fused = fuse_scores(visual_score=None, description_score=0.6)
self.assertAlmostEqual(fused, 0.6)
def test_both_missing_returns_none(self):
self.assertIsNone(fuse_scores(visual_score=None, description_score=None))
class TestToolDefinition(unittest.TestCase):
def test_find_similar_objects_is_registered(self):
tools = get_tool_definitions()
names = [t["function"]["name"] for t in tools]
self.assertIn("find_similar_objects", names)
def test_find_similar_objects_schema(self):
tools = get_tool_definitions()
tool = next(t for t in tools if t["function"]["name"] == "find_similar_objects")
params = tool["function"]["parameters"]["properties"]
self.assertIn("event_id", params)
self.assertIn("after", params)
self.assertIn("before", params)
self.assertIn("cameras", params)
self.assertIn("labels", params)
self.assertIn("sub_labels", params)
self.assertIn("zones", params)
self.assertIn("similarity_mode", params)
self.assertIn("min_score", params)
self.assertIn("limit", params)
self.assertEqual(tool["function"]["parameters"]["required"], ["event_id"])
self.assertEqual(
params["similarity_mode"]["enum"], ["visual", "semantic", "fused"]
)
class TestExecuteFindSimilarObjects(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
self.tmp.close()
self.db = SqliteExtDatabase(self.tmp.name)
Event.bind(self.db, bind_refs=False, bind_backrefs=False)
self.db.connect()
self.db.create_tables([Event])
# Insert an anchor plus two candidates.
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
Event.create(
id=event_id,
label=label,
sub_label=None,
camera=camera,
start_time=start,
end_time=start + 10,
top_score=0.9,
score=0.9,
false_positive=False,
zones=[],
thumbnail="",
has_clip=True,
has_snapshot=True,
region=[0, 0, 1, 1],
box=[0, 0, 1, 1],
area=1,
retain_indefinitely=False,
ratio=1.0,
plus_id="",
model_hash="",
detector_type="",
model_type="",
data={"description": "a green sedan"},
)
make("anchor", start=1_700_000_200)
make("cand_a", start=1_700_000_100)
make("cand_b", start=1_700_000_150)
self.make = make
def tearDown(self):
self.db.close()
os.unlink(self.tmp.name)
def _make_request(self, semantic_enabled=True, embeddings=None):
app = SimpleNamespace(
embeddings=embeddings,
frigate_config=SimpleNamespace(
semantic_search=SimpleNamespace(enabled=semantic_enabled),
),
)
return SimpleNamespace(app=app)
def test_semantic_search_disabled_returns_error(self):
req = self._make_request(semantic_enabled=False)
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor"},
allowed_cameras=["driveway"],
)
)
self.assertEqual(result["error"], "semantic_search_disabled")
def test_anchor_not_found_returns_error(self):
embeddings = MagicMock()
req = self._make_request(embeddings=embeddings)
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "nope"},
allowed_cameras=["driveway"],
)
)
self.assertEqual(result["error"], "anchor_not_found")
def test_empty_candidates_returns_empty_results(self):
embeddings = MagicMock()
req = self._make_request(embeddings=embeddings)
# Filter to a camera with no other events.
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
allowed_cameras=["nonexistent_cam"],
)
)
self.assertEqual(result["results"], [])
self.assertFalse(result["candidate_truncated"])
self.assertEqual(result["anchor"]["id"], "anchor")
def test_fused_calls_both_searches_and_ranks(self):
embeddings = MagicMock()
# cand_a visually closer, cand_b semantically closer.
embeddings.search_thumbnail.return_value = [
("cand_a", 0.10),
("cand_b", 0.40),
]
embeddings.search_description.return_value = [
("cand_a", 0.50),
("cand_b", 0.20),
]
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
embeddings.desc_stats = ZScoreNormalization()
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
req = self._make_request(embeddings=embeddings)
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor"},
allowed_cameras=["driveway"],
)
)
embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_called_once()
# cand_a should rank first because visual is weighted higher.
self.assertEqual(result["results"][0]["id"], "cand_a")
self.assertIn("score", result["results"][0])
self.assertEqual(result["similarity_mode"], "fused")
def test_visual_mode_only_calls_thumbnail(self):
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings)
_run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
)
)
embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_not_called()
def test_semantic_mode_only_calls_description(self):
embeddings = MagicMock()
embeddings.search_description.return_value = [("cand_a", 0.1)]
embeddings.desc_stats = ZScoreNormalization()
embeddings.desc_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings)
_run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "semantic"},
allowed_cameras=["driveway"],
)
)
embeddings.search_description.assert_called_once()
embeddings.search_thumbnail.assert_not_called()
def test_min_score_drops_low_scoring_results(self):
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [
("cand_a", 0.10),
("cand_b", 0.90),
]
embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings)
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
allowed_cameras=["driveway"],
)
)
ids = [r["id"] for r in result["results"]]
self.assertIn("cand_a", ids)
self.assertNotIn("cand_b", ids)
def test_labels_defaults_to_anchor_label(self):
self.make("person_a", label="person")
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [
("cand_a", 0.1),
("cand_b", 0.2),
]
embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings)
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
)
)
ids = [r["id"] for r in result["results"]]
self.assertNotIn("person_a", ids)
if __name__ == "__main__":
unittest.main()

View File

@ -12,6 +12,23 @@
"result": "Result", "result": "Result",
"arguments": "Arguments:", "arguments": "Arguments:",
"response": "Response:", "response": "Response:",
"attachment_chip_label": "{{label}} on {{camera}}",
"attachment_chip_remove": "Remove attachment",
"open_in_explore": "Open in Explore",
"attach_event_aria": "Attach event {{eventId}}",
"attachment_picker_paste_label": "Or paste event ID",
"attachment_picker_attach": "Attach",
"attachment_picker_placeholder": "Attach an event",
"quick_reply_find_similar": "Find similar sightings",
"quick_reply_tell_me_more": "Tell me more about this",
"quick_reply_when_else": "When else was it seen?",
"quick_reply_find_similar_text": "Find similar sightings to this.",
"quick_reply_tell_me_more_text": "Tell me more about this one.",
"quick_reply_when_else_text": "When else was this seen?",
"anchor": "Reference",
"similarity_score": "Similarity",
"no_similar_objects_found": "No similar objects found.",
"semantic_search_required": "Semantic search must be enabled to find similar objects.",
"send": "Send", "send": "Send",
"suggested_requests": "Try asking:", "suggested_requests": "Try asking:",
"starting_requests": { "starting_requests": {

View File

@ -0,0 +1,111 @@
import { useApiHost } from "@/api";
import { useCameraFriendlyName } from "@/hooks/use-camera-friendly-name";
import { useTranslation } from "react-i18next";
import useSWR from "swr";
import { LuX, LuExternalLink } from "react-icons/lu";
import { Button } from "@/components/ui/button";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { cn } from "@/lib/utils";
import { getTranslatedLabel } from "@/utils/i18n";
type ChatAttachmentChipProps = {
eventId: string;
mode: "composer" | "bubble";
onRemove?: () => void;
};
/**
* Small horizontal chip rendering an event as an "attachment": a thumbnail,
* a friendly label like "Person on driveway", an optional remove X (composer
* mode), and an external-link icon that opens the event in Explore.
*/
export function ChatAttachmentChip({
eventId,
mode,
onRemove,
}: ChatAttachmentChipProps) {
const apiHost = useApiHost();
const { t } = useTranslation(["views/chat"]);
const { data: eventData } = useSWR<{ label: string; camera: string }[]>(
`event_ids?ids=${eventId}`,
);
const evt = eventData?.[0];
const cameraName = useCameraFriendlyName(evt?.camera);
const displayLabel = evt
? t("attachment_chip_label", {
label: getTranslatedLabel(evt.label),
camera: cameraName,
})
: eventId;
return (
<div
className={cn(
"inline-flex max-w-full items-center gap-2 rounded-lg border border-border bg-background/80 p-1.5 pr-2",
mode === "bubble" && "border-primary-foreground/30 bg-transparent",
)}
>
<div className="relative size-10 shrink-0 overflow-hidden rounded-md">
<img
className="size-full object-cover"
src={`${apiHost}api/events/${eventId}/thumbnail.webp`}
alt=""
loading="lazy"
onError={(e) => {
(e.currentTarget as HTMLImageElement).style.visibility = "hidden";
}}
/>
</div>
{evt ? (
<span
className={cn(
"truncate text-xs",
mode === "bubble"
? "text-primary-foreground/90"
: "text-foreground",
)}
>
{displayLabel}
</span>
) : (
<ActivityIndicator className="size-4" />
)}
<Tooltip>
<TooltipTrigger asChild>
<a
href={`/explore?event_id=${eventId}`}
target="_blank"
rel="noopener noreferrer"
className={cn(
"flex size-6 shrink-0 items-center justify-center rounded text-muted-foreground hover:text-foreground",
mode === "bubble" &&
"text-primary-foreground/70 hover:text-primary-foreground",
)}
onClick={(e) => e.stopPropagation()}
aria-label={t("open_in_explore")}
>
<LuExternalLink className="size-3.5" />
</a>
</TooltipTrigger>
<TooltipContent>{t("open_in_explore")}</TooltipContent>
</Tooltip>
{mode === "composer" && onRemove && (
<Button
variant="ghost"
size="icon"
className="size-6 shrink-0 text-muted-foreground hover:text-foreground"
onClick={onRemove}
aria-label={t("attachment_chip_remove")}
>
<LuX className="size-3.5" />
</Button>
)}
</div>
);
}

View File

@ -1,42 +1,97 @@
import { useApiHost } from "@/api"; import { useApiHost } from "@/api";
import { useTranslation } from "react-i18next";
import { LuExternalLink } from "react-icons/lu";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { cn } from "@/lib/utils";
type ChatEvent = { id: string; score?: number };
type ChatEventThumbnailsRowProps = { type ChatEventThumbnailsRowProps = {
events: { id: string }[]; events: ChatEvent[];
anchor?: { id: string } | null;
onAttach?: (eventId: string) => void;
}; };
/** /**
* Horizontal scroll row of event thumbnail images for chat (e.g. after search_objects). * Horizontal scroll row of event thumbnail images for chat.
* Renders nothing when events is empty. * Optionally renders an anchor thumbnail with a "reference" badge above the
* results, and per-event similarity scores when provided.
* Clicking a thumbnail calls onAttach; a small external-link overlay opens
* the event in Explore.
* Renders nothing when there is nothing to show.
*/ */
export function ChatEventThumbnailsRow({ export function ChatEventThumbnailsRow({
events, events,
anchor = null,
onAttach,
}: ChatEventThumbnailsRowProps) { }: ChatEventThumbnailsRowProps) {
const apiHost = useApiHost(); const apiHost = useApiHost();
const { t } = useTranslation(["views/chat"]);
if (events.length === 0) return null; if (events.length === 0 && !anchor) return null;
const renderThumb = (event: ChatEvent, isAnchor = false) => (
<div
key={event.id}
className={cn(
"relative aspect-square size-32 shrink-0 overflow-hidden rounded-lg",
isAnchor && "ring-2 ring-primary",
)}
>
<button
type="button"
className="block size-full"
onClick={() => onAttach?.(event.id)}
aria-label={t("attach_event_aria", { eventId: event.id })}
>
<img
className="size-full object-cover"
src={`${apiHost}api/events/${event.id}/thumbnail.webp`}
alt=""
loading="lazy"
/>
</button>
<Tooltip>
<TooltipTrigger asChild>
<a
href={`/explore?event_id=${event.id}`}
target="_blank"
rel="noopener noreferrer"
onClick={(e) => e.stopPropagation()}
className="absolute right-1 top-1 flex size-6 items-center justify-center rounded bg-black/60 text-white hover:bg-black/80"
aria-label={t("open_in_explore")}
>
<LuExternalLink className="size-3" />
</a>
</TooltipTrigger>
<TooltipContent>{t("open_in_explore")}</TooltipContent>
</Tooltip>
{isAnchor && (
<span className="pointer-events-none absolute left-1 top-1 rounded bg-primary px-1 text-[10px] text-primary-foreground">
{t("anchor")}
</span>
)}
</div>
);
return ( return (
<div className="flex min-w-0 max-w-full flex-col gap-1 self-start"> <div className="flex min-w-0 max-w-full flex-col gap-2 self-start">
<div className="scrollbar-container min-w-0 overflow-x-auto"> {anchor && (
<div className="flex w-max gap-2"> <div className="scrollbar-container min-w-0 overflow-x-auto">
{events.map((event) => ( <div className="flex w-max gap-2">{renderThumb(anchor, true)}</div>
<a
key={event.id}
href={`/explore?event_id=${event.id}`}
target="_blank"
rel="noopener noreferrer"
className="relative aspect-square size-32 shrink-0 overflow-hidden rounded-lg"
>
<img
className="size-full object-cover"
src={`${apiHost}api/events/${event.id}/thumbnail.webp`}
alt=""
loading="lazy"
/>
</a>
))}
</div> </div>
</div> )}
{events.length > 0 && (
<div className="scrollbar-container min-w-0 overflow-x-auto">
<div className="flex w-max gap-2">
{events.map((event) => renderThumb(event))}
</div>
</div>
)}
</div> </div>
); );
} }

View File

@ -15,6 +15,8 @@ import {
TooltipTrigger, TooltipTrigger,
} from "@/components/ui/tooltip"; } from "@/components/ui/tooltip";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { ChatAttachmentChip } from "@/components/chat/ChatAttachmentChip";
import { parseAttachedEvent } from "@/utils/chatUtil";
type MessageBubbleProps = { type MessageBubbleProps = {
role: "user" | "assistant"; role: "user" | "assistant";
@ -126,6 +128,10 @@ export function MessageBubble({
); );
} }
const { eventId: attachedEventId, body: displayContent } = isUser
? parseAttachedEvent(content)
: { eventId: null, body: content };
return ( return (
<div <div
className={cn( className={cn(
@ -140,9 +146,20 @@ export function MessageBubble({
)} )}
> >
{isUser ? ( {isUser ? (
content <div className="flex flex-col gap-2">
{attachedEventId && (
<ChatAttachmentChip eventId={attachedEventId} mode="bubble" />
)}
<div className="whitespace-pre-wrap">{displayContent}</div>
</div>
) : ( ) : (
<> <div
className={cn(
"[&>*:last-child]:inline",
!isComplete &&
"after:ml-0.5 after:inline-block after:h-4 after:w-2 after:animate-cursor-blink after:rounded-sm after:bg-foreground after:align-middle after:content-['']",
)}
>
<ReactMarkdown <ReactMarkdown
remarkPlugins={[remarkGfm]} remarkPlugins={[remarkGfm]}
components={{ components={{
@ -168,10 +185,7 @@ export function MessageBubble({
> >
{content} {content}
</ReactMarkdown> </ReactMarkdown>
{!isComplete && ( </div>
<span className="ml-1 inline-block h-4 w-0.5 animate-pulse bg-foreground align-middle" />
)}
</>
)} )}
</div> </div>
<div className="flex items-center gap-0.5"> <div className="flex items-center gap-0.5">

View File

@ -0,0 +1,114 @@
import { useState } from "react";
import { useTranslation } from "react-i18next";
import { LuPaperclip } from "react-icons/lu";
import { useApiHost } from "@/api";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
const EVENT_ID_RE = /^[A-Za-z0-9._-]+$/;
type ChatPaperclipButtonProps = {
recentEventIds: string[];
onAttach: (eventId: string) => void;
disabled?: boolean;
};
/**
* Paperclip button with a popover for picking an event to attach.
* Shows a grid of recent thumbnails (from the latest assistant message) and a
* "paste event ID" fallback input.
*/
export function ChatPaperclipButton({
recentEventIds,
onAttach,
disabled = false,
}: ChatPaperclipButtonProps) {
const apiHost = useApiHost();
const { t } = useTranslation(["views/chat"]);
const [open, setOpen] = useState(false);
const [pasteId, setPasteId] = useState("");
const handlePickThumbnail = (eventId: string) => {
onAttach(eventId);
setOpen(false);
setPasteId("");
};
const handlePasteSubmit = () => {
const trimmed = pasteId.trim();
if (!trimmed || !EVENT_ID_RE.test(trimmed)) return;
onAttach(trimmed);
setOpen(false);
setPasteId("");
};
const handlePasteKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === "Enter") {
e.preventDefault();
handlePasteSubmit();
}
};
return (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild>
<Button
variant="ghost"
size="icon"
className="size-10 shrink-0 rounded-full"
disabled={disabled}
aria-label={t("attachment_picker_placeholder")}
>
<LuPaperclip className="size-4" />
</Button>
</PopoverTrigger>
<PopoverContent className="w-72" align="start">
<div className="flex flex-col gap-3">
{recentEventIds.length > 0 && (
<div className="grid grid-cols-4 gap-2">
{recentEventIds.slice(0, 8).map((id) => (
<button
key={id}
type="button"
onClick={() => handlePickThumbnail(id)}
className="relative aspect-square overflow-hidden rounded-md ring-offset-background hover:ring-2 hover:ring-primary"
aria-label={t("attach_event_aria", { eventId: id })}
>
<img
className="size-full object-cover"
src={`${apiHost}api/events/${id}/thumbnail.webp`}
alt=""
loading="lazy"
/>
</button>
))}
</div>
)}
<div className="flex items-center gap-2">
<Input
placeholder={t("attachment_picker_paste_label")}
value={pasteId}
onChange={(e) => setPasteId(e.target.value)}
onKeyDown={handlePasteKeyDown}
className="h-8 text-xs"
/>
<Button
size="sm"
variant="select"
className="h-8"
disabled={!pasteId.trim() || !EVENT_ID_RE.test(pasteId.trim())}
onClick={handlePasteSubmit}
>
{t("attachment_picker_attach")}
</Button>
</div>
</div>
</PopoverContent>
</Popover>
);
}

View File

@ -0,0 +1,49 @@
import { useTranslation } from "react-i18next";
import { Button } from "@/components/ui/button";
type QuickReply = { labelKey: string; textKey: string };
const REPLIES: QuickReply[] = [
{
labelKey: "quick_reply_find_similar",
textKey: "quick_reply_find_similar_text",
},
{
labelKey: "quick_reply_tell_me_more",
textKey: "quick_reply_tell_me_more_text",
},
{ labelKey: "quick_reply_when_else", textKey: "quick_reply_when_else_text" },
];
type ChatQuickRepliesProps = {
onSend: (text: string) => void;
disabled?: boolean;
};
/**
* Row of pill buttons shown in the composer while an attachment is pending.
* Clicking a pill immediately calls onSend with the canned text.
*/
export function ChatQuickReplies({
onSend,
disabled = false,
}: ChatQuickRepliesProps) {
const { t } = useTranslation(["views/chat"]);
return (
<div className="flex w-full flex-wrap gap-2">
{REPLIES.map((reply) => (
<Button
key={reply.labelKey}
variant="outline"
size="sm"
className="h-7 rounded-full px-3 text-xs"
disabled={disabled}
onClick={() => onSend(t(reply.textKey))}
>
{t(reply.labelKey)}
</Button>
))}
</div>
);
}

View File

@ -1,17 +1,22 @@
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { FaArrowUpLong } from "react-icons/fa6"; import { FaArrowUpLong, FaStop } from "react-icons/fa6";
import { LuCircleAlert } from "react-icons/lu"; import { LuCircleAlert } from "react-icons/lu";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { useState, useCallback, useRef, useEffect } from "react"; import { useState, useCallback, useRef, useEffect, useMemo } from "react";
import axios from "axios"; import axios from "axios";
import { ChatEventThumbnailsRow } from "@/components/chat/ChatEventThumbnailsRow"; import { ChatEventThumbnailsRow } from "@/components/chat/ChatEventThumbnailsRow";
import { MessageBubble } from "@/components/chat/ChatMessage"; import { MessageBubble } from "@/components/chat/ChatMessage";
import { ToolCallsGroup } from "@/components/chat/ToolCallsGroup"; import { ToolCallsGroup } from "@/components/chat/ToolCallsGroup";
import { ChatStartingState } from "@/components/chat/ChatStartingState"; import { ChatStartingState } from "@/components/chat/ChatStartingState";
import { ChatAttachmentChip } from "@/components/chat/ChatAttachmentChip";
import { ChatQuickReplies } from "@/components/chat/ChatQuickReplies";
import { ChatPaperclipButton } from "@/components/chat/ChatPaperclipButton";
import type { ChatMessage } from "@/types/chat"; import type { ChatMessage } from "@/types/chat";
import { import {
getEventIdsFromSearchObjectsToolCalls, getEventIdsFromSearchObjectsToolCalls,
getFindSimilarObjectsFromToolCalls,
prependAttachment,
streamChatCompletion, streamChatCompletion,
} from "@/utils/chatUtil"; } from "@/utils/chatUtil";
@ -21,7 +26,9 @@ export default function ChatPage() {
const [messages, setMessages] = useState<ChatMessage[]>([]); const [messages, setMessages] = useState<ChatMessage[]>([]);
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [attachedEventId, setAttachedEventId] = useState<string | null>(null);
const scrollRef = useRef<HTMLDivElement>(null); const scrollRef = useRef<HTMLDivElement>(null);
const abortRef = useRef<AbortController | null>(null);
useEffect(() => { useEffect(() => {
document.title = t("documentTitle"); document.title = t("documentTitle");
@ -64,22 +71,59 @@ export default function ChatPage() {
...(axios.defaults.headers.common as Record<string, string>), ...(axios.defaults.headers.common as Record<string, string>),
}; };
await streamChatCompletion(url, headers, apiMessages, { const controller = new AbortController();
updateMessages: (updater) => setMessages(updater), abortRef.current = controller;
onError: (message) => setError(message),
onDone: () => setIsLoading(false), await streamChatCompletion(
defaultErrorMessage: t("error"), url,
}); headers,
apiMessages,
{
updateMessages: (updater) => setMessages(updater),
onError: (message) => setError(message),
onDone: () => {
abortRef.current = null;
setIsLoading(false);
},
defaultErrorMessage: t("error"),
},
controller.signal,
);
}, },
[isLoading, t], [isLoading, t],
); );
const sendMessage = useCallback(() => { const recentEventIds = useMemo(() => {
const text = input.trim(); for (let i = messages.length - 1; i >= 0; i--) {
if (!text || isLoading) return; const msg = messages[i];
setInput(""); if (msg.role !== "assistant" || !msg.toolCalls) continue;
submitConversation([...messages, { role: "user", content: text }]); const similar = getFindSimilarObjectsFromToolCalls(msg.toolCalls);
}, [input, isLoading, messages, submitConversation]); if (similar) return similar.results.map((e) => e.id);
const events = getEventIdsFromSearchObjectsToolCalls(msg.toolCalls);
if (events.length > 0) return events.map((e) => e.id);
}
return [];
}, [messages]);
const sendMessage = useCallback(
(textOverride?: string) => {
const text = (textOverride ?? input).trim();
if (!text || isLoading) return;
const wireText = attachedEventId
? prependAttachment(text, attachedEventId)
: text;
setInput("");
setAttachedEventId(null);
submitConversation([...messages, { role: "user", content: wireText }]);
},
[attachedEventId, input, isLoading, messages, submitConversation],
);
const stopGeneration = useCallback(() => {
abortRef.current?.abort();
abortRef.current = null;
setIsLoading(false);
}, []);
const handleEditSubmit = useCallback( const handleEditSubmit = useCallback(
(messageIndex: number, newContent: string) => { (messageIndex: number, newContent: string) => {
@ -92,6 +136,10 @@ export default function ChatPage() {
[messages, submitConversation], [messages, submitConversation],
); );
const handleClearAttachment = useCallback(() => {
setAttachedEventId(null);
}, []);
return ( return (
<div className="flex size-full justify-center p-2 md:p-4"> <div className="flex size-full justify-center p-2 md:p-4">
<div className="flex size-full flex-col xl:w-[50%] 3xl:w-[35%]"> <div className="flex size-full flex-col xl:w-[50%] 3xl:w-[35%]">
@ -161,10 +209,27 @@ export default function ChatPage() {
{msg.role === "assistant" && {msg.role === "assistant" &&
isComplete && isComplete &&
(() => { (() => {
const similar = getFindSimilarObjectsFromToolCalls(
msg.toolCalls,
);
if (similar) {
return (
<ChatEventThumbnailsRow
events={similar.results}
anchor={similar.anchor}
onAttach={setAttachedEventId}
/>
);
}
const events = getEventIdsFromSearchObjectsToolCalls( const events = getEventIdsFromSearchObjectsToolCalls(
msg.toolCalls, msg.toolCalls,
); );
return <ChatEventThumbnailsRow events={events} />; return (
<ChatEventThumbnailsRow
events={events}
onAttach={setAttachedEventId}
/>
);
})()} })()}
</div> </div>
); );
@ -188,6 +253,11 @@ export default function ChatPage() {
sendMessage={sendMessage} sendMessage={sendMessage}
isLoading={isLoading} isLoading={isLoading}
placeholder={t("placeholder")} placeholder={t("placeholder")}
attachedEventId={attachedEventId}
onClearAttachment={handleClearAttachment}
onAttach={setAttachedEventId}
onStop={stopGeneration}
recentEventIds={recentEventIds}
/> />
)} )}
</div> </div>
@ -198,9 +268,14 @@ export default function ChatPage() {
type ChatEntryProps = { type ChatEntryProps = {
input: string; input: string;
setInput: (value: string) => void; setInput: (value: string) => void;
sendMessage: () => void; sendMessage: (textOverride?: string) => void;
isLoading: boolean; isLoading: boolean;
placeholder: string; placeholder: string;
attachedEventId: string | null;
onClearAttachment: () => void;
onAttach: (eventId: string) => void;
onStop: () => void;
recentEventIds: string[];
}; };
function ChatEntry({ function ChatEntry({
@ -209,6 +284,11 @@ function ChatEntry({
sendMessage, sendMessage,
isLoading, isLoading,
placeholder, placeholder,
attachedEventId,
onClearAttachment,
onAttach,
onStop,
recentEventIds,
}: ChatEntryProps) { }: ChatEntryProps) {
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => { const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === "Enter" && !e.shiftKey) { if (e.key === "Enter" && !e.shiftKey) {
@ -218,8 +298,28 @@ function ChatEntry({
}; };
return ( return (
<div className="mt-2 flex w-full flex-col items-center justify-center rounded-xl bg-secondary p-3"> <div className="mt-2 flex w-full flex-col items-stretch justify-center gap-2 rounded-xl bg-secondary p-3">
{attachedEventId && (
<div className="flex items-center">
<ChatAttachmentChip
eventId={attachedEventId}
mode="composer"
onRemove={onClearAttachment}
/>
</div>
)}
{attachedEventId && (
<ChatQuickReplies
onSend={(text) => sendMessage(text)}
disabled={isLoading}
/>
)}
<div className="flex w-full flex-row items-center gap-2"> <div className="flex w-full flex-row items-center gap-2">
<ChatPaperclipButton
recentEventIds={recentEventIds}
onAttach={onAttach}
disabled={isLoading || attachedEventId != null}
/>
<Input <Input
className="w-full flex-1 border-transparent bg-transparent shadow-none focus-visible:ring-0 dark:bg-transparent" className="w-full flex-1 border-transparent bg-transparent shadow-none focus-visible:ring-0 dark:bg-transparent"
placeholder={placeholder} placeholder={placeholder}
@ -228,14 +328,24 @@ function ChatEntry({
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
aria-busy={isLoading} aria-busy={isLoading}
/> />
<Button {isLoading ? (
variant="select" <Button
className="size-10 shrink-0 rounded-full" variant="destructive"
disabled={!input.trim() || isLoading} className="size-10 shrink-0 rounded-full"
onClick={sendMessage} onClick={onStop}
> >
<FaArrowUpLong size="16" /> <FaStop className="size-3" />
</Button> </Button>
) : (
<Button
variant="select"
className="size-10 shrink-0 rounded-full"
disabled={!input.trim()}
onClick={() => sendMessage()}
>
<FaArrowUpLong className="size-4" />
</Button>
)}
</div> </div>
</div> </div>
); );

View File

@ -25,6 +25,7 @@ export async function streamChatCompletion(
headers: Record<string, string>, headers: Record<string, string>,
apiMessages: { role: string; content: string }[], apiMessages: { role: string; content: string }[],
callbacks: StreamChatCallbacks, callbacks: StreamChatCallbacks,
signal?: AbortSignal,
): Promise<void> { ): Promise<void> {
const { const {
updateMessages, updateMessages,
@ -38,6 +39,7 @@ export async function streamChatCompletion(
method: "POST", method: "POST",
headers, headers,
body: JSON.stringify({ messages: apiMessages, stream: true }), body: JSON.stringify({ messages: apiMessages, stream: true }),
signal,
}); });
if (!res.ok) { if (!res.ok) {
@ -152,11 +154,15 @@ export async function streamChatCompletion(
return next; return next;
}); });
} }
} catch { } catch (err) {
onError(defaultErrorMessage); if (err instanceof DOMException && err.name === "AbortError") {
updateMessages((prev) => // User stopped generation — not an error
prev.filter((m) => !(m.role === "assistant" && m.content === "")), } else {
); onError(defaultErrorMessage);
updateMessages((prev) =>
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
);
}
} finally { } finally {
onDone(); onDone();
} }
@ -191,3 +197,72 @@ export function getEventIdsFromSearchObjectsToolCalls(
} }
return results; return results;
} }
const ATTACHED_EVENT_MARKER = /^\[attached_event:([A-Za-z0-9._-]+)\]\s*\n?/;
export function parseAttachedEvent(content: string): {
eventId: string | null;
body: string;
} {
if (!content) return { eventId: null, body: content };
const match = content.match(ATTACHED_EVENT_MARKER);
if (!match) return { eventId: null, body: content };
const body = content.slice(match[0].length).replace(/^\n+/, "");
return { eventId: match[1], body };
}
export function prependAttachment(body: string, eventId: string): string {
return `[attached_event:${eventId}]\n\n${body}`;
}
export type FindSimilarObjectsResult = {
anchor: { id: string } | null;
results: { id: string; score?: number }[];
};
/**
* Parse find_similar_objects tool call response(s) into anchor + ranked results.
* Returns null if no find_similar_objects call is present so the caller can
* decide whether to render.
*/
export function getFindSimilarObjectsFromToolCalls(
toolCalls: ToolCall[] | undefined,
): FindSimilarObjectsResult | null {
if (!toolCalls?.length) return null;
for (const tc of toolCalls) {
if (tc.name !== "find_similar_objects" || !tc.response?.trim()) continue;
try {
const parsed = JSON.parse(tc.response) as {
anchor?: { id?: unknown };
results?: unknown;
};
const anchorId =
parsed.anchor && typeof parsed.anchor.id === "string"
? parsed.anchor.id
: null;
const anchor = anchorId ? { id: anchorId } : null;
const results: { id: string; score?: number }[] = [];
if (Array.isArray(parsed.results)) {
for (const item of parsed.results) {
if (
item &&
typeof item === "object" &&
"id" in item &&
typeof (item as { id: unknown }).id === "string"
) {
const entry: { id: string; score?: number } = {
id: (item as { id: string }).id,
};
const rawScore = (item as { score?: unknown }).score;
if (typeof rawScore === "number") entry.score = rawScore;
results.push(entry);
}
}
}
return { anchor, results };
} catch {
// ignore parse errors
}
}
return null;
}

View File

@ -49,6 +49,7 @@ module.exports = {
scale4: "scale4 3s ease-in-out infinite", scale4: "scale4 3s ease-in-out infinite",
"timeline-zoom-in": "timeline-zoom-in 0.3s ease-out", "timeline-zoom-in": "timeline-zoom-in 0.3s ease-out",
"timeline-zoom-out": "timeline-zoom-out 0.3s ease-out", "timeline-zoom-out": "timeline-zoom-out 0.3s ease-out",
"cursor-blink": "cursor-blink 1s step-end infinite",
}, },
aspectRatio: { aspectRatio: {
wide: "32 / 9", wide: "32 / 9",
@ -189,6 +190,10 @@ module.exports = {
"50%": { transform: "translateY(0%)", opacity: "0.5" }, "50%": { transform: "translateY(0%)", opacity: "0.5" },
"100%": { transform: "translateY(0)", opacity: "1" }, "100%": { transform: "translateY(0)", opacity: "1" },
}, },
"cursor-blink": {
"0%, 100%": { opacity: "1" },
"50%": { opacity: "0" },
},
}, },
screens: { screens: {
xs: "480px", xs: "480px",