initial event search api implementation

This commit is contained in:
Jason Hunter 2024-06-21 17:25:19 -04:00
parent e99f8795ca
commit 48ea8faa4a
5 changed files with 329 additions and 7 deletions

View File

@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp
from frigate.api.review import ReviewBp from frigate.api.review import ReviewBp
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR from frigate.const import CONFIG_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor from frigate.events.external import ExternalEventProcessor
from frigate.models import Event, Timeline from frigate.models import Event, Timeline
from frigate.plus import PlusApi from frigate.plus import PlusApi
@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp)
def create_app( def create_app(
frigate_config, frigate_config,
database: SqliteQueueDatabase, database: SqliteQueueDatabase,
embeddings: EmbeddingsContext,
detected_frames_processor, detected_frames_processor,
storage_maintainer: StorageMaintainer, storage_maintainer: StorageMaintainer,
onvif: OnvifController, onvif: OnvifController,
@ -79,6 +81,7 @@ def create_app(
database.close() database.close()
app.frigate_config = frigate_config app.frigate_config = frigate_config
app.embeddings = embeddings
app.detected_frames_processor = detected_frames_processor app.detected_frames_processor = detected_frames_processor
app.storage_maintainer = storage_maintainer app.storage_maintainer = storage_maintainer
app.onvif = onvif app.onvif = onvif

View File

@ -1,5 +1,7 @@
"""Event apis.""" """Event apis."""
import base64
import io
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@ -8,6 +10,7 @@ from pathlib import Path
from urllib.parse import unquote from urllib.parse import unquote
import cv2 import cv2
import numpy as np
from flask import ( from flask import (
Blueprint, Blueprint,
current_app, current_app,
@ -15,13 +18,15 @@ from flask import (
make_response, make_response,
request, request,
) )
from peewee import DoesNotExist, fn, operator from peewee import JOIN, DoesNotExist, fn, operator
from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from frigate.const import ( from frigate.const import (
CLIPS_DIR, CLIPS_DIR,
) )
from frigate.models import Event, Timeline from frigate.embeddings import EmbeddingsContext, get_metadata
from frigate.models import Event, ReviewSegment, Timeline
from frigate.object_processing import TrackedObject from frigate.object_processing import TrackedObject
from frigate.util.builtin import get_tz_modifiers from frigate.util.builtin import get_tz_modifiers
@ -245,6 +250,190 @@ def events():
return jsonify(list(events)) return jsonify(list(events))
@EventBp.route("/events/search")
def events_search():
query = request.args.get("query", type=str)
search_type = request.args.get("search_type", "text", type=str)
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
limit = request.args.get("limit", 100, type=int)
# Filters
cameras = request.args.get("cameras", "all", type=str)
labels = request.args.get("labels", "all", type=str)
zones = request.args.get("zones", "all", type=str)
after = request.args.get("after", type=float)
before = request.args.get("before", type=float)
if not query:
return make_response(
jsonify(
{
"success": False,
"message": "A search query must be supplied",
}
),
400,
)
if not current_app.frigate_config.semantic_search.enabled:
return make_response(
jsonify(
{
"success": False,
"message": "Semantic search is not enabled",
}
),
400,
)
context: EmbeddingsContext = current_app.embeddings
selected_columns = [
Event.id,
Event.camera,
Event.label,
Event.sub_label,
Event.zones,
Event.start_time,
Event.end_time,
Event.data,
ReviewSegment.thumb_path,
]
if include_thumbnails:
selected_columns.append(Event.thumbnail)
# Build the where clause for the embeddings query
embeddings_filters = []
if cameras != "all":
camera_list = cameras.split(",")
embeddings_filters.append({"camera": {"$in": camera_list}})
if labels != "all":
label_list = labels.split(",")
embeddings_filters.append({"label": {"$in": label_list}})
if zones != "all":
filtered_zones = zones.split(",")
zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
if len(zone_filters) > 1:
embeddings_filters.append({"$or": zone_filters})
else:
embeddings_filters.append(zone_filters[0])
if after:
embeddings_filters.append({"start_time": {"$gt": after}})
if before:
embeddings_filters.append({"start_time": {"$lt": before}})
where = None
if len(embeddings_filters) > 1:
where = {"$and": embeddings_filters}
elif len(embeddings_filters) == 1:
where = embeddings_filters[0]
thumb_ids = {}
desc_ids = {}
if search_type == "thumbnail":
# Grab the ids of events that match the thumbnail image embeddings
try:
search_event: Event = Event.get(Event.id == query)
except DoesNotExist:
return make_response(
jsonify(
{
"success": False,
"message": "Event not found",
}
),
404,
)
thumbnail = base64.b64decode(search_event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumb_result = context.embeddings.thumbnail.query(
query_images=[img],
n_results=int(limit),
where=where,
)
thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
else:
thumb_result = context.embeddings.thumbnail.query(
query_texts=[query],
n_results=limit,
where=where,
)
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
thumb_ids = dict(
zip(
thumb_result["ids"][0],
context.thumb_stats.normalize(thumb_result["distances"][0]),
)
)
desc_result = context.embeddings.description.query(
query_texts=[query],
n_results=limit,
where=where,
)
desc_ids = dict(
zip(
desc_result["ids"][0],
context.desc_stats.normalize(desc_result["distances"][0]),
)
)
results = {}
for event_id in thumb_ids.keys() | desc_ids:
min_distance = min(
i
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
if i is not None
)
results[event_id] = {
"distance": min_distance,
"source": "thumbnail"
if min_distance == thumb_ids.get(event_id)
else "description",
}
# Get the event data
events = (
Event.select(*selected_columns)
.join(
ReviewSegment,
JOIN.LEFT_OUTER,
on=(
fn.json_extract(ReviewSegment.data, "$.detections").contains(
Event.id
)
),
)
.where(Event.id << list(results.keys()))
.dicts()
.iterator()
)
events = list(events)
events = [
{k: v for k, v in event.items() if k != "data"}
| {
k: v
for k, v in event["data"].items()
if k in ["type", "score", "top_score", "description"]
}
| {
"search_distance": results[event["id"]]["distance"],
"search_source": results[event["id"]]["source"],
}
for event in events
]
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
return jsonify(events)
@EventBp.route("/events/summary") @EventBp.route("/events/summary")
def events_summary(): def events_summary():
tz_name = request.args.get("timezone", default="utc", type=str) tz_name = request.args.get("timezone", default="utc", type=str)
@ -604,6 +793,52 @@ def set_sub_label(id):
) )
@EventBp.route("/events/<id>/description", methods=("POST",))
def set_description(id):
try:
event: Event = Event.get(Event.id == id)
except DoesNotExist:
return make_response(
jsonify({"success": False, "message": "Event " + id + " not found"}), 404
)
json: dict[str, any] = request.get_json(silent=True) or {}
new_description = json.get("description")
if new_description is None or len(new_description) == 0:
return make_response(
jsonify(
{
"success": False,
"message": "description cannot be empty",
}
),
400,
)
event.data["description"] = new_description
event.save()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.description.upsert(
documents=[new_description],
metadatas=[get_metadata(event)],
ids=[id],
)
return make_response(
jsonify(
{
"success": True,
"message": "Event " + id + " description set to " + new_description,
}
),
200,
)
@EventBp.route("/events/<id>", methods=("DELETE",)) @EventBp.route("/events/<id>", methods=("DELETE",))
def delete_event(id): def delete_event(id):
try: try:
@ -625,6 +860,11 @@ def delete_event(id):
event.delete_instance() event.delete_instance()
Timeline.delete().where(Timeline.source_id == id).execute() Timeline.delete().where(Timeline.source_id == id).execute()
# If semantic search is enabled, update the index
if current_app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = current_app.embeddings
context.embeddings.thumbnail.delete(ids=[id])
context.embeddings.description.delete(ids=[id])
return make_response( return make_response(
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200 jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
) )

View File

@ -37,8 +37,7 @@ from frigate.const import (
MODEL_CACHE_DIR, MODEL_CACHE_DIR,
RECORD_DIR, RECORD_DIR,
) )
from frigate.embeddings import manage_embeddings from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.embeddings.embeddings import Embeddings
from frigate.events.audio import listen_to_audio from frigate.events.audio import listen_to_audio
from frigate.events.cleanup import EventCleanup from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor from frigate.events.external import ExternalEventProcessor
@ -322,7 +321,7 @@ class FrigateApp:
def init_embeddings_manager(self) -> None: def init_embeddings_manager(self) -> None:
# Create a client for other processes to use # Create a client for other processes to use
self.embeddings = Embeddings() self.embeddings = EmbeddingsContext()
embedding_process = mp.Process( embedding_process = mp.Process(
target=manage_embeddings, target=manage_embeddings,
name="embeddings_manager", name="embeddings_manager",
@ -384,6 +383,7 @@ class FrigateApp:
self.flask_app = create_app( self.flask_app = create_app(
self.config, self.config,
self.db, self.db,
self.embeddings,
self.detected_frames_processor, self.detected_frames_processor,
self.storage_maintainer, self.storage_maintainer,
self.onvif_controller, self.onvif_controller,
@ -811,6 +811,9 @@ class FrigateApp:
self.frigate_watchdog.join() self.frigate_watchdog.join()
self.db.stop() self.db.stop()
# Save embeddings stats to disk
self.embeddings.save_stats()
# Stop Communicators # Stop Communicators
self.inter_process_communicator.stop() self.inter_process_communicator.stop()
self.inter_config_updater.stop() self.inter_config_updater.stop()

View File

@ -1,5 +1,6 @@
"""ChromaDB embeddings database.""" """ChromaDB embeddings database."""
import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import signal import signal
@ -12,9 +13,14 @@ from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.models import Event from frigate.models import Event
from frigate.util.services import listen from frigate.util.services import listen
from .embeddings import Embeddings, get_metadata
from .maintainer import EmbeddingMaintainer
from .util import ZScoreNormalization
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,8 +57,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
# Hotsawp the sqlite3 module for Chroma compatibility # Hotsawp the sqlite3 module for Chroma compatibility
__import__("pysqlite3") __import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer
embeddings = Embeddings() embeddings = Embeddings()
@ -65,3 +69,28 @@ def manage_embeddings(config: FrigateConfig) -> None:
stop_event, stop_event,
) )
maintainer.start() maintainer.start()
class EmbeddingsContext:
def __init__(self):
self.embeddings = Embeddings()
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
# load stats from disk
try:
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f:
data = json.loads(f.read())
self.thumb_stats.from_dict(data["thumb_stats"])
self.desc_stats.from_dict(data["desc_stats"])
except FileNotFoundError:
pass
def save_stats(self):
"""Write the stats to disk as JSON on exit."""
contents = {
"thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(),
}
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f:
f.write(json.dumps(contents))

View File

@ -0,0 +1,47 @@
"""Z-score normalization for search distance."""
import math
class ZScoreNormalization:
"""Running Z-score normalization for search distance."""
def __init__(self):
self.n = 0
self.mean = 0
self.m2 = 0
@property
def variance(self):
return self.m2 / (self.n - 1) if self.n > 1 else 0.0
@property
def stddev(self):
return math.sqrt(self.variance)
def normalize(self, distances: list[float]):
self._update(distances)
if self.stddev == 0:
return distances
return [(x - self.mean) / self.stddev for x in distances]
def _update(self, distances: list[float]):
for x in distances:
self.n += 1
delta = x - self.mean
self.mean += delta / self.n
delta2 = x - self.mean
self.m2 += delta * delta2
def to_dict(self):
return {
"n": self.n,
"mean": self.mean,
"m2": self.m2,
}
def from_dict(self, data: dict):
self.n = data["n"]
self.mean = data["mean"]
self.m2 = data["m2"]
return self