mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-12 22:25:24 +03:00
support for yolo-nas in openvino
This commit is contained in:
parent
9381f257fa
commit
ebf1f54fc9
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DetectionApi(ABC):
|
class DetectionApi(ABC):
|
||||||
type_key: str
|
type_key: str
|
||||||
|
supported_models: List[ModelTypeEnum]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, detector_config):
|
def __init__(self, detector_config):
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class OvDetectorConfig(BaseDetectorConfig):
|
|||||||
|
|
||||||
class OvDetector(DetectionApi):
|
class OvDetector(DetectionApi):
|
||||||
type_key = DETECTOR_KEY
|
type_key = DETECTOR_KEY
|
||||||
|
supported_models = [ModelTypeEnum.ssd, ModelTypeEnum.yolonas, ModelTypeEnum.yolox]
|
||||||
|
|
||||||
def __init__(self, detector_config: OvDetectorConfig):
|
def __init__(self, detector_config: OvDetectorConfig):
|
||||||
self.ov_core = ov.Core()
|
self.ov_core = ov.Core()
|
||||||
@ -34,6 +35,12 @@ class OvDetector(DetectionApi):
|
|||||||
|
|
||||||
self.model_invalid = False
|
self.model_invalid = False
|
||||||
|
|
||||||
|
if self.ov_model_type not in self.supported_models:
|
||||||
|
logger.error(
|
||||||
|
f"OpenVino detector does not support {self.ov_model_type} models."
|
||||||
|
)
|
||||||
|
self.model_invalid = True
|
||||||
|
|
||||||
# Ensure the SSD model has the right input and output shapes
|
# Ensure the SSD model has the right input and output shapes
|
||||||
if self.ov_model_type == ModelTypeEnum.ssd:
|
if self.ov_model_type == ModelTypeEnum.ssd:
|
||||||
model_inputs = self.interpreter.inputs
|
model_inputs = self.interpreter.inputs
|
||||||
@ -61,6 +68,34 @@ class OvDetector(DetectionApi):
|
|||||||
logger.error(f"SSD model output doesn't match. Found {output_shape}.")
|
logger.error(f"SSD model output doesn't match. Found {output_shape}.")
|
||||||
self.model_invalid = True
|
self.model_invalid = True
|
||||||
|
|
||||||
|
if self.ov_model_type == ModelTypeEnum.yolonas:
|
||||||
|
model_inputs = self.interpreter.inputs
|
||||||
|
model_outputs = self.interpreter.outputs
|
||||||
|
|
||||||
|
if len(model_inputs) != 1:
|
||||||
|
logger.error(
|
||||||
|
f"YoloNAS models must only have 1 input. Found {len(model_inputs)}."
|
||||||
|
)
|
||||||
|
self.model_invalid = True
|
||||||
|
if len(model_outputs) != 1:
|
||||||
|
logger.error(
|
||||||
|
f"YoloNAS models must be exported in flat format and only have 1 output. Found {len(model_outputs)}."
|
||||||
|
)
|
||||||
|
self.model_invalid = True
|
||||||
|
|
||||||
|
if model_inputs[0].get_shape() != ov.Shape([1, 3, self.w, self.h]):
|
||||||
|
logger.error(
|
||||||
|
f"YoloNAS model input doesn't match. Found {model_inputs[0].get_shape()}, but expected {[1, 3, self.w, self.h]}."
|
||||||
|
)
|
||||||
|
self.model_invalid = True
|
||||||
|
|
||||||
|
output_shape = model_outputs[0].partial_shape
|
||||||
|
if output_shape[-1] != 7:
|
||||||
|
logger.error(
|
||||||
|
f"YoloNAS models must be exported in flat format. Model output doesn't match. Found {output_shape}."
|
||||||
|
)
|
||||||
|
self.model_invalid = True
|
||||||
|
|
||||||
if self.ov_model_type == ModelTypeEnum.yolox:
|
if self.ov_model_type == ModelTypeEnum.yolox:
|
||||||
self.output_indexes = 0
|
self.output_indexes = 0
|
||||||
while True:
|
while True:
|
||||||
@ -113,12 +148,12 @@ class OvDetector(DetectionApi):
|
|||||||
input_tensor = ov.Tensor(array=tensor_input)
|
input_tensor = ov.Tensor(array=tensor_input)
|
||||||
infer_request.infer(input_tensor)
|
infer_request.infer(input_tensor)
|
||||||
|
|
||||||
if self.ov_model_type == ModelTypeEnum.ssd:
|
|
||||||
detections = np.zeros((20, 6), np.float32)
|
detections = np.zeros((20, 6), np.float32)
|
||||||
|
|
||||||
if self.model_invalid:
|
if self.model_invalid:
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
if self.ov_model_type == ModelTypeEnum.ssd:
|
||||||
results = infer_request.get_output_tensor(0).data[0][0]
|
results = infer_request.get_output_tensor(0).data[0][0]
|
||||||
|
|
||||||
for i, (_, class_id, score, xmin, ymin, xmax, ymax) in enumerate(results):
|
for i, (_, class_id, score, xmin, ymin, xmax, ymax) in enumerate(results):
|
||||||
@ -134,6 +169,26 @@ class OvDetector(DetectionApi):
|
|||||||
]
|
]
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
if self.ov_model_type == ModelTypeEnum.yolonas:
|
||||||
|
predictions = infer_request.get_output_tensor(0).data
|
||||||
|
|
||||||
|
for i, prediction in enumerate(predictions):
|
||||||
|
if i == 20:
|
||||||
|
break
|
||||||
|
(_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction
|
||||||
|
# when running in GPU mode, empty predictions in the output have class_id of -1
|
||||||
|
if class_id < 0:
|
||||||
|
break
|
||||||
|
detections[i] = [
|
||||||
|
class_id,
|
||||||
|
confidence,
|
||||||
|
y_min / self.h,
|
||||||
|
x_min / self.w,
|
||||||
|
y_max / self.h,
|
||||||
|
x_max / self.w,
|
||||||
|
]
|
||||||
|
return detections
|
||||||
|
|
||||||
if self.ov_model_type == ModelTypeEnum.yolox:
|
if self.ov_model_type == ModelTypeEnum.yolox:
|
||||||
out_tensor = infer_request.get_output_tensor()
|
out_tensor = infer_request.get_output_tensor()
|
||||||
# [x, y, h, w, box_score, class_no_1, ..., class_no_80],
|
# [x, y, h, w, box_score, class_no_1, ..., class_no_80],
|
||||||
@ -155,8 +210,6 @@ class OvDetector(DetectionApi):
|
|||||||
|
|
||||||
ordered = dets[dets[:, 5].argsort()[::-1]][:20]
|
ordered = dets[dets[:, 5].argsort()[::-1]][:20]
|
||||||
|
|
||||||
detections = np.zeros((20, 6), np.float32)
|
|
||||||
|
|
||||||
for i, object_detected in enumerate(ordered):
|
for i, object_detected in enumerate(ordered):
|
||||||
detections[i] = self.process_yolo(
|
detections[i] = self.process_yolo(
|
||||||
object_detected[6], object_detected[5], object_detected[:4]
|
object_detected[6], object_detected[5], object_detected[:4]
|
||||||
|
|||||||
101
notebooks/YOLO_NAS_Pretrained_Export.ipynb
Normal file
101
notebooks/YOLO_NAS_Pretrained_Export.ipynb
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "rmuF9iKWTbdk"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"! pip install -q super_gradients==3.7.1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "dTB0jy_NNSFz"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from super_gradients.common.object_names import Models\n",
|
||||||
|
"from super_gradients.conversion import DetectionOutputFormatMode\n",
|
||||||
|
"from super_gradients.training import models\n",
|
||||||
|
"\n",
|
||||||
|
"model = models.get(Models.YOLO_NAS_S, pretrained_weights=\"coco\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "GymUghyCNXem"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# export the model for compatibility with Frigate\n",
|
||||||
|
"\n",
|
||||||
|
"model.export(\"yolo_nas_s.onnx\",\n",
|
||||||
|
" output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT,\n",
|
||||||
|
" max_predictions_per_image=20,\n",
|
||||||
|
" confidence_threshold=0.4,\n",
|
||||||
|
" input_image_shape=(320,320),\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/",
|
||||||
|
"height": 17
|
||||||
|
},
|
||||||
|
"id": "uBhXV5g4Nh42",
|
||||||
|
"outputId": "303104fa-97dd-4efd-8b56-808e0ea00166"
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/javascript": "\n async function download(id, filename, size) {\n if (!google.colab.kernel.accessAllowed) {\n return;\n }\n const div = document.createElement('div');\n const label = document.createElement('label');\n label.textContent = `Downloading \"${filename}\": `;\n div.appendChild(label);\n const progress = document.createElement('progress');\n progress.max = size;\n div.appendChild(progress);\n document.body.appendChild(div);\n\n const buffers = [];\n let downloaded = 0;\n\n const channel = await google.colab.kernel.comms.open(id);\n // Send a message to notify the kernel that we're ready.\n channel.send({})\n\n for await (const message of channel.messages) {\n // Send a message to notify the kernel that we're ready.\n channel.send({})\n if (message.buffers) {\n for (const buffer of message.buffers) {\n buffers.push(buffer);\n downloaded += buffer.byteLength;\n progress.value = downloaded;\n }\n }\n }\n const blob = new Blob(buffers, {type: 'application/binary'});\n const a = document.createElement('a');\n a.href = window.URL.createObjectURL(blob);\n a.download = filename;\n div.appendChild(a);\n a.click();\n div.remove();\n }\n ",
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.Javascript object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/javascript": "download(\"download_9976a3d8-3025-4b5b-b060-6a3aed3c752f\", \"yolo_nas_s.onnx\", 48803558)",
|
||||||
|
"text/plain": [
|
||||||
|
"<IPython.core.display.Javascript object>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from google.colab import files\n",
|
||||||
|
"\n",
|
||||||
|
"files.download('yolo_nas_s.onnx')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user