mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-12 22:25:24 +03:00
support for yolo-nas in openvino
This commit is contained in:
parent
9381f257fa
commit
ebf1f54fc9
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -10,6 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DetectionApi(ABC):
|
||||
type_key: str
|
||||
supported_models: List[ModelTypeEnum]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, detector_config):
|
||||
|
||||
@ -20,6 +20,7 @@ class OvDetectorConfig(BaseDetectorConfig):
|
||||
|
||||
class OvDetector(DetectionApi):
|
||||
type_key = DETECTOR_KEY
|
||||
supported_models = [ModelTypeEnum.ssd, ModelTypeEnum.yolonas, ModelTypeEnum.yolox]
|
||||
|
||||
def __init__(self, detector_config: OvDetectorConfig):
|
||||
self.ov_core = ov.Core()
|
||||
@ -34,6 +35,12 @@ class OvDetector(DetectionApi):
|
||||
|
||||
self.model_invalid = False
|
||||
|
||||
if self.ov_model_type not in self.supported_models:
|
||||
logger.error(
|
||||
f"OpenVino detector does not support {self.ov_model_type} models."
|
||||
)
|
||||
self.model_invalid = True
|
||||
|
||||
# Ensure the SSD model has the right input and output shapes
|
||||
if self.ov_model_type == ModelTypeEnum.ssd:
|
||||
model_inputs = self.interpreter.inputs
|
||||
@ -61,6 +68,34 @@ class OvDetector(DetectionApi):
|
||||
logger.error(f"SSD model output doesn't match. Found {output_shape}.")
|
||||
self.model_invalid = True
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.yolonas:
|
||||
model_inputs = self.interpreter.inputs
|
||||
model_outputs = self.interpreter.outputs
|
||||
|
||||
if len(model_inputs) != 1:
|
||||
logger.error(
|
||||
f"YoloNAS models must only have 1 input. Found {len(model_inputs)}."
|
||||
)
|
||||
self.model_invalid = True
|
||||
if len(model_outputs) != 1:
|
||||
logger.error(
|
||||
f"YoloNAS models must be exported in flat format and only have 1 output. Found {len(model_outputs)}."
|
||||
)
|
||||
self.model_invalid = True
|
||||
|
||||
if model_inputs[0].get_shape() != ov.Shape([1, 3, self.w, self.h]):
|
||||
logger.error(
|
||||
f"YoloNAS model input doesn't match. Found {model_inputs[0].get_shape()}, but expected {[1, 3, self.w, self.h]}."
|
||||
)
|
||||
self.model_invalid = True
|
||||
|
||||
output_shape = model_outputs[0].partial_shape
|
||||
if output_shape[-1] != 7:
|
||||
logger.error(
|
||||
f"YoloNAS models must be exported in flat format. Model output doesn't match. Found {output_shape}."
|
||||
)
|
||||
self.model_invalid = True
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.yolox:
|
||||
self.output_indexes = 0
|
||||
while True:
|
||||
@ -113,12 +148,12 @@ class OvDetector(DetectionApi):
|
||||
input_tensor = ov.Tensor(array=tensor_input)
|
||||
infer_request.infer(input_tensor)
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.ssd:
|
||||
detections = np.zeros((20, 6), np.float32)
|
||||
|
||||
if self.model_invalid:
|
||||
return detections
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.ssd:
|
||||
results = infer_request.get_output_tensor(0).data[0][0]
|
||||
|
||||
for i, (_, class_id, score, xmin, ymin, xmax, ymax) in enumerate(results):
|
||||
@ -134,6 +169,26 @@ class OvDetector(DetectionApi):
|
||||
]
|
||||
return detections
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.yolonas:
|
||||
predictions = infer_request.get_output_tensor(0).data
|
||||
|
||||
for i, prediction in enumerate(predictions):
|
||||
if i == 20:
|
||||
break
|
||||
(_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction
|
||||
# when running in GPU mode, empty predictions in the output have class_id of -1
|
||||
if class_id < 0:
|
||||
break
|
||||
detections[i] = [
|
||||
class_id,
|
||||
confidence,
|
||||
y_min / self.h,
|
||||
x_min / self.w,
|
||||
y_max / self.h,
|
||||
x_max / self.w,
|
||||
]
|
||||
return detections
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.yolox:
|
||||
out_tensor = infer_request.get_output_tensor()
|
||||
# [x, y, h, w, box_score, class_no_1, ..., class_no_80],
|
||||
@ -155,8 +210,6 @@ class OvDetector(DetectionApi):
|
||||
|
||||
ordered = dets[dets[:, 5].argsort()[::-1]][:20]
|
||||
|
||||
detections = np.zeros((20, 6), np.float32)
|
||||
|
||||
for i, object_detected in enumerate(ordered):
|
||||
detections[i] = self.process_yolo(
|
||||
object_detected[6], object_detected[5], object_detected[:4]
|
||||
|
||||
101
notebooks/YOLO_NAS_Pretrained_Export.ipynb
Normal file
101
notebooks/YOLO_NAS_Pretrained_Export.ipynb
Normal file
@ -0,0 +1,101 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rmuF9iKWTbdk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install -q super_gradients==3.7.1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dTB0jy_NNSFz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from super_gradients.common.object_names import Models\n",
|
||||
"from super_gradients.conversion import DetectionOutputFormatMode\n",
|
||||
"from super_gradients.training import models\n",
|
||||
"\n",
|
||||
"model = models.get(Models.YOLO_NAS_S, pretrained_weights=\"coco\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "GymUghyCNXem"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# export the model for compatibility with Frigate\n",
|
||||
"\n",
|
||||
"model.export(\"yolo_nas_s.onnx\",\n",
|
||||
" output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT,\n",
|
||||
" max_predictions_per_image=20,\n",
|
||||
" confidence_threshold=0.4,\n",
|
||||
" input_image_shape=(320,320),\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 17
|
||||
},
|
||||
"id": "uBhXV5g4Nh42",
|
||||
"outputId": "303104fa-97dd-4efd-8b56-808e0ea00166"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/javascript": "\n async function download(id, filename, size) {\n if (!google.colab.kernel.accessAllowed) {\n return;\n }\n const div = document.createElement('div');\n const label = document.createElement('label');\n label.textContent = `Downloading \"${filename}\": `;\n div.appendChild(label);\n const progress = document.createElement('progress');\n progress.max = size;\n div.appendChild(progress);\n document.body.appendChild(div);\n\n const buffers = [];\n let downloaded = 0;\n\n const channel = await google.colab.kernel.comms.open(id);\n // Send a message to notify the kernel that we're ready.\n channel.send({})\n\n for await (const message of channel.messages) {\n // Send a message to notify the kernel that we're ready.\n channel.send({})\n if (message.buffers) {\n for (const buffer of message.buffers) {\n buffers.push(buffer);\n downloaded += buffer.byteLength;\n progress.value = downloaded;\n }\n }\n }\n const blob = new Blob(buffers, {type: 'application/binary'});\n const a = document.createElement('a');\n a.href = window.URL.createObjectURL(blob);\n a.download = filename;\n div.appendChild(a);\n a.click();\n div.remove();\n }\n ",
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Javascript object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/javascript": "download(\"download_9976a3d8-3025-4b5b-b060-6a3aed3c752f\", \"yolo_nas_s.onnx\", 48803558)",
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Javascript object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from google.colab import files\n",
|
||||
"\n",
|
||||
"files.download('yolo_nas_s.onnx')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user