Add candidate query builder for find_similar_objects chat tool

This commit is contained in:
Josh Hawkins 2026-04-08 15:11:04 -05:00
parent bbd0a8943b
commit 8759d5210e
2 changed files with 194 additions and 0 deletions

View File

@ -4,8 +4,10 @@ import base64
import json import json
import logging import logging
import math import math
import operator
import time import time
from datetime import datetime from datetime import datetime
from functools import reduce
from typing import Any, Dict, Generator, List, Optional from typing import Any, Dict, Generator, List, Optional
import cv2 import cv2
@ -28,6 +30,7 @@ from frigate.api.defs.response.chat_response import (
from frigate.api.defs.tags import Tags from frigate.api.defs.tags import Tags
from frigate.api.event import events from frigate.api.event import events
from frigate.embeddings.util import ZScoreNormalization from frigate.embeddings.util import ZScoreNormalization
from frigate.models import Event
from frigate.genai.utils import build_assistant_message_for_conversation from frigate.genai.utils import build_assistant_message_for_conversation
from frigate.jobs.vlm_watch import ( from frigate.jobs.vlm_watch import (
get_vlm_watch_job, get_vlm_watch_job,
@ -146,6 +149,48 @@ def _fuse_scores(
return VISUAL_WEIGHT * visual_score + DESCRIPTION_WEIGHT * description_score return VISUAL_WEIGHT * visual_score + DESCRIPTION_WEIGHT * description_score
def _build_similar_candidates_query(
anchor_id: str,
after: Optional[float],
before: Optional[float],
cameras: Optional[List[str]],
labels: Optional[List[str]],
sub_labels: Optional[List[str]],
zones: Optional[List[str]],
) -> List[str]:
"""Return up to CANDIDATE_CAP event ids eligible as similarity candidates.
Pre-filters events by the structured fields, excludes the anchor itself,
and orders by most recent first so over-cap queries keep recent events.
"""
clauses = [Event.id != anchor_id]
if after is not None:
clauses.append(Event.start_time >= after)
if before is not None:
clauses.append(Event.start_time <= before)
if cameras:
clauses.append(Event.camera.in_(cameras))
if labels:
clauses.append(Event.label.in_(labels))
if sub_labels:
clauses.append(Event.sub_label.in_(sub_labels))
if zones:
# Mirror the pattern used by frigate/api/event.py for JSON-array zone match.
zone_clauses = [
Event.zones.cast("text") % f'*"{zone}"*' for zone in zones
]
clauses.append(reduce(operator.or_, zone_clauses))
query = (
Event.select(Event.id)
.where(reduce(operator.and_, clauses))
.order_by(Event.start_time.desc())
.limit(CANDIDATE_CAP)
)
return [row.id for row in query]
def get_tool_definitions() -> List[Dict[str, Any]]: def get_tool_definitions() -> List[Dict[str, Any]]:
""" """
Get OpenAI-compatible tool definitions for Frigate. Get OpenAI-compatible tool definitions for Frigate.

View File

@ -53,5 +53,154 @@ class TestFuseScores(unittest.TestCase):
self.assertIsNone(_fuse_scores(visual_score=None, description_score=None)) self.assertIsNone(_fuse_scores(visual_score=None, description_score=None))
import datetime
import os
import tempfile
from playhouse.sqlite_ext import SqliteExtDatabase
from frigate.api.chat import CANDIDATE_CAP, _build_similar_candidates_query
from frigate.models import Event
class TestBuildSimilarCandidatesQuery(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
self.tmp.close()
self.db = SqliteExtDatabase(self.tmp.name)
Event.bind(self.db, bind_refs=False, bind_backrefs=False)
self.db.connect()
self.db.create_tables([Event])
# Minimal helper for creating events.
def make_event(
event_id,
camera="driveway",
label="car",
sub_label=None,
start=1_700_000_000,
zones=None,
):
Event.create(
id=event_id,
label=label,
sub_label=sub_label,
camera=camera,
start_time=start,
end_time=start + 10,
top_score=0.9,
score=0.9,
false_positive=False,
zones=zones or [],
thumbnail="",
has_clip=True,
has_snapshot=True,
region=[0, 0, 1, 1],
box=[0, 0, 1, 1],
area=1,
retain_indefinitely=False,
ratio=1.0,
plus_id="",
model_hash="",
detector_type="",
model_type="",
data={},
)
self.make_event = make_event
def tearDown(self):
self.db.close()
os.unlink(self.tmp.name)
def test_excludes_anchor(self):
self.make_event("anchor")
self.make_event("other")
ids = _build_similar_candidates_query(
anchor_id="anchor",
after=None,
before=None,
cameras=None,
labels=["car"],
sub_labels=None,
zones=None,
)
self.assertEqual(ids, ["other"])
def test_time_range_filters(self):
self.make_event("in_range", start=1_700_000_500)
self.make_event("too_early", start=1_699_999_000)
self.make_event("too_late", start=1_700_001_000)
ids = _build_similar_candidates_query(
anchor_id="nonexistent",
after=1_700_000_000,
before=1_700_000_999,
cameras=None,
labels=["car"],
sub_labels=None,
zones=None,
)
self.assertEqual(ids, ["in_range"])
def test_camera_filter(self):
self.make_event("driveway_a", camera="driveway")
self.make_event("porch_a", camera="porch")
ids = _build_similar_candidates_query(
anchor_id="nonexistent",
after=None,
before=None,
cameras=["driveway"],
labels=["car"],
sub_labels=None,
zones=None,
)
self.assertEqual(ids, ["driveway_a"])
def test_label_filter(self):
self.make_event("car_a", label="car")
self.make_event("person_a", label="person")
ids = _build_similar_candidates_query(
anchor_id="nonexistent",
after=None,
before=None,
cameras=None,
labels=["car"],
sub_labels=None,
zones=None,
)
self.assertEqual(ids, ["car_a"])
def test_zone_any_match(self):
self.make_event("in_zone", zones=["driveway_zone"])
self.make_event("other_zone", zones=["porch_zone"])
ids = _build_similar_candidates_query(
anchor_id="nonexistent",
after=None,
before=None,
cameras=None,
labels=["car"],
sub_labels=None,
zones=["driveway_zone"],
)
self.assertEqual(ids, ["in_zone"])
def test_respects_candidate_cap(self):
for i in range(CANDIDATE_CAP + 20):
self.make_event(f"e{i:04d}", start=1_700_000_000 + i)
ids = _build_similar_candidates_query(
anchor_id="nonexistent",
after=None,
before=None,
cameras=None,
labels=["car"],
sub_labels=None,
zones=None,
)
self.assertEqual(len(ids), CANDIDATE_CAP)
# Most recent first, so we should keep the latest CANDIDATE_CAP events.
self.assertIn(f"e{CANDIDATE_CAP + 19:04d}", ids)
self.assertNotIn("e0000", ids)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()