Fix typing and imports of centroid tracker

This commit is contained in:
Nicolas Mowen 2025-08-15 14:07:45 -06:00
parent 4872bb6bc9
commit e4a44ff8ff

View File

@ -1,25 +1,26 @@
import random import random
import string import string
from collections import defaultdict from collections import defaultdict
from typing import Any
import numpy as np import numpy as np
from scipy.spatial import distance as dist from scipy.spatial import distance as dist
from frigate.config import DetectConfig from frigate.config import DetectConfig
from frigate.track import ObjectTracker from frigate.track import ObjectTracker
from frigate.util import intersection_over_union from frigate.util.image import intersection_over_union
class CentroidTracker(ObjectTracker): class CentroidTracker(ObjectTracker):
def __init__(self, config: DetectConfig): def __init__(self, config: DetectConfig):
self.tracked_objects = {} self.tracked_objects: dict[str, Any] = {}
self.untracked_object_boxes = [] self.untracked_object_boxes: list[tuple[int, int, int, int]] = []
self.disappeared = {} self.disappeared: dict[str, Any] = {}
self.positions = {} self.positions: dict[str, Any] = {}
self.max_disappeared = config.max_disappeared self.max_disappeared = config.max_disappeared
self.detect_config = config self.detect_config = config
def register(self, index, obj): def register(self, obj: dict[str, Any]) -> None:
rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
id = f"{obj['frame_time']}-{rand_id}" id = f"{obj['frame_time']}-{rand_id}"
obj["id"] = id obj["id"] = id
@ -39,13 +40,13 @@ class CentroidTracker(ObjectTracker):
"ymax": self.detect_config.height, "ymax": self.detect_config.height,
} }
def deregister(self, id): def deregister(self, id: str) -> None:
del self.tracked_objects[id] del self.tracked_objects[id]
del self.disappeared[id] del self.disappeared[id]
# tracks the current position of the object based on the last N bounding boxes # tracks the current position of the object based on the last N bounding boxes
# returns False if the object has moved outside its previous position # returns False if the object has moved outside its previous position
def update_position(self, id, box): def update_position(self, id: str, box: tuple[int, int, int, int]) -> bool:
position = self.positions[id] position = self.positions[id]
position_box = ( position_box = (
position["xmin"], position["xmin"],
@ -88,7 +89,7 @@ class CentroidTracker(ObjectTracker):
return True return True
def is_expired(self, id): def is_expired(self, id: str) -> bool:
obj = self.tracked_objects[id] obj = self.tracked_objects[id]
# get the max frames for this label type or the default # get the max frames for this label type or the default
max_frames = self.detect_config.stationary.max_frames.objects.get( max_frames = self.detect_config.stationary.max_frames.objects.get(
@ -108,7 +109,7 @@ class CentroidTracker(ObjectTracker):
return False return False
def update(self, id, new_obj): def update(self, id: str, new_obj: dict[str, Any]) -> None:
self.disappeared[id] = 0 self.disappeared[id] = 0
# update the motionless count if the object has not moved to a new position # update the motionless count if the object has not moved to a new position
if self.update_position(id, new_obj["box"]): if self.update_position(id, new_obj["box"]):
@ -129,14 +130,16 @@ class CentroidTracker(ObjectTracker):
self.tracked_objects[id].update(new_obj) self.tracked_objects[id].update(new_obj)
def update_frame_times(self, frame_name, frame_time): def update_frame_times(self, frame_name: str, frame_time: float) -> None:
for id in list(self.tracked_objects.keys()): for id in list(self.tracked_objects.keys()):
self.tracked_objects[id]["frame_time"] = frame_time self.tracked_objects[id]["frame_time"] = frame_time
self.tracked_objects[id]["motionless_count"] += 1 self.tracked_objects[id]["motionless_count"] += 1
if self.is_expired(id): if self.is_expired(id):
self.deregister(id) self.deregister(id)
def match_and_update(self, frame_time, detections): def match_and_update(
self, frame_name: str, frame_time: float, detections: list[dict[Any, Any]]
) -> None:
# group by name # group by name
detection_groups = defaultdict(lambda: []) detection_groups = defaultdict(lambda: [])
for obj in detections: for obj in detections:
@ -180,7 +183,7 @@ class CentroidTracker(ObjectTracker):
if len(current_objects) == 0: if len(current_objects) == 0:
for index, obj in enumerate(group): for index, obj in enumerate(group):
self.register(index, obj) self.register(obj)
continue continue
new_centroids = np.array([o["centroid"] for o in group]) new_centroids = np.array([o["centroid"] for o in group])
@ -238,4 +241,4 @@ class CentroidTracker(ObjectTracker):
# register each new input centroid as a trackable object # register each new input centroid as a trackable object
else: else:
for col in unusedCols: for col in unusedCols:
self.register(col, group[col]) self.register(group[col])