diff --git a/frigate/track/centroid_tracker.py b/frigate/track/centroid_tracker.py index 25d4cb860..9b28cc236 100644 --- a/frigate/track/centroid_tracker.py +++ b/frigate/track/centroid_tracker.py @@ -1,25 +1,26 @@ import random import string from collections import defaultdict +from typing import Any import numpy as np from scipy.spatial import distance as dist from frigate.config import DetectConfig from frigate.track import ObjectTracker -from frigate.util import intersection_over_union +from frigate.util.image import intersection_over_union class CentroidTracker(ObjectTracker): def __init__(self, config: DetectConfig): - self.tracked_objects = {} - self.untracked_object_boxes = [] - self.disappeared = {} - self.positions = {} + self.tracked_objects: dict[str, Any] = {} + self.untracked_object_boxes: list[tuple[int, int, int, int]] = [] + self.disappeared: dict[str, Any] = {} + self.positions: dict[str, Any] = {} self.max_disappeared = config.max_disappeared self.detect_config = config - def register(self, index, obj): + def register(self, obj: dict[str, Any]) -> None: rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) id = f"{obj['frame_time']}-{rand_id}" obj["id"] = id @@ -39,13 +40,13 @@ class CentroidTracker(ObjectTracker): "ymax": self.detect_config.height, } - def deregister(self, id): + def deregister(self, id: str) -> None: del self.tracked_objects[id] del self.disappeared[id] # tracks the current position of the object based on the last N bounding boxes # returns False if the object has moved outside its previous position - def update_position(self, id, box): + def update_position(self, id: str, box: tuple[int, int, int, int]) -> bool: position = self.positions[id] position_box = ( position["xmin"], @@ -88,7 +89,7 @@ class CentroidTracker(ObjectTracker): return True - def is_expired(self, id): + def is_expired(self, id: str) -> bool: obj = self.tracked_objects[id] # get the max frames for this label type or the default max_frames = self.detect_config.stationary.max_frames.objects.get( @@ -108,7 +109,7 @@ class CentroidTracker(ObjectTracker): return False - def update(self, id, new_obj): + def update(self, id: str, new_obj: dict[str, Any]) -> None: self.disappeared[id] = 0 # update the motionless count if the object has not moved to a new position if self.update_position(id, new_obj["box"]): @@ -129,14 +130,16 @@ class CentroidTracker(ObjectTracker): self.tracked_objects[id].update(new_obj) - def update_frame_times(self, frame_name, frame_time): + def update_frame_times(self, frame_name: str, frame_time: float) -> None: for id in list(self.tracked_objects.keys()): self.tracked_objects[id]["frame_time"] = frame_time self.tracked_objects[id]["motionless_count"] += 1 if self.is_expired(id): self.deregister(id) - def match_and_update(self, frame_time, detections): + def match_and_update( + self, frame_name: str, frame_time: float, detections: list[dict[Any, Any]] + ) -> None: # group by name detection_groups = defaultdict(lambda: []) for obj in detections: @@ -180,7 +183,7 @@ class CentroidTracker(ObjectTracker): if len(current_objects) == 0: for index, obj in enumerate(group): - self.register(index, obj) + self.register(obj) continue new_centroids = np.array([o["centroid"] for o in group]) @@ -238,4 +241,4 @@ class CentroidTracker(ObjectTracker): # register each new input centroid as a trackable object else: for col in unusedCols: - self.register(col, group[col]) + self.register(group[col])