migrate api from chroma to sqlite_vec

This commit is contained in:
Josh Hawkins 2024-10-04 13:22:01 -05:00
parent 3c334175c7
commit 3e5420c410

View File

@ -1,8 +1,6 @@
"""Event apis."""
import base64
import datetime
import io
import logging
import os
from functools import reduce
@ -10,12 +8,10 @@ from pathlib import Path
from urllib.parse import unquote
import cv2
import numpy as np
from fastapi import APIRouter, Request
from fastapi.params import Depends
from fastapi.responses import JSONResponse
from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.api.defs.events_body import (
@ -39,7 +35,6 @@ from frigate.const import (
CLIPS_DIR,
)
from frigate.embeddings import EmbeddingsContext
from frigate.embeddings.embeddings import get_metadata
from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers
@ -389,6 +384,8 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
context: EmbeddingsContext = request.app.embeddings
logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}")
selected_columns = [
Event.id,
Event.camera,
@ -484,14 +481,10 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else:
event_ids = []
# Build the Chroma where clause based on the event IDs
where = {"id": {"$in": event_ids}} if event_ids else {}
thumb_ids = {}
desc_ids = {}
thumb_results = []
desc_results = []
if search_type == "similarity":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
@ -504,62 +497,25 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
),
status_code=404,
)
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=limit,
where=where,
)
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
thumb_results = context.embeddings.search_thumbnail(search_event.id, limit)
else:
search_types = search_type.split(",")
if "thumbnail" in search_types:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
thumb_results = context.embeddings.search_thumbnail(query, limit)
if "description" in search_types:
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
desc_results = context.embeddings.search_description(query, limit)
results = {}
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}
for result in thumb_results + desc_results:
event_id, distance = result[0], result[1]
if event_id in event_ids or not event_ids:
if event_id not in results or distance < results[event_id]["distance"]:
results[event_id] = {
"distance": distance,
"source": "thumbnail" if result in thumb_results else "description",
}
if not results:
return JSONResponse(content=[])
@ -975,10 +931,9 @@ def set_description(
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.description.upsert(
documents=[new_description],
metadatas=[get_metadata(event)],
ids=[event_id],
context.embeddings.upsert_description(
event_id=event_id,
description=new_description,
)
response_message = (
@ -1065,8 +1020,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.thumbnail.delete(ids=[event_id])
context.embeddings.description.delete(ids=[event_id])
context.embeddings.delete_thumbnail(id=[event_id])
context.embeddings.delete_description(id=[event_id])
return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200,