mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Handle TensorRT logs with the python logger
This commit is contained in:
parent
aa9271e363
commit
643c3f21cd
@ -5,8 +5,6 @@ import numpy as np
|
||||
import tensorrt as trt
|
||||
from cuda import cuda, cudart
|
||||
|
||||
# import pycuda.driver as cuda
|
||||
# import pycuda.autoinit # This is needed for initializing CUDA driver
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
from typing import Literal
|
||||
@ -17,6 +15,28 @@ logger = logging.getLogger(__name__)
|
||||
DETECTOR_KEY = "tensorrt"
|
||||
|
||||
|
||||
class TrtLogger(trt.ILogger):
|
||||
def __init__(self):
|
||||
trt.ILogger.__init__(self)
|
||||
|
||||
def log(self, severity, msg):
|
||||
logger.log(self.getSeverity(severity), msg)
|
||||
|
||||
def getSeverity(self, sev: trt.ILogger.Severity) -> int:
|
||||
if sev == trt.ILogger.VERBOSE:
|
||||
return logging.DEBUG
|
||||
elif sev == trt.ILogger.INFO:
|
||||
return logging.INFO
|
||||
elif sev == trt.ILogger.WARNING:
|
||||
return logging.WARNING
|
||||
elif sev == trt.ILogger.ERROR:
|
||||
return logging.ERROR
|
||||
elif sev == trt.ILogger.INTERNAL_ERROR:
|
||||
return logging.CRITICAL
|
||||
else:
|
||||
return logging.DEBUG
|
||||
|
||||
|
||||
class TensorRTDetectorConfig(BaseDetectorConfig):
|
||||
type: Literal[DETECTOR_KEY]
|
||||
device: str = Field(default=None, title="Device Type")
|
||||
@ -113,11 +133,13 @@ class TensorRtDetector(DetectionApi):
|
||||
bindings.append(int(device_mem))
|
||||
# Append to the appropriate list.
|
||||
if self.engine.binding_is_input(binding):
|
||||
logger.info(f"Input has Shape {binding_dims}")
|
||||
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||
else:
|
||||
# each grid has 3 anchors, each anchor generates a detection
|
||||
# output of 7 float32 values
|
||||
assert size % 7 == 0, f"output size was {size}"
|
||||
logger.info(f"Output has Shape {binding_dims}")
|
||||
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||
output_idx += 1
|
||||
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
|
||||
@ -159,7 +181,7 @@ class TensorRtDetector(DetectionApi):
|
||||
# self.fps = EventsPerSecond()
|
||||
self.conf_th = 0.4 ##TODO: model config parameter
|
||||
self.nms_threshold = 0.4
|
||||
self.trt_logger = trt.Logger(trt.Logger.INFO)
|
||||
self.trt_logger = TrtLogger()
|
||||
self.engine = self._load_engine(detector_config.model.path)
|
||||
self.input_shape = self._get_input_shape()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user