mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Fixes from rebase to detector factory
This commit is contained in:
parent
556d82da7a
commit
36d2d205e7
@ -11,9 +11,13 @@ from cuda import cuda as cuda
|
|||||||
# from .object_detector import ObjectDetector
|
# from .object_detector import ObjectDetector
|
||||||
# import pycuda.autoinit # This is needed for initializing CUDA driver
|
# import pycuda.autoinit # This is needed for initializing CUDA driver
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
|
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||||
|
from typing import Literal
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DETECTOR_KEY = "tensorrt"
|
||||||
|
|
||||||
# def object_detector_factory(detector_config: DetectorConfig, model_path: str):
|
# def object_detector_factory(detector_config: DetectorConfig, model_path: str):
|
||||||
# if detector_config.type != DetectorTypeEnum.tensorrt:
|
# if detector_config.type != DetectorTypeEnum.tensorrt:
|
||||||
@ -25,6 +29,11 @@ logger = logging.getLogger(__name__)
|
|||||||
# return LocalObjectDetector(detector_config, model_path)
|
# return LocalObjectDetector(detector_config, model_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorRTDetectorConfig(BaseDetectorConfig):
|
||||||
|
type: Literal[DETECTOR_KEY]
|
||||||
|
device: str = Field(default=None, title="Device Type")
|
||||||
|
|
||||||
|
|
||||||
class HostDeviceMem(object):
|
class HostDeviceMem(object):
|
||||||
"""Simple helper data class that's a little nicer to use than a 2-tuple."""
|
"""Simple helper data class that's a little nicer to use than a 2-tuple."""
|
||||||
|
|
||||||
@ -47,6 +56,7 @@ class HostDeviceMem(object):
|
|||||||
|
|
||||||
|
|
||||||
class TensorRtDetector(DetectionApi):
|
class TensorRtDetector(DetectionApi):
|
||||||
|
type_key = DETECTOR_KEY
|
||||||
# class LocalObjectDetector(ObjectDetector):
|
# class LocalObjectDetector(ObjectDetector):
|
||||||
def _load_engine(self, model_path):
|
def _load_engine(self, model_path):
|
||||||
try:
|
try:
|
||||||
@ -151,13 +161,13 @@ class TensorRtDetector(DetectionApi):
|
|||||||
# Return only the host outputs.
|
# Return only the host outputs.
|
||||||
return [np.array([int(out.host_dev)], dtype=np.float32) for out in self.outputs]
|
return [np.array([int(out.host_dev)], dtype=np.float32) for out in self.outputs]
|
||||||
|
|
||||||
def __init__(self, det_device=None, model_config=None, num_threads=1):
|
def __init__(self, detector_config: TensorRTDetectorConfig):
|
||||||
# def __init__(self, detector_config: DetectorConfig, model_path: str):
|
# def __init__(self, detector_config: DetectorConfig, model_path: str):
|
||||||
# self.fps = EventsPerSecond()
|
# self.fps = EventsPerSecond()
|
||||||
self.conf_th = 0.4 ##TODO: model config parameter
|
self.conf_th = 0.4 ##TODO: model config parameter
|
||||||
self.nms_threshold = 0.4
|
self.nms_threshold = 0.4
|
||||||
self.trt_logger = trt.Logger(trt.Logger.INFO)
|
self.trt_logger = trt.Logger(trt.Logger.INFO)
|
||||||
self.engine = self._load_engine(model_config.path)
|
self.engine = self._load_engine(detector_config.model.path)
|
||||||
self.input_shape = self._get_input_shape()
|
self.input_shape = self._get_input_shape()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -34,4 +34,4 @@ nvidia-cuda-runtime-cu117 == 11.7.*; platform_machine == 'x86_64'
|
|||||||
nvidia-cublas-cu11 == 2022.4.8; platform_machine == 'x86_64'
|
nvidia-cublas-cu11 == 2022.4.8; platform_machine == 'x86_64'
|
||||||
nvidia-cublas-cu117 == 11.10.*; platform_machine == 'x86_64'
|
nvidia-cublas-cu117 == 11.10.*; platform_machine == 'x86_64'
|
||||||
nvidia-cudnn-cu11 == 2022.5.19; platform_machine == 'x86_64'
|
nvidia-cudnn-cu11 == 2022.5.19; platform_machine == 'x86_64'
|
||||||
nvidia-cudnn-cu116 == 8.4.1*; platform_machine == 'x86_64'
|
nvidia-cudnn-cu116 == 8.4.*; platform_machine == 'x86_64'
|
||||||
Loading…
Reference in New Issue
Block a user