Add support for non-complex models for CudaExecutionProvider

This commit is contained in:
Nicolas Mowen 2025-09-13 20:13:41 -06:00
parent 9a3fca6b2b
commit 0b8ac5c6ee
2 changed files with 6 additions and 1 deletions

View File

@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from typing import Any
from frigate.detectors.plugins.onnx import CudaGraphRunner
import onnxruntime as ort
from frigate.detectors.plugins.openvino import OpenVINOModelRunner
@ -51,7 +52,7 @@ class ONNXModelRunner(BaseModelRunner):
return self.ort.run(None, input)
def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRunner:
def get_optimized_runner(model_path: str, device: str, complex_model: bool = True, **kwargs) -> BaseModelRunner:
"""Get an optimized runner for the hardware."""
if device == "CPU":
return ONNXModelRunner(model_path, device, **kwargs)
@ -73,4 +74,7 @@ def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRun
provider_options=options,
)
if not complex_model and providers[0] == "CUDAExecutionProvider":
return CudaGraphRunner(ort, options[0]["device_id"])
return ONNXModelRunner(model_path, device, **kwargs)

View File

@ -151,6 +151,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
device=self.config.device or "GPU",
complex_model=False,
)
def _preprocess_inputs(self, raw_inputs):