mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 22:57:40 +03:00
add randomness to object classification
also ensure train_dir is fresh if user has regenerated examples
This commit is contained in:
parent
573a5ede62
commit
57a9233977
@ -5,6 +5,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -397,6 +398,8 @@ def collect_state_classification_examples(
|
|||||||
|
|
||||||
# Step 5: Save to train directory for later classification
|
# Step 5: Save to train directory for later classification
|
||||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||||
|
if os.path.exists(train_dir):
|
||||||
|
shutil.rmtree(train_dir)
|
||||||
os.makedirs(train_dir, exist_ok=True)
|
os.makedirs(train_dir, exist_ok=True)
|
||||||
|
|
||||||
saved_count = 0
|
saved_count = 0
|
||||||
@ -411,8 +414,6 @@ def collect_state_classification_examples(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save image {image_path}: {e}")
|
logger.error(f"Failed to save image {image_path}: {e}")
|
||||||
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -750,6 +751,8 @@ def collect_object_classification_examples(
|
|||||||
|
|
||||||
# Step 5: Save to train directory for later classification
|
# Step 5: Save to train directory for later classification
|
||||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||||
|
if os.path.exists(train_dir):
|
||||||
|
shutil.rmtree(train_dir)
|
||||||
os.makedirs(train_dir, exist_ok=True)
|
os.makedirs(train_dir, exist_ok=True)
|
||||||
|
|
||||||
saved_count = 0
|
saved_count = 0
|
||||||
@ -764,8 +767,6 @@ def collect_object_classification_examples(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save image {image_path}: {e}")
|
logger.error(f"Failed to save image {image_path}: {e}")
|
||||||
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -806,24 +807,25 @@ def _select_balanced_events(
|
|||||||
selected = []
|
selected = []
|
||||||
|
|
||||||
for group_events in grouped.values():
|
for group_events in grouped.values():
|
||||||
|
# Take top events by score, then randomly sample from them
|
||||||
sorted_events = sorted(
|
sorted_events = sorted(
|
||||||
group_events,
|
group_events,
|
||||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_size = min(samples_per_group, len(sorted_events))
|
# Consider top 3x candidates to allow randomness while preferring higher scores
|
||||||
selected.extend(sorted_events[:sample_size])
|
candidate_pool = sorted_events[: samples_per_group * 3]
|
||||||
|
sample_size = min(samples_per_group, len(candidate_pool))
|
||||||
|
selected.extend(random.sample(candidate_pool, sample_size))
|
||||||
|
|
||||||
if len(selected) < target_count:
|
if len(selected) < target_count:
|
||||||
remaining = [e for e in events if e not in selected]
|
remaining = [e for e in events if e not in selected]
|
||||||
remaining_sorted = sorted(
|
|
||||||
remaining,
|
|
||||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
needed = target_count - len(selected)
|
needed = target_count - len(selected)
|
||||||
selected.extend(remaining_sorted[:needed])
|
if len(remaining) > needed:
|
||||||
|
selected.extend(random.sample(remaining, needed))
|
||||||
|
else:
|
||||||
|
selected.extend(remaining)
|
||||||
|
|
||||||
return selected[:target_count]
|
return selected[:target_count]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user