From 3e5420c4109cedd7544ccf7075bd411ff426c5b4 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:22:01 -0500 Subject: [PATCH] migrate api from chroma to sqlite_vec --- frigate/api/event.py | 85 +++++++++++--------------------------------- 1 file changed, 20 insertions(+), 65 deletions(-) diff --git a/frigate/api/event.py b/frigate/api/event.py index 3c861f901..9457d0148 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -1,8 +1,6 @@ """Event apis.""" -import base64 import datetime -import io import logging import os from functools import reduce @@ -10,12 +8,10 @@ from pathlib import Path from urllib.parse import unquote import cv2 -import numpy as np from fastapi import APIRouter, Request from fastapi.params import Depends from fastapi.responses import JSONResponse from peewee import JOIN, DoesNotExist, fn, operator -from PIL import Image from playhouse.shortcuts import model_to_dict from frigate.api.defs.events_body import ( @@ -39,7 +35,6 @@ from frigate.const import ( CLIPS_DIR, ) from frigate.embeddings import EmbeddingsContext -from frigate.embeddings.embeddings import get_metadata from frigate.models import Event, ReviewSegment, Timeline from frigate.object_processing import TrackedObject from frigate.util.builtin import get_tz_modifiers @@ -389,6 +384,8 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) context: EmbeddingsContext = request.app.embeddings + logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}") + selected_columns = [ Event.id, Event.camera, @@ -484,14 +481,10 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) else: event_ids = [] - # Build the Chroma where clause based on the event IDs - where = {"id": {"$in": event_ids}} if event_ids else {} - - thumb_ids = {} - desc_ids = {} + thumb_results = [] + desc_results = [] if search_type == "similarity": - # Grab the ids of events that match the thumbnail image embeddings try: search_event: Event = Event.get(Event.id == event_id) except DoesNotExist: @@ -504,62 +497,25 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) ), status_code=404, ) - thumbnail = base64.b64decode(search_event.thumbnail) - img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) - thumb_result = context.embeddings.thumbnail.query( - query_images=[img], - n_results=limit, - where=where, - ) - thumb_ids = dict( - zip( - thumb_result["ids"][0], - context.thumb_stats.normalize(thumb_result["distances"][0]), - ) - ) + thumb_results = context.embeddings.search_thumbnail(search_event.id, limit) else: search_types = search_type.split(",") if "thumbnail" in search_types: - thumb_result = context.embeddings.thumbnail.query( - query_texts=[query], - n_results=limit, - where=where, - ) - # Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM. - thumb_ids = dict( - zip( - thumb_result["ids"][0], - context.thumb_stats.normalize(thumb_result["distances"][0]), - ) - ) + thumb_results = context.embeddings.search_thumbnail(query, limit) if "description" in search_types: - desc_result = context.embeddings.description.query( - query_texts=[query], - n_results=limit, - where=where, - ) - desc_ids = dict( - zip( - desc_result["ids"][0], - context.desc_stats.normalize(desc_result["distances"][0]), - ) - ) + desc_results = context.embeddings.search_description(query, limit) results = {} - for event_id in thumb_ids.keys() | desc_ids: - min_distance = min( - i - for i in (thumb_ids.get(event_id), desc_ids.get(event_id)) - if i is not None - ) - results[event_id] = { - "distance": min_distance, - "source": "thumbnail" - if min_distance == thumb_ids.get(event_id) - else "description", - } + for result in thumb_results + desc_results: + event_id, distance = result[0], result[1] + if event_id in event_ids or not event_ids: + if event_id not in results or distance < results[event_id]["distance"]: + results[event_id] = { + "distance": distance, + "source": "thumbnail" if result in thumb_results else "description", + } if not results: return JSONResponse(content=[]) @@ -975,10 +931,9 @@ def set_description( # If semantic search is enabled, update the index if request.app.frigate_config.semantic_search.enabled: context: EmbeddingsContext = request.app.embeddings - context.embeddings.description.upsert( - documents=[new_description], - metadatas=[get_metadata(event)], - ids=[event_id], + context.embeddings.upsert_description( + event_id=event_id, + description=new_description, ) response_message = ( @@ -1065,8 +1020,8 @@ def delete_event(request: Request, event_id: str): # If semantic search is enabled, update the index if request.app.frigate_config.semantic_search.enabled: context: EmbeddingsContext = request.app.embeddings - context.embeddings.thumbnail.delete(ids=[event_id]) - context.embeddings.description.delete(ids=[event_id]) + context.embeddings.delete_thumbnail(id=[event_id]) + context.embeddings.delete_description(id=[event_id]) return JSONResponse( content=({"success": True, "message": "Event " + event_id + " deleted"}), status_code=200,