Fix typing and imports of centroid tracker

This commit is contained in:
Nicolas Mowen 2025-08-15 14:07:45 -06:00
parent 4872bb6bc9
commit e4a44ff8ff

View File

@ -1,25 +1,26 @@
import random
import string
from collections import defaultdict
from typing import Any
import numpy as np
from scipy.spatial import distance as dist
from frigate.config import DetectConfig
from frigate.track import ObjectTracker
from frigate.util import intersection_over_union
from frigate.util.image import intersection_over_union
class CentroidTracker(ObjectTracker):
def __init__(self, config: DetectConfig):
self.tracked_objects = {}
self.untracked_object_boxes = []
self.disappeared = {}
self.positions = {}
self.tracked_objects: dict[str, Any] = {}
self.untracked_object_boxes: list[tuple[int, int, int, int]] = []
self.disappeared: dict[str, Any] = {}
self.positions: dict[str, Any] = {}
self.max_disappeared = config.max_disappeared
self.detect_config = config
def register(self, index, obj):
def register(self, obj: dict[str, Any]) -> None:
rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
id = f"{obj['frame_time']}-{rand_id}"
obj["id"] = id
@ -39,13 +40,13 @@ class CentroidTracker(ObjectTracker):
"ymax": self.detect_config.height,
}
def deregister(self, id):
def deregister(self, id: str) -> None:
del self.tracked_objects[id]
del self.disappeared[id]
# tracks the current position of the object based on the last N bounding boxes
# returns False if the object has moved outside its previous position
def update_position(self, id, box):
def update_position(self, id: str, box: tuple[int, int, int, int]) -> bool:
position = self.positions[id]
position_box = (
position["xmin"],
@ -88,7 +89,7 @@ class CentroidTracker(ObjectTracker):
return True
def is_expired(self, id):
def is_expired(self, id: str) -> bool:
obj = self.tracked_objects[id]
# get the max frames for this label type or the default
max_frames = self.detect_config.stationary.max_frames.objects.get(
@ -108,7 +109,7 @@ class CentroidTracker(ObjectTracker):
return False
def update(self, id, new_obj):
def update(self, id: str, new_obj: dict[str, Any]) -> None:
self.disappeared[id] = 0
# update the motionless count if the object has not moved to a new position
if self.update_position(id, new_obj["box"]):
@ -129,14 +130,16 @@ class CentroidTracker(ObjectTracker):
self.tracked_objects[id].update(new_obj)
def update_frame_times(self, frame_name, frame_time):
def update_frame_times(self, frame_name: str, frame_time: float) -> None:
for id in list(self.tracked_objects.keys()):
self.tracked_objects[id]["frame_time"] = frame_time
self.tracked_objects[id]["motionless_count"] += 1
if self.is_expired(id):
self.deregister(id)
def match_and_update(self, frame_time, detections):
def match_and_update(
self, frame_name: str, frame_time: float, detections: list[dict[Any, Any]]
) -> None:
# group by name
detection_groups = defaultdict(lambda: [])
for obj in detections:
@ -180,7 +183,7 @@ class CentroidTracker(ObjectTracker):
if len(current_objects) == 0:
for index, obj in enumerate(group):
self.register(index, obj)
self.register(obj)
continue
new_centroids = np.array([o["centroid"] for o in group])
@ -238,4 +241,4 @@ class CentroidTracker(ObjectTracker):
# register each new input centroid as a trackable object
else:
for col in unusedCols:
self.register(col, group[col])
self.register(group[col])