mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-11 17:47:37 +03:00
feat: add train classification download weights file endpoint: "TF_KERAS_MOBILENET_V2_ENDPOINT"
This commit is contained in:
parent
e1545a8db8
commit
6aa1a10965
@ -22,6 +22,7 @@ from frigate.const import (
|
|||||||
from frigate.log import redirect_output_to_logger
|
from frigate.log import redirect_output_to_logger
|
||||||
from frigate.models import Event, Recordings, ReviewSegment
|
from frigate.models import Event, Recordings, ReviewSegment
|
||||||
from frigate.types import ModelStatusTypesEnum
|
from frigate.types import ModelStatusTypesEnum
|
||||||
|
from frigate.util.downloader import ModelDownloader
|
||||||
from frigate.util.file import get_event_thumbnail_bytes
|
from frigate.util.file import get_event_thumbnail_bytes
|
||||||
from frigate.util.image import get_image_from_recording
|
from frigate.util.image import get_image_from_recording
|
||||||
from frigate.util.process import FrigateProcess
|
from frigate.util.process import FrigateProcess
|
||||||
@ -121,6 +122,10 @@ def get_dataset_image_count(model_name: str) -> int:
|
|||||||
|
|
||||||
class ClassificationTrainingProcess(FrigateProcess):
|
class ClassificationTrainingProcess(FrigateProcess):
|
||||||
def __init__(self, model_name: str) -> None:
|
def __init__(self, model_name: str) -> None:
|
||||||
|
self.BASE_WEIGHT_PATH = os.environ.get(
|
||||||
|
"TF_KERAS_MOBILENET_V2_ENDPOINT",
|
||||||
|
"https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/",
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
stop_event=None,
|
stop_event=None,
|
||||||
priority=PROCESS_PRIORITY_LOW,
|
priority=PROCESS_PRIORITY_LOW,
|
||||||
@ -179,12 +184,23 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
alpha = 0.35
|
||||||
|
# Download MobileNetV2 weights if not present
|
||||||
|
weights_filename = (
|
||||||
|
f"mobilenet_v2_weights_tf_dim_ordering_tf_kernels_{alpha}_224_no_top.h5"
|
||||||
|
)
|
||||||
|
weights_path = os.path.join(MODEL_CACHE_DIR, "MobileNet", weights_filename)
|
||||||
|
if not os.path.exists(weights_path):
|
||||||
|
weights_url = self.BASE_WEIGHT_PATH + weights_filename
|
||||||
|
logger.info(f"Downloading MobileNet V2 weights file: {weights_url}")
|
||||||
|
ModelDownloader.download_from_url(weights_url, weights_path)
|
||||||
|
|
||||||
# Start with imagenet base model with 35% of channels in each layer
|
# Start with imagenet base model with 35% of channels in each layer
|
||||||
base_model = MobileNetV2(
|
base_model = MobileNetV2(
|
||||||
input_shape=(224, 224, 3),
|
input_shape=(224, 224, 3),
|
||||||
include_top=False,
|
include_top=False,
|
||||||
weights="imagenet",
|
weights=weights_path,
|
||||||
alpha=0.35,
|
alpha=alpha,
|
||||||
)
|
)
|
||||||
base_model.trainable = False # Freeze pre-trained layers
|
base_model.trainable = False # Freeze pre-trained layers
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user