Use openvino directly for onnx embeddings if available

This commit is contained in:
Nicolas Mowen 2024-10-11 09:53:43 -06:00
parent cf7dac63a1
commit 1215b598d5
2 changed files with 72 additions and 23 deletions

View File

@ -5,7 +5,6 @@ from io import BytesIO
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import onnxruntime as ort
import requests import requests
from PIL import Image from PIL import Image
@ -19,7 +18,7 @@ from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader from frigate.util.downloader import ModelDownloader
from frigate.util.model import get_ort_providers from frigate.util.model import ONNXModelRunner
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
@ -54,16 +53,12 @@ class GenericONNXEmbedding:
self.download_urls = download_urls self.download_urls = download_urls
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type # 'text' or 'vision'
self.providers, self.provider_options = get_ort_providers( self.model_size = model_size
force_cpu=device == "CPU", self.device = device
requires_fp16=model_size == "large" or self.model_type == "text",
openvino_device=device,
)
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None self.tokenizer = None
self.feature_extractor = None self.feature_extractor = None
self.session = None self.runner = None
files_names = list(self.download_urls.keys()) + ( files_names = list(self.download_urls.keys()) + (
[self.tokenizer_file] if self.tokenizer_file else [] [self.tokenizer_file] if self.tokenizer_file else []
) )
@ -124,15 +119,17 @@ class GenericONNXEmbedding:
) )
def _load_model_and_tokenizer(self): def _load_model_and_tokenizer(self):
if self.session is None: if self.runner is None:
if self.downloader: if self.downloader:
self.downloader.wait_for_download() self.downloader.wait_for_download()
if self.model_type == "text": if self.model_type == "text":
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: else:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
self.session = self._load_model( self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file) os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
) )
def _load_tokenizer(self): def _load_tokenizer(self):
@ -149,14 +146,6 @@ class GenericONNXEmbedding:
f"{MODEL_CACHE_DIR}/{self.model_name}", f"{MODEL_CACHE_DIR}/{self.model_name}",
) )
def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
if os.path.exists(path):
return ort.InferenceSession(
path, providers=self.providers, provider_options=self.provider_options
)
else:
return None
def _process_image(self, image): def _process_image(self, image):
if isinstance(image, str): if isinstance(image, str):
if image.startswith("http"): if image.startswith("http"):
@ -170,7 +159,7 @@ class GenericONNXEmbedding:
) -> List[np.ndarray]: ) -> List[np.ndarray]:
self._load_model_and_tokenizer() self._load_model_and_tokenizer()
if self.session is None or ( if self.runner is None or (
self.tokenizer is None and self.feature_extractor is None self.tokenizer is None and self.feature_extractor is None
): ):
logger.error( logger.error(
@ -188,14 +177,14 @@ class GenericONNXEmbedding:
images=processed_images, return_tensors="np" images=processed_images, return_tensors="np"
) )
input_names = [input.name for input in self.session.get_inputs()] input_names = self.runner.get_input_names()
onnx_inputs = { onnx_inputs = {
name: processed_inputs[name] name: processed_inputs[name]
for name in input_names for name in input_names
if name in processed_inputs if name in processed_inputs
} }
outputs = self.session.run(None, onnx_inputs) outputs = self.runner.run(onnx_inputs)
embeddings = self.embedding_function(outputs) embeddings = self.embedding_function(outputs)
return [embedding for embedding in embeddings] return [embedding for embedding in embeddings]

View File

@ -1,9 +1,16 @@
"""Model Utils""" """Model Utils"""
import os import os
from typing import Any
import onnxruntime as ort import onnxruntime as ort
try:
import openvino as ov
except ImportError:
# openvino is not included
pass
def get_ort_providers( def get_ort_providers(
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False
@ -42,3 +49,56 @@ def get_ort_providers(
options.append({}) options.append({})
return (providers, options) return (providers, options)
class ONNXModelRunner:
"""Run onnx models optimally based on available hardware."""
def __init__(self, model_path: str, device: str, requires_fp16: bool = False):
self.model_path = model_path
self.ort: ort.InferenceSession = None
self.ov: ov.Core = None
providers, options = get_ort_providers(device == "CPU", device, requires_fp16)
if "OpenVINOExecutionProvider" in providers:
# use OpenVINO directly
self.type = "ov"
self.ov = ov.Core()
self.ov.set_property(
{ov.properties.cache_dir: "/config/model_cache/openvino"}
)
self.interpreter = self.ov.compile_model(
model=model_path, device_name=device
)
else:
# Use ONNXRuntime
self.type = "ort"
self.ort = ort.InferenceSession(
model_path, providers=providers, provider_options=options
)
def get_input_names(self) -> list[str]:
if self.type == "ov":
input_names = []
for input in self.interpreter.inputs:
input_names.extend(input.names)
return input_names
elif self.type == "ort":
return [input.name for input in self.ort.get_inputs()]
def run(self, input: dict[str, Any]) -> Any:
if self.type == "ov":
infer_request = self.interpreter.create_infer_request()
input_tensor = list(input.values())
if len(input_tensor) == 1:
input_tensor = ov.Tensor(array=input_tensor[0])
else:
input_tensor = ov.Tensor(array=input_tensor)
infer_request.infer(input_tensor)
return [infer_request.get_output_tensor().data]
elif self.type == "ort":
return self.ort.run(None, input)