mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-11 05:35:25 +03:00
Add support for TensorRT v10 (multiple api calls have changed)
This commit is contained in:
parent
1c9626ecff
commit
9a642086f9
@ -6,6 +6,7 @@ import numpy as np
|
||||
try:
|
||||
import tensorrt as trt
|
||||
from cuda import cuda
|
||||
TRT_VERSION=int(trt.__version__[0:trt.__version__.find(".")])
|
||||
|
||||
TRT_SUPPORT = True
|
||||
except ModuleNotFoundError:
|
||||
@ -91,6 +92,7 @@ class TensorRtDetector(DetectionApi):
|
||||
def _get_input_shape(self):
|
||||
"""Get input shape of the TensorRT YOLO engine."""
|
||||
binding = self.engine[0]
|
||||
if TRT_VERSION < 10:
|
||||
assert self.engine.binding_is_input(binding)
|
||||
binding_dims = self.engine.get_binding_shape(binding)
|
||||
if len(binding_dims) == 4:
|
||||
@ -107,6 +109,23 @@ class TensorRtDetector(DetectionApi):
|
||||
raise ValueError(
|
||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
||||
)
|
||||
else:
|
||||
assert binding == "input"
|
||||
binding_dims = self.engine.get_tensor_shape("input")
|
||||
if len(binding_dims) == 4:
|
||||
return (
|
||||
tuple(binding_dims[2:]),
|
||||
trt.nptype(self.engine.get_tensor_dtype(binding)),
|
||||
)
|
||||
elif len(binding_dims) == 3:
|
||||
return (
|
||||
tuple(binding_dims[1:]),
|
||||
trt.nptype(self.engine.get_tensor_dtype(binding)),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
||||
)
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Allocates all host/device in/out buffers required for an engine."""
|
||||
@ -115,6 +134,7 @@ class TensorRtDetector(DetectionApi):
|
||||
bindings = []
|
||||
output_idx = 0
|
||||
for binding in self.engine:
|
||||
if TRT_VERSION < 10:
|
||||
binding_dims = self.engine.get_binding_shape(binding)
|
||||
if len(binding_dims) == 4:
|
||||
# explicit batch case (TensorRT 7+)
|
||||
@ -150,6 +170,42 @@ class TensorRtDetector(DetectionApi):
|
||||
logger.debug(f"Output has Shape {binding_dims}")
|
||||
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||
output_idx += 1
|
||||
else:
|
||||
binding_dims = self.engine.get_tensor_shape(binding)
|
||||
if len(binding_dims) == 4:
|
||||
# explicit batch case (TensorRT 7+)
|
||||
size = trt.volume(binding_dims)
|
||||
elif len(binding_dims) == 3:
|
||||
# implicit batch case (TensorRT 6 or older)
|
||||
size = trt.volume(binding_dims) * self.engine.max_batch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
||||
)
|
||||
nbytes = size * self.engine.get_tensor_dtype(binding).itemsize
|
||||
# Allocate host and device buffers
|
||||
err, host_mem = cuda.cuMemHostAlloc(
|
||||
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
|
||||
)
|
||||
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
|
||||
logger.debug(
|
||||
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_tensor_dtype(binding)})"
|
||||
)
|
||||
err, device_mem = cuda.cuMemAlloc(nbytes)
|
||||
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
|
||||
# Append the device buffer to device bindings.
|
||||
bindings.append(int(device_mem))
|
||||
# Append to the appropriate list.
|
||||
if binding == "input":
|
||||
logger.debug(f"Input has Shape {binding_dims}")
|
||||
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||
else:
|
||||
# each grid has 3 anchors, each anchor generates a detection
|
||||
# output of 7 float32 values
|
||||
assert size % 7 == 0, f"output size was {size}"
|
||||
logger.debug(f"Output has Shape {binding_dims}")
|
||||
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||
output_idx += 1
|
||||
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
|
||||
assert len(outputs) == 1, f"output len was {len(outputs)}"
|
||||
return inputs, outputs, bindings
|
||||
@ -170,10 +226,16 @@ class TensorRtDetector(DetectionApi):
|
||||
]
|
||||
|
||||
# Run inference.
|
||||
if TRT_VERSION < 10:
|
||||
if not self.context.execute_async_v2(
|
||||
bindings=self.bindings, stream_handle=self.stream
|
||||
):
|
||||
logger.warn("Execute returned false")
|
||||
else:
|
||||
if not self.context.execute_v2(
|
||||
self.bindings
|
||||
):
|
||||
logger.warn("Execute returned false")
|
||||
|
||||
# Transfer predictions back from the GPU.
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user