add generic onnx model class and use jina ai clip models for all embeddings

This commit is contained in:
Josh Hawkins 2024-10-09 11:17:49 -05:00
parent a2f42d51fd
commit 79f82c36ae
4 changed files with 258 additions and 54 deletions

View File

@ -73,7 +73,7 @@ class EmbeddingsContext:
def __init__(self, db: SqliteVecQueueDatabase): def __init__(self, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(db) self.embeddings = Embeddings(db)
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization(scale_factor=3, bias=-2.5) self.desc_stats = ZScoreNormalization()
# load stats from disk # load stats from disk
try: try:

View File

@ -7,6 +7,7 @@ import struct
import time import time
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
@ -16,8 +17,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from .functions.clip import ClipEmbedding from .functions.onnx import GenericONNXEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,9 +53,23 @@ def get_metadata(event: Event) -> dict:
) )
def serialize(vector: List[float]) -> bytes: def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
"""Serializes a list of floats into a compact "raw bytes" format""" """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector) if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
return struct.pack("%sf" % len(vector), *vector)
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> List[float]: def deserialize(bytes_data: bytes) -> List[float]:
@ -74,10 +88,10 @@ class Embeddings:
self._create_tables() self._create_tables()
models = [ models = [
"sentence-transformers/all-MiniLM-L6-v2-model.onnx", "all-jina-clip-v1-text_model_fp16.onnx",
"sentence-transformers/all-MiniLM-L6-v2-tokenizer", "all-jina-clip-v1-tokenizer",
"clip-clip_image_model_vitb32.onnx", "all-jina-clip-v1-vision_model_fp16.onnx",
"clip-clip_text_model_vitb32.onnx", "all-jina-clip-v1-preprocessor_config.json",
] ]
for model in models: for model in models:
@ -89,11 +103,32 @@ class Embeddings:
}, },
) )
self.clip_embedding = ClipEmbedding( def jina_text_embedding_function(outputs):
preferred_providers=["CPUExecutionProvider"] return outputs[0]
def jina_vision_embedding_function(outputs):
return outputs[0]
self.text_embedding = GenericONNXEmbedding(
model_name="all-jina-clip-v1",
model_file="text_model_fp16.onnx",
tokenizer_file="tokenizer",
download_urls={
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx"
},
embedding_function=jina_text_embedding_function,
model_type="text",
) )
self.minilm_embedding = MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"], self.vision_embedding = GenericONNXEmbedding(
model_name="all-jina-clip-v1",
model_file="vision_model_fp16.onnx",
tokenizer_file="preprocessor_config.json",
download_urls={
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx"
},
embedding_function=jina_vision_embedding_function,
model_type="vision",
) )
def _create_tables(self): def _create_tables(self):
@ -101,7 +136,7 @@ class Embeddings:
self.db.execute_sql(""" self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512] thumbnail_embedding FLOAT[768]
); );
""") """)
@ -109,15 +144,14 @@ class Embeddings:
self.db.execute_sql(""" self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
description_embedding FLOAT[384] description_embedding FLOAT[768]
); );
""") """)
def upsert_thumbnail(self, event_id: str, thumbnail: bytes): def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image # Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB") image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
# Generate embedding using CLIP embedding = self.vision_embedding([image])[0]
embedding = self.clip_embedding([image])[0]
self.db.execute_sql( self.db.execute_sql(
""" """
@ -130,8 +164,7 @@ class Embeddings:
return embedding return embedding
def upsert_description(self, event_id: str, description: str): def upsert_description(self, event_id: str, description: str):
# Generate embedding using MiniLM embedding = self.text_embedding([description])[0]
embedding = self.minilm_embedding([description])[0]
self.db.execute_sql( self.db.execute_sql(
""" """
@ -177,7 +210,7 @@ class Embeddings:
thumbnail = base64.b64decode(query.thumbnail) thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail) query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else: else:
query_embedding = self.clip_embedding([query])[0] query_embedding = self.text_embedding([query])[0]
sql_query = """ sql_query = """
SELECT SELECT
@ -211,7 +244,7 @@ class Embeddings:
def search_description( def search_description(
self, query_text: str, event_ids: List[str] = None self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]: ) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0] query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query # Prepare the base SQL query
sql_query = """ sql_query = """

View File

@ -0,0 +1,171 @@
import logging
import os
import warnings
from io import BytesIO
from typing import Callable, Dict, List, Union
import numpy as np
import onnxruntime as ort
import requests
from PIL import Image
# importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from transformers import AutoFeatureExtractor, AutoTokenizer
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message="The class CLIPFeatureExtractor is deprecated",
)
logger = logging.getLogger(__name__)
class GenericONNXEmbedding:
"""Generic embedding function for ONNX models (text and vision)."""
def __init__(
self,
model_name: str,
model_file: str,
tokenizer_file: str,
download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str,
preferred_providers: List[str] = ["CPUExecutionProvider"],
):
self.model_name = model_name
self.model_file = model_file
self.tokenizer_file = tokenizer_file
self.download_urls = download_urls
self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision'
self.preferred_providers = preferred_providers
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.feature_extractor = None
self.session = None
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=[self.model_file, self.tokenizer_file],
download_func=self._download_model,
)
self.downloader.ensure_model_files()
def _download_model(self, path: str):
try:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file:
logger.info(
f"Downloading {self.model_name} tokenizer/feature extractor"
)
if self.model_type == "text":
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, clean_up_tokenization_spaces=True
)
tokenizer.save_pretrained(path)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(
self.model_name
)
feature_extractor.save_pretrained(path)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model_and_tokenizer(self):
if self.session is None:
self.downloader.wait_for_download()
if self.model_type == "text":
self.tokenizer = self._load_tokenizer()
else:
self.feature_extractor = self._load_feature_extractor()
self.session = self._load_model(
os.path.join(self.download_path, self.model_file),
self.preferred_providers,
)
def _load_tokenizer(self):
tokenizer_path = os.path.join(self.download_path, self.tokenizer_file)
return AutoTokenizer.from_pretrained(
tokenizer_path, clean_up_tokenization_spaces=True
)
def _load_feature_extractor(self):
feature_extractor_path = os.path.join(self.download_path, self.tokenizer_file)
return AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
else:
logger.warning(f"{self.model_name} model file {path} not found.")
return None
def _process_image(self, image):
if isinstance(image, str):
if image.startswith("http"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
def __call__(
self, inputs: Union[List[str], List[Image.Image], List[str]]
) -> List[np.ndarray]:
self._load_model_and_tokenizer()
if self.session is None or (
self.tokenizer is None and self.feature_extractor is None
):
logger.error(
f"{self.model_name} model or tokenizer/feature extractor is not loaded."
)
return []
if self.model_type == "text":
processed_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="np"
)
else:
processed_images = [self._process_image(img) for img in inputs]
processed_inputs = self.feature_extractor(
images=processed_images, return_tensors="np"
)
input_names = [input.name for input in self.session.get_inputs()]
onnx_inputs = {
name: processed_inputs[name]
for name in input_names
if name in processed_inputs
}
outputs = self.session.run(None, onnx_inputs)
embeddings = self.embedding_function(outputs)
return [embedding for embedding in embeddings]

View File

@ -184,31 +184,31 @@ export default function Explore() {
// model states // model states
const { payload: minilmModelState } = useModelState( const { payload: textModelState } = useModelState(
"sentence-transformers/all-MiniLM-L6-v2-model.onnx", "all-jina-clip-v1-text_model_fp16.onnx",
); );
const { payload: minilmTokenizerState } = useModelState( const { payload: textTokenizerState } = useModelState(
"sentence-transformers/all-MiniLM-L6-v2-tokenizer", "all-jina-clip-v1-tokenizer",
); );
const { payload: clipImageModelState } = useModelState( const { payload: visionModelState } = useModelState(
"clip-clip_image_model_vitb32.onnx", "all-jina-clip-v1-vision_model_fp16.onnx",
); );
const { payload: clipTextModelState } = useModelState( const { payload: visionFeatureExtractorState } = useModelState(
"clip-clip_text_model_vitb32.onnx", "all-jina-clip-v1-preprocessor_config.json",
); );
const allModelsLoaded = useMemo(() => { const allModelsLoaded = useMemo(() => {
return ( return (
minilmModelState === "downloaded" && textModelState === "downloaded" &&
minilmTokenizerState === "downloaded" && textTokenizerState === "downloaded" &&
clipImageModelState === "downloaded" && visionModelState === "downloaded" &&
clipTextModelState === "downloaded" visionFeatureExtractorState === "downloaded"
); );
}, [ }, [
minilmModelState, textModelState,
minilmTokenizerState, textTokenizerState,
clipImageModelState, visionModelState,
clipTextModelState, visionFeatureExtractorState,
]); ]);
const renderModelStateIcon = (modelState: ModelState) => { const renderModelStateIcon = (modelState: ModelState) => {
@ -225,11 +225,10 @@ export default function Explore() {
}; };
if ( if (
config?.semantic_search.enabled && !textModelState ||
(!minilmModelState || !textTokenizerState ||
!minilmTokenizerState || !visionModelState ||
!clipImageModelState || !visionFeatureExtractorState
!clipTextModelState)
) { ) {
return ( return (
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" /> <ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />
@ -252,25 +251,26 @@ export default function Explore() {
</div> </div>
<div className="flex w-96 flex-col gap-2 py-5"> <div className="flex w-96 flex-col gap-2 py-5">
<div className="flex flex-row items-center justify-center gap-2"> <div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(clipImageModelState)} {renderModelStateIcon(visionModelState)}
CLIP image model Vision model
</div> </div>
<div className="flex flex-row items-center justify-center gap-2"> <div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(clipTextModelState)} {renderModelStateIcon(visionFeatureExtractorState)}
CLIP text model Vision model feature extractor
</div> </div>
<div className="flex flex-row items-center justify-center gap-2"> <div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(minilmModelState)} {renderModelStateIcon(textModelState)}
MiniLM sentence model Text model
</div> </div>
<div className="flex flex-row items-center justify-center gap-2"> <div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(minilmTokenizerState)} {renderModelStateIcon(textTokenizerState)}
MiniLM tokenizer Text tokenizer
</div> </div>
</div> </div>
{(minilmModelState === "error" || {(textModelState === "error" ||
clipImageModelState === "error" || textTokenizerState === "error" ||
clipTextModelState === "error") && ( visionModelState === "error" ||
visionFeatureExtractorState === "error") && (
<div className="my-3 max-w-96 text-center text-danger"> <div className="my-3 max-w-96 text-center text-danger">
An error has occurred. Check Frigate logs. An error has occurred. Check Frigate logs.
</div> </div>