From 79f82c36aede059772b801ebd5a483dcab0fb2f6 Mon Sep 17 00:00:00 2001
From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com>
Date: Wed, 9 Oct 2024 11:17:49 -0500
Subject: [PATCH] add generic onnx model class and use jina ai clip models for
all embeddings
---
frigate/embeddings/__init__.py | 2 +-
frigate/embeddings/embeddings.py | 75 ++++++++----
frigate/embeddings/functions/onnx.py | 171 +++++++++++++++++++++++++++
web/src/pages/Explore.tsx | 64 +++++-----
4 files changed, 258 insertions(+), 54 deletions(-)
create mode 100644 frigate/embeddings/functions/onnx.py
diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py
index a4c8618e4..381d95ed1 100644
--- a/frigate/embeddings/__init__.py
+++ b/frigate/embeddings/__init__.py
@@ -73,7 +73,7 @@ class EmbeddingsContext:
def __init__(self, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(db)
self.thumb_stats = ZScoreNormalization()
- self.desc_stats = ZScoreNormalization(scale_factor=3, bias=-2.5)
+ self.desc_stats = ZScoreNormalization()
# load stats from disk
try:
diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py
index c763bf304..992dcd7d8 100644
--- a/frigate/embeddings/embeddings.py
+++ b/frigate/embeddings/embeddings.py
@@ -7,6 +7,7 @@ import struct
import time
from typing import List, Tuple, Union
+import numpy as np
from PIL import Image
from playhouse.shortcuts import model_to_dict
@@ -16,8 +17,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.types import ModelStatusTypesEnum
-from .functions.clip import ClipEmbedding
-from .functions.minilm_l6_v2 import MiniLMEmbedding
+from .functions.onnx import GenericONNXEmbedding
logger = logging.getLogger(__name__)
@@ -53,9 +53,23 @@ def get_metadata(event: Event) -> dict:
)
-def serialize(vector: List[float]) -> bytes:
- """Serializes a list of floats into a compact "raw bytes" format"""
- return struct.pack("%sf" % len(vector), *vector)
+def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
+ """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
+ if isinstance(vector, np.ndarray):
+ # Convert numpy array to list of floats
+ vector = vector.flatten().tolist()
+ elif isinstance(vector, (float, np.float32, np.float64)):
+ # Handle single float values
+ vector = [vector]
+ elif not isinstance(vector, list):
+ raise TypeError(
+ f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
+ )
+
+ try:
+ return struct.pack("%sf" % len(vector), *vector)
+ except struct.error as e:
+ raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> List[float]:
@@ -74,10 +88,10 @@ class Embeddings:
self._create_tables()
models = [
- "sentence-transformers/all-MiniLM-L6-v2-model.onnx",
- "sentence-transformers/all-MiniLM-L6-v2-tokenizer",
- "clip-clip_image_model_vitb32.onnx",
- "clip-clip_text_model_vitb32.onnx",
+ "all-jina-clip-v1-text_model_fp16.onnx",
+ "all-jina-clip-v1-tokenizer",
+ "all-jina-clip-v1-vision_model_fp16.onnx",
+ "all-jina-clip-v1-preprocessor_config.json",
]
for model in models:
@@ -89,11 +103,32 @@ class Embeddings:
},
)
- self.clip_embedding = ClipEmbedding(
- preferred_providers=["CPUExecutionProvider"]
+ def jina_text_embedding_function(outputs):
+ return outputs[0]
+
+ def jina_vision_embedding_function(outputs):
+ return outputs[0]
+
+ self.text_embedding = GenericONNXEmbedding(
+ model_name="all-jina-clip-v1",
+ model_file="text_model_fp16.onnx",
+ tokenizer_file="tokenizer",
+ download_urls={
+ "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx"
+ },
+ embedding_function=jina_text_embedding_function,
+ model_type="text",
)
- self.minilm_embedding = MiniLMEmbedding(
- preferred_providers=["CPUExecutionProvider"],
+
+ self.vision_embedding = GenericONNXEmbedding(
+ model_name="all-jina-clip-v1",
+ model_file="vision_model_fp16.onnx",
+ tokenizer_file="preprocessor_config.json",
+ download_urls={
+ "vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx"
+ },
+ embedding_function=jina_vision_embedding_function,
+ model_type="vision",
)
def _create_tables(self):
@@ -101,7 +136,7 @@ class Embeddings:
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
- thumbnail_embedding FLOAT[512]
+ thumbnail_embedding FLOAT[768]
);
""")
@@ -109,15 +144,14 @@ class Embeddings:
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
- description_embedding FLOAT[384]
+ description_embedding FLOAT[768]
);
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
- # Generate embedding using CLIP
- embedding = self.clip_embedding([image])[0]
+ embedding = self.vision_embedding([image])[0]
self.db.execute_sql(
"""
@@ -130,8 +164,7 @@ class Embeddings:
return embedding
def upsert_description(self, event_id: str, description: str):
- # Generate embedding using MiniLM
- embedding = self.minilm_embedding([description])[0]
+ embedding = self.text_embedding([description])[0]
self.db.execute_sql(
"""
@@ -177,7 +210,7 @@ class Embeddings:
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
- query_embedding = self.clip_embedding([query])[0]
+ query_embedding = self.text_embedding([query])[0]
sql_query = """
SELECT
@@ -211,7 +244,7 @@ class Embeddings:
def search_description(
self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]:
- query_embedding = self.minilm_embedding([query_text])[0]
+ query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py
new file mode 100644
index 000000000..be33a889f
--- /dev/null
+++ b/frigate/embeddings/functions/onnx.py
@@ -0,0 +1,171 @@
+import logging
+import os
+import warnings
+from io import BytesIO
+from typing import Callable, Dict, List, Union
+
+import numpy as np
+import onnxruntime as ort
+import requests
+from PIL import Image
+
+# importing this without pytorch or others causes a warning
+# https://github.com/huggingface/transformers/issues/27214
+# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
+from transformers import AutoFeatureExtractor, AutoTokenizer
+
+from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
+from frigate.types import ModelStatusTypesEnum
+from frigate.util.downloader import ModelDownloader
+
+warnings.filterwarnings(
+ "ignore",
+ category=FutureWarning,
+ message="The class CLIPFeatureExtractor is deprecated",
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class GenericONNXEmbedding:
+ """Generic embedding function for ONNX models (text and vision)."""
+
+ def __init__(
+ self,
+ model_name: str,
+ model_file: str,
+ tokenizer_file: str,
+ download_urls: Dict[str, str],
+ embedding_function: Callable[[List[np.ndarray]], np.ndarray],
+ model_type: str,
+ preferred_providers: List[str] = ["CPUExecutionProvider"],
+ ):
+ self.model_name = model_name
+ self.model_file = model_file
+ self.tokenizer_file = tokenizer_file
+ self.download_urls = download_urls
+ self.embedding_function = embedding_function
+ self.model_type = model_type # 'text' or 'vision'
+ self.preferred_providers = preferred_providers
+
+ self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
+ self.tokenizer = None
+ self.feature_extractor = None
+ self.session = None
+
+ self.downloader = ModelDownloader(
+ model_name=self.model_name,
+ download_path=self.download_path,
+ file_names=[self.model_file, self.tokenizer_file],
+ download_func=self._download_model,
+ )
+ self.downloader.ensure_model_files()
+
+ def _download_model(self, path: str):
+ try:
+ file_name = os.path.basename(path)
+ if file_name in self.download_urls:
+ ModelDownloader.download_from_url(self.download_urls[file_name], path)
+ elif file_name == self.tokenizer_file:
+ logger.info(
+ f"Downloading {self.model_name} tokenizer/feature extractor"
+ )
+ if self.model_type == "text":
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.model_name, clean_up_tokenization_spaces=True
+ )
+ tokenizer.save_pretrained(path)
+ else:
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
+ self.model_name
+ )
+ feature_extractor.save_pretrained(path)
+
+ self.downloader.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.model_name}-{file_name}",
+ "state": ModelStatusTypesEnum.downloaded,
+ },
+ )
+ except Exception:
+ self.downloader.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.model_name}-{file_name}",
+ "state": ModelStatusTypesEnum.error,
+ },
+ )
+
+ def _load_model_and_tokenizer(self):
+ if self.session is None:
+ self.downloader.wait_for_download()
+ if self.model_type == "text":
+ self.tokenizer = self._load_tokenizer()
+ else:
+ self.feature_extractor = self._load_feature_extractor()
+ self.session = self._load_model(
+ os.path.join(self.download_path, self.model_file),
+ self.preferred_providers,
+ )
+
+ def _load_tokenizer(self):
+ tokenizer_path = os.path.join(self.download_path, self.tokenizer_file)
+ return AutoTokenizer.from_pretrained(
+ tokenizer_path, clean_up_tokenization_spaces=True
+ )
+
+ def _load_feature_extractor(self):
+ feature_extractor_path = os.path.join(self.download_path, self.tokenizer_file)
+ return AutoFeatureExtractor.from_pretrained(feature_extractor_path)
+
+ def _load_model(self, path: str, providers: List[str]):
+ if os.path.exists(path):
+ return ort.InferenceSession(path, providers=providers)
+ else:
+ logger.warning(f"{self.model_name} model file {path} not found.")
+ return None
+
+ def _process_image(self, image):
+ if isinstance(image, str):
+ if image.startswith("http"):
+ response = requests.get(image)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ return image
+
+ def __call__(
+ self, inputs: Union[List[str], List[Image.Image], List[str]]
+ ) -> List[np.ndarray]:
+ self._load_model_and_tokenizer()
+
+ if self.session is None or (
+ self.tokenizer is None and self.feature_extractor is None
+ ):
+ logger.error(
+ f"{self.model_name} model or tokenizer/feature extractor is not loaded."
+ )
+ return []
+
+ if self.model_type == "text":
+ processed_inputs = self.tokenizer(
+ inputs, padding=True, truncation=True, return_tensors="np"
+ )
+ else:
+ processed_images = [self._process_image(img) for img in inputs]
+ processed_inputs = self.feature_extractor(
+ images=processed_images, return_tensors="np"
+ )
+
+ input_names = [input.name for input in self.session.get_inputs()]
+ onnx_inputs = {
+ name: processed_inputs[name]
+ for name in input_names
+ if name in processed_inputs
+ }
+
+ outputs = self.session.run(None, onnx_inputs)
+ embeddings = self.embedding_function(outputs)
+
+ return [embedding for embedding in embeddings]
diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx
index 2bed37ac1..def68727f 100644
--- a/web/src/pages/Explore.tsx
+++ b/web/src/pages/Explore.tsx
@@ -184,31 +184,31 @@ export default function Explore() {
// model states
- const { payload: minilmModelState } = useModelState(
- "sentence-transformers/all-MiniLM-L6-v2-model.onnx",
+ const { payload: textModelState } = useModelState(
+ "all-jina-clip-v1-text_model_fp16.onnx",
);
- const { payload: minilmTokenizerState } = useModelState(
- "sentence-transformers/all-MiniLM-L6-v2-tokenizer",
+ const { payload: textTokenizerState } = useModelState(
+ "all-jina-clip-v1-tokenizer",
);
- const { payload: clipImageModelState } = useModelState(
- "clip-clip_image_model_vitb32.onnx",
+ const { payload: visionModelState } = useModelState(
+ "all-jina-clip-v1-vision_model_fp16.onnx",
);
- const { payload: clipTextModelState } = useModelState(
- "clip-clip_text_model_vitb32.onnx",
+ const { payload: visionFeatureExtractorState } = useModelState(
+ "all-jina-clip-v1-preprocessor_config.json",
);
const allModelsLoaded = useMemo(() => {
return (
- minilmModelState === "downloaded" &&
- minilmTokenizerState === "downloaded" &&
- clipImageModelState === "downloaded" &&
- clipTextModelState === "downloaded"
+ textModelState === "downloaded" &&
+ textTokenizerState === "downloaded" &&
+ visionModelState === "downloaded" &&
+ visionFeatureExtractorState === "downloaded"
);
}, [
- minilmModelState,
- minilmTokenizerState,
- clipImageModelState,
- clipTextModelState,
+ textModelState,
+ textTokenizerState,
+ visionModelState,
+ visionFeatureExtractorState,
]);
const renderModelStateIcon = (modelState: ModelState) => {
@@ -225,11 +225,10 @@ export default function Explore() {
};
if (
- config?.semantic_search.enabled &&
- (!minilmModelState ||
- !minilmTokenizerState ||
- !clipImageModelState ||
- !clipTextModelState)
+ !textModelState ||
+ !textTokenizerState ||
+ !visionModelState ||
+ !visionFeatureExtractorState
) {
return (