mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
add generic onnx model class and use jina ai clip models for all embeddings
This commit is contained in:
parent
74efc94649
commit
8f90f4d954
@ -73,7 +73,7 @@ class EmbeddingsContext:
|
||||
def __init__(self, db: SqliteVecQueueDatabase):
|
||||
self.embeddings = Embeddings(db)
|
||||
self.thumb_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization(scale_factor=3, bias=-2.5)
|
||||
self.desc_stats = ZScoreNormalization()
|
||||
|
||||
# load stats from disk
|
||||
try:
|
||||
|
||||
@ -7,6 +7,7 @@ import struct
|
||||
import time
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
@ -16,8 +17,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
|
||||
from .functions.clip import ClipEmbedding
|
||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||
from .functions.onnx import GenericONNXEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -53,9 +53,23 @@ def get_metadata(event: Event) -> dict:
|
||||
)
|
||||
|
||||
|
||||
def serialize(vector: List[float]) -> bytes:
|
||||
"""Serializes a list of floats into a compact "raw bytes" format"""
|
||||
return struct.pack("%sf" % len(vector), *vector)
|
||||
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
|
||||
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
|
||||
if isinstance(vector, np.ndarray):
|
||||
# Convert numpy array to list of floats
|
||||
vector = vector.flatten().tolist()
|
||||
elif isinstance(vector, (float, np.float32, np.float64)):
|
||||
# Handle single float values
|
||||
vector = [vector]
|
||||
elif not isinstance(vector, list):
|
||||
raise TypeError(
|
||||
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
|
||||
)
|
||||
|
||||
try:
|
||||
return struct.pack("%sf" % len(vector), *vector)
|
||||
except struct.error as e:
|
||||
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
|
||||
|
||||
|
||||
def deserialize(bytes_data: bytes) -> List[float]:
|
||||
@ -74,10 +88,10 @@ class Embeddings:
|
||||
self._create_tables()
|
||||
|
||||
models = [
|
||||
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
|
||||
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
|
||||
"clip-clip_image_model_vitb32.onnx",
|
||||
"clip-clip_text_model_vitb32.onnx",
|
||||
"all-jina-clip-v1-text_model_fp16.onnx",
|
||||
"all-jina-clip-v1-tokenizer",
|
||||
"all-jina-clip-v1-vision_model_fp16.onnx",
|
||||
"all-jina-clip-v1-preprocessor_config.json",
|
||||
]
|
||||
|
||||
for model in models:
|
||||
@ -89,11 +103,32 @@ class Embeddings:
|
||||
},
|
||||
)
|
||||
|
||||
self.clip_embedding = ClipEmbedding(
|
||||
preferred_providers=["CPUExecutionProvider"]
|
||||
def jina_text_embedding_function(outputs):
|
||||
return outputs[0]
|
||||
|
||||
def jina_vision_embedding_function(outputs):
|
||||
return outputs[0]
|
||||
|
||||
self.text_embedding = GenericONNXEmbedding(
|
||||
model_name="all-jina-clip-v1",
|
||||
model_file="text_model_fp16.onnx",
|
||||
tokenizer_file="tokenizer",
|
||||
download_urls={
|
||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx"
|
||||
},
|
||||
embedding_function=jina_text_embedding_function,
|
||||
model_type="text",
|
||||
)
|
||||
self.minilm_embedding = MiniLMEmbedding(
|
||||
preferred_providers=["CPUExecutionProvider"],
|
||||
|
||||
self.vision_embedding = GenericONNXEmbedding(
|
||||
model_name="all-jina-clip-v1",
|
||||
model_file="vision_model_fp16.onnx",
|
||||
tokenizer_file="preprocessor_config.json",
|
||||
download_urls={
|
||||
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx"
|
||||
},
|
||||
embedding_function=jina_vision_embedding_function,
|
||||
model_type="vision",
|
||||
)
|
||||
|
||||
def _create_tables(self):
|
||||
@ -101,7 +136,7 @@ class Embeddings:
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[512]
|
||||
thumbnail_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
||||
@ -109,15 +144,14 @@ class Embeddings:
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[384]
|
||||
description_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
||||
# Convert thumbnail bytes to PIL Image
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
# Generate embedding using CLIP
|
||||
embedding = self.clip_embedding([image])[0]
|
||||
embedding = self.vision_embedding([image])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
@ -130,8 +164,7 @@ class Embeddings:
|
||||
return embedding
|
||||
|
||||
def upsert_description(self, event_id: str, description: str):
|
||||
# Generate embedding using MiniLM
|
||||
embedding = self.minilm_embedding([description])[0]
|
||||
embedding = self.text_embedding([description])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
@ -177,7 +210,7 @@ class Embeddings:
|
||||
thumbnail = base64.b64decode(query.thumbnail)
|
||||
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
|
||||
else:
|
||||
query_embedding = self.clip_embedding([query])[0]
|
||||
query_embedding = self.text_embedding([query])[0]
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
@ -211,7 +244,7 @@ class Embeddings:
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: List[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
query_embedding = self.minilm_embedding([query_text])[0]
|
||||
query_embedding = self.text_embedding([query_text])[0]
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
|
||||
171
frigate/embeddings/functions/onnx.py
Normal file
171
frigate/embeddings/functions/onnx.py
Normal file
@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
# importing this without pytorch or others causes a warning
|
||||
# https://github.com/huggingface/transformers/issues/27214
|
||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=FutureWarning,
|
||||
message="The class CLIPFeatureExtractor is deprecated",
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenericONNXEmbedding:
|
||||
"""Generic embedding function for ONNX models (text and vision)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_file: str,
|
||||
tokenizer_file: str,
|
||||
download_urls: Dict[str, str],
|
||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||
model_type: str,
|
||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_file = model_file
|
||||
self.tokenizer_file = tokenizer_file
|
||||
self.download_urls = download_urls
|
||||
self.embedding_function = embedding_function
|
||||
self.model_type = model_type # 'text' or 'vision'
|
||||
self.preferred_providers = preferred_providers
|
||||
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.tokenizer = None
|
||||
self.feature_extractor = None
|
||||
self.session = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.model_name,
|
||||
download_path=self.download_path,
|
||||
file_names=[self.model_file, self.tokenizer_file],
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
file_name = os.path.basename(path)
|
||||
if file_name in self.download_urls:
|
||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||
elif file_name == self.tokenizer_file:
|
||||
logger.info(
|
||||
f"Downloading {self.model_name} tokenizer/feature extractor"
|
||||
)
|
||||
if self.model_type == "text":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, clean_up_tokenization_spaces=True
|
||||
)
|
||||
tokenizer.save_pretrained(path)
|
||||
else:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
self.model_name
|
||||
)
|
||||
feature_extractor.save_pretrained(path)
|
||||
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file_name}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file_name}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model_and_tokenizer(self):
|
||||
if self.session is None:
|
||||
self.downloader.wait_for_download()
|
||||
if self.model_type == "text":
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
else:
|
||||
self.feature_extractor = self._load_feature_extractor()
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.preferred_providers,
|
||||
)
|
||||
|
||||
def _load_tokenizer(self):
|
||||
tokenizer_path = os.path.join(self.download_path, self.tokenizer_file)
|
||||
return AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
def _load_feature_extractor(self):
|
||||
feature_extractor_path = os.path.join(self.download_path, self.tokenizer_file)
|
||||
return AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||
|
||||
def _load_model(self, path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"{self.model_name} model file {path} not found.")
|
||||
return None
|
||||
|
||||
def _process_image(self, image):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http"):
|
||||
response = requests.get(image)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
def __call__(
|
||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||
) -> List[np.ndarray]:
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.session is None or (
|
||||
self.tokenizer is None and self.feature_extractor is None
|
||||
):
|
||||
logger.error(
|
||||
f"{self.model_name} model or tokenizer/feature extractor is not loaded."
|
||||
)
|
||||
return []
|
||||
|
||||
if self.model_type == "text":
|
||||
processed_inputs = self.tokenizer(
|
||||
inputs, padding=True, truncation=True, return_tensors="np"
|
||||
)
|
||||
else:
|
||||
processed_images = [self._process_image(img) for img in inputs]
|
||||
processed_inputs = self.feature_extractor(
|
||||
images=processed_images, return_tensors="np"
|
||||
)
|
||||
|
||||
input_names = [input.name for input in self.session.get_inputs()]
|
||||
onnx_inputs = {
|
||||
name: processed_inputs[name]
|
||||
for name in input_names
|
||||
if name in processed_inputs
|
||||
}
|
||||
|
||||
outputs = self.session.run(None, onnx_inputs)
|
||||
embeddings = self.embedding_function(outputs)
|
||||
|
||||
return [embedding for embedding in embeddings]
|
||||
@ -184,31 +184,31 @@ export default function Explore() {
|
||||
|
||||
// model states
|
||||
|
||||
const { payload: minilmModelState } = useModelState(
|
||||
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
|
||||
const { payload: textModelState } = useModelState(
|
||||
"all-jina-clip-v1-text_model_fp16.onnx",
|
||||
);
|
||||
const { payload: minilmTokenizerState } = useModelState(
|
||||
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
|
||||
const { payload: textTokenizerState } = useModelState(
|
||||
"all-jina-clip-v1-tokenizer",
|
||||
);
|
||||
const { payload: clipImageModelState } = useModelState(
|
||||
"clip-clip_image_model_vitb32.onnx",
|
||||
const { payload: visionModelState } = useModelState(
|
||||
"all-jina-clip-v1-vision_model_fp16.onnx",
|
||||
);
|
||||
const { payload: clipTextModelState } = useModelState(
|
||||
"clip-clip_text_model_vitb32.onnx",
|
||||
const { payload: visionFeatureExtractorState } = useModelState(
|
||||
"all-jina-clip-v1-preprocessor_config.json",
|
||||
);
|
||||
|
||||
const allModelsLoaded = useMemo(() => {
|
||||
return (
|
||||
minilmModelState === "downloaded" &&
|
||||
minilmTokenizerState === "downloaded" &&
|
||||
clipImageModelState === "downloaded" &&
|
||||
clipTextModelState === "downloaded"
|
||||
textModelState === "downloaded" &&
|
||||
textTokenizerState === "downloaded" &&
|
||||
visionModelState === "downloaded" &&
|
||||
visionFeatureExtractorState === "downloaded"
|
||||
);
|
||||
}, [
|
||||
minilmModelState,
|
||||
minilmTokenizerState,
|
||||
clipImageModelState,
|
||||
clipTextModelState,
|
||||
textModelState,
|
||||
textTokenizerState,
|
||||
visionModelState,
|
||||
visionFeatureExtractorState,
|
||||
]);
|
||||
|
||||
const renderModelStateIcon = (modelState: ModelState) => {
|
||||
@ -225,10 +225,10 @@ export default function Explore() {
|
||||
};
|
||||
|
||||
if (
|
||||
!minilmModelState ||
|
||||
!minilmTokenizerState ||
|
||||
!clipImageModelState ||
|
||||
!clipTextModelState
|
||||
!textModelState ||
|
||||
!textTokenizerState ||
|
||||
!visionModelState ||
|
||||
!visionFeatureExtractorState
|
||||
) {
|
||||
return (
|
||||
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />
|
||||
@ -251,25 +251,26 @@ export default function Explore() {
|
||||
</div>
|
||||
<div className="flex w-96 flex-col gap-2 py-5">
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(clipImageModelState)}
|
||||
CLIP image model
|
||||
{renderModelStateIcon(visionModelState)}
|
||||
Vision model
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(clipTextModelState)}
|
||||
CLIP text model
|
||||
{renderModelStateIcon(visionFeatureExtractorState)}
|
||||
Vision model feature extractor
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(minilmModelState)}
|
||||
MiniLM sentence model
|
||||
{renderModelStateIcon(textModelState)}
|
||||
Text model
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(minilmTokenizerState)}
|
||||
MiniLM tokenizer
|
||||
{renderModelStateIcon(textTokenizerState)}
|
||||
Text tokenizer
|
||||
</div>
|
||||
</div>
|
||||
{(minilmModelState === "error" ||
|
||||
clipImageModelState === "error" ||
|
||||
clipTextModelState === "error") && (
|
||||
{(textModelState === "error" ||
|
||||
textTokenizerState === "error" ||
|
||||
visionModelState === "error" ||
|
||||
visionFeatureExtractorState === "error") && (
|
||||
<div className="my-3 max-w-96 text-center text-danger">
|
||||
An error has occurred. Check Frigate logs.
|
||||
</div>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user