diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index f94c2b28c..0e254ab0d 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -42,7 +42,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.model_config = model_config self.requestor = requestor self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) - self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name) + self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.interpreter: Interpreter = None self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 4ee5e1d54..6b2db3446 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -9,6 +9,8 @@ from tensorflow.keras import layers, models, optimizers from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.preprocessing.image import ImageDataGenerator +from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR + BATCH_SIZE = 16 EPOCHS = 50 LEARNING_RATE = 0.001 @@ -35,9 +37,10 @@ def generate_representative_dataset_factory(dataset_dir: str): @staticmethod -def train_classification_model(model_dir: str) -> bool: +def train_classification_model(model_name: str) -> bool: """Train a classification model.""" - dataset_dir = os.path.join(model_dir, "dataset") + dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") + model_dir = os.path.join(MODEL_CACHE_DIR, model_name) num_classes = len( [ d