initial event search api implementation

This commit is contained in:
Jason Hunter 2024-06-21 17:25:19 -04:00
parent e99f8795ca
commit 48ea8faa4a
5 changed files with 329 additions and 7 deletions

View File

@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp
from frigate.api.review import ReviewBp
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor
from frigate.models import Event, Timeline
from frigate.plus import PlusApi
@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp)
def create_app(
frigate_config,
database: SqliteQueueDatabase,
embeddings: EmbeddingsContext,
detected_frames_processor,
storage_maintainer: StorageMaintainer,
onvif: OnvifController,
@ -79,6 +81,7 @@ def create_app(
database.close()
app.frigate_config = frigate_config
app.embeddings = embeddings
app.detected_frames_processor = detected_frames_processor
app.storage_maintainer = storage_maintainer
app.onvif = onvif

View File

@ -1,5 +1,7 @@
"""Event apis."""
import base64
import io
import logging
import os
from datetime import datetime
@ -8,6 +10,7 @@ from pathlib import Path
from urllib.parse import unquote
import cv2
import numpy as np
from flask import (
Blueprint,
current_app,
@ -15,13 +18,15 @@ from flask import (
make_response,
request,
)
from peewee import DoesNotExist, fn, operator
from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict
from frigate.const import (
CLIPS_DIR,
)
from frigate.models import Event, Timeline
from frigate.embeddings import EmbeddingsContext, get_metadata
from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers
@ -245,6 +250,190 @@ def events():
return jsonify(list(events))
@EventBp.route("/events/search")
def events_search():
query = request.args.get("query", type=str)
search_type = request.args.get("search_type", "text", type=str)
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
limit = request.args.get("limit", 100, type=int)
# Filters
cameras = request.args.get("cameras", "all", type=str)
labels = request.args.get("labels", "all", type=str)
zones = request.args.get("zones", "all", type=str)
after = request.args.get("after", type=float)
before = request.args.get("before", type=float)
if not query:
return make_response(
jsonify(
{
"success": False,
"message": "A search query must be supplied",
}
),
400,
)
if not current_app.frigate_config.semantic_search.enabled:
return make_response(
jsonify(
{
"success": False,
"message": "Semantic search is not enabled",
}
),
400,
)
context: EmbeddingsContext = current_app.embeddings
selected_columns = [
Event.id,
Event.camera,
Event.label,
Event.sub_label,
Event.zones,
Event.start_time,
Event.end_time,
Event.data,
ReviewSegment.thumb_path,
]
if include_thumbnails:
selected_columns.append(Event.thumbnail)
# Build the where clause for the embeddings query
embeddings_filters = []
if cameras != "all":
camera_list = cameras.split(",")
embeddings_filters.append({"camera": {"$in": camera_list}})
if labels != "all":
label_list = labels.split(",")
embeddings_filters.append({"label": {"$in": label_list}})
if zones != "all":
filtered_zones = zones.split(",")
zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
if len(zone_filters) > 1:
embeddings_filters.append({"$or": zone_filters})
else:
embeddings_filters.append(zone_filters[0])
if after:
embeddings_filters.append({"start_time": {"$gt": after}})
if before:
embeddings_filters.append({"start_time": {"$lt": before}})
where = None
if len(embeddings_filters) > 1:
where = {"$and": embeddings_filters}
elif len(embeddings_filters) == 1:
where = embeddings_filters[0]
thumb_ids = {}
desc_ids = {}
if search_type == "thumbnail":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == query)
except DoesNotExist:
return make_response(
jsonify(
{
"success": False,
"message": "Event not found",
}
),
404,
)
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=int(limit),
where=where,
)
thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
else:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
results = {}
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}
# Get the event data
events = (
Event.select(*selected_columns)
.join(
ReviewSegment,
JOIN.LEFT_OUTER,
on=(
fn.json_extract(ReviewSegment.data, "$.detections").contains(
Event.id
)
),
)
.where(Event.id << list(results.keys()))
.dicts()
.iterator()
)
events = list(events)
events = [
{k: v for k, v in event.items() if k != "data"}
| {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
| {
"search_distance": results[event["id"]]["distance"],
"search_source": results[event["id"]]["source"],
}
for event in events
]
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
return jsonify(events)
@EventBp.route("/events/summary")
def events_summary():
tz_name = request.args.get("timezone", default="utc", type=str)
@ -604,6 +793,52 @@ def set_sub_label(id):
)
@EventBp.route("/events/<id>/description", methods=("POST",))
def set_description(id):
try:
event: Event = Event.get(Event.id == id)
except DoesNotExist:
return make_response(
jsonify({"success": False, "message": "Event " + id + " not found"}), 404
)
json: dict[str, any] = request.get_json(silent=True) or {}
new_description = json.get("description")
if new_description is None or len(new_description) == 0:
return make_response(
jsonify(
{
"success": False,
"message": "description cannot be empty",
}
),
400,
)
event.data["description"] = new_description
event.save()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.description.upsert(
documents=[new_description],
metadatas=[get_metadata(event)],
ids=[id],
)
return make_response(
jsonify(
{
"success": True,
"message": "Event " + id + " description set to " + new_description,
}
),
200,
)
@EventBp.route("/events/<id>", methods=("DELETE",))
def delete_event(id):
try:
@ -625,6 +860,11 @@ def delete_event(id):
event.delete_instance()
Timeline.delete().where(Timeline.source_id == id).execute()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.thumbnail.delete(ids=[id])
context.embeddings.description.delete(ids=[id])
return make_response(
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
)

