From 6aa1a10965b3ce0d471a30f55add7d5f8f9bbdcb Mon Sep 17 00:00:00 2001 From: ZhaiSoul <842607283@qq.com> Date: Mon, 15 Dec 2025 15:17:36 +0000 Subject: [PATCH] feat: add train classification download weights file endpoint: "TF_KERAS_MOBILENET_V2_ENDPOINT" --- frigate/util/classification.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 1f4213315..228d60477 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -22,6 +22,7 @@ from frigate.const import ( from frigate.log import redirect_output_to_logger from frigate.models import Event, Recordings, ReviewSegment from frigate.types import ModelStatusTypesEnum +from frigate.util.downloader import ModelDownloader from frigate.util.file import get_event_thumbnail_bytes from frigate.util.image import get_image_from_recording from frigate.util.process import FrigateProcess @@ -121,6 +122,10 @@ def get_dataset_image_count(model_name: str) -> int: class ClassificationTrainingProcess(FrigateProcess): def __init__(self, model_name: str) -> None: + self.BASE_WEIGHT_PATH = os.environ.get( + "TF_KERAS_MOBILENET_V2_ENDPOINT", + "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/", + ) super().__init__( stop_event=None, priority=PROCESS_PRIORITY_LOW, @@ -179,12 +184,23 @@ class ClassificationTrainingProcess(FrigateProcess): ) return False + alpha = 0.35 + # Download MobileNetV2 weights if not present + weights_filename = ( + f"mobilenet_v2_weights_tf_dim_ordering_tf_kernels_{alpha}_224_no_top.h5" + ) + weights_path = os.path.join(MODEL_CACHE_DIR, "MobileNet", weights_filename) + if not os.path.exists(weights_path): + weights_url = self.BASE_WEIGHT_PATH + weights_filename + logger.info(f"Downloading MobileNet V2 weights file: {weights_url}") + ModelDownloader.download_from_url(weights_url, weights_path) + # Start with imagenet base model with 35% of channels in each layer base_model = MobileNetV2( input_shape=(224, 224, 3), include_top=False, - weights="imagenet", - alpha=0.35, + weights=weights_path, + alpha=alpha, ) base_model.trainable = False # Freeze pre-trained layers