mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-14 23:25:25 +03:00
Implement ROCm detectors
This commit is contained in:
parent
efd1194307
commit
f532c9c333
@ -69,7 +69,6 @@ RUN apt-get -y install libnuma1
|
|||||||
|
|
||||||
WORKDIR /opt/frigate/
|
WORKDIR /opt/frigate/
|
||||||
COPY --from=rootfs / /
|
COPY --from=rootfs / /
|
||||||
COPY docker/rocm/rootfs/ /
|
|
||||||
|
|
||||||
#######################################################################
|
#######################################################################
|
||||||
FROM scratch AS rocm-dist
|
FROM scratch AS rocm-dist
|
||||||
|
|||||||
@ -1,20 +0,0 @@
|
|||||||
#!/command/with-contenv bash
|
|
||||||
# shellcheck shell=bash
|
|
||||||
# Compile YoloV8 ONNX files into ROCm MIGraphX files
|
|
||||||
|
|
||||||
OVERRIDE=$(cd /opt/frigate && python3 -c 'import frigate.detectors.plugins.rocm as rocm; print(rocm.auto_override_gfx_version())')
|
|
||||||
|
|
||||||
if ! test -z "$OVERRIDE"; then
|
|
||||||
echo "Using HSA_OVERRIDE_GFX_VERSION=${OVERRIDE}"
|
|
||||||
export HSA_OVERRIDE_GFX_VERSION=$OVERRIDE
|
|
||||||
fi
|
|
||||||
|
|
||||||
for onnx in /config/model_cache/yolov8/*.onnx
|
|
||||||
do
|
|
||||||
mxr="${onnx%.onnx}.mxr"
|
|
||||||
if ! test -f $mxr; then
|
|
||||||
echo "processing $onnx into $mxr"
|
|
||||||
/opt/rocm/bin/migraphx-driver compile $onnx --optimize --gpu --enable-offload-copy --binary -o $mxr
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
@ -1 +0,0 @@
|
|||||||
oneshot
|
|
||||||
@ -1 +0,0 @@
|
|||||||
/etc/s6-overlay/s6-rc.d/compile-rocm-models/run
|
|
||||||
@ -24,7 +24,6 @@ from typing_extensions import Literal
|
|||||||
|
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||||
from frigate.detectors.util import preprocess # Assuming this function is available
|
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -146,17 +145,9 @@ class HailoDetector(DetectionApi):
|
|||||||
f"[detect_raw] Converted tensor_input to numpy array: shape {tensor_input.shape}"
|
f"[detect_raw] Converted tensor_input to numpy array: shape {tensor_input.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Preprocess the tensor input using Frigate's preprocess function
|
input_data = tensor_input
|
||||||
processed_tensor = preprocess(
|
|
||||||
tensor_input, (1, self.h8l_model_height, self.h8l_model_width, 3), np.uint8
|
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[detect_raw] Tensor data and shape after preprocessing: {processed_tensor} {processed_tensor.shape}"
|
f"[detect_raw] Input data for inference shape: {tensor_input.shape}, dtype: {tensor_input.dtype}"
|
||||||
)
|
|
||||||
|
|
||||||
input_data = processed_tensor
|
|
||||||
logger.debug(
|
|
||||||
f"[detect_raw] Input data for inference shape: {processed_tensor.shape}, dtype: {processed_tensor.dtype}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -4,13 +4,13 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
||||||
from frigate.detectors.util import preprocess
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -74,8 +74,14 @@ class ROCmDetector(DetectionApi):
|
|||||||
logger.error("AMD/ROCm: module loading failed, missing ROCm environment?")
|
logger.error("AMD/ROCm: module loading failed, missing ROCm environment?")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if detector_config.conserve_cpu:
|
||||||
|
logger.info("AMD/ROCm: switching HIP to blocking mode to conserve CPU")
|
||||||
|
ctypes.CDLL("/opt/rocm/lib/libamdhip64.so").hipSetDeviceFlags(4)
|
||||||
|
|
||||||
|
self.rocm_model_type = detector_config.model.model_type
|
||||||
path = detector_config.model.path
|
path = detector_config.model.path
|
||||||
mxr_path = os.path.splitext(path)[0] + ".mxr"
|
mxr_path = os.path.splitext(path)[0] + ".mxr"
|
||||||
|
|
||||||
if path.endswith(".mxr"):
|
if path.endswith(".mxr"):
|
||||||
logger.info(f"AMD/ROCm: loading parsed model from {mxr_path}")
|
logger.info(f"AMD/ROCm: loading parsed model from {mxr_path}")
|
||||||
self.model = migraphx.load(mxr_path)
|
self.model = migraphx.load(mxr_path)
|
||||||
@ -84,6 +90,7 @@ class ROCmDetector(DetectionApi):
|
|||||||
self.model = migraphx.load(mxr_path)
|
self.model = migraphx.load(mxr_path)
|
||||||
else:
|
else:
|
||||||
logger.info(f"AMD/ROCm: loading model from {path}")
|
logger.info(f"AMD/ROCm: loading model from {path}")
|
||||||
|
|
||||||
if path.endswith(".onnx"):
|
if path.endswith(".onnx"):
|
||||||
self.model = migraphx.parse_onnx(path)
|
self.model = migraphx.parse_onnx(path)
|
||||||
elif (
|
elif (
|
||||||
@ -95,13 +102,18 @@ class ROCmDetector(DetectionApi):
|
|||||||
self.model = migraphx.parse_tf(path)
|
self.model = migraphx.parse_tf(path)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"AMD/ROCm: unknown model format {path}")
|
raise Exception(f"AMD/ROCm: unknown model format {path}")
|
||||||
|
|
||||||
logger.info("AMD/ROCm: compiling the model")
|
logger.info("AMD/ROCm: compiling the model")
|
||||||
|
|
||||||
self.model.compile(
|
self.model.compile(
|
||||||
migraphx.get_target("gpu"), offload_copy=True, fast_math=True
|
migraphx.get_target("gpu"), offload_copy=True, fast_math=True
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"AMD/ROCm: saving parsed model into {mxr_path}")
|
logger.info(f"AMD/ROCm: saving parsed model into {mxr_path}")
|
||||||
|
|
||||||
os.makedirs("/config/model_cache/rocm", exist_ok=True)
|
os.makedirs("/config/model_cache/rocm", exist_ok=True)
|
||||||
migraphx.save(self.model, mxr_path)
|
migraphx.save(self.model, mxr_path)
|
||||||
|
|
||||||
logger.info("AMD/ROCm: model loaded")
|
logger.info("AMD/ROCm: model loaded")
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input):
|
||||||
@ -109,16 +121,28 @@ class ROCmDetector(DetectionApi):
|
|||||||
model_input_shape = tuple(
|
model_input_shape = tuple(
|
||||||
self.model.get_parameter_shapes()[model_input_name].lens()
|
self.model.get_parameter_shapes()[model_input_name].lens()
|
||||||
)
|
)
|
||||||
tensor_input = preprocess(tensor_input, model_input_shape, np.float32)
|
logger.info(f"the model input shape is {model_input_shape}")
|
||||||
|
|
||||||
|
tensor_input = cv2.dnn.blobFromImage(
|
||||||
|
tensor_input[0],
|
||||||
|
1.0 / 255,
|
||||||
|
model_input_shape,
|
||||||
|
None,
|
||||||
|
swapRB=False,
|
||||||
|
)
|
||||||
|
|
||||||
detector_result = self.model.run({model_input_name: tensor_input})[0]
|
detector_result = self.model.run({model_input_name: tensor_input})[0]
|
||||||
|
|
||||||
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||||
|
|
||||||
# ruff: noqa: F841
|
# ruff: noqa: F841
|
||||||
tensor_output = np.ctypeslib.as_array(
|
tensor_output = np.ctypeslib.as_array(
|
||||||
addr, shape=detector_result.get_shape().lens()
|
addr, shape=detector_result.get_shape().lens()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.rocm_model_type == ModelTypeEnum.yolonas:
|
||||||
|
logger.info(f"ROCM output has {tensor_output.shape[2]} boxes")
|
||||||
|
return np.zeros((20, 6), np.float32)
|
||||||
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"No models are currently supported for rocm. See the docs for more info."
|
f"{self.rocm_model_type} is currently not supported for rocm. See the docs for more info on supported models."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,36 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(tensor_input, model_input_shape, model_input_element_type):
|
|
||||||
model_input_shape = tuple(model_input_shape)
|
|
||||||
assert tensor_input.dtype == np.uint8, f"tensor_input.dtype: {tensor_input.dtype}"
|
|
||||||
if len(tensor_input.shape) == 3:
|
|
||||||
tensor_input = tensor_input[np.newaxis, :]
|
|
||||||
if model_input_element_type == np.uint8:
|
|
||||||
# nothing to do for uint8 model input
|
|
||||||
assert (
|
|
||||||
model_input_shape == tensor_input.shape
|
|
||||||
), f"model_input_shape: {model_input_shape}, tensor_input.shape: {tensor_input.shape}"
|
|
||||||
return tensor_input
|
|
||||||
assert (
|
|
||||||
model_input_element_type == np.float32
|
|
||||||
), f"model_input_element_type: {model_input_element_type}"
|
|
||||||
# tensor_input must be nhwc
|
|
||||||
assert tensor_input.shape[3] == 3, f"tensor_input.shape: {tensor_input.shape}"
|
|
||||||
if tensor_input.shape[1:3] != model_input_shape[2:4]:
|
|
||||||
logger.warn(
|
|
||||||
f"preprocess: tensor_input.shape {tensor_input.shape} and model_input_shape {model_input_shape} do not match!"
|
|
||||||
)
|
|
||||||
# cv2.dnn.blobFromImage is faster than running it through numpy
|
|
||||||
return cv2.dnn.blobFromImage(
|
|
||||||
tensor_input[0],
|
|
||||||
1.0 / 255,
|
|
||||||
(model_input_shape[3], model_input_shape[2]),
|
|
||||||
None,
|
|
||||||
swapRB=False,
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue
Block a user