Cleanup typing

This commit is contained in:
Nicolas Mowen 2025-08-15 14:53:49 -06:00
parent e4a44ff8ff
commit 2c16e4f300
8 changed files with 102 additions and 87 deletions

View File

@ -54,7 +54,7 @@ class CameraState:
self.ptz_autotracker_thread = ptz_autotracker_thread
self.prev_enabled = self.camera_config.enabled
def get_current_frame(self, draw_options: dict[str, Any] = {}):
def get_current_frame(self, draw_options: dict[str, Any] = {}) -> np.ndarray:
with self.current_frame_lock:
frame_copy = np.copy(self._current_frame)
frame_time = self.current_frame_time
@ -272,7 +272,7 @@ class CameraState:
def finished(self, obj_id):
del self.tracked_objects[obj_id]
def on(self, event_type: str, callback: Callable[[dict], None]):
def on(self, event_type: str, callback: Callable[[str, TrackedObject, str]]):
self.callbacks[event_type].append(callback)
def update(

View File

@ -8,7 +8,7 @@ from .zmq_proxy import Publisher, Subscriber
class EventUpdatePublisher(
Publisher[tuple[EventTypeEnum, EventStateEnum, str, str, dict[str, Any]]]
Publisher[tuple[EventTypeEnum, EventStateEnum, str | None, str, dict[str, Any]]]
):
"""Publishes events (objects, audio, manual)."""
@ -19,7 +19,7 @@ class EventUpdatePublisher(
def publish(
self,
payload: tuple[EventTypeEnum, EventStateEnum, str, str, dict[str, Any]],
payload: tuple[EventTypeEnum, EventStateEnum, str | None, str, dict[str, Any]],
sub_topic: str = "",
) -> None:
super().publish(payload, sub_topic)

View File

@ -70,7 +70,7 @@ class Publisher(Generic[T]):
self.context.destroy()
class Subscriber:
class Subscriber(Generic[T]):
"""Receives messages."""
topic_base: str = ""
@ -82,9 +82,7 @@ class Subscriber:
self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic)
self.socket.connect(SOCKET_SUB)
def check_for_update(
self, timeout: float | None = FAST_QUEUE_TIMEOUT
) -> tuple[str, Any] | tuple[None, None] | None:
def check_for_update(self, timeout: float | None = FAST_QUEUE_TIMEOUT) -> T | None:
"""Returns message or None if no update."""
try:
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
@ -101,7 +99,5 @@ class Subscriber:
self.socket.close()
self.context.destroy()
def _return_object(
self, topic: str, payload: Optional[tuple[str, Any]]
) -> tuple[str, Any] | tuple[None, None] | None:
def _return_object(self, topic: str, payload: T | None) -> T | None:
return payload

View File

@ -60,10 +60,10 @@ class PtzMotionEstimator:
def motion_estimator(
self,
detections: list[dict[str, Any]],
detections: list[tuple[Any, Any, Any, Any, Any, Any]],
frame_name: str,
frame_time: float,
camera: str,
camera: str | None,
):
# If we've just started up or returned to our preset, reset motion estimator for new tracking session
if self.ptz_metrics.reset.is_set():

View File

@ -11,6 +11,6 @@ class ObjectTracker(ABC):
@abstractmethod
def match_and_update(
self, frame_name: str, frame_time: float, detections: list[dict[str, Any]]
self, frame_name: str, frame_time: float, detections: list[tuple[Any, Any, Any, Any, Any, Any]]
) -> None:
pass

View File

@ -13,7 +13,7 @@ from frigate.util.image import intersection_over_union
class CentroidTracker(ObjectTracker):
def __init__(self, config: DetectConfig):
self.tracked_objects: dict[str, Any] = {}
self.tracked_objects: dict[str, dict[str, Any]] = {}
self.untracked_object_boxes: list[tuple[int, int, int, int]] = []
self.disappeared: dict[str, Any] = {}
self.positions: dict[str, Any] = {}
@ -138,19 +138,19 @@ class CentroidTracker(ObjectTracker):
self.deregister(id)
def match_and_update(
self, frame_name: str, frame_time: float, detections: list[dict[Any, Any]]
self, frame_name: str, frame_time: float, detections: list[tuple[Any, Any, Any, Any, Any, Any]]
) -> None:
# group by name
detection_groups = defaultdict(lambda: [])
for obj in detections:
detection_groups[obj[0]].append(
for det in detections:
detection_groups[det[0]].append(
{
"label": obj[0],
"score": obj[1],
"box": obj[2],
"area": obj[3],
"ratio": obj[4],
"region": obj[5],
"label": det[0],
"score": det[1],
"box": det[2],
"area": det[3],
"ratio": det[4],
"region": det[5],
"frame_time": frame_time,
}
)

View File

@ -13,6 +13,7 @@ from norfair import (
draw_boxes,
)
from norfair.drawing.drawer import Drawer
from norfair.tracker import TrackedObject
from rich import print
from rich.console import Console
from rich.table import Table
@ -43,7 +44,7 @@ MAX_STATIONARY_HISTORY = 10
# - could be variable based on time since last_detection
# - include estimated velocity in the distance (car driving by of a parked car)
# - include some visual similarity factor in the distance for occlusions
def distance(detection: np.array, estimate: np.array) -> float:
def distance(detection: np.ndarray, estimate: np.ndarray) -> float:
# ultimately, this should try and estimate distance in 3-dimensional space
# consider change in location, width, and height
@ -73,14 +74,16 @@ def distance(detection: np.array, estimate: np.array) -> float:
change = np.append(distance, np.array([width_ratio, height_ratio]))
# calculate euclidean distance of the change vector
return np.linalg.norm(change)
return float(np.linalg.norm(change))
def frigate_distance(detection: Detection, tracked_object) -> float:
def frigate_distance(detection: Detection, tracked_object: TrackedObject) -> float:
return distance(detection.points, tracked_object.estimate)
def histogram_distance(matched_not_init_trackers, unmatched_trackers):
def histogram_distance(
matched_not_init_trackers: TrackedObject, unmatched_trackers: TrackedObject
) -> float:
snd_embedding = unmatched_trackers.last_detection.embedding
if snd_embedding is None:
@ -110,17 +113,17 @@ class NorfairTracker(ObjectTracker):
ptz_metrics: PTZMetrics,
):
self.frame_manager = SharedMemoryFrameManager()
self.tracked_objects = {}
self.tracked_objects: dict[str, dict[str, Any]] = {}
self.untracked_object_boxes: list[list[int]] = []
self.disappeared = {}
self.positions = {}
self.stationary_box_history: dict[str, list[list[int, int, int, int]]] = {}
self.disappeared: dict[str, int] = {}
self.positions: dict[str, dict[str, Any]] = {}
self.stationary_box_history: dict[str, list[list[int]]] = {}
self.camera_config = config
self.detect_config = config.detect
self.ptz_metrics = ptz_metrics
self.ptz_motion_estimator = {}
self.ptz_motion_estimator: PtzMotionEstimator | None = None
self.camera_name = config.name
self.track_id_map = {}
self.track_id_map: dict[str, str] = {}
# Define tracker configurations for static camera
self.object_type_configs = {
@ -169,7 +172,7 @@ class NorfairTracker(ObjectTracker):
"distance_threshold": 3,
}
self.trackers = {}
self.trackers: dict[str, dict[str, Tracker]] = {}
# Handle static trackers
for obj_type, tracker_config in self.object_type_configs.items():
if obj_type in self.camera_config.objects.track:
@ -216,7 +219,7 @@ class NorfairTracker(ObjectTracker):
self.camera_config, self.ptz_metrics
)
def _create_tracker(self, obj_type, tracker_config):
def _create_tracker(self, obj_type: str, tracker_config: dict[str, Any]) -> Tracker:
"""Helper function to create a tracker with given configuration."""
tracker_params = {
"distance_function": tracker_config["distance_function"],
@ -258,7 +261,7 @@ class NorfairTracker(ObjectTracker):
return self.trackers[object_type][mode]
return self.default_tracker[mode]
def register(self, track_id, obj):
def register(self, track_id: str, obj: dict[str, Any]) -> None:
rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
id = f"{obj['frame_time']}-{rand_id}"
self.track_id_map[track_id] = id
@ -297,7 +300,7 @@ class NorfairTracker(ObjectTracker):
}
self.stationary_box_history[id] = boxes
def deregister(self, id, track_id):
def deregister(self, id: str, track_id: str) -> None:
obj = self.tracked_objects[id]
del self.tracked_objects[id]
@ -321,7 +324,7 @@ class NorfairTracker(ObjectTracker):
# tracks the current position of the object based on the last N bounding boxes
# returns False if the object has moved outside its previous position
def update_position(self, id: str, box: list[int, int, int, int], stationary: bool):
def update_position(self, id: str, box: list[int], stationary: bool) -> bool:
xmin, ymin, xmax, ymax = box
position = self.positions[id]
self.stationary_box_history[id].append(box)
@ -396,7 +399,7 @@ class NorfairTracker(ObjectTracker):
return True
def is_expired(self, id):
def is_expired(self, id: str) -> bool:
obj = self.tracked_objects[id]
# get the max frames for this label type or the default
max_frames = self.detect_config.stationary.max_frames.objects.get(
@ -416,7 +419,7 @@ class NorfairTracker(ObjectTracker):
return False
def update(self, track_id, obj):
def update(self, track_id: str, obj: dict[str, Any]) -> None:
id = self.track_id_map[track_id]
self.disappeared[id] = 0
stationary = (
@ -443,7 +446,7 @@ class NorfairTracker(ObjectTracker):
self.tracked_objects[id].update(obj)
def update_frame_times(self, frame_name: str, frame_time: float):
def update_frame_times(self, frame_name: str, frame_time: float) -> None:
# if the object was there in the last frame, assume it's still there
detections = [
(
@ -460,10 +463,13 @@ class NorfairTracker(ObjectTracker):
self.match_and_update(frame_name, frame_time, detections=detections)
def match_and_update(
self, frame_name: str, frame_time: float, detections: list[dict[str, Any]]
):
self,
frame_name: str,
frame_time: float,
detections: list[tuple[Any, Any, Any, Any, Any, Any]],
) -> None:
# Group detections by object type
detections_by_type = {}
detections_by_type: dict[str, list[Detection]] = {}
for obj in detections:
label = obj[0]
if label not in detections_by_type:
@ -551,17 +557,17 @@ class NorfairTracker(ObjectTracker):
estimate = (
max(0, estimate[0]),
max(0, estimate[1]),
min(self.detect_config.width - 1, estimate[2]),
min(self.detect_config.height - 1, estimate[3]),
min(self.detect_config.width - 1, estimate[2]), # type: ignore[operator]
min(self.detect_config.height - 1, estimate[3]), # type: ignore[operator]
)
obj = {
new_obj = {
**t.last_detection.data,
"estimate": estimate,
"estimate_velocity": t.estimate_velocity,
}
active_ids.append(t.global_id)
if t.global_id not in self.track_id_map:
self.register(t.global_id, obj)
self.register(t.global_id, new_obj)
# if there wasn't a detection in this frame, increment disappeared
elif t.last_detection.data["frame_time"] != frame_time:
id = self.track_id_map[t.global_id]
@ -569,10 +575,10 @@ class NorfairTracker(ObjectTracker):
# sometimes the estimate gets way off
# only update if the upper left corner is actually upper left
if estimate[0] < estimate[2] and estimate[1] < estimate[3]:
self.tracked_objects[id]["estimate"] = obj["estimate"]
self.tracked_objects[id]["estimate"] = new_obj["estimate"]
# else update it
else:
self.update(t.global_id, obj)
self.update(t.global_id, new_obj)
# clear expired tracks
expired_ids = [k for k in self.track_id_map.keys() if k not in active_ids]
@ -585,7 +591,7 @@ class NorfairTracker(ObjectTracker):
o[2] for o in detections if o[2] not in tracked_object_boxes
]
def print_objects_as_table(self, tracked_objects: Sequence):
def print_objects_as_table(self, tracked_objects: Sequence) -> None:
"""Used for helping in debugging"""
print()
console = Console()
@ -605,13 +611,13 @@ class NorfairTracker(ObjectTracker):
)
console.print(table)
def debug_draw(self, frame, frame_time):
def debug_draw(self, frame: np.ndarray, frame_time: float) -> None:
# Collect all tracked objects from each tracker
all_tracked_objects = []
# print a table to the console with norfair tracked object info
if False:
if len(self.trackers["license_plate"]["static"].tracked_objects) > 0:
if len(self.trackers["license_plate"]["static"].tracked_objects) > 0: # type: ignore[unreachable]
self.print_objects_as_table(
self.trackers["license_plate"]["static"].tracked_objects
)
@ -662,7 +668,7 @@ class NorfairTracker(ObjectTracker):
if False:
# draw the current formatted time on the frame
from datetime import datetime
from datetime import datetime # type: ignore[unreachable]
formatted_time = datetime.fromtimestamp(frame_time).strftime(
"%m/%d/%Y %I:%M:%S %p"

View File

@ -6,6 +6,7 @@ import queue
import threading
from collections import defaultdict
from enum import Enum
from multiprocessing import Queue as MpQueue
from multiprocessing.synchronize import Event as MpEvent
from typing import Any
@ -39,6 +40,7 @@ from frigate.const import (
)
from frigate.events.types import EventStateEnum, EventTypeEnum
from frigate.models import Event, ReviewSegment, Timeline
from frigate.ptz.autotrack import PtzAutoTrackerThread
from frigate.track.tracked_object import TrackedObject
from frigate.util.image import SharedMemoryFrameManager
@ -56,10 +58,10 @@ class TrackedObjectProcessor(threading.Thread):
self,
config: FrigateConfig,
dispatcher: Dispatcher,
tracked_objects_queue,
ptz_autotracker_thread,
stop_event,
):
tracked_objects_queue: MpQueue,
ptz_autotracker_thread: PtzAutoTrackerThread,
stop_event: MpEvent,
) -> None:
super().__init__(name="detected_frames_processor")
self.config = config
self.dispatcher = dispatcher
@ -98,8 +100,12 @@ class TrackedObjectProcessor(threading.Thread):
# }
# }
# }
self.zone_data = defaultdict(lambda: defaultdict(dict))
self.active_zone_data = defaultdict(lambda: defaultdict(dict))
self.zone_data: dict[str, dict[str, Any]] = defaultdict(
lambda: defaultdict(dict)
)
self.active_zone_data: dict[str, dict[str, Any]] = defaultdict(
lambda: defaultdict(dict)
)
for camera in self.config.cameras.keys():
self.create_camera_state(camera)
@ -107,7 +113,7 @@ class TrackedObjectProcessor(threading.Thread):
def create_camera_state(self, camera: str) -> None:
"""Creates a new camera state."""
def start(camera: str, obj: TrackedObject, frame_name: str):
def start(camera: str, obj: TrackedObject, frame_name: str) -> None:
self.event_sender.publish(
(
EventTypeEnum.tracked_object,
@ -118,7 +124,7 @@ class TrackedObjectProcessor(threading.Thread):
)
)
def update(camera: str, obj: TrackedObject, frame_name: str):
def update(camera: str, obj: TrackedObject, frame_name: str) -> None:
obj.has_snapshot = self.should_save_snapshot(camera, obj)
obj.has_clip = self.should_retain_recording(camera, obj)
after = obj.to_dict()
@ -139,10 +145,10 @@ class TrackedObjectProcessor(threading.Thread):
)
)
def autotrack(camera: str, obj: TrackedObject, frame_name: str):
def autotrack(camera: str, obj: TrackedObject, frame_name: str) -> None:
self.ptz_autotracker_thread.ptz_autotracker.autotrack_object(camera, obj)
def end(camera: str, obj: TrackedObject, frame_name: str):
def end(camera: str, obj: TrackedObject, frame_name: str) -> None:
# populate has_snapshot
obj.has_snapshot = self.should_save_snapshot(camera, obj)
obj.has_clip = self.should_retain_recording(camera, obj)
@ -211,7 +217,7 @@ class TrackedObjectProcessor(threading.Thread):
return False
def camera_activity(camera, activity):
def camera_activity(camera: str, activity: dict[str, Any]) -> None:
last_activity = self.camera_activity.get(camera)
if not last_activity or activity != last_activity:
@ -229,7 +235,7 @@ class TrackedObjectProcessor(threading.Thread):
camera_state.on("camera_activity", camera_activity)
self.camera_states[camera] = camera_state
def should_save_snapshot(self, camera, obj: TrackedObject):
def should_save_snapshot(self, camera: str, obj: TrackedObject) -> bool:
if obj.false_positive:
return False
@ -252,7 +258,7 @@ class TrackedObjectProcessor(threading.Thread):
return True
def should_retain_recording(self, camera: str, obj: TrackedObject):
def should_retain_recording(self, camera: str, obj: TrackedObject) -> bool:
if obj.false_positive:
return False
@ -272,7 +278,7 @@ class TrackedObjectProcessor(threading.Thread):
return True
def should_mqtt_snapshot(self, camera, obj: TrackedObject):
def should_mqtt_snapshot(self, camera: str, obj: TrackedObject) -> bool:
# object never changed position
if obj.is_stationary():
return False
@ -287,7 +293,9 @@ class TrackedObjectProcessor(threading.Thread):
return True
def update_mqtt_motion(self, camera, frame_time, motion_boxes):
def update_mqtt_motion(
self, camera: str, frame_time: float, motion_boxes: list
) -> None:
# publish if motion is currently being detected
if motion_boxes:
# only send ON if motion isn't already active
@ -313,11 +321,15 @@ class TrackedObjectProcessor(threading.Thread):
# reset the last_motion so redundant `off` commands aren't sent
self.last_motion_detected[camera] = 0
def get_best(self, camera, label):
def get_best(self, camera: str, label: str) -> dict[str, Any]:
# TODO: need a lock here
camera_state = self.camera_states[camera]
if label in camera_state.best_objects:
best_obj = camera_state.best_objects[label]
if not best_obj.thumbnail_data:
return {}
best = best_obj.thumbnail_data.copy()
best["frame"] = camera_state.frame_cache.get(
best_obj.thumbnail_data["frame_time"]
@ -340,7 +352,7 @@ class TrackedObjectProcessor(threading.Thread):
return self.camera_states[camera].get_current_frame(draw_options)
def get_current_frame_time(self, camera) -> int:
def get_current_frame_time(self, camera: str) -> float:
"""Returns the latest frame time for a given camera."""
return self.camera_states[camera].current_frame_time
@ -348,7 +360,7 @@ class TrackedObjectProcessor(threading.Thread):
self, event_id: str, sub_label: str | None, score: float | None
) -> None:
"""Update sub label for given event id."""
tracked_obj: TrackedObject = None
tracked_obj: TrackedObject | None = None
for state in self.camera_states.values():
tracked_obj = state.tracked_objects.get(event_id)
@ -357,7 +369,7 @@ class TrackedObjectProcessor(threading.Thread):
break
try:
event: Event = Event.get(Event.id == event_id)
event: Event | None = Event.get(Event.id == event_id)
except DoesNotExist:
event = None
@ -368,7 +380,7 @@ class TrackedObjectProcessor(threading.Thread):
tracked_obj.obj_data["sub_label"] = (sub_label, score)
if event:
event.sub_label = sub_label
event.sub_label = sub_label # type: ignore[assignment]
data = event.data
if sub_label is None:
data["sub_label_score"] = None
@ -402,7 +414,7 @@ class TrackedObjectProcessor(threading.Thread):
objects_list = []
sub_labels = set()
events = Event.select(Event.id, Event.label, Event.sub_label).where(
Event.id.in_(detection_ids)
Event.id.in_(detection_ids) # type: ignore[call-arg, misc]
)
for det_event in events:
if det_event.sub_label:
@ -431,13 +443,11 @@ class TrackedObjectProcessor(threading.Thread):
f"Updated sub_label for event {event_id} in review segment {review_segment.id}"
)
except ReviewSegment.DoesNotExist:
except DoesNotExist:
logger.debug(
f"No review segment found with event ID {event_id} when updating sub_label"
)
return True
def set_object_attribute(
self,
event_id: str,
@ -446,7 +456,7 @@ class TrackedObjectProcessor(threading.Thread):
score: float | None,
) -> None:
"""Update attribute for given event id."""
tracked_obj: TrackedObject = None
tracked_obj: TrackedObject | None = None
for state in self.camera_states.values():
tracked_obj = state.tracked_objects.get(event_id)
@ -455,7 +465,7 @@ class TrackedObjectProcessor(threading.Thread):
break
try:
event: Event = Event.get(Event.id == event_id)
event: Event | None = Event.get(Event.id == event_id)
except DoesNotExist:
event = None
@ -478,8 +488,6 @@ class TrackedObjectProcessor(threading.Thread):
event.data = data
event.save()
return True
def save_lpr_snapshot(self, payload: tuple) -> None:
# save the snapshot image
(frame, event_id, camera) = payload
@ -638,7 +646,7 @@ class TrackedObjectProcessor(threading.Thread):
)
self.ongoing_manual_events.pop(event_id)
def force_end_all_events(self, camera: str, camera_state: CameraState):
def force_end_all_events(self, camera: str, camera_state: CameraState) -> None:
"""Ends all active events on camera when disabling."""
last_frame_name = camera_state.previous_frame_id
for obj_id, obj in list(camera_state.tracked_objects.items()):
@ -656,7 +664,7 @@ class TrackedObjectProcessor(threading.Thread):
{"enabled": False, "motion": 0, "objects": []},
)
def run(self):
def run(self) -> None:
while not self.stop_event.is_set():
# check for config updates
updated_topics = self.camera_config_subscriber.check_for_updates()
@ -698,11 +706,16 @@ class TrackedObjectProcessor(threading.Thread):
# check for sub label updates
while True:
(raw_topic, payload) = self.sub_label_subscriber.check_for_update(
update = self.sub_label_subscriber.check_for_update(
timeout=0
)
if not raw_topic:
if not update:
break
(raw_topic, payload) = update
if not raw_topic or not payload:
break
topic = str(raw_topic)