mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-01-22 20:18:30 +03:00
ensure jina loading takes place in the main thread to prevent lazily importing tensorflow in another thread later
reverts atexit changes in https://github.com/blakeblackshear/frigate/pull/21301 and fixes https://github.com/blakeblackshear/frigate/discussions/21306
This commit is contained in:
parent
7c2f4b5d1e
commit
b12552571e
@ -29,6 +29,11 @@ from frigate.util.object import box_overlaps, calculate_region
|
||||
from ..types import DataProcessorMetrics
|
||||
from .api import RealTimeProcessorApi
|
||||
|
||||
try:
|
||||
from tflite_runtime.interpreter import Interpreter
|
||||
except ModuleNotFoundError:
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_OBJECT_CLASSIFICATIONS = 16
|
||||
@ -47,7 +52,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
self.requestor = requestor
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Any | None = None
|
||||
self.interpreter: Interpreter | None = None
|
||||
self.tensor_input_details: dict[str, Any] | None = None
|
||||
self.tensor_output_details: dict[str, Any] | None = None
|
||||
self.labelmap: dict[int, str] = {}
|
||||
@ -345,7 +350,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.model_config = model_config
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Any | None = None
|
||||
self.interpreter: Interpreter | None = None
|
||||
self.sub_label_publisher = sub_label_publisher
|
||||
self.requestor = requestor
|
||||
self.tensor_input_details: dict[str, Any] | None = None
|
||||
@ -368,11 +373,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
@redirect_output_to_logger(logger, logging.DEBUG)
|
||||
def __build_detector(self) -> None:
|
||||
try:
|
||||
from tflite_runtime.interpreter import Interpreter
|
||||
except ModuleNotFoundError:
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
|
||||
model_path = os.path.join(self.model_dir, "model.tflite")
|
||||
labelmap_path = os.path.join(self.model_dir, "labelmap.txt")
|
||||
|
||||
|
||||
@ -146,29 +146,6 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.detected_license_plates: dict[str, dict[str, Any]] = {}
|
||||
self.genai_client = get_genai_client(config)
|
||||
|
||||
# Pre-import TensorFlow/tflite on main thread to avoid atexit registration issues
|
||||
# when importing from worker threads later (e.g., during dynamic config updates)
|
||||
if (
|
||||
self.config.classification.bird.enabled
|
||||
or len(self.config.classification.custom) > 0
|
||||
):
|
||||
try:
|
||||
from tflite_runtime.interpreter import Interpreter # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
from tensorflow.lite.python.interpreter import ( # noqa: F401
|
||||
Interpreter,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Pre-imported TensorFlow Interpreter on main thread for classification models"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to pre-import TensorFlow Interpreter: {e}. "
|
||||
"Classification models may fail to load if added dynamically."
|
||||
)
|
||||
|
||||
# model runners to share between realtime and post processors
|
||||
if self.config.lpr.enabled:
|
||||
lpr_model_runner = LicensePlateModelRunner(
|
||||
|
||||
@ -186,6 +186,9 @@ class JinaV1ImageEmbedding(BaseEmbedding):
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
# Avoid lazy loading in worker threads: block until downloads complete
|
||||
# and load the model on the main thread during initialization.
|
||||
self._load_model_and_utils()
|
||||
else:
|
||||
self.downloader = None
|
||||
ModelDownloader.mark_files_state(
|
||||
|
||||
@ -65,6 +65,9 @@ class JinaV2Embedding(BaseEmbedding):
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
# Avoid lazy loading in worker threads: block until downloads complete
|
||||
# and load the model on the main thread during initialization.
|
||||
self._load_model_and_utils()
|
||||
else:
|
||||
self.downloader = None
|
||||
ModelDownloader.mark_files_state(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user