feat: add train classification download weights file endpoint (#21294)

* feat: add train classification download weights file endpoint: "TF_KERAS_MOBILENET_V2_ENDPOINT"

* refactor: custom weights file url
This commit is contained in:
GuoQing Liu 2025-12-15 23:59:13 +08:00 committed by GitHub
parent fa16539429
commit 39af85625e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,7 @@ from frigate.const import (
from frigate.log import redirect_output_to_logger from frigate.log import redirect_output_to_logger
from frigate.models import Event, Recordings, ReviewSegment from frigate.models import Event, Recordings, ReviewSegment
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
from frigate.util.file import get_event_thumbnail_bytes from frigate.util.file import get_event_thumbnail_bytes
from frigate.util.image import get_image_from_recording from frigate.util.image import get_image_from_recording
from frigate.util.process import FrigateProcess from frigate.util.process import FrigateProcess
@ -121,6 +122,10 @@ def get_dataset_image_count(model_name: str) -> int:
class ClassificationTrainingProcess(FrigateProcess): class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None: def __init__(self, model_name: str) -> None:
self.BASE_WEIGHT_URL = os.environ.get(
"TF_KERAS_MOBILENET_V2_WEIGHTS_URL",
"",
)
super().__init__( super().__init__(
stop_event=None, stop_event=None,
priority=PROCESS_PRIORITY_LOW, priority=PROCESS_PRIORITY_LOW,
@ -179,11 +184,23 @@ class ClassificationTrainingProcess(FrigateProcess):
) )
return False return False
weights_path = "imagenet"
# Download MobileNetV2 weights if not present
if self.BASE_WEIGHT_URL:
weights_path = os.path.join(
MODEL_CACHE_DIR, "MobileNet", "mobilenet_v2_weights.h5"
)
if not os.path.exists(weights_path):
logger.info("Downloading MobileNet V2 weights file")
ModelDownloader.download_from_url(
self.BASE_WEIGHT_URL, weights_path
)
# Start with imagenet base model with 35% of channels in each layer # Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2( base_model = MobileNetV2(
input_shape=(224, 224, 3), input_shape=(224, 224, 3),
include_top=False, include_top=False,
weights="imagenet", weights=weights_path,
alpha=0.35, alpha=0.35,
) )
base_model.trainable = False # Freeze pre-trained layers base_model.trainable = False # Freeze pre-trained layers