formatting

This commit is contained in:
Josh Hawkins 2026-04-08 15:31:59 -05:00
parent 30997c20d7
commit 635e41fc8c
2 changed files with 105 additions and 84 deletions

View File

@ -30,13 +30,13 @@ from frigate.api.defs.response.chat_response import (
from frigate.api.defs.tags import Tags from frigate.api.defs.tags import Tags
from frigate.api.event import events from frigate.api.event import events
from frigate.embeddings.util import ZScoreNormalization from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
from frigate.genai.utils import build_assistant_message_for_conversation from frigate.genai.utils import build_assistant_message_for_conversation
from frigate.jobs.vlm_watch import ( from frigate.jobs.vlm_watch import (
get_vlm_watch_job, get_vlm_watch_job,
start_vlm_watch_job, start_vlm_watch_job,
stop_vlm_watch_job, stop_vlm_watch_job,
) )
from frigate.models import Event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,15 +1,30 @@
"""Tests for the find_similar_objects chat tool.""" """Tests for the find_similar_objects chat tool."""
import asyncio
import os
import tempfile
import unittest import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
from playhouse.sqlite_ext import SqliteExtDatabase
from frigate.api.chat import ( from frigate.api.chat import (
CANDIDATE_CAP,
DESCRIPTION_WEIGHT, DESCRIPTION_WEIGHT,
VISUAL_WEIGHT, VISUAL_WEIGHT,
_build_similar_candidates_query,
_distance_to_score, _distance_to_score,
_execute_find_similar_objects,
_fuse_scores, _fuse_scores,
get_tool_definitions,
) )
from frigate.embeddings.util import ZScoreNormalization from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)
class TestDistanceToScore(unittest.TestCase): class TestDistanceToScore(unittest.TestCase):
@ -53,16 +68,6 @@ class TestFuseScores(unittest.TestCase):
self.assertIsNone(_fuse_scores(visual_score=None, description_score=None)) self.assertIsNone(_fuse_scores(visual_score=None, description_score=None))
import datetime
import os
import tempfile
from playhouse.sqlite_ext import SqliteExtDatabase
from frigate.api.chat import CANDIDATE_CAP, _build_similar_candidates_query
from frigate.models import Event
class TestBuildSimilarCandidatesQuery(unittest.TestCase): class TestBuildSimilarCandidatesQuery(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
@ -202,9 +207,6 @@ class TestBuildSimilarCandidatesQuery(unittest.TestCase):
self.assertNotIn("e0000", ids) self.assertNotIn("e0000", ids)
from frigate.api.chat import get_tool_definitions
class TestToolDefinition(unittest.TestCase): class TestToolDefinition(unittest.TestCase):
def test_find_similar_objects_is_registered(self): def test_find_similar_objects_is_registered(self):
tools = get_tool_definitions() tools = get_tool_definitions()
@ -213,9 +215,7 @@ class TestToolDefinition(unittest.TestCase):
def test_find_similar_objects_schema(self): def test_find_similar_objects_schema(self):
tools = get_tool_definitions() tools = get_tool_definitions()
tool = next( tool = next(t for t in tools if t["function"]["name"] == "find_similar_objects")
t for t in tools if t["function"]["name"] == "find_similar_objects"
)
params = tool["function"]["parameters"]["properties"] params = tool["function"]["parameters"]["properties"]
self.assertIn("event_id", params) self.assertIn("event_id", params)
self.assertIn("after", params) self.assertIn("after", params)
@ -227,23 +227,12 @@ class TestToolDefinition(unittest.TestCase):
self.assertIn("similarity_mode", params) self.assertIn("similarity_mode", params)
self.assertIn("min_score", params) self.assertIn("min_score", params)
self.assertIn("limit", params) self.assertIn("limit", params)
self.assertEqual( self.assertEqual(tool["function"]["parameters"]["required"], ["event_id"])
tool["function"]["parameters"]["required"], ["event_id"]
)
self.assertEqual( self.assertEqual(
params["similarity_mode"]["enum"], ["visual", "semantic", "fused"] params["similarity_mode"]["enum"], ["visual", "semantic", "fused"]
) )
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)
class TestExecuteFindSimilarObjects(unittest.TestCase): class TestExecuteFindSimilarObjects(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
@ -256,15 +245,29 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
# Insert an anchor plus two candidates. # Insert an anchor plus two candidates.
def make(event_id, label="car", camera="driveway", start=1_700_000_100): def make(event_id, label="car", camera="driveway", start=1_700_000_100):
Event.create( Event.create(
id=event_id, label=label, sub_label=None, camera=camera, id=event_id,
start_time=start, end_time=start + 10, label=label,
top_score=0.9, score=0.9, false_positive=False, sub_label=None,
zones=[], thumbnail="", camera=camera,
has_clip=True, has_snapshot=True, start_time=start,
region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1, end_time=start + 10,
retain_indefinitely=False, ratio=1.0, top_score=0.9,
plus_id="", model_hash="", detector_type="", score=0.9,
model_type="", data={"description": "a green sedan"}, false_positive=False,
zones=[],
thumbnail="",
has_clip=True,
has_snapshot=True,
region=[0, 0, 1, 1],
box=[0, 0, 1, 1],
area=1,
retain_indefinitely=False,
ratio=1.0,
plus_id="",
model_hash="",
detector_type="",
model_type="",
data={"description": "a green sedan"},
) )
make("anchor", start=1_700_000_200) make("anchor", start=1_700_000_200)
@ -286,45 +289,53 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
return SimpleNamespace(app=app) return SimpleNamespace(app=app)
def test_semantic_search_disabled_returns_error(self): def test_semantic_search_disabled_returns_error(self):
from frigate.api.chat import _execute_find_similar_objects
req = self._make_request(semantic_enabled=False) req = self._make_request(semantic_enabled=False)
result = _run(_execute_find_similar_objects( result = _run(
req, {"event_id": "anchor"}, allowed_cameras=["driveway"], _execute_find_similar_objects(
)) req,
{"event_id": "anchor"},
allowed_cameras=["driveway"],
)
)
self.assertEqual(result["error"], "semantic_search_disabled") self.assertEqual(result["error"], "semantic_search_disabled")
def test_anchor_not_found_returns_error(self): def test_anchor_not_found_returns_error(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock() embeddings = MagicMock()
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects( result = _run(
req, {"event_id": "nope"}, allowed_cameras=["driveway"], _execute_find_similar_objects(
)) req,
{"event_id": "nope"},
allowed_cameras=["driveway"],
)
)
self.assertEqual(result["error"], "anchor_not_found") self.assertEqual(result["error"], "anchor_not_found")
def test_empty_candidates_returns_empty_results(self): def test_empty_candidates_returns_empty_results(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock() embeddings = MagicMock()
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
# Filter to a camera with no other events. # Filter to a camera with no other events.
result = _run(_execute_find_similar_objects( result = _run(
req, _execute_find_similar_objects(
{"event_id": "anchor", "cameras": ["nonexistent_cam"]}, req,
allowed_cameras=["nonexistent_cam"], {"event_id": "anchor", "cameras": ["nonexistent_cam"]},
)) allowed_cameras=["nonexistent_cam"],
)
)
self.assertEqual(result["results"], []) self.assertEqual(result["results"], [])
self.assertFalse(result["candidate_truncated"]) self.assertFalse(result["candidate_truncated"])
self.assertEqual(result["anchor"]["id"], "anchor") self.assertEqual(result["anchor"]["id"], "anchor")
def test_fused_calls_both_searches_and_ranks(self): def test_fused_calls_both_searches_and_ranks(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock() embeddings = MagicMock()
# cand_a visually closer, cand_b semantically closer. # cand_a visually closer, cand_b semantically closer.
embeddings.search_thumbnail.return_value = [ embeddings.search_thumbnail.return_value = [
("cand_a", 0.10), ("cand_b", 0.40), ("cand_a", 0.10),
("cand_b", 0.40),
] ]
embeddings.search_description.return_value = [ embeddings.search_description.return_value = [
("cand_a", 0.50), ("cand_b", 0.20), ("cand_a", 0.50),
("cand_b", 0.20),
] ]
embeddings.thumb_stats = ZScoreNormalization() embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5]) embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
@ -332,9 +343,13 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5]) embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects( result = _run(
req, {"event_id": "anchor"}, allowed_cameras=["driveway"], _execute_find_similar_objects(
)) req,
{"event_id": "anchor"},
allowed_cameras=["driveway"],
)
)
embeddings.search_thumbnail.assert_called_once() embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_called_once() embeddings.search_description.assert_called_once()
# cand_a should rank first because visual is weighted higher. # cand_a should rank first because visual is weighted higher.
@ -343,42 +358,44 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
self.assertEqual(result["similarity_mode"], "fused") self.assertEqual(result["similarity_mode"], "fused")
def test_visual_mode_only_calls_thumbnail(self): def test_visual_mode_only_calls_thumbnail(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock() embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)] embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
embeddings.thumb_stats = ZScoreNormalization() embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3]) embeddings.thumb_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
_run(_execute_find_similar_objects( _run(
req, _execute_find_similar_objects(
{"event_id": "anchor", "similarity_mode": "visual"}, req,
allowed_cameras=["driveway"], {"event_id": "anchor", "similarity_mode": "visual"},
)) allowed_cameras=["driveway"],
)
)
embeddings.search_thumbnail.assert_called_once() embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_not_called() embeddings.search_description.assert_not_called()
def test_semantic_mode_only_calls_description(self): def test_semantic_mode_only_calls_description(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock() embeddings = MagicMock()
embeddings.search_description.return_value = [("cand_a", 0.1)] embeddings.search_description.return_value = [("cand_a", 0.1)]
embeddings.desc_stats = ZScoreNormalization() embeddings.desc_stats = ZScoreNormalization()
embeddings.desc_stats._update([0.1, 0.2, 0.3]) embeddings.desc_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
_run(_execute_find_similar_objects( _run(
req, _execute_find_similar_objects(
{"event_id": "anchor", "similarity_mode": "semantic"}, req,
allowed_cameras=["driveway"], {"event_id": "anchor", "similarity_mode": "semantic"},
)) allowed_cameras=["driveway"],
)
)
embeddings.search_description.assert_called_once() embeddings.search_description.assert_called_once()
embeddings.search_thumbnail.assert_not_called() embeddings.search_thumbnail.assert_not_called()
def test_min_score_drops_low_scoring_results(self): def test_min_score_drops_low_scoring_results(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock() embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [ embeddings.search_thumbnail.return_value = [
("cand_a", 0.10), ("cand_b", 0.90), ("cand_a", 0.10),
("cand_b", 0.90),
] ]
embeddings.search_description.return_value = [] embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization() embeddings.thumb_stats = ZScoreNormalization()
@ -386,21 +403,23 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
embeddings.desc_stats = ZScoreNormalization() embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects( result = _run(
req, _execute_find_similar_objects(
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6}, req,
allowed_cameras=["driveway"], {"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
)) allowed_cameras=["driveway"],
)
)
ids = [r["id"] for r in result["results"]] ids = [r["id"] for r in result["results"]]
self.assertIn("cand_a", ids) self.assertIn("cand_a", ids)
self.assertNotIn("cand_b", ids) self.assertNotIn("cand_b", ids)
def test_labels_defaults_to_anchor_label(self): def test_labels_defaults_to_anchor_label(self):
from frigate.api.chat import _execute_find_similar_objects
self.make("person_a", label="person") self.make("person_a", label="person")
embeddings = MagicMock() embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [ embeddings.search_thumbnail.return_value = [
("cand_a", 0.1), ("cand_b", 0.2), ("cand_a", 0.1),
("cand_b", 0.2),
] ]
embeddings.search_description.return_value = [] embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization() embeddings.thumb_stats = ZScoreNormalization()
@ -408,11 +427,13 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
embeddings.desc_stats = ZScoreNormalization() embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings) req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects( result = _run(
req, _execute_find_similar_objects(
{"event_id": "anchor", "similarity_mode": "visual"}, req,
allowed_cameras=["driveway"], {"event_id": "anchor", "similarity_mode": "visual"},
)) allowed_cameras=["driveway"],
)
)
ids = [r["id"] for r in result["results"]] ids = [r["id"] for r in result["results"]]
self.assertNotIn("person_a", ids) self.assertNotIn("person_a", ids)