Update detect to normalize input tensor using model input type

This commit is contained in:
Nate Meyer 2022-12-29 00:45:39 -05:00
parent c01507081e
commit de251c2c21
2 changed files with 16 additions and 6 deletions

View File

@ -166,6 +166,7 @@ To generate the model files, create a new folder to save the models, download th
```bash ```bash
mkdir trt-models mkdir trt-models
wget https://github.com/blakeblackshear/frigate/raw/master/docker/tensorrt_models.sh wget https://github.com/blakeblackshear/frigate/raw/master/docker/tensorrt_models.sh
chmod +x tensorrt_models.sh
docker run --gpus=all --rm -it -v `pwd`/trt-models:/tensorrt_models -v `pwd`/tensorrt_models.sh:/tensorrt_models.sh nvcr.io/nvidia/tensorrt:22.07-py3 /tensorrt_models.sh docker run --gpus=all --rm -it -v `pwd`/trt-models:/tensorrt_models -v `pwd`/tensorrt_models.sh:/tensorrt_models.sh nvcr.io/nvidia/tensorrt:22.07-py3 /tensorrt_models.sh
``` ```

View File

@ -5,7 +5,7 @@ import numpy as np
try: try:
import tensorrt as trt import tensorrt as trt
from cuda import cuda, cudart from cuda import cuda
TRT_SUPPORT = True TRT_SUPPORT = True
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
@ -72,7 +72,7 @@ class HostDeviceMem(object):
class TensorRtDetector(DetectionApi): class TensorRtDetector(DetectionApi):
type_key = DETECTOR_KEY type_key = DETECTOR_KEY
# class LocalObjectDetector(ObjectDetector):
def _load_engine(self, model_path): def _load_engine(self, model_path):
try: try:
ctypes.cdll.LoadLibrary( ctypes.cdll.LoadLibrary(
@ -100,9 +100,15 @@ class TensorRtDetector(DetectionApi):
assert self.engine.binding_is_input(binding) assert self.engine.binding_is_input(binding)
binding_dims = self.engine.get_binding_shape(binding) binding_dims = self.engine.get_binding_shape(binding)
if len(binding_dims) == 4: if len(binding_dims) == 4:
return tuple(binding_dims[2:]) return (
tuple(binding_dims[2:]),
trt.nptype(self.engine.get_binding_dtype(binding)),
)
elif len(binding_dims) == 3: elif len(binding_dims) == 3:
return tuple(binding_dims[1:]) return (
tuple(binding_dims[1:]),
trt.nptype(self.engine.get_binding_dtype(binding)),
)
else: else:
raise ValueError( raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims)) "bad dims of binding %s: %s" % (binding, str(binding_dims))
@ -249,10 +255,13 @@ class TensorRtDetector(DetectionApi):
# 2..5 - a value between 0 and 1 of the box: [top, left, bottom, right] # 2..5 - a value between 0 and 1 of the box: [top, left, bottom, right]
# normalize # normalize
tensor_input = tensor_input.astype(np.float32) if self.input_shape[-1] != trt.int8:
tensor_input = tensor_input.astype(self.input_shape[-1])
tensor_input /= 255.0 tensor_input /= 255.0
self.inputs[0].host = np.ascontiguousarray(tensor_input.astype(np.float32)) self.inputs[0].host = np.ascontiguousarray(
tensor_input.astype(self.input_shape[-1])
)
trt_outputs = self._do_inference() trt_outputs = self._do_inference()
raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th) raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)