diff --git a/frigate/api/event.py b/frigate/api/event.py index a22502849..0815e7c66 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -11,7 +11,7 @@ import cv2 from fastapi import APIRouter, Request from fastapi.params import Depends from fastapi.responses import JSONResponse -from peewee import JOIN, DoesNotExist, fn, operator +from peewee import JOIN, DoesNotExist, Window, fn, operator from playhouse.shortcuts import model_to_dict from frigate.api.defs.events_body import ( @@ -259,7 +259,7 @@ def events(params: EventsQueryParams = Depends()): @router.get("/events/explore") def events_explore(limit: int = 10): - subquery = Event.select( + ranked_events = Event.select( Event.id, Event.camera, Event.label, @@ -275,38 +275,37 @@ def events_explore(limit: int = 10): Event.false_positive, Event.box, Event.data, - fn.rank() + fn.COUNT(Event.id).over(partition_by=[Event.label]).alias("event_count"), + Window.row_number() .over(partition_by=[Event.label], order_by=[Event.start_time.desc()]) .alias("rank"), - fn.COUNT(Event.id).over(partition_by=[Event.label]).alias("event_count"), - ).alias("subquery") + ).alias("ranked_events") query = ( - Event.select( - subquery.c.id, - subquery.c.camera, - subquery.c.label, - subquery.c.zones, - subquery.c.start_time, - subquery.c.end_time, - subquery.c.has_clip, - subquery.c.has_snapshot, - subquery.c.plus_id, - subquery.c.retain_indefinitely, - subquery.c.sub_label, - subquery.c.top_score, - subquery.c.false_positive, - subquery.c.box, - subquery.c.data, - subquery.c.event_count, + ranked_events.select( + ranked_events.c.id, + ranked_events.c.camera, + ranked_events.c.label, + ranked_events.c.zones, + ranked_events.c.start_time, + ranked_events.c.end_time, + ranked_events.c.has_clip, + ranked_events.c.has_snapshot, + ranked_events.c.plus_id, + ranked_events.c.retain_indefinitely, + ranked_events.c.sub_label, + ranked_events.c.top_score, + ranked_events.c.false_positive, + ranked_events.c.box, + ranked_events.c.data, + ranked_events.c.event_count, ) - .from_(subquery) - .where(subquery.c.rank <= limit) - .order_by(subquery.c.event_count.desc(), subquery.c.start_time.desc()) + .where(ranked_events.c.rank <= limit) + .order_by(ranked_events.c.event_count.desc(), ranked_events.c.start_time.desc()) .dicts() ) - events = list(query.iterator()) + events = list(query) processed_events = [ {k: v for k, v in event.items() if k != "data"} @@ -406,16 +405,12 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) event_filters = [] if cameras != "all": - camera_list = cameras.split(",") - event_filters.append((Event.camera << camera_list)) + event_filters.append((Event.camera << cameras.split(","))) if labels != "all": - label_list = labels.split(",") - event_filters.append((Event.label << label_list)) + event_filters.append((Event.label << labels.split(","))) if zones != "all": - # use matching so events with multiple zones - # still match on a search where any zone matches zone_clauses = [] filtered_zones = zones.split(",") @@ -426,8 +421,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) for zone in filtered_zones: zone_clauses.append((Event.zones.cast("text") % f'*"{zone}"*')) - zone_clause = reduce(operator.or_, zone_clauses) - event_filters.append((zone_clause)) + event_filters.append((reduce(operator.or_, zone_clauses))) if after: event_filters.append((Event.start_time > after)) @@ -436,13 +430,11 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) event_filters.append((Event.start_time < before)) if time_range != DEFAULT_TIME_RANGE: - # get timezone arg to ensure browser times are used tz_name = params.timezone hour_modifier, minute_modifier, _ = get_tz_modifiers(tz_name) times = time_range.split(",") - time_after = times[0] - time_before = times[1] + time_after, time_before = times start_hour_fun = fn.strftime( "%H:%M", @@ -465,23 +457,8 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) event_filters.append((start_hour_fun > time_after)) event_filters.append((start_hour_fun < time_before)) - if event_filters: - filtered_event_ids = ( - Event.select(Event.id) - .where(reduce(operator.and_, event_filters)) - .tuples() - .iterator() - ) - event_ids = [event_id[0] for event_id in filtered_event_ids] - - if not event_ids: - return JSONResponse(content=[]) # No events to search on - else: - event_ids = [] - - thumb_ids = {} - desc_ids = {} - + # Perform semantic search + search_results = {} if search_type == "similarity": try: search_event: Event = Event.get(Event.id == event_id) @@ -494,23 +471,22 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) status_code=404, ) - # Get thumbnail results for the specific event - thumb_result = context.embeddings.search_thumbnail( - search_event, event_ids, limit - ) - + thumb_result = context.embeddings.search_thumbnail(search_event) thumb_ids = dict( zip( [result[0] for result in thumb_result], context.thumb_stats.normalize([result[1] for result in thumb_result]), ) ) + search_results = { + event_id: {"distance": distance, "source": "thumbnail"} + for event_id, distance in thumb_ids.items() + } else: search_types = search_type.split(",") if "thumbnail" in search_types: - thumb_result = context.embeddings.search_thumbnail(query, event_ids, limit) - + thumb_result = context.embeddings.search_thumbnail(query) thumb_ids = dict( zip( [result[0] for result in thumb_result], @@ -519,40 +495,35 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) ), ) ) + search_results.update( + { + event_id: {"distance": distance, "source": "thumbnail"} + for event_id, distance in thumb_ids.items() + } + ) if "description" in search_types: - desc_result = context.embeddings.search_description(query, event_ids, limit) - + desc_result = context.embeddings.search_description(query) desc_ids = dict( zip( [result[0] for result in desc_result], context.desc_stats.normalize([result[1] for result in desc_result]), ) ) + for event_id, distance in desc_ids.items(): + if ( + event_id not in search_results + or distance < search_results[event_id]["distance"] + ): + search_results[event_id] = { + "distance": distance, + "source": "description", + } - results = {} - for event_id in thumb_ids.keys() | desc_ids.keys(): - thumb_distance = thumb_ids.get(event_id) - desc_distance = desc_ids.get(event_id) - - # Select the minimum distance from the available results - if thumb_distance is not None and ( - desc_distance is None or thumb_distance < desc_distance - ): - results[event_id] = { - "distance": thumb_distance, - "source": "thumbnail", - } - elif desc_distance is not None: - results[event_id] = { - "distance": desc_distance, - "source": "description", - } - - if not results: + if not search_results: return JSONResponse(content=[]) - # Get the event data + # Fetch events in a single query events = ( Event.select(*selected_columns) .join( @@ -560,11 +531,14 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) JOIN.LEFT_OUTER, on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)), ) - .where(Event.id << list(results.keys())) + .where( + (Event.id << list(search_results.keys())) + & reduce(operator.and_, event_filters) + if event_filters + else True + ) .dicts() - .iterator() ) - events = list(events) events = [ {k: v for k, v in event.items() if k != "data"} @@ -576,8 +550,8 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) } } | { - "search_distance": results[event["id"]]["distance"], - "search_source": results[event["id"]]["source"], + "search_distance": search_results[event["id"]]["distance"], + "search_source": search_results[event["id"]]["source"], } for event in events ]