mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-16 12:02:09 +03:00
Send model type
This commit is contained in:
parent
20c300df95
commit
161ed46c55
@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from frigate.detectors.detector_types import ModelTypeEnum
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
import onnxruntime as ort
|
||||
|
||||
from frigate.util.model import get_ort_providers
|
||||
@ -101,6 +103,15 @@ class CudaGraphRunner(BaseModelRunner):
|
||||
for more complex models like CLIP or PaddleOCR.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_complex_model(model_type: str) -> bool:
|
||||
return model_type in [
|
||||
ModelTypeEnum.yolonas.value,
|
||||
EnrichmentModelTypeEnum.paddleocr.value,
|
||||
EnrichmentModelTypeEnum.jina_v1.value,
|
||||
EnrichmentModelTypeEnum.jina_v2.value,
|
||||
]
|
||||
|
||||
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
||||
self._session = session
|
||||
self._cuda_device_id = cuda_device_id
|
||||
@ -156,10 +167,14 @@ class CudaGraphRunner(BaseModelRunner):
|
||||
class OpenVINOModelRunner(BaseModelRunner):
|
||||
"""OpenVINO model runner that handles inference efficiently."""
|
||||
|
||||
def __init__(self, model_path: str, device: str, complex_model: bool, **kwargs):
|
||||
@staticmethod
|
||||
def is_complex_model(model_type: str) -> bool:
|
||||
return model_type in [EnrichmentModelTypeEnum.paddleocr.value]
|
||||
|
||||
def __init__(self, model_path: str, device: str, model_type: str, **kwargs):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.complex_model = complex_model
|
||||
self.complex_model = OpenVINOModelRunner.is_complex_model(model_type)
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(f"OpenVINO model file {model_path} not found.")
|
||||
@ -381,7 +396,7 @@ class RKNNModelRunner(BaseModelRunner):
|
||||
|
||||
|
||||
def get_optimized_runner(
|
||||
model_path: str, device: str | None, complex_model: bool = True, **kwargs
|
||||
model_path: str, device: str | None, model_type: str, **kwargs
|
||||
) -> BaseModelRunner:
|
||||
"""Get an optimized runner for the hardware."""
|
||||
device = device or "AUTO"
|
||||
@ -398,7 +413,7 @@ def get_optimized_runner(
|
||||
# In other images we will get CUDA / ROCm which are preferred over OpenVINO
|
||||
# There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images
|
||||
if device != "CPU" and is_openvino_gpu_npu_available():
|
||||
return OpenVINOModelRunner(model_path, device, complex_model, **kwargs)
|
||||
return OpenVINOModelRunner(model_path, device, model_type, **kwargs)
|
||||
|
||||
ortSession = ort.InferenceSession(
|
||||
model_path,
|
||||
@ -406,7 +421,10 @@ def get_optimized_runner(
|
||||
provider_options=options,
|
||||
)
|
||||
|
||||
if not complex_model and providers[0] == "CUDAExecutionProvider":
|
||||
if (
|
||||
not CudaGraphRunner.is_complex_model(model_type)
|
||||
and providers[0] == "CUDAExecutionProvider"
|
||||
):
|
||||
return CudaGraphRunner(ortSession, options[0]["device_id"])
|
||||
|
||||
return ONNXModelRunner(ortSession)
|
||||
|
||||
@ -18,11 +18,6 @@ from frigate.util.downloader import ModelDownloader
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingTypeEnum(str, Enum):
|
||||
thumbnail = "thumbnail"
|
||||
description = "description"
|
||||
|
||||
|
||||
class BaseEmbedding(ABC):
|
||||
"""Base embedding class."""
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ import numpy as np
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.detectors.detection_runners import get_optimized_runner
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
from frigate.log import redirect_output_to_logger
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
@ -151,7 +152,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
device=self.config.device or "GPU",
|
||||
complex_model=False,
|
||||
model_type=EnrichmentModelTypeEnum.arcface.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
||||
@ -7,6 +7,7 @@ import warnings
|
||||
# importing this without pytorch or others causes a warning
|
||||
# https://github.com/huggingface/transformers/issues/27214
|
||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||
from transformers.utils.logging import disable_progress_bar
|
||||
|
||||
@ -128,6 +129,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type=EnrichmentModelTypeEnum.jina_v1.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -206,6 +208,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type=EnrichmentModelTypeEnum.jina_v1.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
||||
@ -6,6 +6,7 @@ import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
|
||||
|
||||
@ -128,6 +129,7 @@ class JinaV2Embedding(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type=EnrichmentModelTypeEnum.jina_v2.value,
|
||||
)
|
||||
|
||||
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray:
|
||||
|
||||
@ -8,6 +8,7 @@ import numpy as np
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
@ -79,6 +80,7 @@ class PaddleOCRDetection(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type=EnrichmentModelTypeEnum.paddleocr.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -138,6 +140,7 @@ class PaddleOCRClassification(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type=EnrichmentModelTypeEnum.paddleocr.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -198,6 +201,7 @@ class PaddleOCRRecognition(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type=EnrichmentModelTypeEnum.paddleocr.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -258,7 +262,7 @@ class LicensePlateDetector(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
complex_model=False,
|
||||
model_type="yolov9",
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
||||
13
frigate/embeddings/types.py
Normal file
13
frigate/embeddings/types.py
Normal file
@ -0,0 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EmbeddingTypeEnum(str, Enum):
|
||||
thumbnail = "thumbnail"
|
||||
description = "description"
|
||||
|
||||
class EnrichmentModelTypeEnum(str, Enum):
|
||||
arcface = "arcface"
|
||||
facenet = "facenet"
|
||||
jina_v1 = "jina_v1"
|
||||
jina_v2 = "jina_v2"
|
||||
paddleocr = "paddleocr"
|
||||
Loading…
Reference in New Issue
Block a user