diff --git a/frigate/detectors/plugins/tensorrt.py b/frigate/detectors/plugins/tensorrt.py index cee4626ef..bcab5f027 100644 --- a/frigate/detectors/plugins/tensorrt.py +++ b/frigate/detectors/plugins/tensorrt.py @@ -1,14 +1,11 @@ import logging -# from frigate.config import DetectorConfig, DetectorTypeEnum -# from frigate.util import EventsPerSecond import ctypes import numpy as np import tensorrt as trt -from cuda import cuda as cuda +from cuda import cuda, cudart # import pycuda.driver as cuda -# from .object_detector import ObjectDetector # import pycuda.autoinit # This is needed for initializing CUDA driver from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig @@ -19,15 +16,6 @@ logger = logging.getLogger(__name__) DETECTOR_KEY = "tensorrt" -# def object_detector_factory(detector_config: DetectorConfig, model_path: str): -# if detector_config.type != DetectorTypeEnum.tensorrt: -# return None -# try: -# ctypes.cdll.LoadLibrary("/yolo4/libyolo_layer.so") -# except OSError as e: -# logger.error("ERROR: failed to load /yolo4/libyolo_layer.so. %s", e) -# return LocalObjectDetector(detector_config, model_path) - class TensorRTDetectorConfig(BaseDetectorConfig): type: Literal[DETECTOR_KEY] @@ -159,7 +147,12 @@ class TensorRtDetector(DetectionApi): # Synchronize the stream cuda.cuStreamSynchronize(self.stream) # Return only the host outputs. - return [np.array([int(out.host_dev)], dtype=np.float32) for out in self.outputs] + return [ + np.array( + (ctypes.c_float * out.size).from_address(out.host), dtype=np.float32 + ) + for out in self.outputs + ] def __init__(self, detector_config: TensorRTDetectorConfig): # def __init__(self, detector_config: DetectorConfig, model_path: str): @@ -212,9 +205,6 @@ class TensorRtDetector(DetectionApi): return detections - # def detect(self, tensor_input, threshold=0.4): - # pass - def detect_raw(self, tensor_input): # Input tensor has the shape of the [height, width, 3] # Output tensor of float32 of shape [20, 6] where: