refactor: custom weights file url

This commit is contained in:
ZhaiSoul 2025-12-15 15:47:40 +00:00
parent 6aa1a10965
commit f2eb21c3eb

View File

@ -122,9 +122,9 @@ def get_dataset_image_count(model_name: str) -> int:
class ClassificationTrainingProcess(FrigateProcess): class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None: def __init__(self, model_name: str) -> None:
self.BASE_WEIGHT_PATH = os.environ.get( self.BASE_WEIGHT_URL = os.environ.get(
"TF_KERAS_MOBILENET_V2_ENDPOINT", "TF_KERAS_MOBILENET_V2_WEIGHTS_URL",
"https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/", "",
) )
super().__init__( super().__init__(
stop_event=None, stop_event=None,
@ -184,23 +184,24 @@ class ClassificationTrainingProcess(FrigateProcess):
) )
return False return False
alpha = 0.35 weights_path = "imagenet"
# Download MobileNetV2 weights if not present # Download MobileNetV2 weights if not present
weights_filename = ( if self.BASE_WEIGHT_URL:
f"mobilenet_v2_weights_tf_dim_ordering_tf_kernels_{alpha}_224_no_top.h5" weights_path = os.path.join(
MODEL_CACHE_DIR, "MobileNet", "mobilenet_v2_weights.h5"
) )
weights_path = os.path.join(MODEL_CACHE_DIR, "MobileNet", weights_filename)
if not os.path.exists(weights_path): if not os.path.exists(weights_path):
weights_url = self.BASE_WEIGHT_PATH + weights_filename logger.info("Downloading MobileNet V2 weights file")
logger.info(f"Downloading MobileNet V2 weights file: {weights_url}") ModelDownloader.download_from_url(
ModelDownloader.download_from_url(weights_url, weights_path) self.BASE_WEIGHT_URL, weights_path
)
# Start with imagenet base model with 35% of channels in each layer # Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2( base_model = MobileNetV2(
input_shape=(224, 224, 3), input_shape=(224, 224, 3),
include_top=False, include_top=False,
weights=weights_path, weights=weights_path,
alpha=alpha, alpha=0.35,
) )
base_model.trainable = False # Freeze pre-trained layers base_model.trainable = False # Freeze pre-trained layers