Add input type for dtype

This commit is contained in:
Nicolas Mowen 2024-10-28 09:48:28 -06:00
parent eca504cb07
commit 41102d82d8
3 changed files with 23 additions and 4 deletions

View File

@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum):
nhwc = "nhwc"
class InputDTypeEnum(str, Enum):
float = "float"
int = "int"
class ModelTypeEnum(str, Enum):
ssd = "ssd"
yolox = "yolox"
@ -53,6 +58,9 @@ class ModelConfig(BaseModel):
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
)
input_dtype: InputDTypeEnum = Field(
default=InputDTypeEnum.int, title="Model Input D Type"
)
model_type: ModelTypeEnum = Field(
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
)

View File

@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi):
logger.info(f"ONNX: {path} loaded")
def detect_raw(self, tensor_input):
def detect_raw(self, tensor_input: np.ndarray):
model_input_name = self.model.get_inputs()[0].name
tensor_output = self.model.run(None, {model_input_name: tensor_input})

View File

@ -12,7 +12,11 @@ from setproctitle import setproctitle
import frigate.util as util
from frigate.detectors import create_detector
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
from frigate.detectors.detector_config import (
BaseDetectorConfig,
InputDTypeEnum,
InputTensorEnum,
)
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
from frigate.util.builtin import EventsPerSecond, load_labels
from frigate.util.image import SharedMemoryFrameManager
@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
self.input_transform = tensor_transform(
detector_config.model.input_tensor
)
self.dtype = detector_config.model.input_dtype
else:
self.input_transform = None
self.dtype = InputDTypeEnum.int
self.detect_api = create_detector(detector_config)
def detect(self, tensor_input, threshold=0.4):
def detect(self, tensor_input: np.ndarray, threshold=0.4):
detections = []
raw_detections = self.detect_raw(tensor_input)
@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
self.fps.update()
return detections
def detect_raw(self, tensor_input):
def detect_raw(self, tensor_input: np.ndarray):
if self.input_transform:
tensor_input = np.transpose(tensor_input, self.input_transform)
if self.dtype == InputDTypeEnum.float:
tensor_input = tensor_input.astype(np.float32)
return self.detect_api.detect_raw(tensor_input=tensor_input)