mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-09 08:37:37 +03:00
Add score fusion helpers for find_similar_objects chat tool
This commit is contained in:
parent
5d2a725428
commit
bbd0a8943b
@ -3,6 +3,7 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
@ -26,6 +27,7 @@ from frigate.api.defs.response.chat_response import (
|
||||
)
|
||||
from frigate.api.defs.tags import Tags
|
||||
from frigate.api.event import events
|
||||
from frigate.embeddings.util import ZScoreNormalization
|
||||
from frigate.genai.utils import build_assistant_message_for_conversation
|
||||
from frigate.jobs.vlm_watch import (
|
||||
get_vlm_watch_job,
|
||||
@ -98,6 +100,52 @@ class VLMMonitorRequest(BaseModel):
|
||||
zones: List[str] = []
|
||||
|
||||
|
||||
# Similarity fusion weights for find_similar_objects.
|
||||
# Visual dominates because the feature's primary use case is "same specific object."
|
||||
# If these change, update the test in test_chat_find_similar_objects.py.
|
||||
VISUAL_WEIGHT = 0.65
|
||||
DESCRIPTION_WEIGHT = 0.35
|
||||
|
||||
# Must match or stay <= the k used by EmbeddingsContext vec searches
|
||||
# (see frigate/embeddings/__init__.py search_thumbnail/search_description).
|
||||
# Pre-filtering a larger pool is wasted work — vec will only rank top-k anyway.
|
||||
CANDIDATE_CAP = 100
|
||||
|
||||
|
||||
def _distance_to_score(distance: float, stats: ZScoreNormalization) -> float:
|
||||
"""Convert a cosine distance to a [0, 1] similarity score.
|
||||
|
||||
Uses the existing ZScoreNormalization stats maintained by EmbeddingsContext
|
||||
to normalize across deployments, then a bounded sigmoid. Lower distance ->
|
||||
higher score. If stats are uninitialized (stddev == 0), returns a neutral
|
||||
0.5 so the fallback ordering by raw distance still dominates.
|
||||
"""
|
||||
if stats.stddev == 0:
|
||||
return 0.5
|
||||
z = (distance - stats.mean) / stats.stddev
|
||||
# Sigmoid on -z so that small distance (good) -> high score.
|
||||
return 1.0 / (1.0 + math.exp(z))
|
||||
|
||||
|
||||
def _fuse_scores(
|
||||
visual_score: Optional[float],
|
||||
description_score: Optional[float],
|
||||
) -> Optional[float]:
|
||||
"""Weighted fusion of visual and description similarity scores.
|
||||
|
||||
If one side is missing (e.g., no description embedding for this event),
|
||||
the other side's score is returned alone with no penalty. If both are
|
||||
missing, returns None and the caller should drop the event.
|
||||
"""
|
||||
if visual_score is None and description_score is None:
|
||||
return None
|
||||
if visual_score is None:
|
||||
return description_score
|
||||
if description_score is None:
|
||||
return visual_score
|
||||
return VISUAL_WEIGHT * visual_score + DESCRIPTION_WEIGHT * description_score
|
||||
|
||||
|
||||
def get_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get OpenAI-compatible tool definitions for Frigate.
|
||||
|
||||
57
frigate/test/test_chat_find_similar_objects.py
Normal file
57
frigate/test/test_chat_find_similar_objects.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Tests for the find_similar_objects chat tool."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from frigate.api.chat import (
|
||||
DESCRIPTION_WEIGHT,
|
||||
VISUAL_WEIGHT,
|
||||
_distance_to_score,
|
||||
_fuse_scores,
|
||||
)
|
||||
from frigate.embeddings.util import ZScoreNormalization
|
||||
|
||||
|
||||
class TestDistanceToScore(unittest.TestCase):
|
||||
def test_lower_distance_gives_higher_score(self):
|
||||
stats = ZScoreNormalization()
|
||||
# Seed the stats with a small distribution so stddev > 0.
|
||||
stats._update([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
|
||||
close_score = _distance_to_score(0.1, stats)
|
||||
far_score = _distance_to_score(0.5, stats)
|
||||
|
||||
self.assertGreater(close_score, far_score)
|
||||
self.assertGreaterEqual(close_score, 0.0)
|
||||
self.assertLessEqual(close_score, 1.0)
|
||||
self.assertGreaterEqual(far_score, 0.0)
|
||||
self.assertLessEqual(far_score, 1.0)
|
||||
|
||||
def test_uninitialized_stats_returns_neutral_score(self):
|
||||
stats = ZScoreNormalization() # n == 0, stddev == 0
|
||||
self.assertEqual(_distance_to_score(0.3, stats), 0.5)
|
||||
|
||||
|
||||
class TestFuseScores(unittest.TestCase):
|
||||
def test_weights_sum_to_one(self):
|
||||
self.assertAlmostEqual(VISUAL_WEIGHT + DESCRIPTION_WEIGHT, 1.0)
|
||||
|
||||
def test_fuses_both_sides(self):
|
||||
fused = _fuse_scores(visual_score=0.8, description_score=0.4)
|
||||
expected = VISUAL_WEIGHT * 0.8 + DESCRIPTION_WEIGHT * 0.4
|
||||
self.assertAlmostEqual(fused, expected)
|
||||
|
||||
def test_missing_description_uses_visual_only(self):
|
||||
fused = _fuse_scores(visual_score=0.7, description_score=None)
|
||||
self.assertAlmostEqual(fused, 0.7)
|
||||
|
||||
def test_missing_visual_uses_description_only(self):
|
||||
fused = _fuse_scores(visual_score=None, description_score=0.6)
|
||||
self.assertAlmostEqual(fused, 0.6)
|
||||
|
||||
def test_both_missing_returns_none(self):
|
||||
self.assertIsNone(_fuse_scores(visual_score=None, description_score=None))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user