mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 23:55:25 +03:00
remove old clip and minilm code
This commit is contained in:
parent
006af73844
commit
11c1a37e02
@ -1,166 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from onnx_clip import OnnxClip, Preprocessor, Tokenizer
|
||||
from PIL import Image
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Clip(OnnxClip):
|
||||
"""Override load models to use pre-downloaded models from cache directory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ViT-B/32",
|
||||
batch_size: Optional[int] = None,
|
||||
providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
"""
|
||||
Instantiates the model and required encoding classes.
|
||||
|
||||
Args:
|
||||
model: The model to utilize. Currently ViT-B/32 and RN50 are
|
||||
allowed.
|
||||
batch_size: If set, splits the lists in `get_image_embeddings`
|
||||
and `get_text_embeddings` into batches of this size before
|
||||
passing them to the model. The embeddings are then concatenated
|
||||
back together before being returned. This is necessary when
|
||||
passing large amounts of data (perhaps ~100 or more).
|
||||
"""
|
||||
allowed_models = ["ViT-B/32", "RN50"]
|
||||
if model not in allowed_models:
|
||||
raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
|
||||
if model == "ViT-B/32":
|
||||
self.embedding_size = 512
|
||||
elif model == "RN50":
|
||||
self.embedding_size = 1024
|
||||
self.image_model, self.text_model = self._load_models(model, providers)
|
||||
self._tokenizer = Tokenizer()
|
||||
self._preprocessor = Preprocessor()
|
||||
self._batch_size = batch_size
|
||||
|
||||
@staticmethod
|
||||
def _load_models(
|
||||
model: str,
|
||||
providers: List[str],
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""
|
||||
Load models from cache directory.
|
||||
"""
|
||||
if model == "ViT-B/32":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
|
||||
TEXT_MODEL_FILE = "clip_text_model_vitb32.onnx"
|
||||
elif model == "RN50":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_rn50.onnx"
|
||||
TEXT_MODEL_FILE = "clip_text_model_rn50.onnx"
|
||||
else:
|
||||
raise ValueError(f"Unexpected model {model}. No `.onnx` file found.")
|
||||
|
||||
models = []
|
||||
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
|
||||
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
|
||||
models.append(Clip._load_model(path, providers))
|
||||
|
||||
return models[0], models[1]
|
||||
|
||||
@staticmethod
|
||||
def _load_model(path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"CLIP model file {path} not found.")
|
||||
return None
|
||||
|
||||
|
||||
class ClipEmbedding:
|
||||
"""Embedding function for CLIP model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ViT-B/32",
|
||||
silent: bool = False,
|
||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
self.model_name = model
|
||||
self.silent = silent
|
||||
self.preferred_providers = preferred_providers
|
||||
self.model_files = self._get_model_files()
|
||||
self.model = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name="clip",
|
||||
download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
|
||||
file_names=self.model_files,
|
||||
download_func=self._download_model,
|
||||
silent=self.silent,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _get_model_files(self):
|
||||
if self.model_name == "ViT-B/32":
|
||||
return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
|
||||
elif self.model_name == "RN50":
|
||||
return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected model {self.model_name}. No `.onnx` file found."
|
||||
)
|
||||
|
||||
def _download_model(self, path: str):
|
||||
s3_url = (
|
||||
f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
|
||||
)
|
||||
try:
|
||||
ModelDownloader.download_from_url(s3_url, path, self.silent)
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model(self):
|
||||
if self.model is None:
|
||||
self.downloader.wait_for_download()
|
||||
self.model = Clip(self.model_name, providers=self.preferred_providers)
|
||||
|
||||
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
|
||||
self._load_model()
|
||||
if (
|
||||
self.model is None
|
||||
or self.model.image_model is None
|
||||
or self.model.text_model is None
|
||||
):
|
||||
logger.info(
|
||||
"CLIP model is not fully loaded. Please wait for the download to complete."
|
||||
)
|
||||
return []
|
||||
|
||||
embeddings = []
|
||||
for item in input:
|
||||
if isinstance(item, Image.Image):
|
||||
result = self.model.get_image_embeddings([item])
|
||||
embeddings.append(result[0])
|
||||
elif isinstance(item, str):
|
||||
result = self.model.get_text_embeddings([item])
|
||||
embeddings.append(result[0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(item)}")
|
||||
return embeddings
|
||||
@ -1,107 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
# importing this without pytorch or others causes a warning
|
||||
# https://github.com/huggingface/transformers/issues/27214
|
||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MiniLMEmbedding:
|
||||
"""Embedding function for ONNX MiniLM-L6 model."""
|
||||
|
||||
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
||||
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
IMAGE_MODEL_FILE = "model.onnx"
|
||||
TOKENIZER_FILE = "tokenizer"
|
||||
|
||||
def __init__(self, preferred_providers=["CPUExecutionProvider"]):
|
||||
self.preferred_providers = preferred_providers
|
||||
self.tokenizer = None
|
||||
self.session = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.MODEL_NAME,
|
||||
download_path=self.DOWNLOAD_PATH,
|
||||
file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
if os.path.basename(path) == self.IMAGE_MODEL_FILE:
|
||||
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
|
||||
ModelDownloader.download_from_url(s3_url, path)
|
||||
elif os.path.basename(path) == self.TOKENIZER_FILE:
|
||||
logger.info("Downloading MiniLM tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.MODEL_NAME, clean_up_tokenization_spaces=True
|
||||
)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model_and_tokenizer(self):
|
||||
if self.tokenizer is None or self.session is None:
|
||||
self.downloader.wait_for_download()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
|
||||
self.preferred_providers,
|
||||
)
|
||||
|
||||
def _load_tokenizer(self):
|
||||
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
||||
return AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
def _load_model(self, path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"MiniLM model file {path} not found.")
|
||||
return None
|
||||
|
||||
def __call__(self, texts: List[str]) -> List[np.ndarray]:
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.session is None or self.tokenizer is None:
|
||||
logger.error("MiniLM model or tokenizer is not loaded.")
|
||||
return []
|
||||
|
||||
inputs = self.tokenizer(
|
||||
texts, padding=True, truncation=True, return_tensors="np"
|
||||
)
|
||||
input_names = [input.name for input in self.session.get_inputs()]
|
||||
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
|
||||
|
||||
outputs = self.session.run(None, onnx_inputs)
|
||||
embeddings = outputs[0].mean(axis=1)
|
||||
|
||||
return [embedding for embedding in embeddings]
|
||||
Loading…
Reference in New Issue
Block a user