mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-15 19:42:08 +03:00
Add support for non-complex models for CudaExecutionProvider
This commit is contained in:
parent
9a3fca6b2b
commit
0b8ac5c6ee
@ -3,6 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from frigate.detectors.plugins.onnx import CudaGraphRunner
|
||||
import onnxruntime as ort
|
||||
|
||||
from frigate.detectors.plugins.openvino import OpenVINOModelRunner
|
||||
@ -51,7 +52,7 @@ class ONNXModelRunner(BaseModelRunner):
|
||||
return self.ort.run(None, input)
|
||||
|
||||
|
||||
def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRunner:
|
||||
def get_optimized_runner(model_path: str, device: str, complex_model: bool = True, **kwargs) -> BaseModelRunner:
|
||||
"""Get an optimized runner for the hardware."""
|
||||
if device == "CPU":
|
||||
return ONNXModelRunner(model_path, device, **kwargs)
|
||||
@ -73,4 +74,7 @@ def get_optimized_runner(model_path: str, device: str, **kwargs) -> BaseModelRun
|
||||
provider_options=options,
|
||||
)
|
||||
|
||||
if not complex_model and providers[0] == "CUDAExecutionProvider":
|
||||
return CudaGraphRunner(ort, options[0]["device_id"])
|
||||
|
||||
return ONNXModelRunner(model_path, device, **kwargs)
|
||||
|
||||
@ -151,6 +151,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
device=self.config.device or "GPU",
|
||||
complex_model=False,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user