Cleanup object detection mypy

This commit is contained in:
Nicolas Mowen 2026-03-25 11:18:58 -06:00
parent 706ba0c9c7
commit f1c8c69448
4 changed files with 59 additions and 49 deletions

View File

@ -317,7 +317,7 @@ class MemryXDetector(DetectionApi):
f"Failed to remove downloaded zip {zip_path}: {e}"
)
def send_input(self, connection_id, tensor_input: np.ndarray):
def send_input(self, connection_id, tensor_input: np.ndarray) -> None:
"""Pre-process (if needed) and send frame to MemryX input queue"""
if tensor_input is None:
raise ValueError("[send_input] No image data provided for inference")

View File

@ -50,7 +50,7 @@ ignore_errors = false
[mypy-frigate.motion.*]
ignore_errors = false
[mypy-frigate.object_detection]
[mypy-frigate.object_detection.*]
ignore_errors = false
[mypy-frigate.output]

View File

@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
from collections import deque
from multiprocessing import Queue, Value
from multiprocessing.synchronize import Event as MpEvent
from typing import Any, Optional
import numpy as np
import zmq
@ -34,26 +35,25 @@ logger = logging.getLogger(__name__)
class ObjectDetector(ABC):
@abstractmethod
def detect(self, tensor_input, threshold: float = 0.4):
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
pass
class BaseLocalDetector(ObjectDetector):
def __init__(
self,
detector_config: BaseDetectorConfig = None,
labels: str = None,
stop_event: MpEvent = None,
):
detector_config: Optional[BaseDetectorConfig] = None,
labels: Optional[str] = None,
stop_event: Optional[MpEvent] = None,
) -> None:
self.fps = EventsPerSecond()
if labels is None:
self.labels = {}
self.labels: dict[int, str] = {}
else:
self.labels = load_labels(labels)
if detector_config:
if detector_config and detector_config.model:
self.input_transform = tensor_transform(detector_config.model.input_tensor)
self.dtype = detector_config.model.input_dtype
else:
self.input_transform = None
@ -77,10 +77,10 @@ class BaseLocalDetector(ObjectDetector):
return tensor_input
def detect(self, tensor_input: np.ndarray, threshold=0.4):
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
detections = []
raw_detections = self.detect_raw(tensor_input)
raw_detections = self.detect_raw(tensor_input) # type: ignore[attr-defined]
for d in raw_detections:
if int(d[0]) < 0 or int(d[0]) >= len(self.labels):
@ -96,28 +96,28 @@ class BaseLocalDetector(ObjectDetector):
class LocalObjectDetector(BaseLocalDetector):
def detect_raw(self, tensor_input: np.ndarray):
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
tensor_input = self._transform_input(tensor_input)
return self.detect_api.detect_raw(tensor_input=tensor_input)
return self.detect_api.detect_raw(tensor_input=tensor_input) # type: ignore[no-any-return]
class AsyncLocalObjectDetector(BaseLocalDetector):
def async_send_input(self, tensor_input: np.ndarray, connection_id: str):
def async_send_input(self, tensor_input: np.ndarray, connection_id: str) -> None:
tensor_input = self._transform_input(tensor_input)
return self.detect_api.send_input(connection_id, tensor_input)
self.detect_api.send_input(connection_id, tensor_input)
def async_receive_output(self):
def async_receive_output(self) -> Any:
return self.detect_api.receive_output()
class DetectorRunner(FrigateProcess):
def __init__(
self,
name,
name: str,
detection_queue: Queue,
cameras: list[str],
avg_speed: Value,
start_time: Value,
avg_speed: Any,
start_time: Any,
config: FrigateConfig,
detector_config: BaseDetectorConfig,
stop_event: MpEvent,
@ -129,11 +129,11 @@ class DetectorRunner(FrigateProcess):
self.start_time = start_time
self.config = config
self.detector_config = detector_config
self.outputs: dict = {}
self.outputs: dict[str, Any] = {}
def create_output_shm(self, name: str):
def create_output_shm(self, name: str) -> None:
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
self.outputs[name] = {"shm": out_shm, "np": out_np}
def run(self) -> None:
@ -155,8 +155,8 @@ class DetectorRunner(FrigateProcess):
connection_id,
(
1,
self.detector_config.model.height,
self.detector_config.model.width,
self.detector_config.model.height, # type: ignore[union-attr]
self.detector_config.model.width, # type: ignore[union-attr]
3,
),
)
@ -187,11 +187,11 @@ class DetectorRunner(FrigateProcess):
class AsyncDetectorRunner(FrigateProcess):
def __init__(
self,
name,
name: str,
detection_queue: Queue,
cameras: list[str],
avg_speed: Value,
start_time: Value,
avg_speed: Any,
start_time: Any,
config: FrigateConfig,
detector_config: BaseDetectorConfig,
stop_event: MpEvent,
@ -203,15 +203,15 @@ class AsyncDetectorRunner(FrigateProcess):
self.start_time = start_time
self.config = config
self.detector_config = detector_config
self.outputs: dict = {}
self.outputs: dict[str, Any] = {}
self._frame_manager: SharedMemoryFrameManager | None = None
self._publisher: ObjectDetectorPublisher | None = None
self._detector: AsyncLocalObjectDetector | None = None
self.send_times = deque()
self.send_times: deque[float] = deque()
def create_output_shm(self, name: str):
def create_output_shm(self, name: str) -> None:
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
self.outputs[name] = {"shm": out_shm, "np": out_np}
def _detect_worker(self) -> None:
@ -222,12 +222,13 @@ class AsyncDetectorRunner(FrigateProcess):
except queue.Empty:
continue
assert self._frame_manager is not None
input_frame = self._frame_manager.get(
connection_id,
(
1,
self.detector_config.model.height,
self.detector_config.model.width,
self.detector_config.model.height, # type: ignore[union-attr]
self.detector_config.model.width, # type: ignore[union-attr]
3,
),
)
@ -238,11 +239,13 @@ class AsyncDetectorRunner(FrigateProcess):
# mark start time and send to accelerator
self.send_times.append(time.perf_counter())
assert self._detector is not None
self._detector.async_send_input(input_frame, connection_id)
def _result_worker(self) -> None:
logger.info("Starting Result Worker Thread")
while not self.stop_event.is_set():
assert self._detector is not None
connection_id, detections = self._detector.async_receive_output()
# Handle timeout case (queue.Empty) - just continue
@ -256,6 +259,7 @@ class AsyncDetectorRunner(FrigateProcess):
duration = time.perf_counter() - ts
# release input buffer
assert self._frame_manager is not None
self._frame_manager.close(connection_id)
if connection_id not in self.outputs:
@ -264,6 +268,7 @@ class AsyncDetectorRunner(FrigateProcess):
# write results and publish
if detections is not None:
self.outputs[connection_id]["np"][:] = detections[:]
assert self._publisher is not None
self._publisher.publish(connection_id)
# update timers
@ -330,11 +335,14 @@ class ObjectDetectProcess:
self.stop_event = stop_event
self.start_or_restart()
def stop(self):
def stop(self) -> None:
# if the process has already exited on its own, just return
if self.detect_process and self.detect_process.exitcode:
return
if self.detect_process is None:
return
logging.info("Waiting for detection process to exit gracefully...")
self.detect_process.join(timeout=30)
if self.detect_process.exitcode is None:
@ -343,8 +351,8 @@ class ObjectDetectProcess:
self.detect_process.join()
logging.info("Detection process has exited...")
def start_or_restart(self):
self.detection_start.value = 0.0
def start_or_restart(self) -> None:
self.detection_start.value = 0.0 # type: ignore[attr-defined]
if (self.detect_process is not None) and self.detect_process.is_alive():
self.stop()
@ -389,17 +397,19 @@ class RemoteObjectDetector:
self.detection_queue = detection_queue
self.stop_event = stop_event
self.shm = UntrackedSharedMemory(name=self.name, create=False)
self.np_shm = np.ndarray(
self.np_shm: np.ndarray = np.ndarray(
(1, model_config.height, model_config.width, 3),
dtype=np.uint8,
buffer=self.shm.buf,
)
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
self.out_np_shm: np.ndarray = np.ndarray(
(20, 6), dtype=np.float32, buffer=self.out_shm.buf
)
self.detector_subscriber = ObjectDetectorSubscriber(name)
def detect(self, tensor_input, threshold=0.4):
detections = []
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
detections: list = []
if self.stop_event.is_set():
return detections
@ -431,7 +441,7 @@ class RemoteObjectDetector:
self.fps.update()
return detections
def cleanup(self):
def cleanup(self) -> None:
self.detector_subscriber.stop()
self.shm.unlink()
self.out_shm.unlink()

View File

@ -13,10 +13,10 @@ class RequestStore:
A thread-safe hash-based response store that handles creating requests.
"""
def __init__(self):
def __init__(self) -> None:
self.request_counter = 0
self.request_counter_lock = threading.Lock()
self.input_queue = queue.Queue()
self.input_queue: queue.Queue[tuple[int, ndarray]] = queue.Queue()
def __get_request_id(self) -> int:
with self.request_counter_lock:
@ -45,17 +45,17 @@ class ResponseStore:
their request's result appears.
"""
def __init__(self):
self.responses = {} # Maps request_id -> (original_input, infer_results)
def __init__(self) -> None:
self.responses: dict[int, ndarray] = {} # Maps request_id -> (original_input, infer_results)
self.lock = threading.Lock()
self.cond = threading.Condition(self.lock)
def put(self, request_id: int, response: ndarray):
def put(self, request_id: int, response: ndarray) -> None:
with self.cond:
self.responses[request_id] = response
self.cond.notify_all()
def get(self, request_id: int, timeout=None) -> ndarray:
def get(self, request_id: int, timeout: float | None = None) -> ndarray:
with self.cond:
if not self.cond.wait_for(
lambda: request_id in self.responses, timeout=timeout
@ -65,7 +65,7 @@ class ResponseStore:
return self.responses.pop(request_id)
def tensor_transform(desired_shape: InputTensorEnum):
def tensor_transform(desired_shape: InputTensorEnum) -> tuple[int, int, int, int] | None:
# Currently this function only supports BHWC permutations
if desired_shape == InputTensorEnum.nhwc:
return None