mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
manual minilm onnx inference
This commit is contained in:
parent
654efe6be1
commit
4444d82089
@ -1,11 +1,107 @@
|
|||||||
"""Embedding function for ONNX MiniLM-L6 model used in Chroma."""
|
"""Embedding function for ONNX MiniLM-L6 model."""
|
||||||
|
|
||||||
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
|
import errno
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import requests
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from frigate.const import MODEL_CACHE_DIR
|
from frigate.const import MODEL_CACHE_DIR
|
||||||
|
|
||||||
|
|
||||||
class MiniLMEmbedding(ONNXMiniLM_L6_V2):
|
class MiniLMEmbedding:
|
||||||
"""Override DOWNLOAD_PATH to download to cache directory."""
|
"""Embedding function for ONNX MiniLM-L6 model."""
|
||||||
|
|
||||||
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
||||||
|
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
IMAGE_MODEL_FILE = "model.onnx"
|
||||||
|
TOKENIZER_FILE = "tokenizer"
|
||||||
|
|
||||||
|
def __init__(self, preferred_providers=None):
|
||||||
|
"""Initialize MiniLM Embedding function."""
|
||||||
|
self.tokenizer = self._load_tokenizer()
|
||||||
|
|
||||||
|
model_path = os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE)
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
self._download_model()
|
||||||
|
|
||||||
|
if preferred_providers is None:
|
||||||
|
preferred_providers = ["CPUExecutionProvider"]
|
||||||
|
|
||||||
|
self.session = self._load_model(model_path)
|
||||||
|
|
||||||
|
def _load_tokenizer(self):
|
||||||
|
"""Load the tokenizer from the local path or download it if not available."""
|
||||||
|
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
||||||
|
if os.path.exists(tokenizer_path):
|
||||||
|
return AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
else:
|
||||||
|
return AutoTokenizer.from_pretrained(self.MODEL_NAME)
|
||||||
|
|
||||||
|
def _download_model(self):
|
||||||
|
"""Download the ONNX model and tokenizer from a remote source if they don't exist."""
|
||||||
|
logging.info(f"Downloading {self.MODEL_NAME} ONNX model and tokenizer...")
|
||||||
|
|
||||||
|
# Download the tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
|
||||||
|
os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
|
||||||
|
tokenizer.save_pretrained(os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE))
|
||||||
|
|
||||||
|
# Download the ONNX model
|
||||||
|
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
|
||||||
|
model_path = os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE)
|
||||||
|
self._download_from_url(s3_url, model_path)
|
||||||
|
|
||||||
|
logging.info(f"Model and tokenizer saved to {self.DOWNLOAD_PATH}")
|
||||||
|
|
||||||
|
def _download_from_url(self, url: str, save_path: str):
|
||||||
|
"""Download a file from a URL and save it to a specified path."""
|
||||||
|
temporary_filename = Path(save_path).with_name(
|
||||||
|
os.path.basename(save_path) + ".part"
|
||||||
|
)
|
||||||
|
temporary_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with requests.get(url, stream=True, allow_redirects=True) as r:
|
||||||
|
# if the content type is HTML, it's not the actual model file
|
||||||
|
if "text/html" in r.headers.get("Content-Type", ""):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected an ONNX file but received HTML from the URL: {url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure the download is successful
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
# Write the model to a temporary file first
|
||||||
|
with open(temporary_filename, "wb") as f:
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
temporary_filename.rename(save_path)
|
||||||
|
|
||||||
|
def _load_model(self, path: str):
|
||||||
|
"""Load the ONNX model from a given path."""
|
||||||
|
providers = ["CPUExecutionProvider"]
|
||||||
|
if os.path.exists(path):
|
||||||
|
return ort.InferenceSession(path, providers=providers)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path)
|
||||||
|
|
||||||
|
def __call__(self, texts: List[str]) -> List[np.ndarray]:
|
||||||
|
"""Generate embeddings for the given texts."""
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
texts, padding=True, truncation=True, return_tensors="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_names = [input.name for input in self.session.get_inputs()]
|
||||||
|
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
outputs = self.session.run(None, onnx_inputs)
|
||||||
|
|
||||||
|
embeddings = outputs[0].mean(axis=1)
|
||||||
|
|
||||||
|
return [embedding for embedding in embeddings]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user