implement _execute_find_similar_objects chat tool dispatcher

This commit is contained in:
Josh Hawkins 2026-04-08 15:22:28 -05:00
parent 1ad09f00dc
commit 1a003e7de2
2 changed files with 369 additions and 0 deletions

View File

@ -597,6 +597,193 @@ async def _execute_search_objects(
)
def _parse_iso_to_timestamp(value: Optional[str]) -> Optional[float]:
"""Parse an ISO-8601 string as server-local time -> unix timestamp.
Mirrors the parsing _execute_search_objects uses so both tools accept the
same format from the LLM.
"""
if value is None:
return None
try:
s = value.replace("Z", "").strip()[:19]
dt = datetime.strptime(s, "%Y-%m-%dT%H:%M:%S")
return time.mktime(dt.timetuple())
except (ValueError, AttributeError, TypeError):
logger.warning("Invalid timestamp format: %s", value)
return None
def _hydrate_event(event: Event, score: Optional[float] = None) -> Dict[str, Any]:
"""Convert an Event row into the dict shape returned by find_similar_objects."""
data: Dict[str, Any] = {
"id": event.id,
"camera": event.camera,
"label": event.label,
"sub_label": event.sub_label,
"start_time": event.start_time,
"end_time": event.end_time,
"zones": event.zones,
}
if score is not None:
data["score"] = score
return data
async def _execute_find_similar_objects(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
) -> Dict[str, Any]:
"""Execute the find_similar_objects tool.
Returns a plain dict (not JSONResponse) so the chat loop can embed it
directly in tool-result messages.
"""
# 1. Semantic search enabled?
config = request.app.frigate_config
if not getattr(config.semantic_search, "enabled", False):
return {
"error": "semantic_search_disabled",
"message": (
"Semantic search must be enabled to find similar objects. "
"Enable it in the Frigate config under semantic_search."
),
}
context = request.app.embeddings
if context is None:
return {
"error": "semantic_search_disabled",
"message": "Embeddings context is not available.",
}
# 2. Anchor lookup.
event_id = arguments.get("event_id")
if not event_id:
return {"error": "missing_event_id", "message": "event_id is required."}
try:
anchor = Event.get(Event.id == event_id)
except Event.DoesNotExist:
return {
"error": "anchor_not_found",
"message": f"Could not find event {event_id}.",
}
# 3. Parse params.
after = _parse_iso_to_timestamp(arguments.get("after"))
before = _parse_iso_to_timestamp(arguments.get("before"))
cameras = arguments.get("cameras")
if cameras:
# Respect RBAC: intersect with the user's allowed cameras.
cameras = [c for c in cameras if c in allowed_cameras]
else:
cameras = list(allowed_cameras) if allowed_cameras else None
labels = arguments.get("labels") or [anchor.label]
sub_labels = arguments.get("sub_labels")
zones = arguments.get("zones")
similarity_mode = arguments.get("similarity_mode", "fused")
if similarity_mode not in ("visual", "semantic", "fused"):
similarity_mode = "fused"
min_score = arguments.get("min_score")
limit = int(arguments.get("limit", 10))
limit = max(1, min(limit, 50))
# 4. Pre-filter candidates.
candidate_ids = _build_similar_candidates_query(
anchor_id=anchor.id,
after=after,
before=before,
cameras=cameras,
labels=labels,
sub_labels=sub_labels,
zones=zones,
)
candidate_truncated = len(candidate_ids) == CANDIDATE_CAP
if not candidate_ids:
return {
"anchor": _hydrate_event(anchor),
"results": [],
"similarity_mode": similarity_mode,
"candidate_truncated": False,
}
# 5. Run similarity searches.
visual_distances: Dict[str, float] = {}
description_distances: Dict[str, float] = {}
try:
if similarity_mode in ("visual", "fused"):
rows = context.search_thumbnail(anchor, event_ids=candidate_ids)
visual_distances = {row[0]: row[1] for row in rows}
if similarity_mode in ("semantic", "fused"):
query_text = (
(anchor.data or {}).get("description")
or anchor.sub_label
or anchor.label
)
rows = context.search_description(query_text, event_ids=candidate_ids)
description_distances = {row[0]: row[1] for row in rows}
except Exception:
logger.exception("Similarity search failed")
return {
"error": "similarity_search_failed",
"message": "Failed to run similarity search.",
}
# 6. Fuse and rank.
scored: List[tuple[str, float]] = []
matched_ids = set(visual_distances) | set(description_distances)
for eid in matched_ids:
v_score = (
_distance_to_score(visual_distances[eid], context.thumb_stats)
if eid in visual_distances
else None
)
d_score = (
_distance_to_score(description_distances[eid], context.desc_stats)
if eid in description_distances
else None
)
fused = _fuse_scores(v_score, d_score)
if fused is None:
continue
if min_score is not None and fused < min_score:
continue
scored.append((eid, fused))
scored.sort(key=lambda pair: pair[1], reverse=True)
scored = scored[:limit]
# 7. Hydrate.
if scored:
event_rows = {
e.id: e
for e in Event.select().where(Event.id.in_([eid for eid, _ in scored]))
}
results = [
_hydrate_event(event_rows[eid], score=score)
for eid, score in scored
if eid in event_rows
]
else:
results = []
return {
"anchor": _hydrate_event(anchor),
"results": results,
"similarity_mode": similarity_mode,
"candidate_truncated": candidate_truncated,
}
@router.post(
"/chat/execute",
dependencies=[Depends(allow_any_authenticated())],

View File

@ -235,5 +235,187 @@ class TestToolDefinition(unittest.TestCase):
)
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)
class TestExecuteFindSimilarObjects(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
self.tmp.close()
self.db = SqliteExtDatabase(self.tmp.name)
Event.bind(self.db, bind_refs=False, bind_backrefs=False)
self.db.connect()
self.db.create_tables([Event])
# Insert an anchor plus two candidates.
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
Event.create(
id=event_id, label=label, sub_label=None, camera=camera,
start_time=start, end_time=start + 10,
top_score=0.9, score=0.9, false_positive=False,
zones=[], thumbnail="",
has_clip=True, has_snapshot=True,
region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1,
retain_indefinitely=False, ratio=1.0,
plus_id="", model_hash="", detector_type="",
model_type="", data={"description": "a green sedan"},
)
make("anchor", start=1_700_000_200)
make("cand_a", start=1_700_000_100)
make("cand_b", start=1_700_000_150)
self.make = make
def tearDown(self):
self.db.close()
os.unlink(self.tmp.name)
def _make_request(self, semantic_enabled=True, embeddings=None):
app = SimpleNamespace(
embeddings=embeddings,
frigate_config=SimpleNamespace(
semantic_search=SimpleNamespace(enabled=semantic_enabled),
),
)
return SimpleNamespace(app=app)
def test_semantic_search_disabled_returns_error(self):
from frigate.api.chat import _execute_find_similar_objects
req = self._make_request(semantic_enabled=False)
result = _run(_execute_find_similar_objects(
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
))
self.assertEqual(result["error"], "semantic_search_disabled")
def test_anchor_not_found_returns_error(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req, {"event_id": "nope"}, allowed_cameras=["driveway"],
))
self.assertEqual(result["error"], "anchor_not_found")
def test_empty_candidates_returns_empty_results(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
req = self._make_request(embeddings=embeddings)
# Filter to a camera with no other events.
result = _run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
allowed_cameras=["nonexistent_cam"],
))
self.assertEqual(result["results"], [])
self.assertFalse(result["candidate_truncated"])
self.assertEqual(result["anchor"]["id"], "anchor")
def test_fused_calls_both_searches_and_ranks(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
# cand_a visually closer, cand_b semantically closer.
embeddings.search_thumbnail.return_value = [
("cand_a", 0.10), ("cand_b", 0.40),
]
embeddings.search_description.return_value = [
("cand_a", 0.50), ("cand_b", 0.20),
]
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
embeddings.desc_stats = ZScoreNormalization()
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
))
embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_called_once()
# cand_a should rank first because visual is weighted higher.
self.assertEqual(result["results"][0]["id"], "cand_a")
self.assertIn("score", result["results"][0])
self.assertEqual(result["similarity_mode"], "fused")
def test_visual_mode_only_calls_thumbnail(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings)
_run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
))
embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_not_called()
def test_semantic_mode_only_calls_description(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
embeddings.search_description.return_value = [("cand_a", 0.1)]
embeddings.desc_stats = ZScoreNormalization()
embeddings.desc_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings)
_run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "semantic"},
allowed_cameras=["driveway"],
))
embeddings.search_description.assert_called_once()
embeddings.search_thumbnail.assert_not_called()
def test_min_score_drops_low_scoring_results(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [
("cand_a", 0.10), ("cand_b", 0.90),
]
embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
allowed_cameras=["driveway"],
))
ids = [r["id"] for r in result["results"]]
self.assertIn("cand_a", ids)
self.assertNotIn("cand_b", ids)
def test_labels_defaults_to_anchor_label(self):
from frigate.api.chat import _execute_find_similar_objects
self.make("person_a", label="person")
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [
("cand_a", 0.1), ("cand_b", 0.2),
]
embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
))
ids = [r["id"] for r in result["results"]]
self.assertNotIn("person_a", ids)
if __name__ == "__main__":
unittest.main()