Fix missing

This commit is contained in:
Nicolas Mowen 2025-09-19 07:18:49 -06:00
parent 4c02aec854
commit 161564b91a

View File

@ -27,11 +27,6 @@ from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
try:
from tflite_runtime.interpreter import Interpreter
except ModuleNotFoundError:
from tensorflow.lite.python.interpreter import Interpreter
logger = logging.getLogger(__name__)
@ -44,11 +39,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
self.model_config = model_config
self.requestor = requestor
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.interpreter = None
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
self.labelmap: dict[int, str] = {}
@ -61,6 +57,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
@redirect_output_to_logger(logger, logging.DEBUG)
def __build_detector(self) -> None:
try:
from tflite_runtime.interpreter import Interpreter
except ModuleNotFoundError:
from tensorflow.lite.python.interpreter import Interpreter
self.interpreter = Interpreter(
model_path=os.path.join(self.model_dir, "model.tflite"),
num_threads=2,
@ -197,7 +198,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.model_config = model_config
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.interpreter = None
self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
@ -211,6 +212,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
@redirect_output_to_logger(logger, logging.DEBUG)
def __build_detector(self) -> None:
try:
from tflite_runtime.interpreter import Interpreter
except ModuleNotFoundError:
from tensorflow.lite.python.interpreter import Interpreter
self.interpreter = Interpreter(
model_path=os.path.join(self.model_dir, "model.tflite"),
num_threads=2,