diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index e9cb1c8b7..b5f15f391 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -50,7 +50,9 @@ class GenericONNXEmbedding: self.download_urls = download_urls self.embedding_function = embedding_function self.model_type = model_type # 'text' or 'vision' - self.providers, self.provider_options = get_ort_providers(force_cpu=force_cpu) + self.providers, self.provider_options = get_ort_providers( + force_cpu=force_cpu, requires_fp16=True + ) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.tokenizer = None diff --git a/frigate/util/model.py b/frigate/util/model.py index 6716b2405..fabade387 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -6,7 +6,7 @@ import onnxruntime as ort def get_ort_providers( - force_cpu: bool = False, openvino_device: str = "AUTO" + force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False ) -> tuple[list[str], list[dict[str, any]]]: if force_cpu: return (["CPUExecutionProvider"], [{}]) @@ -17,14 +17,19 @@ def get_ort_providers( for provider in providers: if provider == "TensorrtExecutionProvider": os.makedirs("/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True) - options.append( - { - "trt_timing_cache_enable": True, - "trt_engine_cache_enable": True, - "trt_timing_cache_path": "/config/model_cache/tensorrt/ort", - "trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines", - } - ) + + if not requires_fp16 or os.environ.get("USE_FP_16", "True") != "False": + options.append( + { + "trt_fp16_enable": requires_fp16, + "trt_timing_cache_enable": True, + "trt_engine_cache_enable": True, + "trt_timing_cache_path": "/config/model_cache/tensorrt/ort", + "trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines", + } + ) + else: + options.append({}) elif provider == "OpenVINOExecutionProvider": os.makedirs("/config/model_cache/openvino/ort", exist_ok=True) options.append(