Fix factory

This commit is contained in:
Nicolas Mowen 2025-05-29 14:59:57 -06:00
parent e3af887927
commit 6e2edac90d

View File

@ -15,20 +15,23 @@ LEARNING_RATE = 0.001
@staticmethod @staticmethod
def generate_representative_dataset(dataset_dir: str): def generate_representative_dataset_factory(dataset_dir: str):
image_paths = [] def generate_representative_dataset():
for root, dirs, files in os.walk(dataset_dir): image_paths = []
for file in files: for root, dirs, files in os.walk(dataset_dir):
if file.lower().endswith((".jpg", ".jpeg", ".png")): for file in files:
image_paths.append(os.path.join(root, file)) if file.lower().endswith((".jpg", ".jpeg", ".png")):
image_paths.append(os.path.join(root, file))
for path in image_paths[:300]: for path in image_paths[:300]:
img = cv2.imread(path) img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224)) img = cv2.resize(img, (224, 224))
img_array = np.array(img, dtype=np.float32) / 255.0 img_array = np.array(img, dtype=np.float32) / 255.0
img_array = img_array[None, ...] img_array = img_array[None, ...]
yield [img_array] yield [img_array]
return generate_representative_dataset
@staticmethod @staticmethod
@ -92,7 +95,9 @@ def train_classification_model(model_dir: str) -> bool:
# convert model to tflite # convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = generate_representative_dataset(dataset_dir) converter.representative_dataset = generate_representative_dataset_factory(
dataset_dir
)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8 converter.inference_output_type = tf.uint8