diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 228d60477..03229cc73 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -122,9 +122,9 @@ def get_dataset_image_count(model_name: str) -> int: class ClassificationTrainingProcess(FrigateProcess): def __init__(self, model_name: str) -> None: - self.BASE_WEIGHT_PATH = os.environ.get( - "TF_KERAS_MOBILENET_V2_ENDPOINT", - "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/", + self.BASE_WEIGHT_URL = os.environ.get( + "TF_KERAS_MOBILENET_V2_WEIGHTS_URL", + "", ) super().__init__( stop_event=None, @@ -184,23 +184,24 @@ class ClassificationTrainingProcess(FrigateProcess): ) return False - alpha = 0.35 + weights_path = "imagenet" # Download MobileNetV2 weights if not present - weights_filename = ( - f"mobilenet_v2_weights_tf_dim_ordering_tf_kernels_{alpha}_224_no_top.h5" - ) - weights_path = os.path.join(MODEL_CACHE_DIR, "MobileNet", weights_filename) - if not os.path.exists(weights_path): - weights_url = self.BASE_WEIGHT_PATH + weights_filename - logger.info(f"Downloading MobileNet V2 weights file: {weights_url}") - ModelDownloader.download_from_url(weights_url, weights_path) + if self.BASE_WEIGHT_URL: + weights_path = os.path.join( + MODEL_CACHE_DIR, "MobileNet", "mobilenet_v2_weights.h5" + ) + if not os.path.exists(weights_path): + logger.info("Downloading MobileNet V2 weights file") + ModelDownloader.download_from_url( + self.BASE_WEIGHT_URL, weights_path + ) # Start with imagenet base model with 35% of channels in each layer base_model = MobileNetV2( input_shape=(224, 224, 3), include_top=False, weights=weights_path, - alpha=alpha, + alpha=0.35, ) base_model.trainable = False # Freeze pre-trained layers