diff --git a/frigate/data_processing/real_time/bird.py b/frigate/data_processing/real_time/bird.py index e599ab0fb..8d6e1b2dc 100644 --- a/frigate/data_processing/real_time/bird.py +++ b/frigate/data_processing/real_time/bird.py @@ -19,11 +19,6 @@ from frigate.util.object import calculate_region from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi -try: - from tflite_runtime.interpreter import Interpreter -except ModuleNotFoundError: - from tensorflow.lite.python.interpreter import Interpreter - logger = logging.getLogger(__name__) @@ -35,7 +30,7 @@ class BirdRealTimeProcessor(RealTimeProcessorApi): metrics: DataProcessorMetrics, ): super().__init__(config, metrics) - self.interpreter: Interpreter = None + self.interpreter: Any | None = None self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None @@ -82,6 +77,11 @@ class BirdRealTimeProcessor(RealTimeProcessorApi): @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: + try: + from tflite_runtime.interpreter import Interpreter + except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + self.interpreter = Interpreter( model_path=os.path.join(MODEL_CACHE_DIR, "bird/bird.tflite"), num_threads=2, diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 25ec3bb86..dd011b48e 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -29,11 +29,6 @@ from frigate.util.object import box_overlaps, calculate_region from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi -try: - from tflite_runtime.interpreter import Interpreter -except ModuleNotFoundError: - from tensorflow.lite.python.interpreter import Interpreter - logger = logging.getLogger(__name__) MAX_OBJECT_CLASSIFICATIONS = 16 @@ -52,7 +47,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.requestor = requestor self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") - self.interpreter: Interpreter | None = None + self.interpreter: Any | None = None self.tensor_input_details: dict[str, Any] | None = None self.tensor_output_details: dict[str, Any] | None = None self.labelmap: dict[int, str] = {} @@ -74,6 +69,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: + try: + from tflite_runtime.interpreter import Interpreter + except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + model_path = os.path.join(self.model_dir, "model.tflite") labelmap_path = os.path.join(self.model_dir, "labelmap.txt") @@ -345,7 +345,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.model_config = model_config self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") - self.interpreter: Interpreter | None = None + self.interpreter: Any | None = None self.sub_label_publisher = sub_label_publisher self.requestor = requestor self.tensor_input_details: dict[str, Any] | None = None @@ -368,6 +368,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: + try: + from tflite_runtime.interpreter import Interpreter + except ModuleNotFoundError: + from tensorflow.lite.python.interpreter import Interpreter + model_path = os.path.join(self.model_dir, "model.tflite") labelmap_path = os.path.join(self.model_dir, "labelmap.txt") diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 78a251c42..33d09dcc3 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -146,6 +146,29 @@ class EmbeddingMaintainer(threading.Thread): self.detected_license_plates: dict[str, dict[str, Any]] = {} self.genai_client = get_genai_client(config) + # Pre-import TensorFlow/tflite on main thread to avoid atexit registration issues + # when importing from worker threads later (e.g., during dynamic config updates) + if ( + self.config.classification.bird.enabled + or len(self.config.classification.custom) > 0 + ): + try: + from tflite_runtime.interpreter import Interpreter # noqa: F401 + except ModuleNotFoundError: + try: + from tensorflow.lite.python.interpreter import ( # noqa: F401 + Interpreter, + ) + + logger.debug( + "Pre-imported TensorFlow Interpreter on main thread for classification models" + ) + except Exception as e: + logger.warning( + f"Failed to pre-import TensorFlow Interpreter: {e}. " + "Classification models may fail to load if added dynamically." + ) + # model runners to share between realtime and post processors if self.config.lpr.enabled: lpr_model_runner = LicensePlateModelRunner(