From f1c8c69448dfe64b26fab1d3437b32e41b8f7c6a Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 25 Mar 2026 11:18:58 -0600 Subject: [PATCH] Cleanup object detection mypy --- frigate/detectors/plugins/memryx.py | 2 +- frigate/mypy.ini | 2 +- frigate/object_detection/base.py | 90 ++++++++++++++++------------- frigate/object_detection/util.py | 14 ++--- 4 files changed, 59 insertions(+), 49 deletions(-) diff --git a/frigate/detectors/plugins/memryx.py b/frigate/detectors/plugins/memryx.py index e0ad401cb..2c03d14a4 100644 --- a/frigate/detectors/plugins/memryx.py +++ b/frigate/detectors/plugins/memryx.py @@ -317,7 +317,7 @@ class MemryXDetector(DetectionApi): f"Failed to remove downloaded zip {zip_path}: {e}" ) - def send_input(self, connection_id, tensor_input: np.ndarray): + def send_input(self, connection_id, tensor_input: np.ndarray) -> None: """Pre-process (if needed) and send frame to MemryX input queue""" if tensor_input is None: raise ValueError("[send_input] No image data provided for inference") diff --git a/frigate/mypy.ini b/frigate/mypy.ini index 2341dc629..e4ee5c796 100644 --- a/frigate/mypy.ini +++ b/frigate/mypy.ini @@ -50,7 +50,7 @@ ignore_errors = false [mypy-frigate.motion.*] ignore_errors = false -[mypy-frigate.object_detection] +[mypy-frigate.object_detection.*] ignore_errors = false [mypy-frigate.output] diff --git a/frigate/object_detection/base.py b/frigate/object_detection/base.py index d2a54afbc..a62fe4843 100644 --- a/frigate/object_detection/base.py +++ b/frigate/object_detection/base.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections import deque from multiprocessing import Queue, Value from multiprocessing.synchronize import Event as MpEvent +from typing import Any, Optional import numpy as np import zmq @@ -34,26 +35,25 @@ logger = logging.getLogger(__name__) class ObjectDetector(ABC): @abstractmethod - def detect(self, tensor_input, threshold: float = 0.4): + def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list: pass class BaseLocalDetector(ObjectDetector): def __init__( self, - detector_config: BaseDetectorConfig = None, - labels: str = None, - stop_event: MpEvent = None, - ): + detector_config: Optional[BaseDetectorConfig] = None, + labels: Optional[str] = None, + stop_event: Optional[MpEvent] = None, + ) -> None: self.fps = EventsPerSecond() if labels is None: - self.labels = {} + self.labels: dict[int, str] = {} else: self.labels = load_labels(labels) - if detector_config: + if detector_config and detector_config.model: self.input_transform = tensor_transform(detector_config.model.input_tensor) - self.dtype = detector_config.model.input_dtype else: self.input_transform = None @@ -77,10 +77,10 @@ class BaseLocalDetector(ObjectDetector): return tensor_input - def detect(self, tensor_input: np.ndarray, threshold=0.4): + def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list: detections = [] - raw_detections = self.detect_raw(tensor_input) + raw_detections = self.detect_raw(tensor_input) # type: ignore[attr-defined] for d in raw_detections: if int(d[0]) < 0 or int(d[0]) >= len(self.labels): @@ -96,28 +96,28 @@ class BaseLocalDetector(ObjectDetector): class LocalObjectDetector(BaseLocalDetector): - def detect_raw(self, tensor_input: np.ndarray): + def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray: tensor_input = self._transform_input(tensor_input) - return self.detect_api.detect_raw(tensor_input=tensor_input) + return self.detect_api.detect_raw(tensor_input=tensor_input) # type: ignore[no-any-return] class AsyncLocalObjectDetector(BaseLocalDetector): - def async_send_input(self, tensor_input: np.ndarray, connection_id: str): + def async_send_input(self, tensor_input: np.ndarray, connection_id: str) -> None: tensor_input = self._transform_input(tensor_input) - return self.detect_api.send_input(connection_id, tensor_input) + self.detect_api.send_input(connection_id, tensor_input) - def async_receive_output(self): + def async_receive_output(self) -> Any: return self.detect_api.receive_output() class DetectorRunner(FrigateProcess): def __init__( self, - name, + name: str, detection_queue: Queue, cameras: list[str], - avg_speed: Value, - start_time: Value, + avg_speed: Any, + start_time: Any, config: FrigateConfig, detector_config: BaseDetectorConfig, stop_event: MpEvent, @@ -129,11 +129,11 @@ class DetectorRunner(FrigateProcess): self.start_time = start_time self.config = config self.detector_config = detector_config - self.outputs: dict = {} + self.outputs: dict[str, Any] = {} - def create_output_shm(self, name: str): + def create_output_shm(self, name: str) -> None: out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False) - out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) + out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) self.outputs[name] = {"shm": out_shm, "np": out_np} def run(self) -> None: @@ -155,8 +155,8 @@ class DetectorRunner(FrigateProcess): connection_id, ( 1, - self.detector_config.model.height, - self.detector_config.model.width, + self.detector_config.model.height, # type: ignore[union-attr] + self.detector_config.model.width, # type: ignore[union-attr] 3, ), ) @@ -187,11 +187,11 @@ class DetectorRunner(FrigateProcess): class AsyncDetectorRunner(FrigateProcess): def __init__( self, - name, + name: str, detection_queue: Queue, cameras: list[str], - avg_speed: Value, - start_time: Value, + avg_speed: Any, + start_time: Any, config: FrigateConfig, detector_config: BaseDetectorConfig, stop_event: MpEvent, @@ -203,15 +203,15 @@ class AsyncDetectorRunner(FrigateProcess): self.start_time = start_time self.config = config self.detector_config = detector_config - self.outputs: dict = {} + self.outputs: dict[str, Any] = {} self._frame_manager: SharedMemoryFrameManager | None = None self._publisher: ObjectDetectorPublisher | None = None self._detector: AsyncLocalObjectDetector | None = None - self.send_times = deque() + self.send_times: deque[float] = deque() - def create_output_shm(self, name: str): + def create_output_shm(self, name: str) -> None: out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False) - out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) + out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) self.outputs[name] = {"shm": out_shm, "np": out_np} def _detect_worker(self) -> None: @@ -222,12 +222,13 @@ class AsyncDetectorRunner(FrigateProcess): except queue.Empty: continue + assert self._frame_manager is not None input_frame = self._frame_manager.get( connection_id, ( 1, - self.detector_config.model.height, - self.detector_config.model.width, + self.detector_config.model.height, # type: ignore[union-attr] + self.detector_config.model.width, # type: ignore[union-attr] 3, ), ) @@ -238,11 +239,13 @@ class AsyncDetectorRunner(FrigateProcess): # mark start time and send to accelerator self.send_times.append(time.perf_counter()) + assert self._detector is not None self._detector.async_send_input(input_frame, connection_id) def _result_worker(self) -> None: logger.info("Starting Result Worker Thread") while not self.stop_event.is_set(): + assert self._detector is not None connection_id, detections = self._detector.async_receive_output() # Handle timeout case (queue.Empty) - just continue @@ -256,6 +259,7 @@ class AsyncDetectorRunner(FrigateProcess): duration = time.perf_counter() - ts # release input buffer + assert self._frame_manager is not None self._frame_manager.close(connection_id) if connection_id not in self.outputs: @@ -264,6 +268,7 @@ class AsyncDetectorRunner(FrigateProcess): # write results and publish if detections is not None: self.outputs[connection_id]["np"][:] = detections[:] + assert self._publisher is not None self._publisher.publish(connection_id) # update timers @@ -330,11 +335,14 @@ class ObjectDetectProcess: self.stop_event = stop_event self.start_or_restart() - def stop(self): + def stop(self) -> None: # if the process has already exited on its own, just return if self.detect_process and self.detect_process.exitcode: return + if self.detect_process is None: + return + logging.info("Waiting for detection process to exit gracefully...") self.detect_process.join(timeout=30) if self.detect_process.exitcode is None: @@ -343,8 +351,8 @@ class ObjectDetectProcess: self.detect_process.join() logging.info("Detection process has exited...") - def start_or_restart(self): - self.detection_start.value = 0.0 + def start_or_restart(self) -> None: + self.detection_start.value = 0.0 # type: ignore[attr-defined] if (self.detect_process is not None) and self.detect_process.is_alive(): self.stop() @@ -389,17 +397,19 @@ class RemoteObjectDetector: self.detection_queue = detection_queue self.stop_event = stop_event self.shm = UntrackedSharedMemory(name=self.name, create=False) - self.np_shm = np.ndarray( + self.np_shm: np.ndarray = np.ndarray( (1, model_config.height, model_config.width, 3), dtype=np.uint8, buffer=self.shm.buf, ) self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False) - self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf) + self.out_np_shm: np.ndarray = np.ndarray( + (20, 6), dtype=np.float32, buffer=self.out_shm.buf + ) self.detector_subscriber = ObjectDetectorSubscriber(name) - def detect(self, tensor_input, threshold=0.4): - detections = [] + def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list: + detections: list = [] if self.stop_event.is_set(): return detections @@ -431,7 +441,7 @@ class RemoteObjectDetector: self.fps.update() return detections - def cleanup(self): + def cleanup(self) -> None: self.detector_subscriber.stop() self.shm.unlink() self.out_shm.unlink() diff --git a/frigate/object_detection/util.py b/frigate/object_detection/util.py index ea8bd4226..604eddf34 100644 --- a/frigate/object_detection/util.py +++ b/frigate/object_detection/util.py @@ -13,10 +13,10 @@ class RequestStore: A thread-safe hash-based response store that handles creating requests. """ - def __init__(self): + def __init__(self) -> None: self.request_counter = 0 self.request_counter_lock = threading.Lock() - self.input_queue = queue.Queue() + self.input_queue: queue.Queue[tuple[int, ndarray]] = queue.Queue() def __get_request_id(self) -> int: with self.request_counter_lock: @@ -45,17 +45,17 @@ class ResponseStore: their request's result appears. """ - def __init__(self): - self.responses = {} # Maps request_id -> (original_input, infer_results) + def __init__(self) -> None: + self.responses: dict[int, ndarray] = {} # Maps request_id -> (original_input, infer_results) self.lock = threading.Lock() self.cond = threading.Condition(self.lock) - def put(self, request_id: int, response: ndarray): + def put(self, request_id: int, response: ndarray) -> None: with self.cond: self.responses[request_id] = response self.cond.notify_all() - def get(self, request_id: int, timeout=None) -> ndarray: + def get(self, request_id: int, timeout: float | None = None) -> ndarray: with self.cond: if not self.cond.wait_for( lambda: request_id in self.responses, timeout=timeout @@ -65,7 +65,7 @@ class ResponseStore: return self.responses.pop(request_id) -def tensor_transform(desired_shape: InputTensorEnum): +def tensor_transform(desired_shape: InputTensorEnum) -> tuple[int, int, int, int] | None: # Currently this function only supports BHWC permutations if desired_shape == InputTensorEnum.nhwc: return None