Fine tune behavior

This commit is contained in:
Nicolas Mowen 2025-11-28 14:50:26 -07:00
parent ecb59ff943
commit fb4fe8c430

View File

@ -99,10 +99,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
if self.inference_speed: if self.inference_speed:
self.inference_speed.update(duration) self.inference_speed.update(duration)
def _should_save_image(self, camera: str, detected_state: str) -> bool: def _should_save_image(
self, camera: str, detected_state: str, score: float = 1.0
) -> bool:
""" """
Determine if we should save the image for training. Determine if we should save the image for training.
Only save when the state is changing or being verified, not when it's stable. Save when:
- State is changing or being verified (regardless of score)
- Score is less than 100% (even if state matches, useful for training)
Don't save when:
- State is stable (matches current_state) AND score is 100%
""" """
if camera not in self.state_history: if camera not in self.state_history:
# First detection for this camera, save it # First detection for this camera, save it
@ -121,7 +127,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
if current_state is not None and detected_state != current_state: if current_state is not None and detected_state != current_state:
return True return True
# Don't save if state is stable (detected_state == current_state) # If score is less than 100%, save even if state matches
# (useful for training to improve confidence)
if score < 1.0:
return True
# Don't save if state is stable (detected_state == current_state) AND score is 100%
return False return False
def verify_state_change(self, camera: str, detected_state: str) -> str | None: def verify_state_change(self, camera: str, detected_state: str) -> str | None:
@ -237,7 +248,8 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
return return
if self.interpreter is None: if self.interpreter is None:
if self._should_save_image(camera, "unknown"): # When interpreter is None, always save (score is 0.0, which is < 1.0)
if self._should_save_image(camera, "unknown", 0.0):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
@ -264,7 +276,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
detected_state = self.labelmap[best_id] detected_state = self.labelmap[best_id]
if self._should_save_image(camera, detected_state): if self._should_save_image(camera, detected_state, score):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),