address some PR review comments

This commit is contained in:
Jason Hunter 2024-06-16 00:01:32 -04:00
parent 9ec66fcc9b
commit 9cddb34f22
12 changed files with 140 additions and 112 deletions

View File

@ -8,7 +8,7 @@ ARG SLIM_BASE=debian:11-slim
FROM ${BASE_IMAGE} AS base FROM ${BASE_IMAGE} AS base
FROM --platform=${BUILDPLATFORM} ${BASE_IMAGE} AS base_host FROM --platform=${BUILDPLATFORM} debian:11 AS base_host
FROM ${SLIM_BASE} AS slim-base FROM ${SLIM_BASE} AS slim-base

View File

@ -2,9 +2,12 @@
set -euxo pipefail set -euxo pipefail
SQLITE3_VERSION="96c92aba00c8375bc32fafcdf12429c58bd8aabfcadab6683e35bbb9cdebf19e" # 3.46.0
PYSQLITE3_VERSION="0.5.3"
# Fetch the source code for the latest release of Sqlite. # Fetch the source code for the latest release of Sqlite.
if [[ ! -d "sqlite" ]]; then if [[ ! -d "sqlite" ]]; then
wget https://www.sqlite.org/src/tarball/sqlite.tar.gz?r=release -O sqlite.tar.gz wget https://www.sqlite.org/src/tarball/sqlite.tar.gz?r=${SQLITE3_VERSION} -O sqlite.tar.gz
tar xzf sqlite.tar.gz tar xzf sqlite.tar.gz
cd sqlite/ cd sqlite/
LIBS="-lm" ./configure --disable-tcl --enable-tempstore=always LIBS="-lm" ./configure --disable-tcl --enable-tempstore=always
@ -18,13 +21,15 @@ if [[ ! -d "./pysqlite3" ]]; then
git clone https://github.com/coleifer/pysqlite3.git git clone https://github.com/coleifer/pysqlite3.git
fi fi
cd pysqlite3/
git checkout ${PYSQLITE3_VERSION}
# Copy the sqlite3 source amalgamation into the pysqlite3 directory so we can # Copy the sqlite3 source amalgamation into the pysqlite3 directory so we can
# create a self-contained extension module. # create a self-contained extension module.
cp "sqlite/sqlite3.c" pysqlite3/ cp "../sqlite/sqlite3.c" ./
cp "sqlite/sqlite3.h" pysqlite3/ cp "../sqlite/sqlite3.h" ./
# Create the wheel and put it in the /wheels dir. # Create the wheel and put it in the /wheels dir.
cd pysqlite3/
sed -i "s|name='pysqlite3-binary'|name=PACKAGE_NAME|g" setup.py sed -i "s|name='pysqlite3-binary'|name=PACKAGE_NAME|g" setup.py
python3 setup.py build_static python3 setup.py build_static
pip3 wheel . -w /wheels pip3 wheel . -w /wheels

View File

@ -34,7 +34,7 @@ class EventEndPublisher(Publisher):
topic_base = "event/" topic_base = "event/"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("ended") super().__init__("finalized")
def publish( def publish(
self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]] self, payload: tuple[EventTypeEnum, EventStateEnum, str, dict[str, any]]
@ -48,4 +48,4 @@ class EventEndSubscriber(Subscriber):
topic_base = "event/" topic_base = "event/"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("ended") super().__init__("finalized")

View File

@ -746,6 +746,9 @@ class GenAIConfig(FrigateBaseModel):
title="Default caption prompt.", title="Default caption prompt.",
) )
object_prompts: Dict[str, str] = Field(default={}, title="Object specific prompts.") object_prompts: Dict[str, str] = Field(default={}, title="Object specific prompts.")
reindex: Optional[bool] = Field(
default=False, title="Reindex all detections on startup."
)
class GenAICameraConfig(FrigateBaseModel): class GenAICameraConfig(FrigateBaseModel):

View File

