mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Use enum for model types
This commit is contained in:
parent
843cce2349
commit
4f78aa5f8c
@ -95,7 +95,7 @@ class Embeddings:
|
||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||
},
|
||||
model_size=config.model_size,
|
||||
model_type="text",
|
||||
model_type=ModelTypeEnum.text,
|
||||
requestor=self.requestor,
|
||||
device="CPU",
|
||||
)
|
||||
@ -116,7 +116,7 @@ class Embeddings:
|
||||
model_file=model_file,
|
||||
download_urls=download_urls,
|
||||
model_size=config.model_size,
|
||||
model_type="vision",
|
||||
model_type=ModelTypeEnum.vision,
|
||||
requestor=self.requestor,
|
||||
device="GPU" if config.model_size == "large" else "CPU",
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@ -31,6 +32,12 @@ disable_progress_bar()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelTypeEnum(str, Enum):
|
||||
face = "face"
|
||||
vision = "vision"
|
||||
text = "text"
|
||||
|
||||
|
||||
class GenericONNXEmbedding:
|
||||
"""Generic embedding function for ONNX models (text and vision)."""
|
||||
|
||||
@ -88,7 +95,7 @@ class GenericONNXEmbedding:
|
||||
file_name = os.path.basename(path)
|
||||
if file_name in self.download_urls:
|
||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||
elif file_name == self.tokenizer_file and self.model_type == "text":
|
||||
elif file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text:
|
||||
if not os.path.exists(path + "/" + self.model_name):
|
||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@ -119,7 +126,7 @@ class GenericONNXEmbedding:
|
||||
if self.runner is None:
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
if self.model_type == "text":
|
||||
if self.model_type == ModelTypeEnum.text:
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
else:
|
||||
self.feature_extractor = self._load_feature_extractor()
|
||||
@ -144,7 +151,7 @@ class GenericONNXEmbedding:
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
||||
if self.model_type == "text":
|
||||
if self.model_type == ModelTypeEnum.text:
|
||||
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
||||
return [
|
||||
self.tokenizer(
|
||||
@ -156,7 +163,7 @@ class GenericONNXEmbedding:
|
||||
)
|
||||
for text in raw_inputs
|
||||
]
|
||||
elif self.model_type == "image":
|
||||
elif self.model_type == ModelTypeEnum.vision:
|
||||
processed_images = [self._process_image(img) for img in raw_inputs]
|
||||
return [
|
||||
self.feature_extractor(images=image, return_tensors="np")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user