mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-11 05:35:25 +03:00
Refactor to reduce code duplication
This commit is contained in:
parent
408f295416
commit
6c0abe4833
@ -89,43 +89,50 @@ class TensorRtDetector(DetectionApi):
|
|||||||
with open(model_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime:
|
with open(model_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime:
|
||||||
return runtime.deserialize_cuda_engine(f.read())
|
return runtime.deserialize_cuda_engine(f.read())
|
||||||
|
|
||||||
|
def _binding_is_input(self, binding):
|
||||||
|
if TRT_VERSION < 10:
|
||||||
|
assert self.engine.binding_is_input(binding)
|
||||||
|
else:
|
||||||
|
assert binding == "input"
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_binding_dims(self, binding):
|
||||||
|
if TRT_VERSION < 10:
|
||||||
|
return self.engine.get_binding_shape(binding)
|
||||||
|
else:
|
||||||
|
return self.engine.get_tensor_shape(binding)
|
||||||
|
|
||||||
|
def _get_binding_dtype(self, binding):
|
||||||
|
if TRT_VERSION < 10:
|
||||||
|
return self.engine.get_binding_dtype(binding)
|
||||||
|
else:
|
||||||
|
return self.engine.get_tensor_shape(binding)
|
||||||
|
|
||||||
|
def _execute(self):
|
||||||
|
if TRT_VERSION < 10:
|
||||||
|
return self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream)
|
||||||
|
else:
|
||||||
|
return self.context.execute_v2(self.bindings)
|
||||||
|
|
||||||
def _get_input_shape(self):
|
def _get_input_shape(self):
|
||||||
"""Get input shape of the TensorRT YOLO engine."""
|
"""Get input shape of the TensorRT YOLO engine."""
|
||||||
binding = self.engine[0]
|
binding = self.engine[0]
|
||||||
if TRT_VERSION < 10:
|
assert self._binding_is_input(binding)
|
||||||
assert self.engine.binding_is_input(binding)
|
binding_dims = self._get_binding_dims(binding)
|
||||||
binding_dims = self.engine.get_binding_shape(binding)
|
if len(binding_dims) == 4:
|
||||||
if len(binding_dims) == 4:
|
return (
|
||||||
return (
|
tuple(binding_dims[2:]),
|
||||||
tuple(binding_dims[2:]),
|
trt.nptype(self._get_binding_dtype(binding)),
|
||||||
trt.nptype(self.engine.get_binding_dtype(binding)),
|
)
|
||||||
)
|
elif len(binding_dims) == 3:
|
||||||
elif len(binding_dims) == 3:
|
return (
|
||||||
return (
|
tuple(binding_dims[1:]),
|
||||||
tuple(binding_dims[1:]),
|
trt.nptype(self._get_binding_dtype(binding)),
|
||||||
trt.nptype(self.engine.get_binding_dtype(binding)),
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert binding == "input"
|
raise ValueError(
|
||||||
binding_dims = self.engine.get_tensor_shape("input")
|
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
||||||
if len(binding_dims) == 4:
|
)
|
||||||
return (
|
|
||||||
tuple(binding_dims[2:]),
|
|
||||||
trt.nptype(self.engine.get_tensor_dtype(binding)),
|
|
||||||
)
|
|
||||||
elif len(binding_dims) == 3:
|
|
||||||
return (
|
|
||||||
tuple(binding_dims[1:]),
|
|
||||||
trt.nptype(self.engine.get_tensor_dtype(binding)),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
|
||||||
)
|
|
||||||
|
|
||||||
def _allocate_buffers(self):
|
def _allocate_buffers(self):
|
||||||
"""Allocates all host/device in/out buffers required for an engine."""
|
"""Allocates all host/device in/out buffers required for an engine."""
|
||||||
@ -134,75 +141,41 @@ class TensorRtDetector(DetectionApi):
|
|||||||
bindings = []
|
bindings = []
|
||||||
output_idx = 0
|
output_idx = 0
|
||||||
for binding in self.engine:
|
for binding in self.engine:
|
||||||
if TRT_VERSION < 10:
|
binding_dims = self._get_binding_dims(binding)
|
||||||
binding_dims = self.engine.get_binding_shape(binding)
|
if len(binding_dims) == 4:
|
||||||
if len(binding_dims) == 4:
|
# explicit batch case (TensorRT 7+)
|
||||||
# explicit batch case (TensorRT 7+)
|
size = trt.volume(binding_dims)
|
||||||
size = trt.volume(binding_dims)
|
elif len(binding_dims) == 3:
|
||||||
elif len(binding_dims) == 3:
|
# implicit batch case (TensorRT 6 or older)
|
||||||
# implicit batch case (TensorRT 6 or older)
|
size = trt.volume(binding_dims) * self.engine.max_batch_size
|
||||||
size = trt.volume(binding_dims) * self.engine.max_batch_size
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
|
||||||
)
|
|
||||||
nbytes = size * self.engine.get_binding_dtype(binding).itemsize
|
|
||||||
# Allocate host and device buffers
|
|
||||||
err, host_mem = cuda.cuMemHostAlloc(
|
|
||||||
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
|
|
||||||
)
|
|
||||||
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
|
|
||||||
logger.debug(
|
|
||||||
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_binding_dtype(binding)})"
|
|
||||||
)
|
|
||||||
err, device_mem = cuda.cuMemAlloc(nbytes)
|
|
||||||
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
|
|
||||||
# Append the device buffer to device bindings.
|
|
||||||
bindings.append(int(device_mem))
|
|
||||||
# Append to the appropriate list.
|
|
||||||
if self.engine.binding_is_input(binding):
|
|
||||||
logger.debug(f"Input has Shape {binding_dims}")
|
|
||||||
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
|
||||||
else:
|
|
||||||
# each grid has 3 anchors, each anchor generates a detection
|
|
||||||
# output of 7 float32 values
|
|
||||||
assert size % 7 == 0, f"output size was {size}"
|
|
||||||
logger.debug(f"Output has Shape {binding_dims}")
|
|
||||||
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
|
||||||
output_idx += 1
|
|
||||||
else:
|
else:
|
||||||
binding_dims = self.engine.get_tensor_shape(binding)
|
raise ValueError(
|
||||||
if len(binding_dims) == 4:
|
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
||||||
# explicit batch case (TensorRT 7+)
|
|
||||||
size = trt.volume(binding_dims)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"bad dims of binding %s: %s" % (binding, str(binding_dims))
|
|
||||||
)
|
|
||||||
nbytes = size * self.engine.get_tensor_dtype(binding).itemsize
|
|
||||||
# Allocate host and device buffers
|
|
||||||
err, host_mem = cuda.cuMemHostAlloc(
|
|
||||||
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
|
|
||||||
)
|
)
|
||||||
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
|
nbytes = size * self._get_binding_dtype(binding).itemsize
|
||||||
logger.debug(
|
# Allocate host and device buffers
|
||||||
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_tensor_dtype(binding)})"
|
err, host_mem = cuda.cuMemHostAlloc(
|
||||||
)
|
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
|
||||||
err, device_mem = cuda.cuMemAlloc(nbytes)
|
)
|
||||||
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
|
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
|
||||||
# Append the device buffer to device bindings.
|
logger.debug(
|
||||||
bindings.append(int(device_mem))
|
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self._get_binding_dtype(binding)})"
|
||||||
# Append to the appropriate list.
|
)
|
||||||
if binding == "input":
|
err, device_mem = cuda.cuMemAlloc(nbytes)
|
||||||
logger.debug(f"Input has Shape {binding_dims}")
|
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
|
||||||
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
# Append the device buffer to device bindings.
|
||||||
else:
|
bindings.append(int(device_mem))
|
||||||
# each grid has 3 anchors, each anchor generates a detection
|
# Append to the appropriate list.
|
||||||
# output of 7 float32 values
|
if self._binding_is_input(binding):
|
||||||
assert size % 7 == 0, f"output size was {size}"
|
logger.debug(f"Input has Shape {binding_dims}")
|
||||||
logger.debug(f"Output has Shape {binding_dims}")
|
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||||
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
else:
|
||||||
output_idx += 1
|
# each grid has 3 anchors, each anchor generates a detection
|
||||||
|
# output of 7 float32 values
|
||||||
|
assert size % 7 == 0, f"output size was {size}"
|
||||||
|
logger.debug(f"Output has Shape {binding_dims}")
|
||||||
|
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||||
|
output_idx += 1
|
||||||
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
|
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
|
||||||
assert len(outputs) == 1, f"output len was {len(outputs)}"
|
assert len(outputs) == 1, f"output len was {len(outputs)}"
|
||||||
return inputs, outputs, bindings
|
return inputs, outputs, bindings
|
||||||
@ -223,16 +196,8 @@ class TensorRtDetector(DetectionApi):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Run inference.
|
# Run inference.
|
||||||
if TRT_VERSION < 10:
|
if not self._execute():
|
||||||
if not self.context.execute_async_v2(
|
logger.warn("Execute returned false")
|
||||||
bindings=self.bindings, stream_handle=self.stream
|
|
||||||
):
|
|
||||||
logger.warn("Execute returned false")
|
|
||||||
else:
|
|
||||||
if not self.context.execute_v2(
|
|
||||||
self.bindings
|
|
||||||
):
|
|
||||||
logger.warn("Execute returned false")
|
|
||||||
|
|
||||||
# Transfer predictions back from the GPU.
|
# Transfer predictions back from the GPU.
|
||||||
[
|
[
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user