diff --git a/frigate/api/chat.py b/frigate/api/chat.py index 41413048c..a549bf70f 100644 --- a/frigate/api/chat.py +++ b/frigate/api/chat.py @@ -597,6 +597,193 @@ async def _execute_search_objects( ) +def _parse_iso_to_timestamp(value: Optional[str]) -> Optional[float]: + """Parse an ISO-8601 string as server-local time -> unix timestamp. + + Mirrors the parsing _execute_search_objects uses so both tools accept the + same format from the LLM. + """ + if value is None: + return None + try: + s = value.replace("Z", "").strip()[:19] + dt = datetime.strptime(s, "%Y-%m-%dT%H:%M:%S") + return time.mktime(dt.timetuple()) + except (ValueError, AttributeError, TypeError): + logger.warning("Invalid timestamp format: %s", value) + return None + + +def _hydrate_event(event: Event, score: Optional[float] = None) -> Dict[str, Any]: + """Convert an Event row into the dict shape returned by find_similar_objects.""" + data: Dict[str, Any] = { + "id": event.id, + "camera": event.camera, + "label": event.label, + "sub_label": event.sub_label, + "start_time": event.start_time, + "end_time": event.end_time, + "zones": event.zones, + } + if score is not None: + data["score"] = score + return data + + +async def _execute_find_similar_objects( + request: Request, + arguments: Dict[str, Any], + allowed_cameras: List[str], +) -> Dict[str, Any]: + """Execute the find_similar_objects tool. + + Returns a plain dict (not JSONResponse) so the chat loop can embed it + directly in tool-result messages. + """ + # 1. Semantic search enabled? + config = request.app.frigate_config + if not getattr(config.semantic_search, "enabled", False): + return { + "error": "semantic_search_disabled", + "message": ( + "Semantic search must be enabled to find similar objects. " + "Enable it in the Frigate config under semantic_search." + ), + } + + context = request.app.embeddings + if context is None: + return { + "error": "semantic_search_disabled", + "message": "Embeddings context is not available.", + } + + # 2. Anchor lookup. + event_id = arguments.get("event_id") + if not event_id: + return {"error": "missing_event_id", "message": "event_id is required."} + + try: + anchor = Event.get(Event.id == event_id) + except Event.DoesNotExist: + return { + "error": "anchor_not_found", + "message": f"Could not find event {event_id}.", + } + + # 3. Parse params. + after = _parse_iso_to_timestamp(arguments.get("after")) + before = _parse_iso_to_timestamp(arguments.get("before")) + + cameras = arguments.get("cameras") + if cameras: + # Respect RBAC: intersect with the user's allowed cameras. + cameras = [c for c in cameras if c in allowed_cameras] + else: + cameras = list(allowed_cameras) if allowed_cameras else None + + labels = arguments.get("labels") or [anchor.label] + sub_labels = arguments.get("sub_labels") + zones = arguments.get("zones") + + similarity_mode = arguments.get("similarity_mode", "fused") + if similarity_mode not in ("visual", "semantic", "fused"): + similarity_mode = "fused" + + min_score = arguments.get("min_score") + limit = int(arguments.get("limit", 10)) + limit = max(1, min(limit, 50)) + + # 4. Pre-filter candidates. + candidate_ids = _build_similar_candidates_query( + anchor_id=anchor.id, + after=after, + before=before, + cameras=cameras, + labels=labels, + sub_labels=sub_labels, + zones=zones, + ) + candidate_truncated = len(candidate_ids) == CANDIDATE_CAP + + if not candidate_ids: + return { + "anchor": _hydrate_event(anchor), + "results": [], + "similarity_mode": similarity_mode, + "candidate_truncated": False, + } + + # 5. Run similarity searches. + visual_distances: Dict[str, float] = {} + description_distances: Dict[str, float] = {} + + try: + if similarity_mode in ("visual", "fused"): + rows = context.search_thumbnail(anchor, event_ids=candidate_ids) + visual_distances = {row[0]: row[1] for row in rows} + + if similarity_mode in ("semantic", "fused"): + query_text = ( + (anchor.data or {}).get("description") + or anchor.sub_label + or anchor.label + ) + rows = context.search_description(query_text, event_ids=candidate_ids) + description_distances = {row[0]: row[1] for row in rows} + except Exception: + logger.exception("Similarity search failed") + return { + "error": "similarity_search_failed", + "message": "Failed to run similarity search.", + } + + # 6. Fuse and rank. + scored: List[tuple[str, float]] = [] + matched_ids = set(visual_distances) | set(description_distances) + for eid in matched_ids: + v_score = ( + _distance_to_score(visual_distances[eid], context.thumb_stats) + if eid in visual_distances + else None + ) + d_score = ( + _distance_to_score(description_distances[eid], context.desc_stats) + if eid in description_distances + else None + ) + fused = _fuse_scores(v_score, d_score) + if fused is None: + continue + if min_score is not None and fused < min_score: + continue + scored.append((eid, fused)) + + scored.sort(key=lambda pair: pair[1], reverse=True) + scored = scored[:limit] + + # 7. Hydrate. + if scored: + event_rows = { + e.id: e + for e in Event.select().where(Event.id.in_([eid for eid, _ in scored])) + } + results = [ + _hydrate_event(event_rows[eid], score=score) + for eid, score in scored + if eid in event_rows + ] + else: + results = [] + + return { + "anchor": _hydrate_event(anchor), + "results": results, + "similarity_mode": similarity_mode, + "candidate_truncated": candidate_truncated, + } + + @router.post( "/chat/execute", dependencies=[Depends(allow_any_authenticated())], diff --git a/frigate/test/test_chat_find_similar_objects.py b/frigate/test/test_chat_find_similar_objects.py index 7987f6588..886ed753c 100644 --- a/frigate/test/test_chat_find_similar_objects.py +++ b/frigate/test/test_chat_find_similar_objects.py @@ -235,5 +235,187 @@ class TestToolDefinition(unittest.TestCase): ) +import asyncio +from types import SimpleNamespace +from unittest.mock import patch + + +def _run(coro): + return asyncio.new_event_loop().run_until_complete(coro) + + +class TestExecuteFindSimilarObjects(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + self.tmp.close() + self.db = SqliteExtDatabase(self.tmp.name) + Event.bind(self.db, bind_refs=False, bind_backrefs=False) + self.db.connect() + self.db.create_tables([Event]) + + # Insert an anchor plus two candidates. + def make(event_id, label="car", camera="driveway", start=1_700_000_100): + Event.create( + id=event_id, label=label, sub_label=None, camera=camera, + start_time=start, end_time=start + 10, + top_score=0.9, score=0.9, false_positive=False, + zones=[], thumbnail="", + has_clip=True, has_snapshot=True, + region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1, + retain_indefinitely=False, ratio=1.0, + plus_id="", model_hash="", detector_type="", + model_type="", data={"description": "a green sedan"}, + ) + + make("anchor", start=1_700_000_200) + make("cand_a", start=1_700_000_100) + make("cand_b", start=1_700_000_150) + self.make = make + + def tearDown(self): + self.db.close() + os.unlink(self.tmp.name) + + def _make_request(self, semantic_enabled=True, embeddings=None): + app = SimpleNamespace( + embeddings=embeddings, + frigate_config=SimpleNamespace( + semantic_search=SimpleNamespace(enabled=semantic_enabled), + ), + ) + return SimpleNamespace(app=app) + + def test_semantic_search_disabled_returns_error(self): + from frigate.api.chat import _execute_find_similar_objects + req = self._make_request(semantic_enabled=False) + result = _run(_execute_find_similar_objects( + req, {"event_id": "anchor"}, allowed_cameras=["driveway"], + )) + self.assertEqual(result["error"], "semantic_search_disabled") + + def test_anchor_not_found_returns_error(self): + from frigate.api.chat import _execute_find_similar_objects + embeddings = MagicMock() + req = self._make_request(embeddings=embeddings) + result = _run(_execute_find_similar_objects( + req, {"event_id": "nope"}, allowed_cameras=["driveway"], + )) + self.assertEqual(result["error"], "anchor_not_found") + + def test_empty_candidates_returns_empty_results(self): + from frigate.api.chat import _execute_find_similar_objects + embeddings = MagicMock() + req = self._make_request(embeddings=embeddings) + # Filter to a camera with no other events. + result = _run(_execute_find_similar_objects( + req, + {"event_id": "anchor", "cameras": ["nonexistent_cam"]}, + allowed_cameras=["nonexistent_cam"], + )) + self.assertEqual(result["results"], []) + self.assertFalse(result["candidate_truncated"]) + self.assertEqual(result["anchor"]["id"], "anchor") + + def test_fused_calls_both_searches_and_ranks(self): + from frigate.api.chat import _execute_find_similar_objects + embeddings = MagicMock() + # cand_a visually closer, cand_b semantically closer. + embeddings.search_thumbnail.return_value = [ + ("cand_a", 0.10), ("cand_b", 0.40), + ] + embeddings.search_description.return_value = [ + ("cand_a", 0.50), ("cand_b", 0.20), + ] + embeddings.thumb_stats = ZScoreNormalization() + embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5]) + embeddings.desc_stats = ZScoreNormalization() + embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5]) + + req = self._make_request(embeddings=embeddings) + result = _run(_execute_find_similar_objects( + req, {"event_id": "anchor"}, allowed_cameras=["driveway"], + )) + embeddings.search_thumbnail.assert_called_once() + embeddings.search_description.assert_called_once() + # cand_a should rank first because visual is weighted higher. + self.assertEqual(result["results"][0]["id"], "cand_a") + self.assertIn("score", result["results"][0]) + self.assertEqual(result["similarity_mode"], "fused") + + def test_visual_mode_only_calls_thumbnail(self): + from frigate.api.chat import _execute_find_similar_objects + embeddings = MagicMock() + embeddings.search_thumbnail.return_value = [("cand_a", 0.1)] + embeddings.thumb_stats = ZScoreNormalization() + embeddings.thumb_stats._update([0.1, 0.2, 0.3]) + + req = self._make_request(embeddings=embeddings) + _run(_execute_find_similar_objects( + req, + {"event_id": "anchor", "similarity_mode": "visual"}, + allowed_cameras=["driveway"], + )) + embeddings.search_thumbnail.assert_called_once() + embeddings.search_description.assert_not_called() + + def test_semantic_mode_only_calls_description(self): + from frigate.api.chat import _execute_find_similar_objects + embeddings = MagicMock() + embeddings.search_description.return_value = [("cand_a", 0.1)] + embeddings.desc_stats = ZScoreNormalization() + embeddings.desc_stats._update([0.1, 0.2, 0.3]) + + req = self._make_request(embeddings=embeddings) + _run(_execute_find_similar_objects( + req, + {"event_id": "anchor", "similarity_mode": "semantic"}, + allowed_cameras=["driveway"], + )) + embeddings.search_description.assert_called_once() + embeddings.search_thumbnail.assert_not_called() + + def test_min_score_drops_low_scoring_results(self): + from frigate.api.chat import _execute_find_similar_objects + embeddings = MagicMock() + embeddings.search_thumbnail.return_value = [ + ("cand_a", 0.10), ("cand_b", 0.90), + ] + embeddings.search_description.return_value = [] + embeddings.thumb_stats = ZScoreNormalization() + embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5]) + embeddings.desc_stats = ZScoreNormalization() + + req = self._make_request(embeddings=embeddings) + result = _run(_execute_find_similar_objects( + req, + {"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6}, + allowed_cameras=["driveway"], + )) + ids = [r["id"] for r in result["results"]] + self.assertIn("cand_a", ids) + self.assertNotIn("cand_b", ids) + + def test_labels_defaults_to_anchor_label(self): + from frigate.api.chat import _execute_find_similar_objects + self.make("person_a", label="person") + embeddings = MagicMock() + embeddings.search_thumbnail.return_value = [ + ("cand_a", 0.1), ("cand_b", 0.2), + ] + embeddings.search_description.return_value = [] + embeddings.thumb_stats = ZScoreNormalization() + embeddings.thumb_stats._update([0.1, 0.2, 0.3]) + embeddings.desc_stats = ZScoreNormalization() + + req = self._make_request(embeddings=embeddings) + result = _run(_execute_find_similar_objects( + req, + {"event_id": "anchor", "similarity_mode": "visual"}, + allowed_cameras=["driveway"], + )) + ids = [r["id"] for r in result["results"]] + self.assertNotIn("person_a", ids) + + if __name__ == "__main__": unittest.main()