From 57a9233977700d8b1fe5704000fc2ee82a2659dc Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Tue, 24 Mar 2026 07:38:04 -0500 Subject: [PATCH] add randomness to object classification also ensure train_dir is fresh if user has regenerated examples --- frigate/util/classification.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 643f77d3b..ada3ee1f7 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -5,6 +5,7 @@ import json import logging import os import random +import shutil from collections import defaultdict import cv2 @@ -397,6 +398,8 @@ def collect_state_classification_examples( # Step 5: Save to train directory for later classification train_dir = os.path.join(CLIPS_DIR, model_name, "train") + if os.path.exists(train_dir): + shutil.rmtree(train_dir) os.makedirs(train_dir, exist_ok=True) saved_count = 0 @@ -411,8 +414,6 @@ def collect_state_classification_examples( except Exception as e: logger.error(f"Failed to save image {image_path}: {e}") - import shutil - try: shutil.rmtree(temp_dir) except Exception as e: @@ -750,6 +751,8 @@ def collect_object_classification_examples( # Step 5: Save to train directory for later classification train_dir = os.path.join(CLIPS_DIR, model_name, "train") + if os.path.exists(train_dir): + shutil.rmtree(train_dir) os.makedirs(train_dir, exist_ok=True) saved_count = 0 @@ -764,8 +767,6 @@ def collect_object_classification_examples( except Exception as e: logger.error(f"Failed to save image {image_path}: {e}") - import shutil - try: shutil.rmtree(temp_dir) except Exception as e: @@ -806,24 +807,25 @@ def _select_balanced_events( selected = [] for group_events in grouped.values(): + # Take top events by score, then randomly sample from them sorted_events = sorted( group_events, key=lambda e: e.data.get("score", 0) if e.data else 0, reverse=True, ) - sample_size = min(samples_per_group, len(sorted_events)) - selected.extend(sorted_events[:sample_size]) + # Consider top 3x candidates to allow randomness while preferring higher scores + candidate_pool = sorted_events[: samples_per_group * 3] + sample_size = min(samples_per_group, len(candidate_pool)) + selected.extend(random.sample(candidate_pool, sample_size)) if len(selected) < target_count: remaining = [e for e in events if e not in selected] - remaining_sorted = sorted( - remaining, - key=lambda e: e.data.get("score", 0) if e.data else 0, - reverse=True, - ) needed = target_count - len(selected) - selected.extend(remaining_sorted[:needed]) + if len(remaining) > needed: + selected.extend(random.sample(remaining, needed)) + else: + selected.extend(remaining) return selected[:target_count]