View File

@ -37,8 +37,7 @@ from frigate.const import (
MODEL_CACHE_DIR,
RECORD_DIR,
)
from frigate.embeddings import manage_embeddings
from frigate.embeddings.embeddings import Embeddings
from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.events.audio import listen_to_audio
from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor
@ -322,7 +321,7 @@ class FrigateApp:
def init_embeddings_manager(self) -> None:
# Create a client for other processes to use
self.embeddings = Embeddings()
self.embeddings = EmbeddingsContext()
embedding_process = mp.Process(
target=manage_embeddings,
name="embeddings_manager",
@ -384,6 +383,7 @@ class FrigateApp:
self.flask_app = create_app(
self.config,
self.db,
self.embeddings,
self.detected_frames_processor,
self.storage_maintainer,
self.onvif_controller,
@ -811,6 +811,9 @@ class FrigateApp:
self.frigate_watchdog.join()
self.db.stop()
# Save embeddings stats to disk
self.embeddings.save_stats()
# Stop Communicators
self.inter_process_communicator.stop()
self.inter_config_updater.stop()

View File

@ -1,5 +1,6 @@
"""ChromaDB embeddings database."""
import json
import logging
import multiprocessing as mp
import signal
@ -12,9 +13,14 @@ from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.models import Event
from frigate.util.services import listen
from .embeddings import Embeddings, get_metadata
from .maintainer import EmbeddingMaintainer
from .util import ZScoreNormalization
logger = logging.getLogger(__name__)
@ -51,8 +57,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
# Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer
embeddings = Embeddings()
@ -65,3 +69,28 @@ def manage_embeddings(config: FrigateConfig) -> None:
stop_event,
)
maintainer.start()
class EmbeddingsContext:
def __init__(self):
self.embeddings = Embeddings()
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
# load stats from disk
try:
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f:
data = json.loads(f.read())
self.thumb_stats.from_dict(data["thumb_stats"])
self.desc_stats.from_dict(data["desc_stats"])
except FileNotFoundError:
pass
def save_stats(self):
"""Write the stats to disk as JSON on exit."""
contents = {
"thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(),
}
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f:
f.write(json.dumps(contents))

View File

@ -0,0 +1,47 @@
"""Z-score normalization for search distance."""
import math
class ZScoreNormalization:
"""Running Z-score normalization for search distance."""
def __init__(self):
self.n = 0
self.mean = 0
self.m2 = 0
@property
def variance(self):
return self.m2 / (self.n - 1) if self.n > 1 else 0.0
@property
def stddev(self):
return math.sqrt(self.variance)
def normalize(self, distances: list[float]):
self._update(distances)
if self.stddev == 0:
return distances
return [(x - self.mean) / self.stddev for x in distances]
def _update(self, distances: list[float]):
for x in distances:
self.n += 1
delta = x - self.mean
self.mean += delta / self.n
delta2 = x - self.mean
self.m2 += delta * delta2
def to_dict(self):
return {
"n": self.n,
"mean": self.mean,
"m2": self.m2,
}
def from_dict(self, data: dict):
self.n = data["n"]
self.mean = data["mean"]
self.m2 = data["m2"]
return self