Handle TensorRT logs with the python logger

This commit is contained in:
Nate Meyer 2022-12-20 19:21:20 -05:00
parent aa9271e363
commit 643c3f21cd

View File

@ -5,8 +5,6 @@ import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
# import pycuda.driver as cuda
# import pycuda.autoinit # This is needed for initializing CUDA driver
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from typing import Literal
@ -17,6 +15,28 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "tensorrt"
class TrtLogger(trt.ILogger):
def __init__(self):
trt.ILogger.__init__(self)
def log(self, severity, msg):
logger.log(self.getSeverity(severity), msg)
def getSeverity(self, sev: trt.ILogger.Severity) -> int:
if sev == trt.ILogger.VERBOSE:
return logging.DEBUG
elif sev == trt.ILogger.INFO:
return logging.INFO
elif sev == trt.ILogger.WARNING:
return logging.WARNING
elif sev == trt.ILogger.ERROR:
return logging.ERROR
elif sev == trt.ILogger.INTERNAL_ERROR:
return logging.CRITICAL
else:
return logging.DEBUG
class TensorRTDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
device: str = Field(default=None, title="Device Type")
@ -113,11 +133,13 @@ class TensorRtDetector(DetectionApi):
bindings.append(int(device_mem))
# Append to the appropriate list.
if self.engine.binding_is_input(binding):
logger.info(f"Input has Shape {binding_dims}")
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
else:
# each grid has 3 anchors, each anchor generates a detection
# output of 7 float32 values
assert size % 7 == 0, f"output size was {size}"
logger.info(f"Output has Shape {binding_dims}")
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
output_idx += 1
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
@ -159,7 +181,7 @@ class TensorRtDetector(DetectionApi):
# self.fps = EventsPerSecond()
self.conf_th = 0.4 ##TODO: model config parameter
self.nms_threshold = 0.4
self.trt_logger = trt.Logger(trt.Logger.INFO)
self.trt_logger = TrtLogger()
self.engine = self._load_engine(detector_config.model.path)
self.input_shape = self._get_input_shape()