Cleanup object detection mypy

This commit is contained in:
Nicolas Mowen 2026-03-25 11:18:58 -06:00
parent 706ba0c9c7
commit f1c8c69448
4 changed files with 59 additions and 49 deletions

View File

@ -317,7 +317,7 @@ class MemryXDetector(DetectionApi):
f"Failed to remove downloaded zip {zip_path}: {e}" f"Failed to remove downloaded zip {zip_path}: {e}"
) )
def send_input(self, connection_id, tensor_input: np.ndarray): def send_input(self, connection_id, tensor_input: np.ndarray) -> None:
"""Pre-process (if needed) and send frame to MemryX input queue""" """Pre-process (if needed) and send frame to MemryX input queue"""
if tensor_input is None: if tensor_input is None:
raise ValueError("[send_input] No image data provided for inference") raise ValueError("[send_input] No image data provided for inference")

View File

@ -50,7 +50,7 @@ ignore_errors = false
[mypy-frigate.motion.*] [mypy-frigate.motion.*]
ignore_errors = false ignore_errors = false
[mypy-frigate.object_detection] [mypy-frigate.object_detection.*]
ignore_errors = false ignore_errors = false
[mypy-frigate.output] [mypy-frigate.output]

View File

@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
from collections import deque from collections import deque
from multiprocessing import Queue, Value from multiprocessing import Queue, Value
from multiprocessing.synchronize import Event as MpEvent from multiprocessing.synchronize import Event as MpEvent
from typing import Any, Optional
import numpy as np import numpy as np
import zmq import zmq
@ -34,26 +35,25 @@ logger = logging.getLogger(__name__)
class ObjectDetector(ABC): class ObjectDetector(ABC):
@abstractmethod @abstractmethod
def detect(self, tensor_input, threshold: float = 0.4): def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
pass pass
class BaseLocalDetector(ObjectDetector): class BaseLocalDetector(ObjectDetector):
def __init__( def __init__(
self, self,
detector_config: BaseDetectorConfig = None, detector_config: Optional[BaseDetectorConfig] = None,
labels: str = None, labels: Optional[str] = None,
stop_event: MpEvent = None, stop_event: Optional[MpEvent] = None,
): ) -> None:
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
if labels is None: if labels is None:
self.labels = {} self.labels: dict[int, str] = {}
else: else:
self.labels = load_labels(labels) self.labels = load_labels(labels)
if detector_config: if detector_config and detector_config.model:
self.input_transform = tensor_transform(detector_config.model.input_tensor) self.input_transform = tensor_transform(detector_config.model.input_tensor)
self.dtype = detector_config.model.input_dtype self.dtype = detector_config.model.input_dtype
else: else:
self.input_transform = None self.input_transform = None
@ -77,10 +77,10 @@ class BaseLocalDetector(ObjectDetector):
return tensor_input return tensor_input
def detect(self, tensor_input: np.ndarray, threshold=0.4): def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
detections = [] detections = []
raw_detections = self.detect_raw(tensor_input) raw_detections = self.detect_raw(tensor_input) # type: ignore[attr-defined]
for d in raw_detections: for d in raw_detections:
if int(d[0]) < 0 or int(d[0]) >= len(self.labels): if int(d[0]) < 0 or int(d[0]) >= len(self.labels):
@ -96,28 +96,28 @@ class BaseLocalDetector(ObjectDetector):
class LocalObjectDetector(BaseLocalDetector): class LocalObjectDetector(BaseLocalDetector):
def detect_raw(self, tensor_input: np.ndarray): def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
tensor_input = self._transform_input(tensor_input) tensor_input = self._transform_input(tensor_input)
return self.detect_api.detect_raw(tensor_input=tensor_input) return self.detect_api.detect_raw(tensor_input=tensor_input) # type: ignore[no-any-return]
class AsyncLocalObjectDetector(BaseLocalDetector): class AsyncLocalObjectDetector(BaseLocalDetector):
def async_send_input(self, tensor_input: np.ndarray, connection_id: str): def async_send_input(self, tensor_input: np.ndarray, connection_id: str) -> None:
tensor_input = self._transform_input(tensor_input) tensor_input = self._transform_input(tensor_input)
return self.detect_api.send_input(connection_id, tensor_input) self.detect_api.send_input(connection_id, tensor_input)
def async_receive_output(self): def async_receive_output(self) -> Any:
return self.detect_api.receive_output() return self.detect_api.receive_output()
class DetectorRunner(FrigateProcess): class DetectorRunner(FrigateProcess):
def __init__( def __init__(
self, self,
name, name: str,
detection_queue: Queue, detection_queue: Queue,
cameras: list[str], cameras: list[str],
avg_speed: Value, avg_speed: Any,
start_time: Value, start_time: Any,
config: FrigateConfig, config: FrigateConfig,
detector_config: BaseDetectorConfig, detector_config: BaseDetectorConfig,
stop_event: MpEvent, stop_event: MpEvent,
@ -129,11 +129,11 @@ class DetectorRunner(FrigateProcess):
self.start_time = start_time self.start_time = start_time
self.config = config self.config = config
self.detector_config = detector_config self.detector_config = detector_config
self.outputs: dict = {} self.outputs: dict[str, Any] = {}
def create_output_shm(self, name: str): def create_output_shm(self, name: str) -> None:
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False) out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
self.outputs[name] = {"shm": out_shm, "np": out_np} self.outputs[name] = {"shm": out_shm, "np": out_np}
def run(self) -> None: def run(self) -> None:
@ -155,8 +155,8 @@ class DetectorRunner(FrigateProcess):
connection_id, connection_id,
( (
1, 1,
self.detector_config.model.height, self.detector_config.model.height, # type: ignore[union-attr]
self.detector_config.model.width, self.detector_config.model.width, # type: ignore[union-attr]
3, 3,
), ),
) )
@ -187,11 +187,11 @@ class DetectorRunner(FrigateProcess):
class AsyncDetectorRunner(FrigateProcess): class AsyncDetectorRunner(FrigateProcess):
def __init__( def __init__(
self, self,
name, name: str,
detection_queue: Queue, detection_queue: Queue,
cameras: list[str], cameras: list[str],
avg_speed: Value, avg_speed: Any,
start_time: Value, start_time: Any,
config: FrigateConfig, config: FrigateConfig,
detector_config: BaseDetectorConfig, detector_config: BaseDetectorConfig,
stop_event: MpEvent, stop_event: MpEvent,
@ -203,15 +203,15 @@ class AsyncDetectorRunner(FrigateProcess):
self.start_time = start_time self.start_time = start_time
self.config = config self.config = config
self.detector_config = detector_config self.detector_config = detector_config
self.outputs: dict = {} self.outputs: dict[str, Any] = {}
self._frame_manager: SharedMemoryFrameManager | None = None self._frame_manager: SharedMemoryFrameManager | None = None
self._publisher: ObjectDetectorPublisher | None = None self._publisher: ObjectDetectorPublisher | None = None
self._detector: AsyncLocalObjectDetector | None = None self._detector: AsyncLocalObjectDetector | None = None
self.send_times = deque() self.send_times: deque[float] = deque()
def create_output_shm(self, name: str): def create_output_shm(self, name: str) -> None:
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False) out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
self.outputs[name] = {"shm": out_shm, "np": out_np} self.outputs[name] = {"shm": out_shm, "np": out_np}
def _detect_worker(self) -> None: def _detect_worker(self) -> None:
@ -222,12 +222,13 @@ class AsyncDetectorRunner(FrigateProcess):
except queue.Empty: except queue.Empty:
continue continue
assert self._frame_manager is not None
input_frame = self._frame_manager.get( input_frame = self._frame_manager.get(
connection_id, connection_id,
( (
1, 1,
self.detector_config.model.height, self.detector_config.model.height, # type: ignore[union-attr]
self.detector_config.model.width, self.detector_config.model.width, # type: ignore[union-attr]
3, 3,
), ),
) )
@ -238,11 +239,13 @@ class AsyncDetectorRunner(FrigateProcess):
# mark start time and send to accelerator # mark start time and send to accelerator
self.send_times.append(time.perf_counter()) self.send_times.append(time.perf_counter())
assert self._detector is not None
self._detector.async_send_input(input_frame, connection_id) self._detector.async_send_input(input_frame, connection_id)
def _result_worker(self) -> None: def _result_worker(self) -> None:
logger.info("Starting Result Worker Thread") logger.info("Starting Result Worker Thread")
while not self.stop_event.is_set(): while not self.stop_event.is_set():
assert self._detector is not None
connection_id, detections = self._detector.async_receive_output() connection_id, detections = self._detector.async_receive_output()
# Handle timeout case (queue.Empty) - just continue # Handle timeout case (queue.Empty) - just continue
@ -256,6 +259,7 @@ class AsyncDetectorRunner(FrigateProcess):
duration = time.perf_counter() - ts duration = time.perf_counter() - ts
# release input buffer # release input buffer
assert self._frame_manager is not None
self._frame_manager.close(connection_id) self._frame_manager.close(connection_id)
if connection_id not in self.outputs: if connection_id not in self.outputs:
@ -264,6 +268,7 @@ class AsyncDetectorRunner(FrigateProcess):
# write results and publish # write results and publish
if detections is not None: if detections is not None:
self.outputs[connection_id]["np"][:] = detections[:] self.outputs[connection_id]["np"][:] = detections[:]
assert self._publisher is not None
self._publisher.publish(connection_id) self._publisher.publish(connection_id)
# update timers # update timers
@ -330,11 +335,14 @@ class ObjectDetectProcess:
self.stop_event = stop_event self.stop_event = stop_event
self.start_or_restart() self.start_or_restart()
def stop(self): def stop(self) -> None:
# if the process has already exited on its own, just return # if the process has already exited on its own, just return
if self.detect_process and self.detect_process.exitcode: if self.detect_process and self.detect_process.exitcode:
return return
if self.detect_process is None:
return
logging.info("Waiting for detection process to exit gracefully...") logging.info("Waiting for detection process to exit gracefully...")
self.detect_process.join(timeout=30) self.detect_process.join(timeout=30)
if self.detect_process.exitcode is None: if self.detect_process.exitcode is None:
@ -343,8 +351,8 @@ class ObjectDetectProcess:
self.detect_process.join() self.detect_process.join()
logging.info("Detection process has exited...") logging.info("Detection process has exited...")
def start_or_restart(self): def start_or_restart(self) -> None:
self.detection_start.value = 0.0 self.detection_start.value = 0.0 # type: ignore[attr-defined]
if (self.detect_process is not None) and self.detect_process.is_alive(): if (self.detect_process is not None) and self.detect_process.is_alive():
self.stop() self.stop()
@ -389,17 +397,19 @@ class RemoteObjectDetector:
self.detection_queue = detection_queue self.detection_queue = detection_queue
self.stop_event = stop_event self.stop_event = stop_event
self.shm = UntrackedSharedMemory(name=self.name, create=False) self.shm = UntrackedSharedMemory(name=self.name, create=False)
self.np_shm = np.ndarray( self.np_shm: np.ndarray = np.ndarray(
(1, model_config.height, model_config.width, 3), (1, model_config.height, model_config.width, 3),
dtype=np.uint8, dtype=np.uint8,
buffer=self.shm.buf, buffer=self.shm.buf,
) )
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False) self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf) self.out_np_shm: np.ndarray = np.ndarray(
(20, 6), dtype=np.float32, buffer=self.out_shm.buf
)
self.detector_subscriber = ObjectDetectorSubscriber(name) self.detector_subscriber = ObjectDetectorSubscriber(name)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
detections = [] detections: list = []
if self.stop_event.is_set(): if self.stop_event.is_set():
return detections return detections
@ -431,7 +441,7 @@ class RemoteObjectDetector:
self.fps.update() self.fps.update()
return detections return detections
def cleanup(self): def cleanup(self) -> None:
self.detector_subscriber.stop() self.detector_subscriber.stop()
self.shm.unlink() self.shm.unlink()
self.out_shm.unlink() self.out_shm.unlink()

