From 39af85625e8f8f1caec1e4a69eaba3952cb61856 Mon Sep 17 00:00:00 2001 From: GuoQing Liu <842607283@qq.com> Date: Mon, 15 Dec 2025 23:59:13 +0800 Subject: [PATCH] feat: add train classification download weights file endpoint (#21294) * feat: add train classification download weights file endpoint: "TF_KERAS_MOBILENET_V2_ENDPOINT" * refactor: custom weights file url --- frigate/util/classification.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 1f4213315..03229cc73 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -22,6 +22,7 @@ from frigate.const import ( from frigate.log import redirect_output_to_logger from frigate.models import Event, Recordings, ReviewSegment from frigate.types import ModelStatusTypesEnum +from frigate.util.downloader import ModelDownloader from frigate.util.file import get_event_thumbnail_bytes from frigate.util.image import get_image_from_recording from frigate.util.process import FrigateProcess @@ -121,6 +122,10 @@ def get_dataset_image_count(model_name: str) -> int: class ClassificationTrainingProcess(FrigateProcess): def __init__(self, model_name: str) -> None: + self.BASE_WEIGHT_URL = os.environ.get( + "TF_KERAS_MOBILENET_V2_WEIGHTS_URL", + "", + ) super().__init__( stop_event=None, priority=PROCESS_PRIORITY_LOW, @@ -179,11 +184,23 @@ class ClassificationTrainingProcess(FrigateProcess): ) return False + weights_path = "imagenet" + # Download MobileNetV2 weights if not present + if self.BASE_WEIGHT_URL: + weights_path = os.path.join( + MODEL_CACHE_DIR, "MobileNet", "mobilenet_v2_weights.h5" + ) + if not os.path.exists(weights_path): + logger.info("Downloading MobileNet V2 weights file") + ModelDownloader.download_from_url( + self.BASE_WEIGHT_URL, weights_path + ) + # Start with imagenet base model with 35% of channels in each layer base_model = MobileNetV2( input_shape=(224, 224, 3), include_top=False, - weights="imagenet", + weights=weights_path, alpha=0.35, ) base_model.trainable = False # Freeze pre-trained layers