diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index a4c8618e4..381d95ed1 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -73,7 +73,7 @@ class EmbeddingsContext: def __init__(self, db: SqliteVecQueueDatabase): self.embeddings = Embeddings(db) self.thumb_stats = ZScoreNormalization() - self.desc_stats = ZScoreNormalization(scale_factor=3, bias=-2.5) + self.desc_stats = ZScoreNormalization() # load stats from disk try: diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index c763bf304..992dcd7d8 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -7,6 +7,7 @@ import struct import time from typing import List, Tuple, Union +import numpy as np from PIL import Image from playhouse.shortcuts import model_to_dict @@ -16,8 +17,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.models import Event from frigate.types import ModelStatusTypesEnum -from .functions.clip import ClipEmbedding -from .functions.minilm_l6_v2 import MiniLMEmbedding +from .functions.onnx import GenericONNXEmbedding logger = logging.getLogger(__name__) @@ -53,9 +53,23 @@ def get_metadata(event: Event) -> dict: ) -def serialize(vector: List[float]) -> bytes: - """Serializes a list of floats into a compact "raw bytes" format""" - return struct.pack("%sf" % len(vector), *vector) +def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes: + """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format""" + if isinstance(vector, np.ndarray): + # Convert numpy array to list of floats + vector = vector.flatten().tolist() + elif isinstance(vector, (float, np.float32, np.float64)): + # Handle single float values + vector = [vector] + elif not isinstance(vector, list): + raise TypeError( + f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}" + ) + + try: + return struct.pack("%sf" % len(vector), *vector) + except struct.error as e: + raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}") def deserialize(bytes_data: bytes) -> List[float]: @@ -74,10 +88,10 @@ class Embeddings: self._create_tables() models = [ - "sentence-transformers/all-MiniLM-L6-v2-model.onnx", - "sentence-transformers/all-MiniLM-L6-v2-tokenizer", - "clip-clip_image_model_vitb32.onnx", - "clip-clip_text_model_vitb32.onnx", + "all-jina-clip-v1-text_model_fp16.onnx", + "all-jina-clip-v1-tokenizer", + "all-jina-clip-v1-vision_model_fp16.onnx", + "all-jina-clip-v1-preprocessor_config.json", ] for model in models: @@ -89,11 +103,32 @@ class Embeddings: }, ) - self.clip_embedding = ClipEmbedding( - preferred_providers=["CPUExecutionProvider"] + def jina_text_embedding_function(outputs): + return outputs[0] + + def jina_vision_embedding_function(outputs): + return outputs[0] + + self.text_embedding = GenericONNXEmbedding( + model_name="all-jina-clip-v1", + model_file="text_model_fp16.onnx", + tokenizer_file="tokenizer", + download_urls={ + "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx" + }, + embedding_function=jina_text_embedding_function, + model_type="text", ) - self.minilm_embedding = MiniLMEmbedding( - preferred_providers=["CPUExecutionProvider"], + + self.vision_embedding = GenericONNXEmbedding( + model_name="all-jina-clip-v1", + model_file="vision_model_fp16.onnx", + tokenizer_file="preprocessor_config.json", + download_urls={ + "vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx" + }, + embedding_function=jina_vision_embedding_function, + model_type="vision", ) def _create_tables(self): @@ -101,7 +136,7 @@ class Embeddings: self.db.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( id TEXT PRIMARY KEY, - thumbnail_embedding FLOAT[512] + thumbnail_embedding FLOAT[768] ); """) @@ -109,15 +144,14 @@ class Embeddings: self.db.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( id TEXT PRIMARY KEY, - description_embedding FLOAT[384] + description_embedding FLOAT[768] ); """) def upsert_thumbnail(self, event_id: str, thumbnail: bytes): # Convert thumbnail bytes to PIL Image image = Image.open(io.BytesIO(thumbnail)).convert("RGB") - # Generate embedding using CLIP - embedding = self.clip_embedding([image])[0] + embedding = self.vision_embedding([image])[0] self.db.execute_sql( """ @@ -130,8 +164,7 @@ class Embeddings: return embedding def upsert_description(self, event_id: str, description: str): - # Generate embedding using MiniLM - embedding = self.minilm_embedding([description])[0] + embedding = self.text_embedding([description])[0] self.db.execute_sql( """ @@ -177,7 +210,7 @@ class Embeddings: thumbnail = base64.b64decode(query.thumbnail) query_embedding = self.upsert_thumbnail(query.id, thumbnail) else: - query_embedding = self.clip_embedding([query])[0] + query_embedding = self.text_embedding([query])[0] sql_query = """ SELECT @@ -211,7 +244,7 @@ class Embeddings: def search_description( self, query_text: str, event_ids: List[str] = None ) -> List[Tuple[str, float]]: - query_embedding = self.minilm_embedding([query_text])[0] + query_embedding = self.text_embedding([query_text])[0] # Prepare the base SQL query sql_query = """ diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py new file mode 100644 index 000000000..be33a889f --- /dev/null +++ b/frigate/embeddings/functions/onnx.py @@ -0,0 +1,171 @@ +import logging +import os +import warnings +from io import BytesIO +from typing import Callable, Dict, List, Union + +import numpy as np +import onnxruntime as ort +import requests +from PIL import Image + +# importing this without pytorch or others causes a warning +# https://github.com/huggingface/transformers/issues/27214 +# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 +from transformers import AutoFeatureExtractor, AutoTokenizer + +from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum +from frigate.util.downloader import ModelDownloader + +warnings.filterwarnings( + "ignore", + category=FutureWarning, + message="The class CLIPFeatureExtractor is deprecated", +) + + +logger = logging.getLogger(__name__) + + +class GenericONNXEmbedding: + """Generic embedding function for ONNX models (text and vision).""" + + def __init__( + self, + model_name: str, + model_file: str, + tokenizer_file: str, + download_urls: Dict[str, str], + embedding_function: Callable[[List[np.ndarray]], np.ndarray], + model_type: str, + preferred_providers: List[str] = ["CPUExecutionProvider"], + ): + self.model_name = model_name + self.model_file = model_file + self.tokenizer_file = tokenizer_file + self.download_urls = download_urls + self.embedding_function = embedding_function + self.model_type = model_type # 'text' or 'vision' + self.preferred_providers = preferred_providers + + self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) + self.tokenizer = None + self.feature_extractor = None + self.session = None + + self.downloader = ModelDownloader( + model_name=self.model_name, + download_path=self.download_path, + file_names=[self.model_file, self.tokenizer_file], + download_func=self._download_model, + ) + self.downloader.ensure_model_files() + + def _download_model(self, path: str): + try: + file_name = os.path.basename(path) + if file_name in self.download_urls: + ModelDownloader.download_from_url(self.download_urls[file_name], path) + elif file_name == self.tokenizer_file: + logger.info( + f"Downloading {self.model_name} tokenizer/feature extractor" + ) + if self.model_type == "text": + tokenizer = AutoTokenizer.from_pretrained( + self.model_name, clean_up_tokenization_spaces=True + ) + tokenizer.save_pretrained(path) + else: + feature_extractor = AutoFeatureExtractor.from_pretrained( + self.model_name + ) + feature_extractor.save_pretrained(path) + + self.downloader.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file_name}", + "state": ModelStatusTypesEnum.downloaded, + }, + ) + except Exception: + self.downloader.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file_name}", + "state": ModelStatusTypesEnum.error, + }, + ) + + def _load_model_and_tokenizer(self): + if self.session is None: + self.downloader.wait_for_download() + if self.model_type == "text": + self.tokenizer = self._load_tokenizer() + else: + self.feature_extractor = self._load_feature_extractor() + self.session = self._load_model( + os.path.join(self.download_path, self.model_file), + self.preferred_providers, + ) + + def _load_tokenizer(self): + tokenizer_path = os.path.join(self.download_path, self.tokenizer_file) + return AutoTokenizer.from_pretrained( + tokenizer_path, clean_up_tokenization_spaces=True + ) + + def _load_feature_extractor(self): + feature_extractor_path = os.path.join(self.download_path, self.tokenizer_file) + return AutoFeatureExtractor.from_pretrained(feature_extractor_path) + + def _load_model(self, path: str, providers: List[str]): + if os.path.exists(path): + return ort.InferenceSession(path, providers=providers) + else: + logger.warning(f"{self.model_name} model file {path} not found.") + return None + + def _process_image(self, image): + if isinstance(image, str): + if image.startswith("http"): + response = requests.get(image) + image = Image.open(BytesIO(response.content)).convert("RGB") + + return image + + def __call__( + self, inputs: Union[List[str], List[Image.Image], List[str]] + ) -> List[np.ndarray]: + self._load_model_and_tokenizer() + + if self.session is None or ( + self.tokenizer is None and self.feature_extractor is None + ): + logger.error( + f"{self.model_name} model or tokenizer/feature extractor is not loaded." + ) + return [] + + if self.model_type == "text": + processed_inputs = self.tokenizer( + inputs, padding=True, truncation=True, return_tensors="np" + ) + else: + processed_images = [self._process_image(img) for img in inputs] + processed_inputs = self.feature_extractor( + images=processed_images, return_tensors="np" + ) + + input_names = [input.name for input in self.session.get_inputs()] + onnx_inputs = { + name: processed_inputs[name] + for name in input_names + if name in processed_inputs + } + + outputs = self.session.run(None, onnx_inputs) + embeddings = self.embedding_function(outputs) + + return [embedding for embedding in embeddings] diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 7fc0f9286..0a877a148 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -184,31 +184,31 @@ export default function Explore() { // model states - const { payload: minilmModelState } = useModelState( - "sentence-transformers/all-MiniLM-L6-v2-model.onnx", + const { payload: textModelState } = useModelState( + "all-jina-clip-v1-text_model_fp16.onnx", ); - const { payload: minilmTokenizerState } = useModelState( - "sentence-transformers/all-MiniLM-L6-v2-tokenizer", + const { payload: textTokenizerState } = useModelState( + "all-jina-clip-v1-tokenizer", ); - const { payload: clipImageModelState } = useModelState( - "clip-clip_image_model_vitb32.onnx", + const { payload: visionModelState } = useModelState( + "all-jina-clip-v1-vision_model_fp16.onnx", ); - const { payload: clipTextModelState } = useModelState( - "clip-clip_text_model_vitb32.onnx", + const { payload: visionFeatureExtractorState } = useModelState( + "all-jina-clip-v1-preprocessor_config.json", ); const allModelsLoaded = useMemo(() => { return ( - minilmModelState === "downloaded" && - minilmTokenizerState === "downloaded" && - clipImageModelState === "downloaded" && - clipTextModelState === "downloaded" + textModelState === "downloaded" && + textTokenizerState === "downloaded" && + visionModelState === "downloaded" && + visionFeatureExtractorState === "downloaded" ); }, [ - minilmModelState, - minilmTokenizerState, - clipImageModelState, - clipTextModelState, + textModelState, + textTokenizerState, + visionModelState, + visionFeatureExtractorState, ]); const renderModelStateIcon = (modelState: ModelState) => { @@ -225,10 +225,10 @@ export default function Explore() { }; if ( - !minilmModelState || - !minilmTokenizerState || - !clipImageModelState || - !clipTextModelState + !textModelState || + !textTokenizerState || + !visionModelState || + !visionFeatureExtractorState ) { return ( @@ -251,25 +251,26 @@ export default function Explore() {
- {renderModelStateIcon(clipImageModelState)} - CLIP image model + {renderModelStateIcon(visionModelState)} + Vision model
- {renderModelStateIcon(clipTextModelState)} - CLIP text model + {renderModelStateIcon(visionFeatureExtractorState)} + Vision model feature extractor
- {renderModelStateIcon(minilmModelState)} - MiniLM sentence model + {renderModelStateIcon(textModelState)} + Text model
- {renderModelStateIcon(minilmTokenizerState)} - MiniLM tokenizer + {renderModelStateIcon(textTokenizerState)} + Text tokenizer
- {(minilmModelState === "error" || - clipImageModelState === "error" || - clipTextModelState === "error") && ( + {(textModelState === "error" || + textTokenizerState === "error" || + visionModelState === "error" || + visionFeatureExtractorState === "error") && (
An error has occurred. Check Frigate logs.