Make number of classification images to be kept configurable

This commit is contained in:
Nicolas Mowen 2025-12-01 12:49:13 -07:00
parent d5f5e93f4f
commit 755f51f1ad
5 changed files with 63 additions and 3 deletions

View File

@ -710,6 +710,38 @@ audio_transcription:
# List of language codes: https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 # List of language codes: https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
language: en language: en
# Optional: Configuration for custom classification models
classification:
custom:
# Required: name of the classification model
model_name:
# Optional: Enable running the model (default: shown below)
enabled: True
# Optional: Name of classification model (default: shown below)
name: None
# Optional: Classification score threshold to change the state (default: shown below)
threshold: 0.8
# Optional: Number of classification attempts to save in the recent classifications tab (default: shown below)
# NOTE: Defaults to 200 for object classification and 100 for state classification if not specified
save_attempts: None
# Optional: Object classification configuration
object_config:
# Required: Object types to classify
objects: [dog]
# Optional: Type of classification that is applied (default: shown below)
classification_type: sub_label
# Optional: State classification configuration
state_config:
# Required: Cameras to run classification on
cameras:
camera_name:
# Required: Crop of image frame on this camera to run classification on
crop: [0, 180, 220, 400]
# Optional: If classification should be run when motion is detected in the crop (default: shown below)
motion: False
# Optional: Interval to run classification on in seconds (default: shown below)
interval: None
# Optional: Restream configuration # Optional: Restream configuration
# Uses https://github.com/AlexxIT/go2rtc (v1.9.10) # Uses https://github.com/AlexxIT/go2rtc (v1.9.10)
# NOTE: The default go2rtc API port (1984) must be used, # NOTE: The default go2rtc API port (1984) must be used,

View File

@ -105,6 +105,11 @@ class CustomClassificationConfig(FrigateBaseModel):
threshold: float = Field( threshold: float = Field(
default=0.8, title="Classification score threshold to change the state." default=0.8, title="Classification score threshold to change the state."
) )
save_attempts: int | None = Field(
default=None,
title="Number of classification attempts to save in the recent classifications tab. If not specified, defaults to 200 for object classification and 100 for state classification.",
ge=0,
)
object_config: CustomClassificationObjectConfig | None = Field(default=None) object_config: CustomClassificationObjectConfig | None = Field(default=None)
state_config: CustomClassificationStateConfig | None = Field(default=None) state_config: CustomClassificationStateConfig | None = Field(default=None)

View File

@ -250,6 +250,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
if self.interpreter is None: if self.interpreter is None:
# When interpreter is None, always save (score is 0.0, which is < 1.0) # When interpreter is None, always save (score is 0.0, which is < 1.0)
if self._should_save_image(camera, "unknown", 0.0): if self._should_save_image(camera, "unknown", 0.0):
save_attempts = (
self.model_config.save_attempts
if self.model_config.save_attempts is not None
else 100
)
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
@ -257,6 +262,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
now, now,
"unknown", "unknown",
0.0, 0.0,
max_files=save_attempts,
) )
return return
@ -277,6 +283,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
detected_state = self.labelmap[best_id] detected_state = self.labelmap[best_id]
if self._should_save_image(camera, detected_state, score): if self._should_save_image(camera, detected_state, score):
save_attempts = (
self.model_config.save_attempts
if self.model_config.save_attempts is not None
else 100
)
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
@ -284,6 +295,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
now, now,
detected_state, detected_state,
score, score,
max_files=save_attempts,
) )
if score < self.model_config.threshold: if score < self.model_config.threshold:
@ -482,6 +494,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
return return
if self.interpreter is None: if self.interpreter is None:
save_attempts = (
self.model_config.save_attempts
if self.model_config.save_attempts is not None
else 200
)
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
@ -489,6 +506,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
now, now,
"unknown", "unknown",
0.0, 0.0,
max_files=save_attempts,
) )
return return
@ -506,6 +524,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now) self.__update_metrics(datetime.datetime.now().timestamp() - now)
save_attempts = (
self.model_config.save_attempts
if self.model_config.save_attempts is not None
else 200
)
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
@ -513,7 +536,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
now, now,
self.labelmap[best_id], self.labelmap[best_id],
score, score,
max_files=200, max_files=save_attempts,
) )
if score < self.model_config.threshold: if score < self.model_config.threshold:

View File

@ -330,7 +330,7 @@ def collect_state_classification_examples(
1. Queries review items from specified cameras 1. Queries review items from specified cameras
2. Selects 100 balanced timestamps across the data 2. Selects 100 balanced timestamps across the data
3. Extracts keyframes from recordings (cropped to specified regions) 3. Extracts keyframes from recordings (cropped to specified regions)
4. Selects 20 most visually distinct images 4. Selects 24 most visually distinct images
5. Saves them to the dataset directory 5. Saves them to the dataset directory
Args: Args:
@ -660,7 +660,6 @@ def collect_object_classification_examples(
Args: Args:
model_name: Name of the classification model model_name: Name of the classification model
label: Object label to collect (e.g., "person", "car") label: Object label to collect (e.g., "person", "car")
cameras: List of camera names to collect examples from
""" """
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
temp_dir = os.path.join(dataset_dir, "temp") temp_dir = os.path.join(dataset_dir, "temp")

View File

@ -305,6 +305,7 @@ export type CustomClassificationModelConfig = {
enabled: boolean; enabled: boolean;
name: string; name: string;
threshold: number; threshold: number;
save_attempts?: number;
object_config?: { object_config?: {
objects: string[]; objects: string[];
classification_type: string; classification_type: string;