mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-23 07:14:35 +03:00
Send model type
This commit is contained in:
parent
20c300df95
commit
161ed46c55
@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from frigate.detectors.detector_types import ModelTypeEnum
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from frigate.util.model import get_ort_providers
|
from frigate.util.model import get_ort_providers
|
||||||
@ -101,6 +103,15 @@ class CudaGraphRunner(BaseModelRunner):
|
|||||||
for more complex models like CLIP or PaddleOCR.
|
for more complex models like CLIP or PaddleOCR.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_complex_model(model_type: str) -> bool:
|
||||||
|
return model_type in [
|
||||||
|
ModelTypeEnum.yolonas.value,
|
||||||
|
EnrichmentModelTypeEnum.paddleocr.value,
|
||||||
|
EnrichmentModelTypeEnum.jina_v1.value,
|
||||||
|
EnrichmentModelTypeEnum.jina_v2.value,
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
||||||
self._session = session
|
self._session = session
|
||||||
self._cuda_device_id = cuda_device_id
|
self._cuda_device_id = cuda_device_id
|
||||||
@ -156,10 +167,14 @@ class CudaGraphRunner(BaseModelRunner):
|
|||||||
class OpenVINOModelRunner(BaseModelRunner):
|
class OpenVINOModelRunner(BaseModelRunner):
|
||||||
"""OpenVINO model runner that handles inference efficiently."""
|
"""OpenVINO model runner that handles inference efficiently."""
|
||||||
|
|
||||||
def __init__(self, model_path: str, device: str, complex_model: bool, **kwargs):
|
@staticmethod
|
||||||
|
def is_complex_model(model_type: str) -> bool:
|
||||||
|
return model_type in [EnrichmentModelTypeEnum.paddleocr.value]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, device: str, model_type: str, **kwargs):
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.device = device
|
self.device = device
|
||||||
self.complex_model = complex_model
|
self.complex_model = OpenVINOModelRunner.is_complex_model(model_type)
|
||||||
|
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise FileNotFoundError(f"OpenVINO model file {model_path} not found.")
|
raise FileNotFoundError(f"OpenVINO model file {model_path} not found.")
|
||||||
@ -381,7 +396,7 @@ class RKNNModelRunner(BaseModelRunner):
|
|||||||
|
|
||||||
|
|
||||||
def get_optimized_runner(
|
def get_optimized_runner(
|
||||||
model_path: str, device: str | None, complex_model: bool = True, **kwargs
|
model_path: str, device: str | None, model_type: str, **kwargs
|
||||||
) -> BaseModelRunner:
|
) -> BaseModelRunner:
|
||||||
"""Get an optimized runner for the hardware."""
|
"""Get an optimized runner for the hardware."""
|
||||||
device = device or "AUTO"
|
device = device or "AUTO"
|
||||||
@ -398,7 +413,7 @@ def get_optimized_runner(
|
|||||||
# In other images we will get CUDA / ROCm which are preferred over OpenVINO
|
# In other images we will get CUDA / ROCm which are preferred over OpenVINO
|
||||||
# There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images
|
# There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images
|
||||||
if device != "CPU" and is_openvino_gpu_npu_available():
|
if device != "CPU" and is_openvino_gpu_npu_available():
|
||||||
return OpenVINOModelRunner(model_path, device, complex_model, **kwargs)
|
return OpenVINOModelRunner(model_path, device, model_type, **kwargs)
|
||||||
|
|
||||||
ortSession = ort.InferenceSession(
|
ortSession = ort.InferenceSession(
|
||||||
model_path,
|
model_path,
|
||||||
@ -406,7 +421,10 @@ def get_optimized_runner(
|
|||||||
provider_options=options,
|
provider_options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not complex_model and providers[0] == "CUDAExecutionProvider":
|
if (
|
||||||
|
not CudaGraphRunner.is_complex_model(model_type)
|
||||||
|
and providers[0] == "CUDAExecutionProvider"
|
||||||
|
):
|
||||||
return CudaGraphRunner(ortSession, options[0]["device_id"])
|
return CudaGraphRunner(ortSession, options[0]["device_id"])
|
||||||
|
|
||||||
return ONNXModelRunner(ortSession)
|
return ONNXModelRunner(ortSession)
|
||||||
|
|||||||
@ -18,11 +18,6 @@ from frigate.util.downloader import ModelDownloader
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingTypeEnum(str, Enum):
|
|
||||||
thumbnail = "thumbnail"
|
|
||||||
description = "description"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbedding(ABC):
|
class BaseEmbedding(ABC):
|
||||||
"""Base embedding class."""
|
"""Base embedding class."""
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
|
|
||||||
from frigate.const import MODEL_CACHE_DIR
|
from frigate.const import MODEL_CACHE_DIR
|
||||||
from frigate.detectors.detection_runners import get_optimized_runner
|
from frigate.detectors.detection_runners import get_optimized_runner
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
from frigate.log import redirect_output_to_logger
|
from frigate.log import redirect_output_to_logger
|
||||||
from frigate.util.downloader import ModelDownloader
|
from frigate.util.downloader import ModelDownloader
|
||||||
|
|
||||||
@ -151,7 +152,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
device=self.config.device or "GPU",
|
device=self.config.device or "GPU",
|
||||||
complex_model=False,
|
model_type=EnrichmentModelTypeEnum.arcface.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import warnings
|
|||||||
# importing this without pytorch or others causes a warning
|
# importing this without pytorch or others causes a warning
|
||||||
# https://github.com/huggingface/transformers/issues/27214
|
# https://github.com/huggingface/transformers/issues/27214
|
||||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||||
from transformers.utils.logging import disable_progress_bar
|
from transformers.utils.logging import disable_progress_bar
|
||||||
|
|
||||||
@ -128,6 +129,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
|
model_type=EnrichmentModelTypeEnum.jina_v1.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
@ -206,6 +208,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
|
model_type=EnrichmentModelTypeEnum.jina_v1.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
|
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
|
||||||
|
|
||||||
@ -128,6 +129,7 @@ class JinaV2Embedding(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
|
model_type=EnrichmentModelTypeEnum.jina_v2.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray:
|
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
from frigate.comms.inter_process import InterProcessRequestor
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
from frigate.const import MODEL_CACHE_DIR
|
from frigate.const import MODEL_CACHE_DIR
|
||||||
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
|
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
from frigate.types import ModelStatusTypesEnum
|
from frigate.types import ModelStatusTypesEnum
|
||||||
from frigate.util.downloader import ModelDownloader
|
from frigate.util.downloader import ModelDownloader
|
||||||
|
|
||||||
@ -79,6 +80,7 @@ class PaddleOCRDetection(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
|
model_type=EnrichmentModelTypeEnum.paddleocr.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
@ -138,6 +140,7 @@ class PaddleOCRClassification(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
|
model_type=EnrichmentModelTypeEnum.paddleocr.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
@ -198,6 +201,7 @@ class PaddleOCRRecognition(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
|
model_type=EnrichmentModelTypeEnum.paddleocr.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
@ -258,7 +262,7 @@ class LicensePlateDetector(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
self.device,
|
||||||
complex_model=False,
|
model_type="yolov9",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
|||||||
13
frigate/embeddings/types.py
Normal file
13
frigate/embeddings/types.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTypeEnum(str, Enum):
|
||||||
|
thumbnail = "thumbnail"
|
||||||
|
description = "description"
|
||||||
|
|
||||||
|
class EnrichmentModelTypeEnum(str, Enum):
|
||||||
|
arcface = "arcface"
|
||||||
|
facenet = "facenet"
|
||||||
|
jina_v1 = "jina_v1"
|
||||||
|
jina_v2 = "jina_v2"
|
||||||
|
paddleocr = "paddleocr"
|
||||||
Loading…
Reference in New Issue
Block a user