mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-15 19:42:08 +03:00
Cleanup comments and typing
This commit is contained in:
parent
6e5c1f49d9
commit
f306d761f0
@ -29,15 +29,15 @@ class CudaGraphRunner:
|
|||||||
This runner assumes a single tensor input and binds all model outputs.
|
This runner assumes a single tensor input and binds all model outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session, cuda_device_id: int):
|
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
||||||
self._session = session
|
self._session = session
|
||||||
self._cuda_device_id = cuda_device_id
|
self._cuda_device_id = cuda_device_id
|
||||||
self._captured = False
|
self._captured = False
|
||||||
self._io_binding = None
|
self._io_binding: ort.IOBinding | None = None
|
||||||
self._input_name = None
|
self._input_name: str | None = None
|
||||||
self._output_names = None
|
self._output_names: list[str] | None = None
|
||||||
self._input_ortvalue = None
|
self._input_ortvalue: ort.OrtValue | None = None
|
||||||
self._output_ortvalues = None
|
self._output_ortvalues: ort.OrtValue | None = None
|
||||||
|
|
||||||
def run(self, input_name: str, tensor_input: np.ndarray):
|
def run(self, input_name: str, tensor_input: np.ndarray):
|
||||||
tensor_input = np.ascontiguousarray(tensor_input)
|
tensor_input = np.ascontiguousarray(tensor_input)
|
||||||
@ -109,15 +109,14 @@ class ONNXDetector(DetectionApi):
|
|||||||
if self.onnx_model_type == ModelTypeEnum.yolox:
|
if self.onnx_model_type == ModelTypeEnum.yolox:
|
||||||
self.calculate_grids_strides()
|
self.calculate_grids_strides()
|
||||||
|
|
||||||
# Internal CUDA Graphs state
|
|
||||||
self._cuda_device_id = 0
|
self._cuda_device_id = 0
|
||||||
self._cg_runner = None
|
self._cg_runner: CudaGraphRunner | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "CUDAExecutionProvider" in providers:
|
if "CUDAExecutionProvider" in providers:
|
||||||
cuda_idx = providers.index("CUDAExecutionProvider")
|
cuda_idx = providers.index("CUDAExecutionProvider")
|
||||||
self._cuda_device_id = options[cuda_idx].get("device_id", 0)
|
self._cuda_device_id = options[cuda_idx].get("device_id", 0)
|
||||||
# If we enabled CUDA graphs above for supported models, set flag
|
|
||||||
if options[cuda_idx].get("enable_cuda_graph"):
|
if options[cuda_idx].get("enable_cuda_graph"):
|
||||||
self._cg_runner = CudaGraphRunner(self.model, self._cuda_device_id)
|
self._cg_runner = CudaGraphRunner(self.model, self._cuda_device_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -140,7 +139,6 @@ class ONNXDetector(DetectionApi):
|
|||||||
|
|
||||||
model_input_name = self.model.get_inputs()[0].name
|
model_input_name = self.model.get_inputs()[0].name
|
||||||
|
|
||||||
|
|
||||||
if self._cg_runner is not None:
|
if self._cg_runner is not None:
|
||||||
try:
|
try:
|
||||||
# Run using CUDA graphs if available
|
# Run using CUDA graphs if available
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user