Improve image cropping and model saving

This commit is contained in:
Nicolas Mowen 2025-06-26 18:07:50 -06:00
parent 1ddfbe47c5
commit 3c33178414

View File

@ -187,7 +187,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
super().__init__(config, metrics) super().__init__(config, metrics)
self.model_config = model_config self.model_config = model_config
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(self.model_dir, "train") self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None self.interpreter: Interpreter = None
self.sub_label_publisher = sub_label_publisher self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None self.tensor_input_details: dict[str, Any] = None
@ -232,20 +232,23 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
obj_data["box"][1], obj_data["box"][1],
obj_data["box"][2], obj_data["box"][2],
obj_data["box"][3], obj_data["box"][3],
224, max(
obj_data["box"][1] - obj_data["box"][0],
obj_data["box"][3] - obj_data["box"][2],
),
1.0, 1.0,
) )
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
input = rgb[ crop = rgb[
y:y2, y:y2,
x:x2, x:x2,
] ]
if input.shape != (224, 224): if crop.shape != (224, 224):
input = cv2.resize(input, (224, 224)) crop = cv2.resize(crop, (224, 224))
input = np.expand_dims(input, axis=0) input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke() self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor( res: np.ndarray = self.interpreter.get_tensor(
@ -259,7 +262,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
now, now,
self.labelmap[best_id], self.labelmap[best_id],
score, score,