diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 11376563d..2b81678d9 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -15,9 +15,9 @@ LEARNING_RATE = 0.001 @staticmethod -def generate_representative_dataset(train_dir: str): +def generate_representative_dataset(dataset_dir: str): image_paths = [] - for root, dirs, files in os.walk("train"): + for root, dirs, files in os.walk(dataset_dir): for file in files: if file.lower().endswith((".jpg", ".jpeg", ".png")): image_paths.append(os.path.join(root, file)) @@ -32,10 +32,15 @@ def generate_representative_dataset(train_dir: str): @staticmethod -def train_classification_model(train_dir: str) -> bool: +def train_classification_model(model_dir: str) -> bool: """Train a classification model.""" + dataset_dir = os.path.join(model_dir, "dataset") num_classes = len( - [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))] + [ + d + for d in os.listdir(dataset_dir) + if os.path.isdir(os.path.join(dataset_dir, d)) + ] ) # Start with imagenet base model with 35% of channels in each layer @@ -77,7 +82,7 @@ def train_classification_model(train_dir: str) -> bool: class_indices = train_gen.class_indices index_to_class = {v: k for k, v in class_indices.items()} sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] - with open(os.path.join(train_dir, "labelmap.txt"), "w") as f: + with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: for class_name in sorted_classes: f.write(f"{class_name}\n") @@ -87,12 +92,12 @@ def train_classification_model(train_dir: str) -> bool: # convert model to tflite converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = generate_representative_dataset + converter.representative_dataset = generate_representative_dataset(dataset_dir) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 tflite_model = converter.convert() # write model - with open(os.path.join(train_dir, "model.tflite"), "wb") as f: + with open(os.path.join(model_dir, "model.tflite"), "wb") as f: f.write(tflite_model)