mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-22 23:11:54 +03:00
ran ruff
This commit is contained in:
parent
b2d83d0a9c
commit
515597345a
@ -18,13 +18,18 @@ except ModuleNotFoundError:
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum, InputTensorEnum
|
from frigate.detectors.detector_config import (
|
||||||
|
BaseDetectorConfig,
|
||||||
|
ModelTypeEnum,
|
||||||
|
InputTensorEnum,
|
||||||
|
)
|
||||||
from frigate.util.model import post_process_yolo
|
from frigate.util.model import post_process_yolo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DETECTOR_KEY = "memryx"
|
DETECTOR_KEY = "memryx"
|
||||||
|
|
||||||
|
|
||||||
# Configuration class for model settings
|
# Configuration class for model settings
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
path: str = Field(default=None, title="Model Path") # Path to the DFP file
|
path: str = Field(default=None, title="Model Path") # Path to the DFP file
|
||||||
@ -37,7 +42,7 @@ class MemryXDetectorConfig(BaseDetectorConfig):
|
|||||||
|
|
||||||
|
|
||||||
class MemryXDetector(DetectionApi):
|
class MemryXDetector(DetectionApi):
|
||||||
type_key = DETECTOR_KEY # Set the type key
|
type_key = DETECTOR_KEY # Set the type key
|
||||||
supported_models = [
|
supported_models = [
|
||||||
ModelTypeEnum.ssd,
|
ModelTypeEnum.ssd,
|
||||||
ModelTypeEnum.yolonas,
|
ModelTypeEnum.yolonas,
|
||||||
@ -54,7 +59,9 @@ class MemryXDetector(DetectionApi):
|
|||||||
if "model_type" in getattr(model_cfg, "__fields_set__", set()):
|
if "model_type" in getattr(model_cfg, "__fields_set__", set()):
|
||||||
detector_config.model.model_type = model_cfg.model_type
|
detector_config.model.model_type = model_cfg.model_type
|
||||||
else:
|
else:
|
||||||
logger.info("model_type not set in config — defaulting to yolonas for MemryX.")
|
logger.info(
|
||||||
|
"model_type not set in config — defaulting to yolonas for MemryX."
|
||||||
|
)
|
||||||
detector_config.model.model_type = ModelTypeEnum.yolonas
|
detector_config.model.model_type = ModelTypeEnum.yolonas
|
||||||
|
|
||||||
self.capture_queue = Queue(maxsize=10)
|
self.capture_queue = Queue(maxsize=10)
|
||||||
@ -65,13 +72,13 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
self.memx_model_path = detector_config.model.path # Path to .dfp file
|
||||||
self.memx_post_model = None # Path to .post file
|
self.memx_post_model = None # Path to .post file
|
||||||
self.expected_post_model = None
|
self.expected_post_model = None
|
||||||
|
|
||||||
self.memx_device_path = detector_config.device # Device path
|
self.memx_device_path = detector_config.device # Device path
|
||||||
# Parse the device string to split PCIe:<index>
|
# Parse the device string to split PCIe:<index>
|
||||||
device_str = self.memx_device_path
|
device_str = self.memx_device_path
|
||||||
self.device_id = []
|
self.device_id = []
|
||||||
self.device_id.append(int(device_str.split(":")[1]))
|
self.device_id.append(int(device_str.split(":")[1]))
|
||||||
|
|
||||||
self.memx_model_height = detector_config.model.height
|
self.memx_model_height = detector_config.model.height
|
||||||
self.memx_model_width = detector_config.model.width
|
self.memx_model_width = detector_config.model.width
|
||||||
self.memx_model_type = detector_config.model.model_type
|
self.memx_model_type = detector_config.model.model_type
|
||||||
@ -80,41 +87,51 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
(640, 640): ("https://developer.memryx.com/example_files/2p0_frigate/yolov9_640.zip", "yolov9_640"),
|
(640, 640): (
|
||||||
(320, 320): ("https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip", "yolov9_320")
|
"https://developer.memryx.com/example_files/2p0_frigate/yolov9_640.zip",
|
||||||
|
"yolov9_640",
|
||||||
|
),
|
||||||
|
(320, 320): (
|
||||||
|
"https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip",
|
||||||
|
"yolov9_320",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
self.model_url, self.model_folder = model_mapping.get(
|
self.model_url, self.model_folder = model_mapping.get(
|
||||||
(self.memx_model_height, self.memx_model_width),
|
(self.memx_model_height, self.memx_model_width),
|
||||||
("https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip", "yolov9_320")
|
(
|
||||||
)
|
"https://developer.memryx.com/example_files/2p0_frigate/yolov9_320.zip",
|
||||||
self.expected_dfp_model = (
|
"yolov9_320",
|
||||||
"YOLO_v9_small_onnx.dfp"
|
),
|
||||||
)
|
)
|
||||||
|
self.expected_dfp_model = "YOLO_v9_small_onnx.dfp"
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
(640, 640): ("https://developer.memryx.com/example_files/2p0_frigate/yolonas_640.zip", "yolonas_640"),
|
(640, 640): (
|
||||||
(320, 320): ("https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip", "yolonas_320")
|
"https://developer.memryx.com/example_files/2p0_frigate/yolonas_640.zip",
|
||||||
|
"yolonas_640",
|
||||||
|
),
|
||||||
|
(320, 320): (
|
||||||
|
"https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip",
|
||||||
|
"yolonas_320",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
self.model_url, self.model_folder = model_mapping.get(
|
self.model_url, self.model_folder = model_mapping.get(
|
||||||
(self.memx_model_height, self.memx_model_width),
|
(self.memx_model_height, self.memx_model_width),
|
||||||
("https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip", "yolonas_320")
|
(
|
||||||
)
|
"https://developer.memryx.com/example_files/2p0_frigate/yolonas_320.zip",
|
||||||
self.expected_dfp_model = (
|
"yolonas_320",
|
||||||
"yolo_nas_s.dfp"
|
),
|
||||||
)
|
|
||||||
self.expected_post_model = (
|
|
||||||
"yolo_nas_s_post.onnx"
|
|
||||||
)
|
)
|
||||||
|
self.expected_dfp_model = "yolo_nas_s.dfp"
|
||||||
|
self.expected_post_model = "yolo_nas_s_post.onnx"
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||||
self.model_folder = "yolox"
|
self.model_folder = "yolox"
|
||||||
self.model_url = (
|
self.model_url = (
|
||||||
"https://developer.memryx.com/example_files/2p0_frigate/yolox.zip"
|
"https://developer.memryx.com/example_files/2p0_frigate/yolox.zip"
|
||||||
)
|
)
|
||||||
self.expected_dfp_model = (
|
self.expected_dfp_model = "YOLOX_640_640_3_onnx.dfp"
|
||||||
"YOLOX_640_640_3_onnx.dfp"
|
|
||||||
)
|
|
||||||
self.set_strides_grids()
|
self.set_strides_grids()
|
||||||
|
|
||||||
elif self.memx_model_type == ModelTypeEnum.ssd:
|
elif self.memx_model_type == ModelTypeEnum.ssd:
|
||||||
@ -122,12 +139,8 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.model_url = (
|
self.model_url = (
|
||||||
"https://developer.memryx.com/example_files/2p0_frigate/ssd.zip"
|
"https://developer.memryx.com/example_files/2p0_frigate/ssd.zip"
|
||||||
)
|
)
|
||||||
self.expected_dfp_model = (
|
self.expected_dfp_model = "SSDlite_MobileNet_v2_320_320_3_onnx.dfp"
|
||||||
"SSDlite_MobileNet_v2_320_320_3_onnx.dfp"
|
self.expected_post_model = "SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
||||||
)
|
|
||||||
self.expected_post_model = (
|
|
||||||
"SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.check_and_prepare_model()
|
self.check_and_prepare_model()
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -143,9 +156,9 @@ class MemryXDetector(DetectionApi):
|
|||||||
self.accl = AsyncAccl(
|
self.accl = AsyncAccl(
|
||||||
self.memx_model_path,
|
self.memx_model_path,
|
||||||
device_ids=self.device_id, # AsyncAccl device ids
|
device_ids=self.device_id, # AsyncAccl device ids
|
||||||
local_mode=True
|
local_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
# Models that use cropped post-processing sections (YOLO-NAS and SSD)
|
||||||
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
# --> These will be moved to pure numpy in the future to improve performance on low-end CPUs
|
||||||
if self.memx_post_model:
|
if self.memx_post_model:
|
||||||
@ -161,19 +174,13 @@ class MemryXDetector(DetectionApi):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize MemryX model: {e}")
|
logger.error(f"Failed to initialize MemryX model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def load_yolo_constants(self):
|
def load_yolo_constants(self):
|
||||||
base = f"{self.cache_dir}/{self.model_folder}"
|
base = f"{self.cache_dir}/{self.model_folder}"
|
||||||
# constants for yolov9 post-processing
|
# constants for yolov9 post-processing
|
||||||
self.const_A = np.load(
|
self.const_A = np.load(f"{base}/_model_22_Constant_9_output_0.npy")
|
||||||
f"{base}/_model_22_Constant_9_output_0.npy"
|
self.const_B = np.load(f"{base}/_model_22_Constant_10_output_0.npy")
|
||||||
)
|
self.const_C = np.load(f"{base}/_model_22_Constant_12_output_0.npy")
|
||||||
self.const_B = np.load(
|
|
||||||
f"{base}/_model_22_Constant_10_output_0.npy"
|
|
||||||
)
|
|
||||||
self.const_C = np.load(
|
|
||||||
f"{base}/_model_22_Constant_12_output_0.npy"
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_and_prepare_model(self):
|
def check_and_prepare_model(self):
|
||||||
"""Check if models exist; if not, download and extract them."""
|
"""Check if models exist; if not, download and extract them."""
|
||||||
@ -182,7 +189,11 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
model_subdir = os.path.join(self.cache_dir, self.model_folder)
|
model_subdir = os.path.join(self.cache_dir, self.model_folder)
|
||||||
dfp_path = os.path.join(model_subdir, self.expected_dfp_model)
|
dfp_path = os.path.join(model_subdir, self.expected_dfp_model)
|
||||||
post_path = os.path.join(model_subdir, self.expected_post_model) if self.expected_post_model else None
|
post_path = (
|
||||||
|
os.path.join(model_subdir, self.expected_post_model)
|
||||||
|
if self.expected_post_model
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
dfp_exists = os.path.exists(dfp_path)
|
dfp_exists = os.path.exists(dfp_path)
|
||||||
post_exists = os.path.exists(post_path) if post_path else True
|
post_exists = os.path.exists(post_path) if post_path else True
|
||||||
@ -210,7 +221,11 @@ class MemryXDetector(DetectionApi):
|
|||||||
|
|
||||||
# Re-assign model paths after extraction
|
# Re-assign model paths after extraction
|
||||||
self.memx_model_path = os.path.join(model_subdir, self.expected_dfp_model)
|
self.memx_model_path = os.path.join(model_subdir, self.expected_dfp_model)
|
||||||
self.memx_post_model = os.path.join(model_subdir, self.expected_post_model) if self.expected_post_model else None
|
self.memx_post_model = (
|
||||||
|
os.path.join(model_subdir, self.expected_post_model)
|
||||||
|
if self.expected_post_model
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||||
self.load_yolo_constants()
|
self.load_yolo_constants()
|
||||||
@ -232,7 +247,9 @@ class MemryXDetector(DetectionApi):
|
|||||||
if self.memx_model_type == ModelTypeEnum.yolonas:
|
if self.memx_model_type == ModelTypeEnum.yolonas:
|
||||||
if tensor_input.ndim == 4 and tensor_input.shape[1:] == (320, 320, 3):
|
if tensor_input.ndim == 4 and tensor_input.shape[1:] == (320, 320, 3):
|
||||||
logger.debug("Transposing tensor from NHWC to NCHW for YOLO-NAS")
|
logger.debug("Transposing tensor from NHWC to NCHW for YOLO-NAS")
|
||||||
tensor_input = np.transpose(tensor_input, (0, 3, 1, 2)) # (1, H, W, C) → (1, C, H, W)
|
tensor_input = np.transpose(
|
||||||
|
tensor_input, (0, 3, 1, 2)
|
||||||
|
) # (1, H, W, C) → (1, C, H, W)
|
||||||
tensor_input = tensor_input.astype(np.float32)
|
tensor_input = tensor_input.astype(np.float32)
|
||||||
tensor_input /= 255
|
tensor_input /= 255
|
||||||
|
|
||||||
@ -390,7 +407,6 @@ class MemryXDetector(DetectionApi):
|
|||||||
return reshaped
|
return reshaped
|
||||||
|
|
||||||
def post_process_yolox(self, output):
|
def post_process_yolox(self, output):
|
||||||
|
|
||||||
output_785 = output[0] # 785
|
output_785 = output[0] # 785
|
||||||
output_794 = output[1] # 794
|
output_794 = output[1] # 794
|
||||||
output_795 = output[2] # 795
|
output_795 = output[2] # 795
|
||||||
@ -528,7 +544,6 @@ class MemryXDetector(DetectionApi):
|
|||||||
def process_output(self, *outputs):
|
def process_output(self, *outputs):
|
||||||
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
|
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
|
||||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||||
|
|
||||||
conv_out1 = outputs[0]
|
conv_out1 = outputs[0]
|
||||||
conv_out2 = outputs[1]
|
conv_out2 = outputs[1]
|
||||||
conv_out3 = outputs[2]
|
conv_out3 = outputs[2]
|
||||||
|
|||||||
@ -97,10 +97,10 @@ class AsyncLocalObjectDetector(BaseLocalDetector):
|
|||||||
def async_send_input(self, tensor_input: np.ndarray, connection_id):
|
def async_send_input(self, tensor_input: np.ndarray, connection_id):
|
||||||
tensor_input = self._transform_input(tensor_input)
|
tensor_input = self._transform_input(tensor_input)
|
||||||
return self.detect_api.send_input(connection_id, tensor_input)
|
return self.detect_api.send_input(connection_id, tensor_input)
|
||||||
|
|
||||||
def async_receive_output(self):
|
def async_receive_output(self):
|
||||||
return self.detect_api.receive_output()
|
return self.detect_api.receive_output()
|
||||||
|
|
||||||
|
|
||||||
def prepare_detector(name, out_events):
|
def prepare_detector(name, out_events):
|
||||||
threading.current_thread().name = f"detector:{name}"
|
threading.current_thread().name = f"detector:{name}"
|
||||||
@ -136,10 +136,7 @@ def run_detector(
|
|||||||
start: Value,
|
start: Value,
|
||||||
detector_config: BaseDetectorConfig,
|
detector_config: BaseDetectorConfig,
|
||||||
):
|
):
|
||||||
|
stop_event, frame_manager, outputs, logger = prepare_detector(name, out_events)
|
||||||
stop_event, frame_manager, outputs, logger = prepare_detector(
|
|
||||||
name, out_events
|
|
||||||
)
|
|
||||||
|
|
||||||
object_detector = LocalObjectDetector(detector_config=detector_config)
|
object_detector = LocalObjectDetector(detector_config=detector_config)
|
||||||
|
|
||||||
@ -179,10 +176,7 @@ def async_run_detector(
|
|||||||
start: Value,
|
start: Value,
|
||||||
detector_config: BaseDetectorConfig,
|
detector_config: BaseDetectorConfig,
|
||||||
):
|
):
|
||||||
|
stop_event, frame_manager, outputs, logger = prepare_detector(name, out_events)
|
||||||
stop_event, frame_manager, outputs, logger = prepare_detector(
|
|
||||||
name, out_events
|
|
||||||
)
|
|
||||||
|
|
||||||
object_detector = AsyncLocalObjectDetector(detector_config=detector_config)
|
object_detector = AsyncLocalObjectDetector(detector_config=detector_config)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user