From 161564b91a4f394de2d3a34275e09996b32e9354 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Fri, 19 Sep 2025 07:18:49 -0600 Subject: [PATCH] Fix missing --- .../real_time/custom_classification.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 841267a60..2a37fc6c9 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -27,11 +27,6 @@ from frigate.util.object import box_overlaps, calculate_region from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi -try: - from tflite_runtime.interpreter import Interpreter -except ModuleNotFoundError: - from tensorflow.lite.python.interpreter import Interpreter - logger = logging.getLogger(__name__) @@ -44,11 +39,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): metrics: DataProcessorMetrics, ): super().__init__(config, metrics) + self.model_config = model_config self.requestor = requestor self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") - self.interpreter: Interpreter = None + self.interpreter = None self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None self.labelmap: dict[int, str] = {} @@ -61,6 +57,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: + try: + from tflite_runtime.interpreter import Interpreter + except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + self.interpreter = Interpreter( model_path=os.path.join(self.model_dir, "model.tflite"), num_threads=2, @@ -197,7 +198,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.model_config = model_config self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") - self.interpreter: Interpreter = None + self.interpreter = None self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None @@ -211,6 +212,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: + try: + from tflite_runtime.interpreter import Interpreter + except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + self.interpreter = Interpreter( model_path=os.path.join(self.model_dir, "model.tflite"), num_threads=2,