formatting

This commit is contained in:
Josh Hawkins 2026-04-08 15:31:59 -05:00
parent 30997c20d7
commit 635e41fc8c
2 changed files with 105 additions and 84 deletions

View File

@ -30,13 +30,13 @@ from frigate.api.defs.response.chat_response import (
from frigate.api.defs.tags import Tags
from frigate.api.event import events
from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
from frigate.genai.utils import build_assistant_message_for_conversation
from frigate.jobs.vlm_watch import (
get_vlm_watch_job,
start_vlm_watch_job,
stop_vlm_watch_job,
)
from frigate.models import Event
logger = logging.getLogger(__name__)

View File

@ -1,15 +1,30 @@
"""Tests for the find_similar_objects chat tool."""
import asyncio
import os
import tempfile
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock
from playhouse.sqlite_ext import SqliteExtDatabase
from frigate.api.chat import (
CANDIDATE_CAP,
DESCRIPTION_WEIGHT,
VISUAL_WEIGHT,
_build_similar_candidates_query,
_distance_to_score,
_execute_find_similar_objects,
_fuse_scores,
get_tool_definitions,
)
from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)
class TestDistanceToScore(unittest.TestCase):
@ -53,16 +68,6 @@ class TestFuseScores(unittest.TestCase):
self.assertIsNone(_fuse_scores(visual_score=None, description_score=None))
import datetime
import os
import tempfile
from playhouse.sqlite_ext import SqliteExtDatabase
from frigate.api.chat import CANDIDATE_CAP, _build_similar_candidates_query
from frigate.models import Event
class TestBuildSimilarCandidatesQuery(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
@ -202,9 +207,6 @@ class TestBuildSimilarCandidatesQuery(unittest.TestCase):
self.assertNotIn("e0000", ids)
from frigate.api.chat import get_tool_definitions
class TestToolDefinition(unittest.TestCase):
def test_find_similar_objects_is_registered(self):
tools = get_tool_definitions()
@ -213,9 +215,7 @@ class TestToolDefinition(unittest.TestCase):
def test_find_similar_objects_schema(self):
tools = get_tool_definitions()
tool = next(
t for t in tools if t["function"]["name"] == "find_similar_objects"
)
tool = next(t for t in tools if t["function"]["name"] == "find_similar_objects")
params = tool["function"]["parameters"]["properties"]
self.assertIn("event_id", params)
self.assertIn("after", params)
@ -227,23 +227,12 @@ class TestToolDefinition(unittest.TestCase):
self.assertIn("similarity_mode", params)
self.assertIn("min_score", params)
self.assertIn("limit", params)
self.assertEqual(
tool["function"]["parameters"]["required"], ["event_id"]
)
self.assertEqual(tool["function"]["parameters"]["required"], ["event_id"])
self.assertEqual(
params["similarity_mode"]["enum"], ["visual", "semantic", "fused"]
)
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
def _run(coro):
return asyncio.new_event_loop().run_until_complete(coro)
class TestExecuteFindSimilarObjects(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
@ -256,15 +245,29 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
# Insert an anchor plus two candidates.
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
Event.create(
id=event_id, label=label, sub_label=None, camera=camera,
start_time=start, end_time=start + 10,
top_score=0.9, score=0.9, false_positive=False,
zones=[], thumbnail="",
has_clip=True, has_snapshot=True,
region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1,
retain_indefinitely=False, ratio=1.0,
plus_id="", model_hash="", detector_type="",
model_type="", data={"description": "a green sedan"},
id=event_id,
label=label,
sub_label=None,
camera=camera,
start_time=start,
end_time=start + 10,
top_score=0.9,
score=0.9,
false_positive=False,
zones=[],
thumbnail="",
has_clip=True,
has_snapshot=True,
region=[0, 0, 1, 1],
box=[0, 0, 1, 1],
area=1,
retain_indefinitely=False,
ratio=1.0,
plus_id="",
model_hash="",
detector_type="",
model_type="",
data={"description": "a green sedan"},
)
make("anchor", start=1_700_000_200)
@ -286,45 +289,53 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
return SimpleNamespace(app=app)
def test_semantic_search_disabled_returns_error(self):
from frigate.api.chat import _execute_find_similar_objects
req = self._make_request(semantic_enabled=False)
result = _run(_execute_find_similar_objects(
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
))
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor"},
allowed_cameras=["driveway"],
)
)
self.assertEqual(result["error"], "semantic_search_disabled")
def test_anchor_not_found_returns_error(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req, {"event_id": "nope"}, allowed_cameras=["driveway"],
))
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "nope"},
allowed_cameras=["driveway"],
)
)
self.assertEqual(result["error"], "anchor_not_found")
def test_empty_candidates_returns_empty_results(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
req = self._make_request(embeddings=embeddings)
# Filter to a camera with no other events.
result = _run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
allowed_cameras=["nonexistent_cam"],
))
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
allowed_cameras=["nonexistent_cam"],
)
)
self.assertEqual(result["results"], [])
self.assertFalse(result["candidate_truncated"])
self.assertEqual(result["anchor"]["id"], "anchor")
def test_fused_calls_both_searches_and_ranks(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
# cand_a visually closer, cand_b semantically closer.
embeddings.search_thumbnail.return_value = [
("cand_a", 0.10), ("cand_b", 0.40),
("cand_a", 0.10),
("cand_b", 0.40),
]
embeddings.search_description.return_value = [
("cand_a", 0.50), ("cand_b", 0.20),
("cand_a", 0.50),
("cand_b", 0.20),
]
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
@ -332,9 +343,13 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
))
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor"},
allowed_cameras=["driveway"],
)
)
embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_called_once()
# cand_a should rank first because visual is weighted higher.
@ -343,42 +358,44 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
self.assertEqual(result["similarity_mode"], "fused")
def test_visual_mode_only_calls_thumbnail(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
embeddings.thumb_stats = ZScoreNormalization()
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings)
_run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
))
_run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
)
)
embeddings.search_thumbnail.assert_called_once()
embeddings.search_description.assert_not_called()
def test_semantic_mode_only_calls_description(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
embeddings.search_description.return_value = [("cand_a", 0.1)]
embeddings.desc_stats = ZScoreNormalization()
embeddings.desc_stats._update([0.1, 0.2, 0.3])
req = self._make_request(embeddings=embeddings)
_run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "semantic"},
allowed_cameras=["driveway"],
))
_run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "semantic"},
allowed_cameras=["driveway"],
)
)
embeddings.search_description.assert_called_once()
embeddings.search_thumbnail.assert_not_called()
def test_min_score_drops_low_scoring_results(self):
from frigate.api.chat import _execute_find_similar_objects
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [
("cand_a", 0.10), ("cand_b", 0.90),
("cand_a", 0.10),
("cand_b", 0.90),
]
embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization()
@ -386,21 +403,23 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
allowed_cameras=["driveway"],
))
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
allowed_cameras=["driveway"],
)
)
ids = [r["id"] for r in result["results"]]
self.assertIn("cand_a", ids)
self.assertNotIn("cand_b", ids)
def test_labels_defaults_to_anchor_label(self):
from frigate.api.chat import _execute_find_similar_objects
self.make("person_a", label="person")
embeddings = MagicMock()
embeddings.search_thumbnail.return_value = [
("cand_a", 0.1), ("cand_b", 0.2),
("cand_a", 0.1),
("cand_b", 0.2),
]
embeddings.search_description.return_value = []
embeddings.thumb_stats = ZScoreNormalization()
@ -408,11 +427,13 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
embeddings.desc_stats = ZScoreNormalization()
req = self._make_request(embeddings=embeddings)
result = _run(_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
))
result = _run(
_execute_find_similar_objects(
req,
{"event_id": "anchor", "similarity_mode": "visual"},
allowed_cameras=["driveway"],
)
)
ids = [r["id"] for r in result["results"]]
self.assertNotIn("person_a", ids)