diff --git a/frigate/api/event.py b/frigate/api/event.py index f0f846b00..a22502849 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -479,42 +479,75 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) else: event_ids = [] - thumb_results = [] - desc_results = [] + thumb_ids = {} + desc_ids = {} if search_type == "similarity": try: search_event: Event = Event.get(Event.id == event_id) except DoesNotExist: return JSONResponse( - content=( - { - "success": False, - "message": "Event not found", - } - ), + content={ + "success": False, + "message": "Event not found", + }, status_code=404, ) - thumb_results = context.embeddings.search_thumbnail(search_event, limit) + + # Get thumbnail results for the specific event + thumb_result = context.embeddings.search_thumbnail( + search_event, event_ids, limit + ) + + thumb_ids = dict( + zip( + [result[0] for result in thumb_result], + context.thumb_stats.normalize([result[1] for result in thumb_result]), + ) + ) else: search_types = search_type.split(",") if "thumbnail" in search_types: - thumb_results = context.embeddings.search_thumbnail(query, limit) - logger.info(f"thumb results: {thumb_results}") + thumb_result = context.embeddings.search_thumbnail(query, event_ids, limit) + + thumb_ids = dict( + zip( + [result[0] for result in thumb_result], + context.thumb_stats.normalize( + [result[1] for result in thumb_result] + ), + ) + ) if "description" in search_types: - desc_results = context.embeddings.search_description(query, limit) + desc_result = context.embeddings.search_description(query, event_ids, limit) + + desc_ids = dict( + zip( + [result[0] for result in desc_result], + context.desc_stats.normalize([result[1] for result in desc_result]), + ) + ) results = {} - for result in thumb_results + desc_results: - event_id, distance = result[0], result[1] - if event_id in event_ids or not event_ids: - if event_id not in results or distance < results[event_id]["distance"]: - results[event_id] = { - "distance": distance, - "source": "thumbnail" if result in thumb_results else "description", - } + for event_id in thumb_ids.keys() | desc_ids.keys(): + thumb_distance = thumb_ids.get(event_id) + desc_distance = desc_ids.get(event_id) + + # Select the minimum distance from the available results + if thumb_distance is not None and ( + desc_distance is None or thumb_distance < desc_distance + ): + results[event_id] = { + "distance": thumb_distance, + "source": "thumbnail", + } + elif desc_distance is not None: + results[event_id] = { + "distance": desc_distance, + "source": "description", + } if not results: return JSONResponse(content=[])