refactor centroidtracker into standard API

This commit is contained in:
Blake Blackshear 2023-05-14 08:11:12 -05:00
parent ab50d0b006
commit fc6d98a2ed
4 changed files with 42 additions and 20 deletions

13
frigate/track/__init__.py Normal file
View File

@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
from frigate.config import DetectConfig
class ObjectTracker(ABC):
@abstractmethod
def __init__(self, config: DetectConfig):
pass
@abstractmethod
def match_and_update(self, detections):
pass

View File

@ -4,13 +4,13 @@ from collections import defaultdict
import numpy as np
from scipy.spatial import distance as dist
from frigate.config import DetectConfig
from frigate.track import ObjectTracker
from frigate.util import intersection_over_union
class ObjectTracker:
def __init__(self, config: DetectConfig):
class CentroidTracker(ObjectTracker):
def __init__(self, config: DetectConfig):
self.tracked_objects = {}
self.disappeared = {}
self.positions = {}
@ -134,11 +134,11 @@ class ObjectTracker:
if self.is_expired(id):
self.deregister(id)
def match_and_update(self, frame_time, new_objects):
def match_and_update(self, frame_time, detections):
# group by name
new_object_groups = defaultdict(lambda: [])
for obj in new_objects:
new_object_groups[obj[0]].append(
detection_groups = defaultdict(lambda: [])
for obj in detections:
detection_groups[obj[0]].append(
{
"label": obj[0],
"score": obj[1],
@ -153,17 +153,17 @@ class ObjectTracker:
# update any tracked objects with labels that are not
# seen in the current objects and deregister if needed
for obj in list(self.tracked_objects.values()):
if obj["label"] not in new_object_groups:
if obj["label"] not in detection_groups:
if self.disappeared[obj["id"]] >= self.max_disappeared:
self.deregister(obj["id"])
else:
self.disappeared[obj["id"]] += 1
if len(new_objects) == 0:
if len(detections) == 0:
return
# track objects for each label type
for label, group in new_object_groups.items():
for label, group in detection_groups.items():
current_objects = [
o for o in self.tracked_objects.values() if o["label"] == label
]
@ -236,4 +236,4 @@ class ObjectTracker:
# register each new input centroid as a trackable object
else:
for col in unusedCols:
self.register(col, group[col])
self.register(col, group[col])

View File

@ -19,7 +19,8 @@ from frigate.const import CACHE_DIR
from frigate.log import LogPipe
from frigate.motion import MotionDetector
from frigate.object_detection import RemoteObjectDetector
from frigate.objects import ObjectTracker
from frigate.track import ObjectTracker
from frigate.track.centroid_tracker import CentroidTracker
from frigate.util import (
EventsPerSecond,
FrameManager,
@ -472,7 +473,7 @@ def track_camera(
name, labelmap, detection_queue, result_connection, model_config, stop_event
)
object_tracker = ObjectTracker(config.detect)
object_tracker = CentroidTracker(config.detect)
frame_manager = SharedMemoryFrameManager()

View File

@ -1,4 +1,11 @@
import csv
import sys
from typing_extensions import runtime
from frigate.track.centroid_tracker import CentroidTracker
sys.path.append("/workspace/frigate")
import json
import logging
import multiprocessing as mp
@ -10,13 +17,14 @@ import click
import cv2
import numpy as np
sys.path.append("/lab/frigate")
from frigate.config import FrigateConfig # noqa: E402
from frigate.motion import MotionDetector # noqa: E402
from frigate.object_detection import LocalObjectDetector # noqa: E402
from frigate.object_processing import CameraState # noqa: E402
from frigate.objects import ObjectTracker # noqa: E402
from frigate.config import FrigateConfig
from frigate.object_detection import LocalObjectDetector
from frigate.motion import MotionDetector
from frigate.object_processing import CameraState
from frigate.util import ( # noqa: E402
EventsPerSecond,
SharedMemoryFrameManager,
@ -108,7 +116,7 @@ class ProcessClip:
motion_detector = MotionDetector(self.frame_shape, self.camera_config.motion)
motion_detector.save_images = False
object_tracker = ObjectTracker(self.camera_config.detect)
object_tracker = CentroidTracker(self.camera_config.detect)
process_info = {
"process_fps": mp.Value("d", 0.0),
"detection_fps": mp.Value("d", 0.0),
@ -248,7 +256,7 @@ def process(path, label, output, debug_path):
clips.append(path)
json_config = {
"mqtt": {"host": "mqtt"},
"mqtt": {"enabled": False},
"detectors": {"coral": {"type": "edgetpu", "device": "usb"}},
"cameras": {
"camera": {
@ -282,7 +290,7 @@ def process(path, label, output, debug_path):
json_config["cameras"]["camera"]["ffmpeg"]["inputs"][0]["path"] = c
frigate_config = FrigateConfig(**json_config)
runtime_config = frigate_config.runtime_config
runtime_config = frigate_config.runtime_config()
runtime_config.cameras["camera"].create_ffmpeg_cmds()
process_clip = ProcessClip(c, frame_shape, runtime_config)