mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Fix parsing output memory pointer
This commit is contained in:
parent
36d2d205e7
commit
aa9271e363
@ -1,14 +1,11 @@
|
||||
import logging
|
||||
|
||||
# from frigate.config import DetectorConfig, DetectorTypeEnum
|
||||
# from frigate.util import EventsPerSecond
|
||||
import ctypes
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
from cuda import cuda as cuda
|
||||
from cuda import cuda, cudart
|
||||
|
||||
# import pycuda.driver as cuda
|
||||
# from .object_detector import ObjectDetector
|
||||
# import pycuda.autoinit # This is needed for initializing CUDA driver
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
@ -19,15 +16,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DETECTOR_KEY = "tensorrt"
|
||||
|
||||
# def object_detector_factory(detector_config: DetectorConfig, model_path: str):
|
||||
# if detector_config.type != DetectorTypeEnum.tensorrt:
|
||||
# return None
|
||||
# try:
|
||||
# ctypes.cdll.LoadLibrary("/yolo4/libyolo_layer.so")
|
||||
# except OSError as e:
|
||||
# logger.error("ERROR: failed to load /yolo4/libyolo_layer.so. %s", e)
|
||||
# return LocalObjectDetector(detector_config, model_path)
|
||||
|
||||
|
||||
class TensorRTDetectorConfig(BaseDetectorConfig):
|
||||
type: Literal[DETECTOR_KEY]
|
||||
@ -159,7 +147,12 @@ class TensorRtDetector(DetectionApi):
|
||||
# Synchronize the stream
|
||||
cuda.cuStreamSynchronize(self.stream)
|
||||
# Return only the host outputs.
|
||||
return [np.array([int(out.host_dev)], dtype=np.float32) for out in self.outputs]
|
||||
return [
|
||||
np.array(
|
||||
(ctypes.c_float * out.size).from_address(out.host), dtype=np.float32
|
||||
)
|
||||
for out in self.outputs
|
||||
]
|
||||
|
||||
def __init__(self, detector_config: TensorRTDetectorConfig):
|
||||
# def __init__(self, detector_config: DetectorConfig, model_path: str):
|
||||
@ -212,9 +205,6 @@ class TensorRtDetector(DetectionApi):
|
||||
|
||||
return detections
|
||||
|
||||
# def detect(self, tensor_input, threshold=0.4):
|
||||
# pass
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
# Input tensor has the shape of the [height, width, 3]
|
||||
# Output tensor of float32 of shape [20, 6] where:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user