mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Fixes from rebase to detector factory
This commit is contained in:
parent
556d82da7a
commit
36d2d205e7
@ -11,9 +11,13 @@ from cuda import cuda as cuda
|
||||
# from .object_detector import ObjectDetector
|
||||
# import pycuda.autoinit # This is needed for initializing CUDA driver
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DETECTOR_KEY = "tensorrt"
|
||||
|
||||
# def object_detector_factory(detector_config: DetectorConfig, model_path: str):
|
||||
# if detector_config.type != DetectorTypeEnum.tensorrt:
|
||||
@ -25,6 +29,11 @@ logger = logging.getLogger(__name__)
|
||||
# return LocalObjectDetector(detector_config, model_path)
|
||||
|
||||
|
||||
class TensorRTDetectorConfig(BaseDetectorConfig):
|
||||
type: Literal[DETECTOR_KEY]
|
||||
device: str = Field(default=None, title="Device Type")
|
||||
|
||||
|
||||
class HostDeviceMem(object):
|
||||
"""Simple helper data class that's a little nicer to use than a 2-tuple."""
|
||||
|
||||
@ -47,6 +56,7 @@ class HostDeviceMem(object):
|
||||
|
||||
|
||||
class TensorRtDetector(DetectionApi):
|
||||
type_key = DETECTOR_KEY
|
||||
# class LocalObjectDetector(ObjectDetector):
|
||||
def _load_engine(self, model_path):
|
||||
try:
|
||||
@ -151,13 +161,13 @@ class TensorRtDetector(DetectionApi):
|
||||
# Return only the host outputs.
|
||||
return [np.array([int(out.host_dev)], dtype=np.float32) for out in self.outputs]
|
||||
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=1):
|
||||
def __init__(self, detector_config: TensorRTDetectorConfig):
|
||||
# def __init__(self, detector_config: DetectorConfig, model_path: str):
|
||||
# self.fps = EventsPerSecond()
|
||||
self.conf_th = 0.4 ##TODO: model config parameter
|
||||
self.nms_threshold = 0.4
|
||||
self.trt_logger = trt.Logger(trt.Logger.INFO)
|
||||
self.engine = self._load_engine(model_config.path)
|
||||
self.engine = self._load_engine(detector_config.model.path)
|
||||
self.input_shape = self._get_input_shape()
|
||||
|
||||
try:
|
||||
@ -34,4 +34,4 @@ nvidia-cuda-runtime-cu117 == 11.7.*; platform_machine == 'x86_64'
|
||||
nvidia-cublas-cu11 == 2022.4.8; platform_machine == 'x86_64'
|
||||
nvidia-cublas-cu117 == 11.10.*; platform_machine == 'x86_64'
|
||||
nvidia-cudnn-cu11 == 2022.5.19; platform_machine == 'x86_64'
|
||||
nvidia-cudnn-cu116 == 8.4.1*; platform_machine == 'x86_64'
|
||||
nvidia-cudnn-cu116 == 8.4.*; platform_machine == 'x86_64'
|
||||
Loading…
Reference in New Issue
Block a user