mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 14:47:40 +03:00
add randomness to object classification
also ensure train_dir is fresh if user has regenerated examples
This commit is contained in:
parent
573a5ede62
commit
57a9233977
@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
@ -397,6 +398,8 @@ def collect_state_classification_examples(
|
||||
|
||||
# Step 5: Save to train directory for later classification
|
||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||
if os.path.exists(train_dir):
|
||||
shutil.rmtree(train_dir)
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
|
||||
saved_count = 0
|
||||
@ -411,8 +414,6 @@ def collect_state_classification_examples(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_path}: {e}")
|
||||
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
@ -750,6 +751,8 @@ def collect_object_classification_examples(
|
||||
|
||||
# Step 5: Save to train directory for later classification
|
||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||
if os.path.exists(train_dir):
|
||||
shutil.rmtree(train_dir)
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
|
||||
saved_count = 0
|
||||
@ -764,8 +767,6 @@ def collect_object_classification_examples(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_path}: {e}")
|
||||
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
@ -806,24 +807,25 @@ def _select_balanced_events(
|
||||
selected = []
|
||||
|
||||
for group_events in grouped.values():
|
||||
# Take top events by score, then randomly sample from them
|
||||
sorted_events = sorted(
|
||||
group_events,
|
||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
sample_size = min(samples_per_group, len(sorted_events))
|
||||
selected.extend(sorted_events[:sample_size])
|
||||
# Consider top 3x candidates to allow randomness while preferring higher scores
|
||||
candidate_pool = sorted_events[: samples_per_group * 3]
|
||||
sample_size = min(samples_per_group, len(candidate_pool))
|
||||
selected.extend(random.sample(candidate_pool, sample_size))
|
||||
|
||||
if len(selected) < target_count:
|
||||
remaining = [e for e in events if e not in selected]
|
||||
remaining_sorted = sorted(
|
||||
remaining,
|
||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||
reverse=True,
|
||||
)
|
||||
needed = target_count - len(selected)
|
||||
selected.extend(remaining_sorted[:needed])
|
||||
if len(remaining) > needed:
|
||||
selected.extend(random.sample(remaining, needed))
|
||||
else:
|
||||
selected.extend(remaining)
|
||||
|
||||
return selected[:target_count]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user