From a78518bd94a87f2d1c921610bd83209daab00ba1 Mon Sep 17 00:00:00 2001 From: Jimmy Hon Date: Tue, 20 May 2025 04:46:09 +0000 Subject: [PATCH] Refactor common functions for tflite detector implementations --- frigate/detectors/detector_utils.py | 36 ++++++++++++++++++++++++++++ frigate/detectors/plugins/cpu_tfl.py | 36 ++++------------------------ 2 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 frigate/detectors/detector_utils.py diff --git a/frigate/detectors/detector_utils.py b/frigate/detectors/detector_utils.py new file mode 100644 index 000000000..d6445d5c0 --- /dev/null +++ b/frigate/detectors/detector_utils.py @@ -0,0 +1,36 @@ +import numpy as np + + +def tflite_init(self, interpreter): + self.interpreter = interpreter + + self.interpreter.allocate_tensors() + + self.tensor_input_details = self.interpreter.get_input_details() + self.tensor_output_details = self.interpreter.get_output_details() + + +def tflite_detect_raw(self, tensor_input): + self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input) + self.interpreter.invoke() + + boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0] + class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0] + scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0] + count = int(self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]) + + detections = np.zeros((20, 6), np.float32) + + for i in range(count): + if scores[i] < 0.4 or i == 20: + break + detections[i] = [ + class_ids[i], + float(scores[i]), + boxes[i][0], + boxes[i][1], + boxes[i][2], + boxes[i][3], + ] + + return detections diff --git a/frigate/detectors/plugins/cpu_tfl.py b/frigate/detectors/plugins/cpu_tfl.py index 8a54363e1..fc8db0f4b 100644 --- a/frigate/detectors/plugins/cpu_tfl.py +++ b/frigate/detectors/plugins/cpu_tfl.py @@ -1,12 +1,13 @@ import logging -import numpy as np from pydantic import Field from typing_extensions import Literal from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig +from ..detector_utils import tflite_detect_raw, tflite_init + try: from tflite_runtime.interpreter import Interpreter except ModuleNotFoundError: @@ -27,39 +28,12 @@ class CpuTfl(DetectionApi): type_key = DETECTOR_KEY def __init__(self, detector_config: CpuDetectorConfig): - self.interpreter = Interpreter( + interpreter = Interpreter( model_path=detector_config.model.path, num_threads=detector_config.num_threads or 3, ) - self.interpreter.allocate_tensors() - - self.tensor_input_details = self.interpreter.get_input_details() - self.tensor_output_details = self.interpreter.get_output_details() + tflite_init(self, interpreter) def detect_raw(self, tensor_input): - self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input) - self.interpreter.invoke() - - boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0] - class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0] - scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0] - count = int( - self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0] - ) - - detections = np.zeros((20, 6), np.float32) - - for i in range(count): - if scores[i] < 0.4 or i == 20: - break - detections[i] = [ - class_ids[i], - float(scores[i]), - boxes[i][0], - boxes[i][1], - boxes[i][2], - boxes[i][3], - ] - - return detections + return tflite_detect_raw(self, tensor_input)