From a8b5aebacbddce6a073f9e5ee7c9e1af780be428 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:03:18 -0500 Subject: [PATCH] Only apply normalization to multi modal searches --- frigate/api/event.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/frigate/api/event.py b/frigate/api/event.py index 3a8d003ad..b1c98bb19 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) ) thumb_result = context.search_thumbnail(search_event) - thumb_ids = dict( - zip( - [result[0] for result in thumb_result], - context.thumb_stats.normalize([result[1] for result in thumb_result]), - ) - ) + thumb_ids = {result[0]: result[1] for result in thumb_result} search_results = { event_id: {"distance": distance, "source": "thumbnail"} for event_id, distance in thumb_ids.items() @@ -486,15 +481,23 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) else: search_types = search_type.split(",") + # only normalize multi-modal searches + apply_normalization = ( + "thumbnail" in search_types and "description" in search_types + ) + if "thumbnail" in search_types: thumb_result = context.search_thumbnail(query) - thumb_ids = dict( - zip( - [result[0] for result in thumb_result], - context.thumb_stats.normalize( - [result[1] for result in thumb_result] - ), + + if apply_normalization: + thumb_distances = context.thumb_stats.normalize( + [result[1] for result in thumb_result] ) + else: + thumb_distances = [result[1] for result in thumb_result] + + thumb_ids = dict( + zip([result[0] for result in thumb_result], thumb_distances) ) search_results.update( { @@ -505,12 +508,17 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) if "description" in search_types: desc_result = context.search_description(query) - desc_ids = dict( - zip( - [result[0] for result in desc_result], - context.desc_stats.normalize([result[1] for result in desc_result]), + if apply_normalization: + desc_distances = context.desc_stats.normalize( + [result[1] for result in desc_result] ) - ) + else: + desc_distances = [ + result[1] for result in desc_result + ] # Use raw distances + + desc_ids = dict(zip([result[0] for result in desc_result], desc_distances)) + for event_id, distance in desc_ids.items(): if ( event_id not in search_results