From 4f78aa5f8c59450d46c3b4da5c3af9725735ea2c Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 21 Oct 2024 16:11:56 -0600 Subject: [PATCH] Use enum for model types --- frigate/embeddings/embeddings.py | 4 ++-- frigate/embeddings/functions/onnx.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 907a17584..388624fdb 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -95,7 +95,7 @@ class Embeddings: "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", }, model_size=config.model_size, - model_type="text", + model_type=ModelTypeEnum.text, requestor=self.requestor, device="CPU", ) @@ -116,7 +116,7 @@ class Embeddings: model_file=model_file, download_urls=download_urls, model_size=config.model_size, - model_type="vision", + model_type=ModelTypeEnum.vision, requestor=self.requestor, device="GPU" if config.model_size == "large" else "CPU", ) diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index b79f2fceb..66bdfe5e7 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -1,6 +1,7 @@ import logging import os import warnings +from enum import Enum from io import BytesIO from typing import Dict, List, Optional, Union @@ -31,6 +32,12 @@ disable_progress_bar() logger = logging.getLogger(__name__) +class ModelTypeEnum(str, Enum): + face = "face" + vision = "vision" + text = "text" + + class GenericONNXEmbedding: """Generic embedding function for ONNX models (text and vision).""" @@ -88,7 +95,7 @@ class GenericONNXEmbedding: file_name = os.path.basename(path) if file_name in self.download_urls: ModelDownloader.download_from_url(self.download_urls[file_name], path) - elif file_name == self.tokenizer_file and self.model_type == "text": + elif file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text: if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} tokenizer") tokenizer = AutoTokenizer.from_pretrained( @@ -119,7 +126,7 @@ class GenericONNXEmbedding: if self.runner is None: if self.downloader: self.downloader.wait_for_download() - if self.model_type == "text": + if self.model_type == ModelTypeEnum.text: self.tokenizer = self._load_tokenizer() else: self.feature_extractor = self._load_feature_extractor() @@ -144,7 +151,7 @@ class GenericONNXEmbedding: ) def _preprocess_inputs(self, raw_inputs: any) -> any: - if self.model_type == "text": + if self.model_type == ModelTypeEnum.text: max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs) return [ self.tokenizer( @@ -156,7 +163,7 @@ class GenericONNXEmbedding: ) for text in raw_inputs ] - elif self.model_type == "image": + elif self.model_type == ModelTypeEnum.vision: processed_images = [self._process_image(img) for img in raw_inputs] return [ self.feature_extractor(images=image, return_tensors="np")