migrate api from chroma to sqlite_vec

This commit is contained in:
Josh Hawkins 2024-10-04 13:22:01 -05:00
parent 3c334175c7
commit 3e5420c410

View File

@ -1,8 +1,6 @@
"""Event apis.""" """Event apis."""
import base64
import datetime import datetime
import io
import logging import logging
import os import os
from functools import reduce from functools import reduce
@ -10,12 +8,10 @@ from pathlib import Path
from urllib.parse import unquote from urllib.parse import unquote
import cv2 import cv2
import numpy as np
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.params import Depends from fastapi.params import Depends
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from peewee import JOIN, DoesNotExist, fn, operator from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from frigate.api.defs.events_body import ( from frigate.api.defs.events_body import (
@ -39,7 +35,6 @@ from frigate.const import (
CLIPS_DIR, CLIPS_DIR,
) )
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.embeddings.embeddings import get_metadata
from frigate.models import Event, ReviewSegment, Timeline from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers from frigate.util.builtin import get_tz_modifiers
@ -389,6 +384,8 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}")
selected_columns = [ selected_columns = [
Event.id, Event.id,
Event.camera, Event.camera,
@ -484,14 +481,10 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else: else:
event_ids = [] event_ids = []
# Build the Chroma where clause based on the event IDs thumb_results = []
where = {"id": {"$in": event_ids}} if event_ids else {} desc_results = []
thumb_ids = {}
desc_ids = {}
if search_type == "similarity": if search_type == "similarity":
# Grab the ids of events that match the thumbnail image embeddings
try: try:
search_event: Event = Event.get(Event.id == event_id) search_event: Event = Event.get(Event.id == event_id)
except DoesNotExist: except DoesNotExist:
@ -504,61 +497,24 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
), ),
status_code=404, status_code=404,
) )
thumbnail = base64.b64decode(search_event.thumbnail) thumb_results = context.embeddings.search_thumbnail(search_event.id, limit)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=limit,
where=where,
)
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
else: else:
search_types = search_type.split(",") search_types = search_type.split(",")
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_result = context.embeddings.thumbnail.query( thumb_results = context.embeddings.search_thumbnail(query, limit)
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
if "description" in search_types: if "description" in search_types:
desc_result = context.embeddings.description.query( desc_results = context.embeddings.search_description(query, limit)
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
results = {} results = {}
for event_id in thumb_ids.keys() | desc_ids: for result in thumb_results + desc_results:
min_distance = min( event_id, distance = result[0], result[1]
i if event_id in event_ids or not event_ids:
for i in (thumb_ids.get(event_id), desc_ids.get(event_id)) if event_id not in results or distance < results[event_id]["distance"]:
if i is not None
)
results[event_id] = { results[event_id] = {
"distance": min_distance, "distance": distance,
"source": "thumbnail" "source": "thumbnail" if result in thumb_results else "description",
if min_distance == thumb_ids.get(event_id)
else "description",
} }
if not results: if not results:
@ -975,10 +931,9 @@ def set_description(
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.description.upsert( context.embeddings.upsert_description(
documents=[new_description], event_id=event_id,
metadatas=[get_metadata(event)], description=new_description,
ids=[event_id],
) )
response_message = ( response_message = (
@ -1065,8 +1020,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.thumbnail.delete(ids=[event_id]) context.embeddings.delete_thumbnail(id=[event_id])
context.embeddings.description.delete(ids=[event_id]) context.embeddings.delete_description(id=[event_id])
return JSONResponse( return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}), content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200, status_code=200,