Create utility to train mobilenet classification models

This commit is contained in:
Nicolas Mowen 2025-05-29 10:39:23 -06:00
parent 2bd6fa53fe
commit 34a85b238b
2 changed files with 100 additions and 0 deletions

View File

@ -11,6 +11,8 @@ joserfc == 1.0.*
pathvalidate == 3.2.*
markupsafe == 3.0.*
python-multipart == 0.0.12
# Classification Model
tensorflow == 2.19.*
# General
mypy == 1.6.1
onvif-zeep-async == 3.1.*

View File

@ -0,0 +1,98 @@
"""Util for classification models."""
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001
@staticmethod
def generate_representative_dataset(train_dir: str):
image_paths = []
for root, dirs, files in os.walk("train"):
for file in files:
if file.lower().endswith((".jpg", ".jpeg", ".png")):
image_paths.append(os.path.join(root, file))
for path in image_paths[:300]:
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
img_array = np.array(img, dtype=np.float32) / 255.0
img_array = img_array[None, ...]
yield [img_array]
@staticmethod
def train_classification_model(train_dir: str) -> bool:
"""Train a classification model."""
num_classes = len(
[d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]
)
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights="imagenet",
alpha=0.35,
)
base_model.trainable = False # Freeze pre-trained layers
model = models.Sequential(
[
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax"),
]
)
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# create training set
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
train_gen = datagen.flow_from_directory(
"train",
target_size=(224, 224),
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
)
# write labelmap
class_indices = train_gen.class_indices
index_to_class = {v: k for k, v in class_indices.items()}
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
with open(os.path.join(train_dir, "labelmap.txt"), "w") as f:
for class_name in sorted_classes:
f.write(f"{class_name}\n")
# train the model
model.fit(train_gen, epochs=EPOCHS)
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = generate_representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write model
with open(os.path.join(train_dir, "model.tflite"), "wb") as f:
f.write(tflite_model)