Add support for non-complex models for CudaExecutionProvider

This commit is contained in:
Nicolas Mowen 2025-09-13 20:13:41 -06:00
parent 9a3fca6b2b
commit 0b8ac5c6ee
2 changed files with 6 additions and 1 deletions

View File

@ -3,6 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from frigate.detectors.plugins.onnx import CudaGraphRunner
import onnxruntime as ort import onnxruntime as ort
from frigate.detectors.plugins.openvino import OpenVINOModelRunner from frigate.detectors.plugins.openvino import OpenVINOModelRunner
@ -51,7 +52,7 @@ class ONNXModelRunner(BaseModelRunner):
return self.ort.run(None, input) return self.ort.run(None, input)
def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRunner: def get_optimized_runner(model_path: str, device: str, complex_model: bool = True, **kwargs) -> BaseModelRunner:
"""Get an optimized runner for the hardware.""" """Get an optimized runner for the hardware."""
if device == "CPU": if device == "CPU":
return ONNXModelRunner(model_path, device, **kwargs) return ONNXModelRunner(model_path, device, **kwargs)
@ -73,4 +74,7 @@ def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRun
provider_options=options, provider_options=options,
) )
if not complex_model and providers[0] == "CUDAExecutionProvider":
return CudaGraphRunner(ort, options[0]["device_id"])
return ONNXModelRunner(model_path, device, **kwargs) return ONNXModelRunner(model_path, device, **kwargs)

View File

@ -151,6 +151,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
device=self.config.device or "GPU", device=self.config.device or "GPU",
complex_model=False,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):