@ -59,9 +59,8 @@ def manage_embeddings(config: FrigateConfig) -> None:
embeddings = Embeddings() embeddings = Embeddings()
# Check if we need to re-index events # Check if we need to re-index events
if os.path.exists(f"{CONFIG_DIR}/.reindex"): if config.genai.reindex:
embeddings.reindex() embeddings.reindex()
os.remove(f"{CONFIG_DIR}/.reindex")
maintainer = EmbeddingMaintainer( maintainer = EmbeddingMaintainer(
config, config,

View File

@ -3,7 +3,6 @@
import base64 import base64
import io import io
import logging import logging
import os
import time import time
import numpy as np import numpy as np
@ -13,7 +12,6 @@ from chromadb.config import Settings
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from frigate.const import CONFIG_DIR
from frigate.models import Event from frigate.models import Event
from .functions.clip import ClipEmbedding from .functions.clip import ClipEmbedding
@ -112,5 +110,3 @@ class Embeddings:
len(descriptions["ids"]), len(descriptions["ids"]),
time.time() - st, time.time() - st,
) )
os.remove(f"{CONFIG_DIR}/.reindex_events")

View File

@ -46,14 +46,21 @@ class EmbeddingMaintainer(threading.Thread):
def run(self) -> None: def run(self) -> None:
"""Maintain a Chroma vector database for semantic search.""" """Maintain a Chroma vector database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self._process_updates()
self._process_finalized()
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update() update = self.event_subscriber.check_for_update()
if update is None: if update is None:
continue return
source_type, _, camera, data = update source_type, _, camera, data = update
if camera and source_type == EventTypeEnum.tracked_object: if not camera or source_type != EventTypeEnum.tracked_object:
return
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events: if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = [] self.tracked_events[data["id"]] = []
@ -61,16 +68,15 @@ class EmbeddingMaintainer(threading.Thread):
# Create our own thumbnail based on the bounding box and the frame time # Create our own thumbnail based on the bounding box and the frame time
try: try:
frame_id = f"{camera}{data['frame_time']}" frame_id = f"{camera}{data['frame_time']}"
yuv_frame = self.frame_manager.get( yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
frame_id, camera_config.frame_shape_yuv
)
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data) self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id) self.frame_manager.close(frame_id)
except FileNotFoundError: except FileNotFoundError:
continue pass
# Embed thumbnails when an event ends def _process_finalized(self) -> None:
"""Process the end of an event."""
while True: while True:
ended = self.event_end_subscriber.check_for_update() ended = self.event_end_subscriber.check_for_update()
@ -111,10 +117,10 @@ class EmbeddingMaintainer(threading.Thread):
event, event,
[ [
data["thumbnail"] data["thumbnail"]
for data in self.tracked_events.get( for data in self.tracked_events[event_id]
event_id, [{"thumbnail": thumbnail}] ]
) if len(self.tracked_events.get(event_id, [])) > 0
], else [thumbnail],
metadata, metadata,
), ),
).start() ).start()

View File

@ -206,10 +206,13 @@ class EventCleanup(threading.Thread):
) )
events_to_delete = [e.id for e in events] events_to_delete = [e.id for e in events]
if len(events_to_delete) > 0: if len(events_to_delete) > 0:
Event.delete().where(Event.id << events_to_delete).execute() chunk_size = 50
for i in range(0, len(events_to_delete), chunk_size):
chunk = events_to_delete[i : i + chunk_size]
Event.delete().where(Event.id << chunk).execute()
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
self.embeddings.thumbnail.delete(ids=events_to_delete) self.embeddings.thumbnail.delete(ids=chunk)
self.embeddings.description.delete(ids=events_to_delete) self.embeddings.description.delete(ids=chunk)
logger.info("Exiting event cleanup...") logger.info("Exiting event cleanup...")

View File

@ -22,11 +22,14 @@ def register_genai_provider(key: GenAIProviderEnum):
class GenAIClient: class GenAIClient:
"""Generative AI client for Frigate.""" """Generative AI client for Frigate."""
def __init__(self, genai_config: GenAIConfig) -> None: def __init__(self, genai_config: GenAIConfig, timeout: int = 60) -> None:
self.genai_config: GenAIConfig = genai_config self.genai_config: GenAIConfig = genai_config
self.timeout = timeout
self.provider = self._init_provider() self.provider = self._init_provider()
def generate_description(self, thumbnails: list[bytes], metadata: dict[str, any]): def generate_description(
self, thumbnails: list[bytes], metadata: dict[str, any]
) -> Optional[str]:
"""Generate a description for the frame.""" """Generate a description for the frame."""
prompt = self.genai_config.object_prompts.get( prompt = self.genai_config.object_prompts.get(
metadata["label"], self.genai_config.prompt metadata["label"], self.genai_config.prompt

View File

@ -3,6 +3,7 @@
from typing import Optional from typing import Optional
import google.generativeai as genai import google.generativeai as genai
from google.api_core.exceptions import DeadlineExceeded
from frigate.config import GenAIProviderEnum from frigate.config import GenAIProviderEnum
from frigate.genai import GenAIClient, register_genai_provider from frigate.genai import GenAIClient, register_genai_provider
@ -28,12 +29,18 @@ class GeminiClient(GenAIClient):
} }
for img in images for img in images
] + [prompt] ] + [prompt]
try:
response = self.provider.generate_content( response = self.provider.generate_content(
data, data,
generation_config=genai.types.GenerationConfig( generation_config=genai.types.GenerationConfig(
candidate_count=1, candidate_count=1,
), ),
request_options=genai.types.RequestOptions(
timeout=self.timeout,
),
) )
except DeadlineExceeded:
return None
try: try:
description = response.text.strip() description = response.text.strip()
except ValueError: except ValueError:

View File

@ -3,6 +3,7 @@
import logging import logging
from typing import Optional from typing import Optional
from httpx import TimeoutException
from ollama import Client as ApiClient from ollama import Client as ApiClient
from ollama import ResponseError from ollama import ResponseError
@ -20,7 +21,7 @@ class OllamaClient(GenAIClient):
def _init_provider(self): def _init_provider(self):
"""Initialize the client.""" """Initialize the client."""
client = ApiClient(host=self.genai_config.base_url) client = ApiClient(host=self.genai_config.base_url, timeout=self.timeout)
response = client.pull(self.genai_config.model) response = client.pull(self.genai_config.model)
if response["status"] != "success": if response["status"] != "success":
logger.error("Failed to pull %s model from Ollama", self.genai_config.model) logger.error("Failed to pull %s model from Ollama", self.genai_config.model)
@ -28,7 +29,7 @@ class OllamaClient(GenAIClient):
return client return client
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to Ollama.""" """Submit a request to Ollama"""
try: try:
result = self.provider.generate( result = self.provider.generate(
self.genai_config.model, self.genai_config.model,
@ -36,5 +37,5 @@ class OllamaClient(GenAIClient):
images=images, images=images,
) )
return result["response"].strip() return result["response"].strip()
except ResponseError: except (TimeoutException, ResponseError):
return None return None

View File

@ -3,6 +3,7 @@
import base64 import base64
from typing import Optional from typing import Optional
from httpx import TimeoutException
from openai import OpenAI from openai import OpenAI
from frigate.config import GenAIProviderEnum from frigate.config import GenAIProviderEnum
@ -22,6 +23,7 @@ class OpenAIClient(GenAIClient):
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(self, prompt: str, images: list[bytes]) -> Optional[str]:
"""Submit a request to OpenAI.""" """Submit a request to OpenAI."""
encoded_images = [base64.b64encode(image).decode("utf-8") for image in images] encoded_images = [base64.b64encode(image).decode("utf-8") for image in images]
try:
result = self.provider.chat.completions.create( result = self.provider.chat.completions.create(
model=self.genai_config.model, model=self.genai_config.model,
messages=[ messages=[
@ -40,7 +42,10 @@ class OpenAIClient(GenAIClient):
+ [prompt], + [prompt],
}, },
], ],
timeout=self.timeout,
) )
except TimeoutException:
return None
if len(result.choices) > 0: if len(result.choices) > 0:
return result.choices[0].message.content.strip() return result.choices[0].message.content.strip()
return None return None