From 53aa9faf7a9c097fd16c774ba9d0f114ddb1a0f6 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 30 Sep 2024 11:41:54 -0600 Subject: [PATCH] Improve rocm handling of different models --- docker/rocm/Dockerfile | 1 + frigate/object_detection.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docker/rocm/Dockerfile b/docker/rocm/Dockerfile index a1d6ce832..eebe04878 100644 --- a/docker/rocm/Dockerfile +++ b/docker/rocm/Dockerfile @@ -83,6 +83,7 @@ ARG AMDGPU COPY --from=rocm /opt/rocm-$ROCM/bin/rocminfo /opt/rocm-$ROCM/bin/migraphx-driver /opt/rocm-$ROCM/bin/ COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*$AMDGPU* /opt/rocm-$ROCM/share/miopen/db/ +COPY --from=rocm /opt/rocm-$ROCM/share/miopen/db/*gfx908* /opt/rocm-$ROCM/share/miopen/db/ COPY --from=rocm /opt/rocm-$ROCM/lib/rocblas/library/*$AMDGPU* /opt/rocm-$ROCM/lib/rocblas/library/ COPY --from=rocm /opt/rocm-dist/ / COPY --from=debian-build /opt/rocm/lib/migraphx.cpython-39-x86_64-linux-gnu.so /opt/rocm-$ROCM/lib/ diff --git a/frigate/object_detection.py b/frigate/object_detection.py index eac019a7a..7784a5520 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -12,7 +12,8 @@ from setproctitle import setproctitle import frigate.util as util from frigate.detectors import create_detector -from frigate.detectors.detector_config import InputTensorEnum +from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum +from frigate.detectors.plugins.rocm import ROCmDetector from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.image import SharedMemoryFrameManager from frigate.util.services import listen @@ -22,11 +23,11 @@ logger = logging.getLogger(__name__) class ObjectDetector(ABC): @abstractmethod - def detect(self, tensor_input, threshold=0.4): + def detect(self, tensor_input, threshold: float = 0.4): pass -def tensor_transform(desired_shape): +def tensor_transform(desired_shape: InputTensorEnum): # Currently this function only supports BHWC permutations if desired_shape == InputTensorEnum.nhwc: return None @@ -37,8 +38,8 @@ def tensor_transform(desired_shape): class LocalObjectDetector(ObjectDetector): def __init__( self, - detector_config=None, - labels=None, + detector_config: BaseDetectorConfig = None, + labels: str = None, ): self.fps = EventsPerSecond() if labels is None: @@ -47,7 +48,11 @@ class LocalObjectDetector(ObjectDetector): self.labels = load_labels(labels) if detector_config: - self.input_transform = tensor_transform(detector_config.model.input_tensor) + if detector_config is ROCmDetector: + # ROCm requires NHWC as input + self.input_transform = None + else: + self.input_transform = tensor_transform(detector_config.model.input_tensor) else: self.input_transform = None