mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Add input type for dtype
This commit is contained in:
parent
eca504cb07
commit
41102d82d8
@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum):
|
|||||||
nhwc = "nhwc"
|
nhwc = "nhwc"
|
||||||
|
|
||||||
|
|
||||||
|
class InputDTypeEnum(str, Enum):
|
||||||
|
float = "float"
|
||||||
|
int = "int"
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeEnum(str, Enum):
|
class ModelTypeEnum(str, Enum):
|
||||||
ssd = "ssd"
|
ssd = "ssd"
|
||||||
yolox = "yolox"
|
yolox = "yolox"
|
||||||
@ -53,6 +58,9 @@ class ModelConfig(BaseModel):
|
|||||||
input_pixel_format: PixelFormatEnum = Field(
|
input_pixel_format: PixelFormatEnum = Field(
|
||||||
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
||||||
)
|
)
|
||||||
|
input_dtype: InputDTypeEnum = Field(
|
||||||
|
default=InputDTypeEnum.int, title="Model Input D Type"
|
||||||
|
)
|
||||||
model_type: ModelTypeEnum = Field(
|
model_type: ModelTypeEnum = Field(
|
||||||
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
|
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi):
|
|||||||
|
|
||||||
logger.info(f"ONNX: {path} loaded")
|
logger.info(f"ONNX: {path} loaded")
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
model_input_name = self.model.get_inputs()[0].name
|
model_input_name = self.model.get_inputs()[0].name
|
||||||
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,11 @@ from setproctitle import setproctitle
|
|||||||
|
|
||||||
import frigate.util as util
|
import frigate.util as util
|
||||||
from frigate.detectors import create_detector
|
from frigate.detectors import create_detector
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
|
from frigate.detectors.detector_config import (
|
||||||
|
BaseDetectorConfig,
|
||||||
|
InputDTypeEnum,
|
||||||
|
InputTensorEnum,
|
||||||
|
)
|
||||||
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
||||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||||
from frigate.util.image import SharedMemoryFrameManager
|
from frigate.util.image import SharedMemoryFrameManager
|
||||||
@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.input_transform = tensor_transform(
|
self.input_transform = tensor_transform(
|
||||||
detector_config.model.input_tensor
|
detector_config.model.input_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.dtype = detector_config.model.input_dtype
|
||||||
else:
|
else:
|
||||||
self.input_transform = None
|
self.input_transform = None
|
||||||
|
self.dtype = InputDTypeEnum.int
|
||||||
|
|
||||||
self.detect_api = create_detector(detector_config)
|
self.detect_api = create_detector(detector_config)
|
||||||
|
|
||||||
def detect(self, tensor_input, threshold=0.4):
|
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
raw_detections = self.detect_raw(tensor_input)
|
raw_detections = self.detect_raw(tensor_input)
|
||||||
@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.fps.update()
|
self.fps.update()
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
if self.input_transform:
|
if self.input_transform:
|
||||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||||
|
|
||||||
|
if self.dtype == InputDTypeEnum.float:
|
||||||
|
tensor_input = tensor_input.astype(np.float32)
|
||||||
|
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user