Improve stationary classification (#20303)

* Improve stationary classification

* Cleanup for mypy
This commit is contained in:
Nicolas Mowen 2025-10-01 06:39:11 -06:00 committed by GitHub
parent 28e3aa39f0
commit 8f0be18422
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 20 deletions

View File

@ -17,7 +17,11 @@ from frigate.camera import PTZMetrics
from frigate.config import CameraConfig
from frigate.ptz.autotrack import PtzMotionEstimator
from frigate.track import ObjectTracker
from frigate.track.stationary_classifier import StationaryMotionClassifier
from frigate.track.stationary_classifier import (
StationaryMotionClassifier,
StationaryThresholds,
get_stationary_threshold,
)
from frigate.util.image import (
SharedMemoryFrameManager,
get_histogram,
@ -28,12 +32,6 @@ from frigate.util.object import average_boxes, median_of_boxes
logger = logging.getLogger(__name__)
THRESHOLD_KNOWN_ACTIVE_IOU = 0.2
THRESHOLD_STATIONARY_CHECK_IOU = 0.6
THRESHOLD_ACTIVE_CHECK_IOU = 0.9
MAX_STATIONARY_HISTORY = 10
# Normalizes distance from estimate relative to object size
# Other ideas:
# - if estimates are inaccurate for first N detections, compare with last_detection (may be fine)
@ -328,6 +326,7 @@ class NorfairTracker(ObjectTracker):
id: str,
box: list[int],
stationary: bool,
thresholds: StationaryThresholds,
yuv_frame: np.ndarray | None,
) -> bool:
def reset_position(xmin: int, ymin: int, xmax: int, ymax: int) -> None:
@ -346,9 +345,9 @@ class NorfairTracker(ObjectTracker):
position = self.positions[id]
self.stationary_box_history[id].append(box)
if len(self.stationary_box_history[id]) > MAX_STATIONARY_HISTORY:
if len(self.stationary_box_history[id]) > thresholds.max_stationary_history:
self.stationary_box_history[id] = self.stationary_box_history[id][
-MAX_STATIONARY_HISTORY:
-thresholds.max_stationary_history :
]
avg_box = average_boxes(self.stationary_box_history[id])
@ -367,7 +366,7 @@ class NorfairTracker(ObjectTracker):
# object has minimal or zero iou
# assume object is active
if avg_iou < THRESHOLD_KNOWN_ACTIVE_IOU:
if avg_iou < thresholds.known_active_iou:
if stationary and yuv_frame is not None:
if not self.stationary_classifier.evaluate(
id, yuv_frame, cast(tuple[int, int, int, int], tuple(box))
@ -379,7 +378,9 @@ class NorfairTracker(ObjectTracker):
return False
threshold = (
THRESHOLD_STATIONARY_CHECK_IOU if stationary else THRESHOLD_ACTIVE_CHECK_IOU
thresholds.stationary_check_iou
if stationary
else thresholds.active_check_iou
)
# object has iou below threshold, check median and optionally crop similarity
@ -447,6 +448,7 @@ class NorfairTracker(ObjectTracker):
self,
track_id: str,
obj: dict[str, Any],
thresholds: StationaryThresholds,
yuv_frame: np.ndarray | None,
) -> None:
id = self.track_id_map[track_id]
@ -456,7 +458,7 @@ class NorfairTracker(ObjectTracker):
>= self.detect_config.stationary.threshold
)
# update the motionless count if the object has not moved to a new position
if self.update_position(id, obj["box"], stationary, yuv_frame):
if self.update_position(id, obj["box"], stationary, thresholds, yuv_frame):
self.tracked_objects[id]["motionless_count"] += 1
if self.is_expired(id):
self.deregister(id, track_id)
@ -502,9 +504,9 @@ class NorfairTracker(ObjectTracker):
detections_by_type: dict[str, list[Detection]] = {}
yuv_frame: np.ndarray | None = None
if self.ptz_metrics.autotracker_enabled.value or (
self.detect_config.stationary.classifier
and any(obj[0] == "car" for obj in detections)
if (
self.ptz_metrics.autotracker_enabled.value
or self.detect_config.stationary.classifier
):
yuv_frame = self.frame_manager.get(
frame_name, self.camera_config.frame_shape_yuv
@ -614,10 +616,12 @@ class NorfairTracker(ObjectTracker):
self.tracked_objects[id]["estimate"] = new_obj["estimate"]
# else update it
else:
thresholds = get_stationary_threshold(new_obj["label"])
self.update(
str(t.global_id),
new_obj,
yuv_frame if new_obj["label"] == "car" else None,
thresholds,
yuv_frame if thresholds.motion_classifier_enabled else None,
)
# clear expired tracks

View File

@ -1,6 +1,7 @@
"""Tools for determining if an object is stationary."""
import logging
from dataclasses import dataclass
from typing import Any, cast
import cv2
@ -10,10 +11,60 @@ from scipy.ndimage import gaussian_filter
logger = logging.getLogger(__name__)
THRESHOLD_KNOWN_ACTIVE_IOU = 0.2
THRESHOLD_STATIONARY_CHECK_IOU = 0.6
THRESHOLD_ACTIVE_CHECK_IOU = 0.9
MAX_STATIONARY_HISTORY = 10
@dataclass
class StationaryThresholds:
"""IOU thresholds and history parameters for stationary object classification.
This allows different sensitivity settings for different object types.
"""
# Objects to apply these thresholds to
# If None, apply to all objects
objects: list[str] = []
# Threshold of IoU that causes the object to immediately be considered active
# Below this threshold, assume object is active
known_active_iou: float = 0.2
# IOU threshold for checking if stationary object has moved
# If mean and median IOU drops below this, assume object is no longer stationary
stationary_check_iou: float = 0.6
# IOU threshold for checking if active object has changed position
# Higher threshold makes it more difficult for the object to be considered stationary
active_check_iou: float = 0.9
# Maximum number of bounding boxes to keep in stationary history
max_stationary_history: int = 10
# Whether to use the motion classifier
motion_classifier_enabled: bool = False
# Thresholds for objects that are expected to be stationary
STATIONARY_OBJECT_THRESHOLDS = StationaryThresholds(
objects=["bbq_grill", "package", "waste_bin"],
known_active_iou=0.0,
motion_classifier_enabled=True,
)
# Thresholds for objects that are active but can be stationary for longer periods of time
DYNAMIC_OBJECT_THRESHOLDS = StationaryThresholds(
objects=["bicycle", "boat", "car", "motorcycle", "tractor", "truck"],
motion_classifier_enabled=True,
)
def get_stationary_threshold(label: str) -> StationaryThresholds:
"""Get the stationary thresholds for a given object label."""
if label in STATIONARY_OBJECT_THRESHOLDS.objects:
return STATIONARY_OBJECT_THRESHOLDS
if label in DYNAMIC_OBJECT_THRESHOLDS.objects:
return DYNAMIC_OBJECT_THRESHOLDS
return StationaryThresholds()
class StationaryMotionClassifier: