mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-09 08:37:37 +03:00
implement _execute_find_similar_objects chat tool dispatcher
This commit is contained in:
parent
1ad09f00dc
commit
1a003e7de2
@ -597,6 +597,193 @@ async def _execute_search_objects(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_iso_to_timestamp(value: Optional[str]) -> Optional[float]:
|
||||||
|
"""Parse an ISO-8601 string as server-local time -> unix timestamp.
|
||||||
|
|
||||||
|
Mirrors the parsing _execute_search_objects uses so both tools accept the
|
||||||
|
same format from the LLM.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
s = value.replace("Z", "").strip()[:19]
|
||||||
|
dt = datetime.strptime(s, "%Y-%m-%dT%H:%M:%S")
|
||||||
|
return time.mktime(dt.timetuple())
|
||||||
|
except (ValueError, AttributeError, TypeError):
|
||||||
|
logger.warning("Invalid timestamp format: %s", value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _hydrate_event(event: Event, score: Optional[float] = None) -> Dict[str, Any]:
|
||||||
|
"""Convert an Event row into the dict shape returned by find_similar_objects."""
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"id": event.id,
|
||||||
|
"camera": event.camera,
|
||||||
|
"label": event.label,
|
||||||
|
"sub_label": event.sub_label,
|
||||||
|
"start_time": event.start_time,
|
||||||
|
"end_time": event.end_time,
|
||||||
|
"zones": event.zones,
|
||||||
|
}
|
||||||
|
if score is not None:
|
||||||
|
data["score"] = score
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_find_similar_objects(
|
||||||
|
request: Request,
|
||||||
|
arguments: Dict[str, Any],
|
||||||
|
allowed_cameras: List[str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute the find_similar_objects tool.
|
||||||
|
|
||||||
|
Returns a plain dict (not JSONResponse) so the chat loop can embed it
|
||||||
|
directly in tool-result messages.
|
||||||
|
"""
|
||||||
|
# 1. Semantic search enabled?
|
||||||
|
config = request.app.frigate_config
|
||||||
|
if not getattr(config.semantic_search, "enabled", False):
|
||||||
|
return {
|
||||||
|
"error": "semantic_search_disabled",
|
||||||
|
"message": (
|
||||||
|
"Semantic search must be enabled to find similar objects. "
|
||||||
|
"Enable it in the Frigate config under semantic_search."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
context = request.app.embeddings
|
||||||
|
if context is None:
|
||||||
|
return {
|
||||||
|
"error": "semantic_search_disabled",
|
||||||
|
"message": "Embeddings context is not available.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Anchor lookup.
|
||||||
|
event_id = arguments.get("event_id")
|
||||||
|
if not event_id:
|
||||||
|
return {"error": "missing_event_id", "message": "event_id is required."}
|
||||||
|
|
||||||
|
try:
|
||||||
|
anchor = Event.get(Event.id == event_id)
|
||||||
|
except Event.DoesNotExist:
|
||||||
|
return {
|
||||||
|
"error": "anchor_not_found",
|
||||||
|
"message": f"Could not find event {event_id}.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Parse params.
|
||||||
|
after = _parse_iso_to_timestamp(arguments.get("after"))
|
||||||
|
before = _parse_iso_to_timestamp(arguments.get("before"))
|
||||||
|
|
||||||
|
cameras = arguments.get("cameras")
|
||||||
|
if cameras:
|
||||||
|
# Respect RBAC: intersect with the user's allowed cameras.
|
||||||
|
cameras = [c for c in cameras if c in allowed_cameras]
|
||||||
|
else:
|
||||||
|
cameras = list(allowed_cameras) if allowed_cameras else None
|
||||||
|
|
||||||
|
labels = arguments.get("labels") or [anchor.label]
|
||||||
|
sub_labels = arguments.get("sub_labels")
|
||||||
|
zones = arguments.get("zones")
|
||||||
|
|
||||||
|
similarity_mode = arguments.get("similarity_mode", "fused")
|
||||||
|
if similarity_mode not in ("visual", "semantic", "fused"):
|
||||||
|
similarity_mode = "fused"
|
||||||
|
|
||||||
|
min_score = arguments.get("min_score")
|
||||||
|
limit = int(arguments.get("limit", 10))
|
||||||
|
limit = max(1, min(limit, 50))
|
||||||
|
|
||||||
|
# 4. Pre-filter candidates.
|
||||||
|
candidate_ids = _build_similar_candidates_query(
|
||||||
|
anchor_id=anchor.id,
|
||||||
|
after=after,
|
||||||
|
before=before,
|
||||||
|
cameras=cameras,
|
||||||
|
labels=labels,
|
||||||
|
sub_labels=sub_labels,
|
||||||
|
zones=zones,
|
||||||
|
)
|
||||||
|
candidate_truncated = len(candidate_ids) == CANDIDATE_CAP
|
||||||
|
|
||||||
|
if not candidate_ids:
|
||||||
|
return {
|
||||||
|
"anchor": _hydrate_event(anchor),
|
||||||
|
"results": [],
|
||||||
|
"similarity_mode": similarity_mode,
|
||||||
|
"candidate_truncated": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Run similarity searches.
|
||||||
|
visual_distances: Dict[str, float] = {}
|
||||||
|
description_distances: Dict[str, float] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if similarity_mode in ("visual", "fused"):
|
||||||
|
rows = context.search_thumbnail(anchor, event_ids=candidate_ids)
|
||||||
|
visual_distances = {row[0]: row[1] for row in rows}
|
||||||
|
|
||||||
|
if similarity_mode in ("semantic", "fused"):
|
||||||
|
query_text = (
|
||||||
|
(anchor.data or {}).get("description")
|
||||||
|
or anchor.sub_label
|
||||||
|
or anchor.label
|
||||||
|
)
|
||||||
|
rows = context.search_description(query_text, event_ids=candidate_ids)
|
||||||
|
description_distances = {row[0]: row[1] for row in rows}
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Similarity search failed")
|
||||||
|
return {
|
||||||
|
"error": "similarity_search_failed",
|
||||||
|
"message": "Failed to run similarity search.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 6. Fuse and rank.
|
||||||
|
scored: List[tuple[str, float]] = []
|
||||||
|
matched_ids = set(visual_distances) | set(description_distances)
|
||||||
|
for eid in matched_ids:
|
||||||
|
v_score = (
|
||||||
|
_distance_to_score(visual_distances[eid], context.thumb_stats)
|
||||||
|
if eid in visual_distances
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
d_score = (
|
||||||
|
_distance_to_score(description_distances[eid], context.desc_stats)
|
||||||
|
if eid in description_distances
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
fused = _fuse_scores(v_score, d_score)
|
||||||
|
if fused is None:
|
||||||
|
continue
|
||||||
|
if min_score is not None and fused < min_score:
|
||||||
|
continue
|
||||||
|
scored.append((eid, fused))
|
||||||
|
|
||||||
|
scored.sort(key=lambda pair: pair[1], reverse=True)
|
||||||
|
scored = scored[:limit]
|
||||||
|
|
||||||
|
# 7. Hydrate.
|
||||||
|
if scored:
|
||||||
|
event_rows = {
|
||||||
|
e.id: e
|
||||||
|
for e in Event.select().where(Event.id.in_([eid for eid, _ in scored]))
|
||||||
|
}
|
||||||
|
results = [
|
||||||
|
_hydrate_event(event_rows[eid], score=score)
|
||||||
|
for eid, score in scored
|
||||||
|
if eid in event_rows
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
return {
|
||||||
|
"anchor": _hydrate_event(anchor),
|
||||||
|
"results": results,
|
||||||
|
"similarity_mode": similarity_mode,
|
||||||
|
"candidate_truncated": candidate_truncated,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/chat/execute",
|
"/chat/execute",
|
||||||
dependencies=[Depends(allow_any_authenticated())],
|
dependencies=[Depends(allow_any_authenticated())],
|
||||||
|
|||||||
@ -235,5 +235,187 @@ class TestToolDefinition(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.new_event_loop().run_until_complete(coro)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteFindSimilarObjects(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||||
|
self.tmp.close()
|
||||||
|
self.db = SqliteExtDatabase(self.tmp.name)
|
||||||
|
Event.bind(self.db, bind_refs=False, bind_backrefs=False)
|
||||||
|
self.db.connect()
|
||||||
|
self.db.create_tables([Event])
|
||||||
|
|
||||||
|
# Insert an anchor plus two candidates.
|
||||||
|
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
|
||||||
|
Event.create(
|
||||||
|
id=event_id, label=label, sub_label=None, camera=camera,
|
||||||
|
start_time=start, end_time=start + 10,
|
||||||
|
top_score=0.9, score=0.9, false_positive=False,
|
||||||
|
zones=[], thumbnail="",
|
||||||
|
has_clip=True, has_snapshot=True,
|
||||||
|
region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1,
|
||||||
|
retain_indefinitely=False, ratio=1.0,
|
||||||
|
plus_id="", model_hash="", detector_type="",
|
||||||
|
model_type="", data={"description": "a green sedan"},
|
||||||
|
)
|
||||||
|
|
||||||
|
make("anchor", start=1_700_000_200)
|
||||||
|
make("cand_a", start=1_700_000_100)
|
||||||
|
make("cand_b", start=1_700_000_150)
|
||||||
|
self.make = make
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.db.close()
|
||||||
|
os.unlink(self.tmp.name)
|
||||||
|
|
||||||
|
def _make_request(self, semantic_enabled=True, embeddings=None):
|
||||||
|
app = SimpleNamespace(
|
||||||
|
embeddings=embeddings,
|
||||||
|
frigate_config=SimpleNamespace(
|
||||||
|
semantic_search=SimpleNamespace(enabled=semantic_enabled),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return SimpleNamespace(app=app)
|
||||||
|
|
||||||
|
def test_semantic_search_disabled_returns_error(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
req = self._make_request(semantic_enabled=False)
|
||||||
|
result = _run(_execute_find_similar_objects(
|
||||||
|
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
self.assertEqual(result["error"], "semantic_search_disabled")
|
||||||
|
|
||||||
|
def test_anchor_not_found_returns_error(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
embeddings = MagicMock()
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
result = _run(_execute_find_similar_objects(
|
||||||
|
req, {"event_id": "nope"}, allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
self.assertEqual(result["error"], "anchor_not_found")
|
||||||
|
|
||||||
|
def test_empty_candidates_returns_empty_results(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
embeddings = MagicMock()
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
# Filter to a camera with no other events.
|
||||||
|
result = _run(_execute_find_similar_objects(
|
||||||
|
req,
|
||||||
|
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
|
||||||
|
allowed_cameras=["nonexistent_cam"],
|
||||||
|
))
|
||||||
|
self.assertEqual(result["results"], [])
|
||||||
|
self.assertFalse(result["candidate_truncated"])
|
||||||
|
self.assertEqual(result["anchor"]["id"], "anchor")
|
||||||
|
|
||||||
|
def test_fused_calls_both_searches_and_ranks(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
embeddings = MagicMock()
|
||||||
|
# cand_a visually closer, cand_b semantically closer.
|
||||||
|
embeddings.search_thumbnail.return_value = [
|
||||||
|
("cand_a", 0.10), ("cand_b", 0.40),
|
||||||
|
]
|
||||||
|
embeddings.search_description.return_value = [
|
||||||
|
("cand_a", 0.50), ("cand_b", 0.20),
|
||||||
|
]
|
||||||
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
|
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
|
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
result = _run(_execute_find_similar_objects(
|
||||||
|
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
embeddings.search_thumbnail.assert_called_once()
|
||||||
|
embeddings.search_description.assert_called_once()
|
||||||
|
# cand_a should rank first because visual is weighted higher.
|
||||||
|
self.assertEqual(result["results"][0]["id"], "cand_a")
|
||||||
|
self.assertIn("score", result["results"][0])
|
||||||
|
self.assertEqual(result["similarity_mode"], "fused")
|
||||||
|
|
||||||
|
def test_visual_mode_only_calls_thumbnail(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
embeddings = MagicMock()
|
||||||
|
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
|
||||||
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
|
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
_run(_execute_find_similar_objects(
|
||||||
|
req,
|
||||||
|
{"event_id": "anchor", "similarity_mode": "visual"},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
embeddings.search_thumbnail.assert_called_once()
|
||||||
|
embeddings.search_description.assert_not_called()
|
||||||
|
|
||||||
|
def test_semantic_mode_only_calls_description(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
embeddings = MagicMock()
|
||||||
|
embeddings.search_description.return_value = [("cand_a", 0.1)]
|
||||||
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
|
embeddings.desc_stats._update([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
_run(_execute_find_similar_objects(
|
||||||
|
req,
|
||||||
|
{"event_id": "anchor", "similarity_mode": "semantic"},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
embeddings.search_description.assert_called_once()
|
||||||
|
embeddings.search_thumbnail.assert_not_called()
|
||||||
|
|
||||||
|
def test_min_score_drops_low_scoring_results(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
embeddings = MagicMock()
|
||||||
|
embeddings.search_thumbnail.return_value = [
|
||||||
|
("cand_a", 0.10), ("cand_b", 0.90),
|
||||||
|
]
|
||||||
|
embeddings.search_description.return_value = []
|
||||||
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
|
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
|
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
result = _run(_execute_find_similar_objects(
|
||||||
|
req,
|
||||||
|
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
ids = [r["id"] for r in result["results"]]
|
||||||
|
self.assertIn("cand_a", ids)
|
||||||
|
self.assertNotIn("cand_b", ids)
|
||||||
|
|
||||||
|
def test_labels_defaults_to_anchor_label(self):
|
||||||
|
from frigate.api.chat import _execute_find_similar_objects
|
||||||
|
self.make("person_a", label="person")
|
||||||
|
embeddings = MagicMock()
|
||||||
|
embeddings.search_thumbnail.return_value = [
|
||||||
|
("cand_a", 0.1), ("cand_b", 0.2),
|
||||||
|
]
|
||||||
|
embeddings.search_description.return_value = []
|
||||||
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
|
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
|
||||||
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
|
|
||||||
|
req = self._make_request(embeddings=embeddings)
|
||||||
|
result = _run(_execute_find_similar_objects(
|
||||||
|
req,
|
||||||
|
{"event_id": "anchor", "similarity_mode": "visual"},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
))
|
||||||
|
ids = [r["id"] for r in result["results"]]
|
||||||
|
self.assertNotIn("person_a", ids)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user