diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 55a204616..4ee5e1d54 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -15,20 +15,23 @@ LEARNING_RATE = 0.001 @staticmethod -def generate_representative_dataset(dataset_dir: str): - image_paths = [] - for root, dirs, files in os.walk(dataset_dir): - for file in files: - if file.lower().endswith((".jpg", ".jpeg", ".png")): - image_paths.append(os.path.join(root, file)) +def generate_representative_dataset_factory(dataset_dir: str): + def generate_representative_dataset(): + image_paths = [] + for root, dirs, files in os.walk(dataset_dir): + for file in files: + if file.lower().endswith((".jpg", ".jpeg", ".png")): + image_paths.append(os.path.join(root, file)) - for path in image_paths[:300]: - img = cv2.imread(path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = cv2.resize(img, (224, 224)) - img_array = np.array(img, dtype=np.float32) / 255.0 - img_array = img_array[None, ...] - yield [img_array] + for path in image_paths[:300]: + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (224, 224)) + img_array = np.array(img, dtype=np.float32) / 255.0 + img_array = img_array[None, ...] + yield [img_array] + + return generate_representative_dataset @staticmethod @@ -92,7 +95,9 @@ def train_classification_model(model_dir: str) -> bool: # convert model to tflite converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = generate_representative_dataset(dataset_dir) + converter.representative_dataset = generate_representative_dataset_factory( + dataset_dir + ) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8