mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-09 08:37:37 +03:00
formatting
This commit is contained in:
parent
30997c20d7
commit
635e41fc8c
@ -30,13 +30,13 @@ from frigate.api.defs.response.chat_response import (
|
|||||||
from frigate.api.defs.tags import Tags
|
from frigate.api.defs.tags import Tags
|
||||||
from frigate.api.event import events
|
from frigate.api.event import events
|
||||||
from frigate.embeddings.util import ZScoreNormalization
|
from frigate.embeddings.util import ZScoreNormalization
|
||||||
from frigate.models import Event
|
|
||||||
from frigate.genai.utils import build_assistant_message_for_conversation
|
from frigate.genai.utils import build_assistant_message_for_conversation
|
||||||
from frigate.jobs.vlm_watch import (
|
from frigate.jobs.vlm_watch import (
|
||||||
get_vlm_watch_job,
|
get_vlm_watch_job,
|
||||||
start_vlm_watch_job,
|
start_vlm_watch_job,
|
||||||
stop_vlm_watch_job,
|
stop_vlm_watch_job,
|
||||||
)
|
)
|
||||||
|
from frigate.models import Event
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,30 @@
|
|||||||
"""Tests for the find_similar_objects chat tool."""
|
"""Tests for the find_similar_objects chat tool."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
|
|
||||||
from frigate.api.chat import (
|
from frigate.api.chat import (
|
||||||
|
CANDIDATE_CAP,
|
||||||
DESCRIPTION_WEIGHT,
|
DESCRIPTION_WEIGHT,
|
||||||
VISUAL_WEIGHT,
|
VISUAL_WEIGHT,
|
||||||
|
_build_similar_candidates_query,
|
||||||
_distance_to_score,
|
_distance_to_score,
|
||||||
|
_execute_find_similar_objects,
|
||||||
_fuse_scores,
|
_fuse_scores,
|
||||||
|
get_tool_definitions,
|
||||||
)
|
)
|
||||||
from frigate.embeddings.util import ZScoreNormalization
|
from frigate.embeddings.util import ZScoreNormalization
|
||||||
|
from frigate.models import Event
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.new_event_loop().run_until_complete(coro)
|
||||||
|
|
||||||
|
|
||||||
class TestDistanceToScore(unittest.TestCase):
|
class TestDistanceToScore(unittest.TestCase):
|
||||||
@ -53,16 +68,6 @@ class TestFuseScores(unittest.TestCase):
|
|||||||
self.assertIsNone(_fuse_scores(visual_score=None, description_score=None))
|
self.assertIsNone(_fuse_scores(visual_score=None, description_score=None))
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
|
||||||
|
|
||||||
from frigate.api.chat import CANDIDATE_CAP, _build_similar_candidates_query
|
|
||||||
from frigate.models import Event
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuildSimilarCandidatesQuery(unittest.TestCase):
|
class TestBuildSimilarCandidatesQuery(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||||
@ -202,9 +207,6 @@ class TestBuildSimilarCandidatesQuery(unittest.TestCase):
|
|||||||
self.assertNotIn("e0000", ids)
|
self.assertNotIn("e0000", ids)
|
||||||
|
|
||||||
|
|
||||||
from frigate.api.chat import get_tool_definitions
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolDefinition(unittest.TestCase):
|
class TestToolDefinition(unittest.TestCase):
|
||||||
def test_find_similar_objects_is_registered(self):
|
def test_find_similar_objects_is_registered(self):
|
||||||
tools = get_tool_definitions()
|
tools = get_tool_definitions()
|
||||||
@ -213,9 +215,7 @@ class TestToolDefinition(unittest.TestCase):
|
|||||||
|
|
||||||
def test_find_similar_objects_schema(self):
|
def test_find_similar_objects_schema(self):
|
||||||
tools = get_tool_definitions()
|
tools = get_tool_definitions()
|
||||||
tool = next(
|
tool = next(t for t in tools if t["function"]["name"] == "find_similar_objects")
|
||||||
t for t in tools if t["function"]["name"] == "find_similar_objects"
|
|
||||||
)
|
|
||||||
params = tool["function"]["parameters"]["properties"]
|
params = tool["function"]["parameters"]["properties"]
|
||||||
self.assertIn("event_id", params)
|
self.assertIn("event_id", params)
|
||||||
self.assertIn("after", params)
|
self.assertIn("after", params)
|
||||||
@ -227,23 +227,12 @@ class TestToolDefinition(unittest.TestCase):
|
|||||||
self.assertIn("similarity_mode", params)
|
self.assertIn("similarity_mode", params)
|
||||||
self.assertIn("min_score", params)
|
self.assertIn("min_score", params)
|
||||||
self.assertIn("limit", params)
|
self.assertIn("limit", params)
|
||||||
self.assertEqual(
|
self.assertEqual(tool["function"]["parameters"]["required"], ["event_id"])
|
||||||
tool["function"]["parameters"]["required"], ["event_id"]
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
params["similarity_mode"]["enum"], ["visual", "semantic", "fused"]
|
params["similarity_mode"]["enum"], ["visual", "semantic", "fused"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
|
|
||||||
def _run(coro):
|
|
||||||
return asyncio.new_event_loop().run_until_complete(coro)
|
|
||||||
|
|
||||||
|
|
||||||
class TestExecuteFindSimilarObjects(unittest.TestCase):
|
class TestExecuteFindSimilarObjects(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||||
@ -256,15 +245,29 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
|
|||||||
# Insert an anchor plus two candidates.
|
# Insert an anchor plus two candidates.
|
||||||
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
|
def make(event_id, label="car", camera="driveway", start=1_700_000_100):
|
||||||
Event.create(
|
Event.create(
|
||||||
id=event_id, label=label, sub_label=None, camera=camera,
|
id=event_id,
|
||||||
start_time=start, end_time=start + 10,
|
label=label,
|
||||||
top_score=0.9, score=0.9, false_positive=False,
|
sub_label=None,
|
||||||
zones=[], thumbnail="",
|
camera=camera,
|
||||||
has_clip=True, has_snapshot=True,
|
start_time=start,
|
||||||
region=[0, 0, 1, 1], box=[0, 0, 1, 1], area=1,
|
end_time=start + 10,
|
||||||
retain_indefinitely=False, ratio=1.0,
|
top_score=0.9,
|
||||||
plus_id="", model_hash="", detector_type="",
|
score=0.9,
|
||||||
model_type="", data={"description": "a green sedan"},
|
false_positive=False,
|
||||||
|
zones=[],
|
||||||
|
thumbnail="",
|
||||||
|
has_clip=True,
|
||||||
|
has_snapshot=True,
|
||||||
|
region=[0, 0, 1, 1],
|
||||||
|
box=[0, 0, 1, 1],
|
||||||
|
area=1,
|
||||||
|
retain_indefinitely=False,
|
||||||
|
ratio=1.0,
|
||||||
|
plus_id="",
|
||||||
|
model_hash="",
|
||||||
|
detector_type="",
|
||||||
|
model_type="",
|
||||||
|
data={"description": "a green sedan"},
|
||||||
)
|
)
|
||||||
|
|
||||||
make("anchor", start=1_700_000_200)
|
make("anchor", start=1_700_000_200)
|
||||||
@ -286,45 +289,53 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
|
|||||||
return SimpleNamespace(app=app)
|
return SimpleNamespace(app=app)
|
||||||
|
|
||||||
def test_semantic_search_disabled_returns_error(self):
|
def test_semantic_search_disabled_returns_error(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
req = self._make_request(semantic_enabled=False)
|
req = self._make_request(semantic_enabled=False)
|
||||||
result = _run(_execute_find_similar_objects(
|
result = _run(
|
||||||
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
|
_execute_find_similar_objects(
|
||||||
))
|
req,
|
||||||
|
{"event_id": "anchor"},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
)
|
||||||
|
)
|
||||||
self.assertEqual(result["error"], "semantic_search_disabled")
|
self.assertEqual(result["error"], "semantic_search_disabled")
|
||||||
|
|
||||||
def test_anchor_not_found_returns_error(self):
|
def test_anchor_not_found_returns_error(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
result = _run(_execute_find_similar_objects(
|
result = _run(
|
||||||
req, {"event_id": "nope"}, allowed_cameras=["driveway"],
|
_execute_find_similar_objects(
|
||||||
))
|
req,
|
||||||
|
{"event_id": "nope"},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
)
|
||||||
|
)
|
||||||
self.assertEqual(result["error"], "anchor_not_found")
|
self.assertEqual(result["error"], "anchor_not_found")
|
||||||
|
|
||||||
def test_empty_candidates_returns_empty_results(self):
|
def test_empty_candidates_returns_empty_results(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
# Filter to a camera with no other events.
|
# Filter to a camera with no other events.
|
||||||
result = _run(_execute_find_similar_objects(
|
result = _run(
|
||||||
|
_execute_find_similar_objects(
|
||||||
req,
|
req,
|
||||||
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
|
{"event_id": "anchor", "cameras": ["nonexistent_cam"]},
|
||||||
allowed_cameras=["nonexistent_cam"],
|
allowed_cameras=["nonexistent_cam"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
self.assertEqual(result["results"], [])
|
self.assertEqual(result["results"], [])
|
||||||
self.assertFalse(result["candidate_truncated"])
|
self.assertFalse(result["candidate_truncated"])
|
||||||
self.assertEqual(result["anchor"]["id"], "anchor")
|
self.assertEqual(result["anchor"]["id"], "anchor")
|
||||||
|
|
||||||
def test_fused_calls_both_searches_and_ranks(self):
|
def test_fused_calls_both_searches_and_ranks(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
# cand_a visually closer, cand_b semantically closer.
|
# cand_a visually closer, cand_b semantically closer.
|
||||||
embeddings.search_thumbnail.return_value = [
|
embeddings.search_thumbnail.return_value = [
|
||||||
("cand_a", 0.10), ("cand_b", 0.40),
|
("cand_a", 0.10),
|
||||||
|
("cand_b", 0.40),
|
||||||
]
|
]
|
||||||
embeddings.search_description.return_value = [
|
embeddings.search_description.return_value = [
|
||||||
("cand_a", 0.50), ("cand_b", 0.20),
|
("cand_a", 0.50),
|
||||||
|
("cand_b", 0.20),
|
||||||
]
|
]
|
||||||
embeddings.thumb_stats = ZScoreNormalization()
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
embeddings.thumb_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
@ -332,9 +343,13 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
|
|||||||
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
embeddings.desc_stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||||
|
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
result = _run(_execute_find_similar_objects(
|
result = _run(
|
||||||
req, {"event_id": "anchor"}, allowed_cameras=["driveway"],
|
_execute_find_similar_objects(
|
||||||
))
|
req,
|
||||||
|
{"event_id": "anchor"},
|
||||||
|
allowed_cameras=["driveway"],
|
||||||
|
)
|
||||||
|
)
|
||||||
embeddings.search_thumbnail.assert_called_once()
|
embeddings.search_thumbnail.assert_called_once()
|
||||||
embeddings.search_description.assert_called_once()
|
embeddings.search_description.assert_called_once()
|
||||||
# cand_a should rank first because visual is weighted higher.
|
# cand_a should rank first because visual is weighted higher.
|
||||||
@ -343,42 +358,44 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
|
|||||||
self.assertEqual(result["similarity_mode"], "fused")
|
self.assertEqual(result["similarity_mode"], "fused")
|
||||||
|
|
||||||
def test_visual_mode_only_calls_thumbnail(self):
|
def test_visual_mode_only_calls_thumbnail(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
|
embeddings.search_thumbnail.return_value = [("cand_a", 0.1)]
|
||||||
embeddings.thumb_stats = ZScoreNormalization()
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
|
embeddings.thumb_stats._update([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
_run(_execute_find_similar_objects(
|
_run(
|
||||||
|
_execute_find_similar_objects(
|
||||||
req,
|
req,
|
||||||
{"event_id": "anchor", "similarity_mode": "visual"},
|
{"event_id": "anchor", "similarity_mode": "visual"},
|
||||||
allowed_cameras=["driveway"],
|
allowed_cameras=["driveway"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
embeddings.search_thumbnail.assert_called_once()
|
embeddings.search_thumbnail.assert_called_once()
|
||||||
embeddings.search_description.assert_not_called()
|
embeddings.search_description.assert_not_called()
|
||||||
|
|
||||||
def test_semantic_mode_only_calls_description(self):
|
def test_semantic_mode_only_calls_description(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
embeddings.search_description.return_value = [("cand_a", 0.1)]
|
embeddings.search_description.return_value = [("cand_a", 0.1)]
|
||||||
embeddings.desc_stats = ZScoreNormalization()
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
embeddings.desc_stats._update([0.1, 0.2, 0.3])
|
embeddings.desc_stats._update([0.1, 0.2, 0.3])
|
||||||
|
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
_run(_execute_find_similar_objects(
|
_run(
|
||||||
|
_execute_find_similar_objects(
|
||||||
req,
|
req,
|
||||||
{"event_id": "anchor", "similarity_mode": "semantic"},
|
{"event_id": "anchor", "similarity_mode": "semantic"},
|
||||||
allowed_cameras=["driveway"],
|
allowed_cameras=["driveway"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
embeddings.search_description.assert_called_once()
|
embeddings.search_description.assert_called_once()
|
||||||
embeddings.search_thumbnail.assert_not_called()
|
embeddings.search_thumbnail.assert_not_called()
|
||||||
|
|
||||||
def test_min_score_drops_low_scoring_results(self):
|
def test_min_score_drops_low_scoring_results(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
embeddings.search_thumbnail.return_value = [
|
embeddings.search_thumbnail.return_value = [
|
||||||
("cand_a", 0.10), ("cand_b", 0.90),
|
("cand_a", 0.10),
|
||||||
|
("cand_b", 0.90),
|
||||||
]
|
]
|
||||||
embeddings.search_description.return_value = []
|
embeddings.search_description.return_value = []
|
||||||
embeddings.thumb_stats = ZScoreNormalization()
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
@ -386,21 +403,23 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
|
|||||||
embeddings.desc_stats = ZScoreNormalization()
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
|
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
result = _run(_execute_find_similar_objects(
|
result = _run(
|
||||||
|
_execute_find_similar_objects(
|
||||||
req,
|
req,
|
||||||
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
|
{"event_id": "anchor", "similarity_mode": "visual", "min_score": 0.6},
|
||||||
allowed_cameras=["driveway"],
|
allowed_cameras=["driveway"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
ids = [r["id"] for r in result["results"]]
|
ids = [r["id"] for r in result["results"]]
|
||||||
self.assertIn("cand_a", ids)
|
self.assertIn("cand_a", ids)
|
||||||
self.assertNotIn("cand_b", ids)
|
self.assertNotIn("cand_b", ids)
|
||||||
|
|
||||||
def test_labels_defaults_to_anchor_label(self):
|
def test_labels_defaults_to_anchor_label(self):
|
||||||
from frigate.api.chat import _execute_find_similar_objects
|
|
||||||
self.make("person_a", label="person")
|
self.make("person_a", label="person")
|
||||||
embeddings = MagicMock()
|
embeddings = MagicMock()
|
||||||
embeddings.search_thumbnail.return_value = [
|
embeddings.search_thumbnail.return_value = [
|
||||||
("cand_a", 0.1), ("cand_b", 0.2),
|
("cand_a", 0.1),
|
||||||
|
("cand_b", 0.2),
|
||||||
]
|
]
|
||||||
embeddings.search_description.return_value = []
|
embeddings.search_description.return_value = []
|
||||||
embeddings.thumb_stats = ZScoreNormalization()
|
embeddings.thumb_stats = ZScoreNormalization()
|
||||||
@ -408,11 +427,13 @@ class TestExecuteFindSimilarObjects(unittest.TestCase):
|
|||||||
embeddings.desc_stats = ZScoreNormalization()
|
embeddings.desc_stats = ZScoreNormalization()
|
||||||
|
|
||||||
req = self._make_request(embeddings=embeddings)
|
req = self._make_request(embeddings=embeddings)
|
||||||
result = _run(_execute_find_similar_objects(
|
result = _run(
|
||||||
|
_execute_find_similar_objects(
|
||||||
req,
|
req,
|
||||||
{"event_id": "anchor", "similarity_mode": "visual"},
|
{"event_id": "anchor", "similarity_mode": "visual"},
|
||||||
allowed_cameras=["driveway"],
|
allowed_cameras=["driveway"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
ids = [r["id"] for r in result["results"]]
|
ids = [r["id"] for r in result["results"]]
|
||||||
self.assertNotIn("person_a", ids)
|
self.assertNotIn("person_a", ids)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user