View File

@ -13,10 +13,10 @@ class RequestStore:
A thread-safe hash-based response store that handles creating requests. A thread-safe hash-based response store that handles creating requests.
""" """
def __init__(self): def __init__(self) -> None:
self.request_counter = 0 self.request_counter = 0
self.request_counter_lock = threading.Lock() self.request_counter_lock = threading.Lock()
self.input_queue = queue.Queue() self.input_queue: queue.Queue[tuple[int, ndarray]] = queue.Queue()
def __get_request_id(self) -> int: def __get_request_id(self) -> int:
with self.request_counter_lock: with self.request_counter_lock:
@ -45,17 +45,17 @@ class ResponseStore:
their request's result appears. their request's result appears.
""" """
def __init__(self): def __init__(self) -> None:
self.responses = {} # Maps request_id -> (original_input, infer_results) self.responses: dict[int, ndarray] = {} # Maps request_id -> (original_input, infer_results)
self.lock = threading.Lock() self.lock = threading.Lock()
self.cond = threading.Condition(self.lock) self.cond = threading.Condition(self.lock)
def put(self, request_id: int, response: ndarray): def put(self, request_id: int, response: ndarray) -> None:
with self.cond: with self.cond:
self.responses[request_id] = response self.responses[request_id] = response
self.cond.notify_all() self.cond.notify_all()
def get(self, request_id: int, timeout=None) -> ndarray: def get(self, request_id: int, timeout: float | None = None) -> ndarray:
with self.cond: with self.cond:
if not self.cond.wait_for( if not self.cond.wait_for(
lambda: request_id in self.responses, timeout=timeout lambda: request_id in self.responses, timeout=timeout
@ -65,7 +65,7 @@ class ResponseStore:
return self.responses.pop(request_id) return self.responses.pop(request_id)
def tensor_transform(desired_shape: InputTensorEnum): def tensor_transform(desired_shape: InputTensorEnum) -> tuple[int, int, int, int] | None:
# Currently this function only supports BHWC permutations # Currently this function only supports BHWC permutations
if desired_shape == InputTensorEnum.nhwc: if desired_shape == InputTensorEnum.nhwc:
return None return None