mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Use enum for model types
This commit is contained in:
parent
843cce2349
commit
4f78aa5f8c
@ -95,7 +95,7 @@ class Embeddings:
|
|||||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||||
},
|
},
|
||||||
model_size=config.model_size,
|
model_size=config.model_size,
|
||||||
model_type="text",
|
model_type=ModelTypeEnum.text,
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
@ -116,7 +116,7 @@ class Embeddings:
|
|||||||
model_file=model_file,
|
model_file=model_file,
|
||||||
download_urls=download_urls,
|
download_urls=download_urls,
|
||||||
model_size=config.model_size,
|
model_size=config.model_size,
|
||||||
model_type="vision",
|
model_type=ModelTypeEnum.vision,
|
||||||
requestor=self.requestor,
|
requestor=self.requestor,
|
||||||
device="GPU" if config.model_size == "large" else "CPU",
|
device="GPU" if config.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
@ -31,6 +32,12 @@ disable_progress_bar()
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTypeEnum(str, Enum):
|
||||||
|
face = "face"
|
||||||
|
vision = "vision"
|
||||||
|
text = "text"
|
||||||
|
|
||||||
|
|
||||||
class GenericONNXEmbedding:
|
class GenericONNXEmbedding:
|
||||||
"""Generic embedding function for ONNX models (text and vision)."""
|
"""Generic embedding function for ONNX models (text and vision)."""
|
||||||
|
|
||||||
@ -88,7 +95,7 @@ class GenericONNXEmbedding:
|
|||||||
file_name = os.path.basename(path)
|
file_name = os.path.basename(path)
|
||||||
if file_name in self.download_urls:
|
if file_name in self.download_urls:
|
||||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||||
elif file_name == self.tokenizer_file and self.model_type == "text":
|
elif file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text:
|
||||||
if not os.path.exists(path + "/" + self.model_name):
|
if not os.path.exists(path + "/" + self.model_name):
|
||||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
@ -119,7 +126,7 @@ class GenericONNXEmbedding:
|
|||||||
if self.runner is None:
|
if self.runner is None:
|
||||||
if self.downloader:
|
if self.downloader:
|
||||||
self.downloader.wait_for_download()
|
self.downloader.wait_for_download()
|
||||||
if self.model_type == "text":
|
if self.model_type == ModelTypeEnum.text:
|
||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
else:
|
else:
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
self.feature_extractor = self._load_feature_extractor()
|
||||||
@ -144,7 +151,7 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
||||||
if self.model_type == "text":
|
if self.model_type == ModelTypeEnum.text:
|
||||||
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
||||||
return [
|
return [
|
||||||
self.tokenizer(
|
self.tokenizer(
|
||||||
@ -156,7 +163,7 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
for text in raw_inputs
|
for text in raw_inputs
|
||||||
]
|
]
|
||||||
elif self.model_type == "image":
|
elif self.model_type == ModelTypeEnum.vision:
|
||||||
processed_images = [self._process_image(img) for img in raw_inputs]
|
processed_images = [self._process_image(img) for img in raw_inputs]
|
||||||
return [
|
return [
|
||||||
self.feature_extractor(images=image, return_tensors="np")
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user