mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-15 11:32:09 +03:00
Add support for non-complex models for CudaExecutionProvider
This commit is contained in:
parent
9a3fca6b2b
commit
0b8ac5c6ee
@ -3,6 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from frigate.detectors.plugins.onnx import CudaGraphRunner
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from frigate.detectors.plugins.openvino import OpenVINOModelRunner
|
from frigate.detectors.plugins.openvino import OpenVINOModelRunner
|
||||||
@ -51,7 +52,7 @@ class ONNXModelRunner(BaseModelRunner):
|
|||||||
return self.ort.run(None, input)
|
return self.ort.run(None, input)
|
||||||
|
|
||||||
|
|
||||||
def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRunner:
|
def get_optimized_runner(model_path: str, device: str, complex_model: bool = True, **kwargs) -> BaseModelRunner:
|
||||||
"""Get an optimized runner for the hardware."""
|
"""Get an optimized runner for the hardware."""
|
||||||
if device == "CPU":
|
if device == "CPU":
|
||||||
return ONNXModelRunner(model_path, device, **kwargs)
|
return ONNXModelRunner(model_path, device, **kwargs)
|
||||||
@ -73,4 +74,7 @@ def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRun
|
|||||||
provider_options=options,
|
provider_options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not complex_model and providers[0] == "CUDAExecutionProvider":
|
||||||
|
return CudaGraphRunner(ort, options[0]["device_id"])
|
||||||
|
|
||||||
return ONNXModelRunner(model_path, device, **kwargs)
|
return ONNXModelRunner(model_path, device, **kwargs)
|
||||||
|
|||||||
@ -151,6 +151,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
|||||||
self.runner = get_optimized_runner(
|
self.runner = get_optimized_runner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
device=self.config.device or "GPU",
|
device=self.config.device or "GPU",
|
||||||
|
complex_model=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user