add randomness to object classification

also ensure train_dir is fresh if user has regenerated examples
This commit is contained in:
Josh Hawkins 2026-03-24 07:38:04 -05:00
parent 573a5ede62
commit 57a9233977

View File

@ -5,6 +5,7 @@ import json
import logging import logging
import os import os
import random import random
import shutil
from collections import defaultdict from collections import defaultdict
import cv2 import cv2
@ -397,6 +398,8 @@ def collect_state_classification_examples(
# Step 5: Save to train directory for later classification # Step 5: Save to train directory for later classification
train_dir = os.path.join(CLIPS_DIR, model_name, "train") train_dir = os.path.join(CLIPS_DIR, model_name, "train")
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
os.makedirs(train_dir, exist_ok=True) os.makedirs(train_dir, exist_ok=True)
saved_count = 0 saved_count = 0
@ -411,8 +414,6 @@ def collect_state_classification_examples(
except Exception as e: except Exception as e:
logger.error(f"Failed to save image {image_path}: {e}") logger.error(f"Failed to save image {image_path}: {e}")
import shutil
try: try:
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
except Exception as e: except Exception as e:
@ -750,6 +751,8 @@ def collect_object_classification_examples(
# Step 5: Save to train directory for later classification # Step 5: Save to train directory for later classification
train_dir = os.path.join(CLIPS_DIR, model_name, "train") train_dir = os.path.join(CLIPS_DIR, model_name, "train")
if os.path.exists(train_dir):
shutil.rmtree(train_dir)
os.makedirs(train_dir, exist_ok=True) os.makedirs(train_dir, exist_ok=True)
saved_count = 0 saved_count = 0
@ -764,8 +767,6 @@ def collect_object_classification_examples(
except Exception as e: except Exception as e:
logger.error(f"Failed to save image {image_path}: {e}") logger.error(f"Failed to save image {image_path}: {e}")
import shutil
try: try:
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
except Exception as e: except Exception as e:
@ -806,24 +807,25 @@ def _select_balanced_events(
selected = [] selected = []
for group_events in grouped.values(): for group_events in grouped.values():
# Take top events by score, then randomly sample from them
sorted_events = sorted( sorted_events = sorted(
group_events, group_events,
key=lambda e: e.data.get("score", 0) if e.data else 0, key=lambda e: e.data.get("score", 0) if e.data else 0,
reverse=True, reverse=True,
) )
sample_size = min(samples_per_group, len(sorted_events)) # Consider top 3x candidates to allow randomness while preferring higher scores
selected.extend(sorted_events[:sample_size]) candidate_pool = sorted_events[: samples_per_group * 3]
sample_size = min(samples_per_group, len(candidate_pool))
selected.extend(random.sample(candidate_pool, sample_size))
if len(selected) < target_count: if len(selected) < target_count:
remaining = [e for e in events if e not in selected] remaining = [e for e in events if e not in selected]
remaining_sorted = sorted(
remaining,
key=lambda e: e.data.get("score", 0) if e.data else 0,
reverse=True,
)
needed = target_count - len(selected) needed = target_count - len(selected)
selected.extend(remaining_sorted[:needed]) if len(remaining) > needed:
selected.extend(random.sample(remaining, needed))
else:
selected.extend(remaining)
return selected[:target_count] return selected[:target_count]