Fine tune behavior

This commit is contained in:
Nicolas Mowen 2025-11-28 14:50:26 -07:00
parent ecb59ff943
commit fb4fe8c430

View File

@ -99,10 +99,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
if self.inference_speed:
self.inference_speed.update(duration)
def _should_save_image(self, camera: str, detected_state: str) -> bool:
def _should_save_image(
self, camera: str, detected_state: str, score: float = 1.0
) -> bool:
"""
Determine if we should save the image for training.
Only save when the state is changing or being verified, not when it's stable.
Save when:
- State is changing or being verified (regardless of score)
- Score is less than 100% (even if state matches, useful for training)
Don't save when:
- State is stable (matches current_state) AND score is 100%
"""
if camera not in self.state_history:
# First detection for this camera, save it
@ -121,7 +127,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
if current_state is not None and detected_state != current_state:
return True
# Don't save if state is stable (detected_state == current_state)
# If score is less than 100%, save even if state matches
# (useful for training to improve confidence)
if score < 1.0:
return True
# Don't save if state is stable (detected_state == current_state) AND score is 100%
return False
def verify_state_change(self, camera: str, detected_state: str) -> str | None:
@ -237,7 +248,8 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
return
if self.interpreter is None:
if self._should_save_image(camera, "unknown"):
# When interpreter is None, always save (score is 0.0, which is < 1.0)
if self._should_save_image(camera, "unknown", 0.0):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
@ -264,7 +276,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
detected_state = self.labelmap[best_id]
if self._should_save_image(camera, detected_state):
if self._should_save_image(camera, detected_state, score):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),