mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-09 08:37:37 +03:00
implement _execute_find_similar_objects chat tool dispatcher
This commit is contained in:
parent
1ad09f00dc
commit
1a003e7de2
@ -597,6 +597,193 @@ async def _execute_search_objects(
|
||||
)
|
||||
|
||||
|
||||
def _parse_iso_to_timestamp(value: Optional[str]) -> Optional[float]:
|
||||
"""Parse an ISO-8601 string as server-local time -> unix timestamp.
|
||||
|
||||
Mirrors the parsing _execute_search_objects uses so both tools accept the
|
||||
same format from the LLM.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
s = value.replace("Z", "").strip()[:19]
|
||||
dt = datetime.strptime(s, "%Y-%m-%dT%H:%M:%S")
|
||||
return time.mktime(dt.timetuple())
|
||||
except (ValueError, AttributeError, TypeError):
|
||||
logger.warning("Invalid timestamp format: %s", value)
|
||||
return None
|
||||
|
||||
|
||||
def _hydrate_event(event: Event, score: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Convert an Event row into the dict shape returned by find_similar_objects."""
|
||||
data: Dict[str, Any] = {
|
||||
"id": event.id,
|
||||
"camera": event.camera,
|
||||
"label": event.label,
|
||||
"sub_label": event.sub_label,
|
||||
"start_time": event.start_time,
|
||||
"end_time": event.end_time,
|
||||
"zones": event.zones,
|
||||
}
|
||||
if score is not None:
|
||||
data["score"] = score
|
||||
return data
|
||||
|
||||
|
||||
async def _execute_find_similar_objects(
|
||||
request: Request,
|
||||
arguments: Dict[str, Any],
|
||||
allowed_cameras: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the find_similar_objects tool.
|
||||
|
||||
Returns a plain dict (not JSONResponse) so the chat loop can embed it
|
||||
directly in tool-result messages.
|
||||
"""
|
||||
# 1. Semantic search enabled?
|
||||
config = request.app.frigate_config
|
||||
if not getattr(config.semantic_search, "enabled", False):
|
||||
return {
|
||||
"error": "semantic_search_disabled",
|
||||
"message": (
|
||||
"Semantic search must be enabled to find similar objects. "
|
||||
"Enable it in the Frigate config under semantic_search."
|
||||
),
|
||||
}
|
||||
|
||||
context = request.app.embeddings
|
||||
if context is None:
|
||||
return {
|
||||
"error": "semantic_search_disabled",
|
||||
"message": "Embeddings context is not available.",
|
||||
}
|
||||
|
||||
# 2. Anchor lookup.
|
||||
event_id = arguments.get("event_id")
|
||||
if not event_id:
|
||||
return {"error": "missing_event_id", "message": "event_id is required."}
|
||||
|
||||
try:
|
||||
anchor = Event.get(Event.id == event_id)
|
||||
except Event.DoesNotExist:
|
||||
return {
|
||||
"error": "anchor_not_found",
|
||||
"message": f"Could not find event {event_id}.",
|
||||
}
|
||||
|
||||
# 3. Parse params.
|
||||
after = _parse_iso_to_timestamp(arguments.get("after"))
|
||||
before = _parse_iso_to_timestamp(arguments.get("before"))
|
||||
|
||||
cameras = arguments.get("cameras")
|
||||
if cameras:
|
||||
# Respect RBAC: intersect with the user's allowed cameras.
|
||||
cameras = [c for c in cameras if c in allowed_cameras]
|
||||
else:
|
||||
cameras = list(allowed_cameras) if allowed_cameras else None
|
||||
|
||||
labels = arguments.get("labels") or [anchor.label]
|
||||
sub_labels = arguments.get("sub_labels")
|
||||
zones = arguments.get("zones")
|
||||
|
||||
similarity_mode = arguments.get("similarity_mode", "fused")
|
||||
if similarity_mode not in ("visual", "semantic", "fused"):
|
||||
similarity_mode = "fused"
|
||||
|
||||
min_score = arguments.get("min_score")
|
||||
limit = int(arguments.get("limit", 10))
|
||||
limit = max(1, min(limit, 50))
|
||||
|
||||
# 4. Pre-filter candidates.
|
||||
candidate_ids = _build_similar_candidates_query(
|
||||
anchor_id=anchor.id,
|
||||
after=after,
|
||||
before=before,
|
||||
cameras=cameras,
|
||||
labels=labels,
|
||||
sub_labels=sub_labels,
|
||||
zones=zones,
|
||||
)
|
||||
candidate_truncated = len(candidate_ids) == CANDIDATE_CAP
|
||||
|
||||
if not candidate_ids:
|
||||
return {
|
||||
"anchor": _hydrate_event(anchor),
|
||||
"results": [],
|
||||
"similarity_mode": similarity_mode,
|
||||
"candidate_truncated": False,
|
||||
}
|
||||
|
||||
# 5. Run similarity searches.
|
||||
visual_distances: Dict[str, float] = {}
|
||||
description_distances: Dict[str, float] = {}
|
||||
|
||||
try:
|
||||
if similarity_mode in ("visual", "fused"):
|
||||
rows = context.search_thumbnail(anchor, event_ids=candidate_ids)
|
||||
visual_distances = {row[0]: row[1] for row in rows}
|
||||
|
||||
if similarity_mode in ("semantic", "fused"):
|
||||
query_text = (
|
||||
(anchor.data or {}).get("description")
|
||||
or anchor.sub_label
|
||||
or anchor.label
|
||||
)
|
||||
rows = context.search_description(query_text, event_ids=candidate_ids)
|
||||
description_distances = {row[0]: row[1] for row in rows}
|
||||
except Exception:
|
||||
logger.exception("Similarity search failed")
|
||||
return {
|
||||
"error": "similarity_search_failed",
|
||||
"message": "Failed to run similarity search.",
|
||||
}
|
||||
|
||||
# 6. Fuse and rank.
|
||||
scored: List[tuple[str, float]] = []
|
||||
matched_ids = set(visual_distances) | set(description_distances)
|
||||
for eid in matched_ids:
|
||||
v_score = (
|
||||
_distance_to_score(visual_distances[eid], context.thumb_stats)
|
||||
if eid in visual_distances
|
||||
else None
|
||||
)
|
||||
d_score = (
|
||||
_distance_to_score(description_distances[eid], context.desc_stats)
|
||||
if eid in description_distances
|
||||
else None
|
||||
)
|
||||
fused = _fuse_scores(v_score, d_score)
|
||||
if fused is None:
|
||||
continue
|
||||
if min_score is not None and fused < min_score:
|
||||
continue
|
||||
scored.append((eid, fused))
|
||||
|
||||
scored.sort(key=lambda pair: pair[1], reverse=True)
|
||||
scored = scored[:limit]
|
||||
|
||||
# 7. Hydrate.
|
||||
if scored:
|
||||
event_rows = {
|
||||
e.id: e
|
||||
for e in Event.select().where(Event.id.in_([eid for eid, _ in scored]))
|
||||
}
|
||||
results = [
|
||||
_hydrate_event(event_rows[eid], score=score)
|
||||
for eid, score in scored
|
||||
if eid in event_rows
|
||||
]
|
||||
else:
|
||||
results = []
|
||||
|
||||
return {
|
||||
"anchor": _hydrate_event(anchor),
|
||||
"results": results,
|
||||
"similarity_mode": similarity_mode,
|
||||
"candidate_truncated": candidate_truncated,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/execute",
|
||||
dependencies=[Depends(allow_any_authenticated())],
|
||||
|
||||
@ -235,5 +235,187 @@ class TestToolDefinition(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.new_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
class TestExecuteFindSimilarObjects(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
self.tmp.close()
|
||||
self.db = SqliteExtDatabase(self.tmp.name)
|
||||
Event.bind(self.db, bind_refs=False, bind_backrefs=False)
|
||||
self.db.connect()
|
||||
self.db.create_tables([Event])
|
||||
|
||||
# Insert an anchor plus two candidates.
|
||||
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
|
||||
Event.create(
|
||||
id=event_id, label=label, sub_label=None, camera=camera,
|
||||
start_time=start, end_time=start + 10,
|
||||
top_score=0.9, score=0.9, false_positive=False,
|
||||
zones=[], thumbnail="",
|
||||
has_clip=True, has_snapshot=True,
|
||||
region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1,
|
||||
retain_indefinitely=False, ratio=1.0,
|
||||
plus_id="", model_hash="", detector_type="",
|
||||
model_type="", data={"description": "a green sedan"},
|
||||
)
|
||||
|
||||
make("anchor", start=1_700_000_200)
|
||||
make("cand_a", start=1_700_000_100)
|
||||
make("cand_b", start=1_700_000_150)
|
||||
self.make = make
|
||||
|
||||
def tearDown(self):
|
||||
self.db.close()
|
||||
os.unlink(self.tmp.name)
|
||||
|
||||
def _make_request(self, semantic_enabled=True, embeddings=None):
|
||||
app = SimpleNamespace(
|
||||
embeddings=embeddings,
|
||||
frigate_config=SimpleNamespace(
|
||||
semantic_search=SimpleNamespace(enabled=semantic_enabled),
|
||||
),
|
||||
)
|
||||
return SimpleNamespace(app=app)
|
||||
|
||||
def test_semantic_search_disabled_returns_error(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
req = self._make_request(semantic_enabled=False)
|
||||
result = _run(_execute_find_similar_objects(
|
||||
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
|
||||
))
|
||||
self.assertEqual(result["error"], "semantic_search_disabled")
|
||||
|
||||
def test_anchor_not_found_returns_error(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
embeddings = MagicMock()
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
result = _run(_execute_find_similar_objects(
|
||||
req, {"event_id": "nope"}, allowed_cameras=["driveway"],
|
||||
))
|
||||
self.assertEqual(result["error"], "anchor_not_found")
|
||||
|
||||
def test_empty_candidates_returns_empty_results(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
embeddings = MagicMock()
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
# Filter to a camera with no other events.
|
||||
result = _run(_execute_find_similar_objects(
|
||||
req,
|
||||
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
|
||||
allowed_cameras=["nonexistent_cam"],
|
||||
))
|
||||
self.assertEqual(result["results"], [])
|
||||
self.assertFalse(result["candidate_truncated"])
|
||||
self.assertEqual(result["anchor"]["id"], "anchor")
|
||||
|
||||
def test_fused_calls_both_searches_and_ranks(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
embeddings = MagicMock()
|
||||
# cand_a visually closer, cand_b semantically closer.
|
||||
embeddings.search_thumbnail.return_value = [
|
||||
("cand_a", 0.10), ("cand_b", 0.40),
|
||||
]
|
||||
embeddings.search_description.return_value = [
|
||||
("cand_a", 0.50), ("cand_b", 0.20),
|
||||
]
|
||||
embeddings.thumb_stats = ZScoreNormalization()
|
||||
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
embeddings.desc_stats = ZScoreNormalization()
|
||||
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
result = _run(_execute_find_similar_objects(
|
||||
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
|
||||
))
|
||||
embeddings.search_thumbnail.assert_called_once()
|
||||
embeddings.search_description.assert_called_once()
|
||||
# cand_a should rank first because visual is weighted higher.
|
||||
self.assertEqual(result["results"][0]["id"], "cand_a")
|
||||
self.assertIn("score", result["results"][0])
|
||||
self.assertEqual(result["similarity_mode"], "fused")
|
||||
|
||||
def test_visual_mode_only_calls_thumbnail(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
embeddings = MagicMock()
|
||||
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
|
||||
embeddings.thumb_stats = ZScoreNormalization()
|
||||
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
|
||||
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
_run(_execute_find_similar_objects(
|
||||
req,
|
||||
{"event_id": "anchor", "similarity_mode": "visual"},
|
||||
allowed_cameras=["driveway"],
|
||||
))
|
||||
embeddings.search_thumbnail.assert_called_once()
|
||||
embeddings.search_description.assert_not_called()
|
||||
|
||||
def test_semantic_mode_only_calls_description(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
embeddings = MagicMock()
|
||||
embeddings.search_description.return_value = [("cand_a", 0.1)]
|
||||
embeddings.desc_stats = ZScoreNormalization()
|
||||
embeddings.desc_stats._update([0.1, 0.2, 0.3])
|
||||
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
_run(_execute_find_similar_objects(
|
||||
req,
|
||||
{"event_id": "anchor", "similarity_mode": "semantic"},
|
||||
allowed_cameras=["driveway"],
|
||||
))
|
||||
embeddings.search_description.assert_called_once()
|
||||
embeddings.search_thumbnail.assert_not_called()
|
||||
|
||||
def test_min_score_drops_low_scoring_results(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
embeddings = MagicMock()
|
||||
embeddings.search_thumbnail.return_value = [
|
||||
("cand_a", 0.10), ("cand_b", 0.90),
|
||||
]
|
||||
embeddings.search_description.return_value = []
|
||||
embeddings.thumb_stats = ZScoreNormalization()
|
||||
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
embeddings.desc_stats = ZScoreNormalization()
|
||||
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
result = _run(_execute_find_similar_objects(
|
||||
req,
|
||||
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
|
||||
allowed_cameras=["driveway"],
|
||||
))
|
||||
ids = [r["id"] for r in result["results"]]
|
||||
self.assertIn("cand_a", ids)
|
||||
self.assertNotIn("cand_b", ids)
|
||||
|
||||
def test_labels_defaults_to_anchor_label(self):
|
||||
from frigate.api.chat import _execute_find_similar_objects
|
||||
self.make("person_a", label="person")
|
||||
embeddings = MagicMock()
|
||||
embeddings.search_thumbnail.return_value = [
|
||||
("cand_a", 0.1), ("cand_b", 0.2),
|
||||
]
|
||||
embeddings.search_description.return_value = []
|
||||
embeddings.thumb_stats = ZScoreNormalization()
|
||||
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
|
||||
embeddings.desc_stats = ZScoreNormalization()
|
||||
|
||||
req = self._make_request(embeddings=embeddings)
|
||||
result = _run(_execute_find_similar_objects(
|
||||
req,
|
||||
{"event_id": "anchor", "similarity_mode": "visual"},
|
||||
allowed_cameras=["driveway"],
|
||||
))
|
||||
ids = [r["id"] for r in result["results"]]
|
||||
self.assertNotIn("person_a", ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user