add randomness to object classification

also ensure train_dir is fresh if user has regenerated examples
This commit is contained in:
Josh Hawkins 2026-03-24 07:38:04 -05:00
parent 573a5ede62
commit 57a9233977

View File

@ -5,6 +5,7 @@ import json
import logging
import os
import random
import shutil
from collections import defaultdict
import cv2
@ -397,6 +398,8 @@ def collect_state_classification_examples(
# Step 5: Save to train directory for later classification
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
os.makedirs(train_dir, exist_ok=True)
saved_count = 0
@ -411,8 +414,6 @@ def collect_state_classification_examples(
except Exception as e:
logger.error(f"Failed to save image {image_path}: {e}")
import shutil
try:
shutil.rmtree(temp_dir)
except Exception as e:
@ -750,6 +751,8 @@ def collect_object_classification_examples(
# Step 5: Save to train directory for later classification
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
os.makedirs(train_dir, exist_ok=True)
saved_count = 0
@ -764,8 +767,6 @@ def collect_object_classification_examples(
except Exception as e:
logger.error(f"Failed to save image {image_path}: {e}")
import shutil
try:
shutil.rmtree(temp_dir)
except Exception as e:
@ -806,24 +807,25 @@ def _select_balanced_events(
selected = []
for group_events in grouped.values():
# Take top events by score, then randomly sample from them
sorted_events = sorted(
group_events,
key=lambda e: e.data.get("score", 0) if e.data else 0,
reverse=True,
)
sample_size = min(samples_per_group, len(sorted_events))
selected.extend(sorted_events[:sample_size])
# Consider top 3x candidates to allow randomness while preferring higher scores
candidate_pool = sorted_events[: samples_per_group * 3]
sample_size = min(samples_per_group, len(candidate_pool))
selected.extend(random.sample(candidate_pool, sample_size))
if len(selected) < target_count:
remaining = [e for e in events if e not in selected]
remaining_sorted = sorted(
remaining,
key=lambda e: e.data.get("score", 0) if e.data else 0,
reverse=True,
)
needed = target_count - len(selected)
selected.extend(remaining_sorted[:needed])
if len(remaining) > needed:
selected.extend(random.sample(remaining, needed))
else:
selected.extend(remaining)
return selected[:target_count]