diff --git a/frigate/detectors/plugins/tensorrt.py b/frigate/detectors/plugins/tensorrt.py index bcab5f027..98c4fe9b5 100644 --- a/frigate/detectors/plugins/tensorrt.py +++ b/frigate/detectors/plugins/tensorrt.py @@ -5,8 +5,6 @@ import numpy as np import tensorrt as trt from cuda import cuda, cudart -# import pycuda.driver as cuda -# import pycuda.autoinit # This is needed for initializing CUDA driver from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig from typing import Literal @@ -17,6 +15,28 @@ logger = logging.getLogger(__name__) DETECTOR_KEY = "tensorrt" +class TrtLogger(trt.ILogger): + def __init__(self): + trt.ILogger.__init__(self) + + def log(self, severity, msg): + logger.log(self.getSeverity(severity), msg) + + def getSeverity(self, sev: trt.ILogger.Severity) -> int: + if sev == trt.ILogger.VERBOSE: + return logging.DEBUG + elif sev == trt.ILogger.INFO: + return logging.INFO + elif sev == trt.ILogger.WARNING: + return logging.WARNING + elif sev == trt.ILogger.ERROR: + return logging.ERROR + elif sev == trt.ILogger.INTERNAL_ERROR: + return logging.CRITICAL + else: + return logging.DEBUG + + class TensorRTDetectorConfig(BaseDetectorConfig): type: Literal[DETECTOR_KEY] device: str = Field(default=None, title="Device Type") @@ -113,11 +133,13 @@ class TensorRtDetector(DetectionApi): bindings.append(int(device_mem)) # Append to the appropriate list. if self.engine.binding_is_input(binding): + logger.info(f"Input has Shape {binding_dims}") inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size)) else: # each grid has 3 anchors, each anchor generates a detection # output of 7 float32 values assert size % 7 == 0, f"output size was {size}" + logger.info(f"Output has Shape {binding_dims}") outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size)) output_idx += 1 assert len(inputs) == 1, f"inputs len was {len(inputs)}" @@ -159,7 +181,7 @@ class TensorRtDetector(DetectionApi): # self.fps = EventsPerSecond() self.conf_th = 0.4 ##TODO: model config parameter self.nms_threshold = 0.4 - self.trt_logger = trt.Logger(trt.Logger.INFO) + self.trt_logger = TrtLogger() self.engine = self._load_engine(detector_config.model.path) self.input_shape = self._get_input_shape()