mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Handle TensorRT logs with the python logger
This commit is contained in:
parent
aa9271e363
commit
643c3f21cd
@ -5,8 +5,6 @@ import numpy as np
|
|||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
from cuda import cuda, cudart
|
from cuda import cuda, cudart
|
||||||
|
|
||||||
# import pycuda.driver as cuda
|
|
||||||
# import pycuda.autoinit # This is needed for initializing CUDA driver
|
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -17,6 +15,28 @@ logger = logging.getLogger(__name__)
|
|||||||
DETECTOR_KEY = "tensorrt"
|
DETECTOR_KEY = "tensorrt"
|
||||||
|
|
||||||
|
|
||||||
|
class TrtLogger(trt.ILogger):
|
||||||
|
def __init__(self):
|
||||||
|
trt.ILogger.__init__(self)
|
||||||
|
|
||||||
|
def log(self, severity, msg):
|
||||||
|
logger.log(self.getSeverity(severity), msg)
|
||||||
|
|
||||||
|
def getSeverity(self, sev: trt.ILogger.Severity) -> int:
|
||||||
|
if sev == trt.ILogger.VERBOSE:
|
||||||
|
return logging.DEBUG
|
||||||
|
elif sev == trt.ILogger.INFO:
|
||||||
|
return logging.INFO
|
||||||
|
elif sev == trt.ILogger.WARNING:
|
||||||
|
return logging.WARNING
|
||||||
|
elif sev == trt.ILogger.ERROR:
|
||||||
|
return logging.ERROR
|
||||||
|
elif sev == trt.ILogger.INTERNAL_ERROR:
|
||||||
|
return logging.CRITICAL
|
||||||
|
else:
|
||||||
|
return logging.DEBUG
|
||||||
|
|
||||||
|
|
||||||
class TensorRTDetectorConfig(BaseDetectorConfig):
|
class TensorRTDetectorConfig(BaseDetectorConfig):
|
||||||
type: Literal[DETECTOR_KEY]
|
type: Literal[DETECTOR_KEY]
|
||||||
device: str = Field(default=None, title="Device Type")
|
device: str = Field(default=None, title="Device Type")
|
||||||
@ -113,11 +133,13 @@ class TensorRtDetector(DetectionApi):
|
|||||||
bindings.append(int(device_mem))
|
bindings.append(int(device_mem))
|
||||||
# Append to the appropriate list.
|
# Append to the appropriate list.
|
||||||
if self.engine.binding_is_input(binding):
|
if self.engine.binding_is_input(binding):
|
||||||
|
logger.info(f"Input has Shape {binding_dims}")
|
||||||
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||||
else:
|
else:
|
||||||
# each grid has 3 anchors, each anchor generates a detection
|
# each grid has 3 anchors, each anchor generates a detection
|
||||||
# output of 7 float32 values
|
# output of 7 float32 values
|
||||||
assert size % 7 == 0, f"output size was {size}"
|
assert size % 7 == 0, f"output size was {size}"
|
||||||
|
logger.info(f"Output has Shape {binding_dims}")
|
||||||
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
|
||||||
output_idx += 1
|
output_idx += 1
|
||||||
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
|
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
|
||||||
@ -159,7 +181,7 @@ class TensorRtDetector(DetectionApi):
|
|||||||
# self.fps = EventsPerSecond()
|
# self.fps = EventsPerSecond()
|
||||||
self.conf_th = 0.4 ##TODO: model config parameter
|
self.conf_th = 0.4 ##TODO: model config parameter
|
||||||
self.nms_threshold = 0.4
|
self.nms_threshold = 0.4
|
||||||
self.trt_logger = trt.Logger(trt.Logger.INFO)
|
self.trt_logger = TrtLogger()
|
||||||
self.engine = self._load_engine(detector_config.model.path)
|
self.engine = self._load_engine(detector_config.model.path)
|
||||||
self.input_shape = self._get_input_shape()
|
self.input_shape = self._get_input_shape()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user