mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 22:57:40 +03:00
Cleanup object detection mypy
This commit is contained in:
parent
706ba0c9c7
commit
f1c8c69448
@ -317,7 +317,7 @@ class MemryXDetector(DetectionApi):
|
||||
f"Failed to remove downloaded zip {zip_path}: {e}"
|
||||
)
|
||||
|
||||
def send_input(self, connection_id, tensor_input: np.ndarray):
|
||||
def send_input(self, connection_id, tensor_input: np.ndarray) -> None:
|
||||
"""Pre-process (if needed) and send frame to MemryX input queue"""
|
||||
if tensor_input is None:
|
||||
raise ValueError("[send_input] No image data provided for inference")
|
||||
|
||||
@ -50,7 +50,7 @@ ignore_errors = false
|
||||
[mypy-frigate.motion.*]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.object_detection]
|
||||
[mypy-frigate.object_detection.*]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.output]
|
||||
|
||||
@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from multiprocessing import Queue, Value
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
@ -34,26 +35,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectDetector(ABC):
|
||||
@abstractmethod
|
||||
def detect(self, tensor_input, threshold: float = 0.4):
|
||||
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
|
||||
pass
|
||||
|
||||
|
||||
class BaseLocalDetector(ObjectDetector):
|
||||
def __init__(
|
||||
self,
|
||||
detector_config: BaseDetectorConfig = None,
|
||||
labels: str = None,
|
||||
stop_event: MpEvent = None,
|
||||
):
|
||||
detector_config: Optional[BaseDetectorConfig] = None,
|
||||
labels: Optional[str] = None,
|
||||
stop_event: Optional[MpEvent] = None,
|
||||
) -> None:
|
||||
self.fps = EventsPerSecond()
|
||||
if labels is None:
|
||||
self.labels = {}
|
||||
self.labels: dict[int, str] = {}
|
||||
else:
|
||||
self.labels = load_labels(labels)
|
||||
|
||||
if detector_config:
|
||||
if detector_config and detector_config.model:
|
||||
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
||||
|
||||
self.dtype = detector_config.model.input_dtype
|
||||
else:
|
||||
self.input_transform = None
|
||||
@ -77,10 +77,10 @@ class BaseLocalDetector(ObjectDetector):
|
||||
|
||||
return tensor_input
|
||||
|
||||
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
||||
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
|
||||
detections = []
|
||||
|
||||
raw_detections = self.detect_raw(tensor_input)
|
||||
raw_detections = self.detect_raw(tensor_input) # type: ignore[attr-defined]
|
||||
|
||||
for d in raw_detections:
|
||||
if int(d[0]) < 0 or int(d[0]) >= len(self.labels):
|
||||
@ -96,28 +96,28 @@ class BaseLocalDetector(ObjectDetector):
|
||||
|
||||
|
||||
class LocalObjectDetector(BaseLocalDetector):
|
||||
def detect_raw(self, tensor_input: np.ndarray):
|
||||
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||
tensor_input = self._transform_input(tensor_input)
|
||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||
return self.detect_api.detect_raw(tensor_input=tensor_input) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class AsyncLocalObjectDetector(BaseLocalDetector):
|
||||
def async_send_input(self, tensor_input: np.ndarray, connection_id: str):
|
||||
def async_send_input(self, tensor_input: np.ndarray, connection_id: str) -> None:
|
||||
tensor_input = self._transform_input(tensor_input)
|
||||
return self.detect_api.send_input(connection_id, tensor_input)
|
||||
self.detect_api.send_input(connection_id, tensor_input)
|
||||
|
||||
def async_receive_output(self):
|
||||
def async_receive_output(self) -> Any:
|
||||
return self.detect_api.receive_output()
|
||||
|
||||
|
||||
class DetectorRunner(FrigateProcess):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
name: str,
|
||||
detection_queue: Queue,
|
||||
cameras: list[str],
|
||||
avg_speed: Value,
|
||||
start_time: Value,
|
||||
avg_speed: Any,
|
||||
start_time: Any,
|
||||
config: FrigateConfig,
|
||||
detector_config: BaseDetectorConfig,
|
||||
stop_event: MpEvent,
|
||||
@ -129,11 +129,11 @@ class DetectorRunner(FrigateProcess):
|
||||
self.start_time = start_time
|
||||
self.config = config
|
||||
self.detector_config = detector_config
|
||||
self.outputs: dict = {}
|
||||
self.outputs: dict[str, Any] = {}
|
||||
|
||||
def create_output_shm(self, name: str):
|
||||
def create_output_shm(self, name: str) -> None:
|
||||
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
self.outputs[name] = {"shm": out_shm, "np": out_np}
|
||||
|
||||
def run(self) -> None:
|
||||
@ -155,8 +155,8 @@ class DetectorRunner(FrigateProcess):
|
||||
connection_id,
|
||||
(
|
||||
1,
|
||||
self.detector_config.model.height,
|
||||
self.detector_config.model.width,
|
||||
self.detector_config.model.height, # type: ignore[union-attr]
|
||||
self.detector_config.model.width, # type: ignore[union-attr]
|
||||
3,
|
||||
),
|
||||
)
|
||||
@ -187,11 +187,11 @@ class DetectorRunner(FrigateProcess):
|
||||
class AsyncDetectorRunner(FrigateProcess):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
name: str,
|
||||
detection_queue: Queue,
|
||||
cameras: list[str],
|
||||
avg_speed: Value,
|
||||
start_time: Value,
|
||||
avg_speed: Any,
|
||||
start_time: Any,
|
||||
config: FrigateConfig,
|
||||
detector_config: BaseDetectorConfig,
|
||||
stop_event: MpEvent,
|
||||
@ -203,15 +203,15 @@ class AsyncDetectorRunner(FrigateProcess):
|
||||
self.start_time = start_time
|
||||
self.config = config
|
||||
self.detector_config = detector_config
|
||||
self.outputs: dict = {}
|
||||
self.outputs: dict[str, Any] = {}
|
||||
self._frame_manager: SharedMemoryFrameManager | None = None
|
||||
self._publisher: ObjectDetectorPublisher | None = None
|
||||
self._detector: AsyncLocalObjectDetector | None = None
|
||||
self.send_times = deque()
|
||||
self.send_times: deque[float] = deque()
|
||||
|
||||
def create_output_shm(self, name: str):
|
||||
def create_output_shm(self, name: str) -> None:
|
||||
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
self.outputs[name] = {"shm": out_shm, "np": out_np}
|
||||
|
||||
def _detect_worker(self) -> None:
|
||||
@ -222,12 +222,13 @@ class AsyncDetectorRunner(FrigateProcess):
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
assert self._frame_manager is not None
|
||||
input_frame = self._frame_manager.get(
|
||||
connection_id,
|
||||
(
|
||||
1,
|
||||
self.detector_config.model.height,
|
||||
self.detector_config.model.width,
|
||||
self.detector_config.model.height, # type: ignore[union-attr]
|
||||
self.detector_config.model.width, # type: ignore[union-attr]
|
||||
3,
|
||||
),
|
||||
)
|
||||
@ -238,11 +239,13 @@ class AsyncDetectorRunner(FrigateProcess):
|
||||
|
||||
# mark start time and send to accelerator
|
||||
self.send_times.append(time.perf_counter())
|
||||
assert self._detector is not None
|
||||
self._detector.async_send_input(input_frame, connection_id)
|
||||
|
||||
def _result_worker(self) -> None:
|
||||
logger.info("Starting Result Worker Thread")
|
||||
while not self.stop_event.is_set():
|
||||
assert self._detector is not None
|
||||
connection_id, detections = self._detector.async_receive_output()
|
||||
|
||||
# Handle timeout case (queue.Empty) - just continue
|
||||
@ -256,6 +259,7 @@ class AsyncDetectorRunner(FrigateProcess):
|
||||
duration = time.perf_counter() - ts
|
||||
|
||||
# release input buffer
|
||||
assert self._frame_manager is not None
|
||||
self._frame_manager.close(connection_id)
|
||||
|
||||
if connection_id not in self.outputs:
|
||||
@ -264,6 +268,7 @@ class AsyncDetectorRunner(FrigateProcess):
|
||||
# write results and publish
|
||||
if detections is not None:
|
||||
self.outputs[connection_id]["np"][:] = detections[:]
|
||||
assert self._publisher is not None
|
||||
self._publisher.publish(connection_id)
|
||||
|
||||
# update timers
|
||||
@ -330,11 +335,14 @@ class ObjectDetectProcess:
|
||||
self.stop_event = stop_event
|
||||
self.start_or_restart()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
# if the process has already exited on its own, just return
|
||||
if self.detect_process and self.detect_process.exitcode:
|
||||
return
|
||||
|
||||
if self.detect_process is None:
|
||||
return
|
||||
|
||||
logging.info("Waiting for detection process to exit gracefully...")
|
||||
self.detect_process.join(timeout=30)
|
||||
if self.detect_process.exitcode is None:
|
||||
@ -343,8 +351,8 @@ class ObjectDetectProcess:
|
||||
self.detect_process.join()
|
||||
logging.info("Detection process has exited...")
|
||||
|
||||
def start_or_restart(self):
|
||||
self.detection_start.value = 0.0
|
||||
def start_or_restart(self) -> None:
|
||||
self.detection_start.value = 0.0 # type: ignore[attr-defined]
|
||||
if (self.detect_process is not None) and self.detect_process.is_alive():
|
||||
self.stop()
|
||||
|
||||
@ -389,17 +397,19 @@ class RemoteObjectDetector:
|
||||
self.detection_queue = detection_queue
|
||||
self.stop_event = stop_event
|
||||
self.shm = UntrackedSharedMemory(name=self.name, create=False)
|
||||
self.np_shm = np.ndarray(
|
||||
self.np_shm: np.ndarray = np.ndarray(
|
||||
(1, model_config.height, model_config.width, 3),
|
||||
dtype=np.uint8,
|
||||
buffer=self.shm.buf,
|
||||
)
|
||||
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
|
||||
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
|
||||
self.out_np_shm: np.ndarray = np.ndarray(
|
||||
(20, 6), dtype=np.float32, buffer=self.out_shm.buf
|
||||
)
|
||||
self.detector_subscriber = ObjectDetectorSubscriber(name)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
|
||||
detections: list = []
|
||||
|
||||
if self.stop_event.is_set():
|
||||
return detections
|
||||
@ -431,7 +441,7 @@ class RemoteObjectDetector:
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
self.detector_subscriber.stop()
|
||||
self.shm.unlink()
|
||||
self.out_shm.unlink()
|
||||
|
||||
@ -13,10 +13,10 @@ class RequestStore:
|
||||
A thread-safe hash-based response store that handles creating requests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = threading.Lock()
|
||||
self.input_queue = queue.Queue()
|
||||
self.input_queue: queue.Queue[tuple[int, ndarray]] = queue.Queue()
|
||||
|
||||
def __get_request_id(self) -> int:
|
||||
with self.request_counter_lock:
|
||||
@ -45,17 +45,17 @@ class ResponseStore:
|
||||
their request's result appears.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.responses = {} # Maps request_id -> (original_input, infer_results)
|
||||
def __init__(self) -> None:
|
||||
self.responses: dict[int, ndarray] = {} # Maps request_id -> (original_input, infer_results)
|
||||
self.lock = threading.Lock()
|
||||
self.cond = threading.Condition(self.lock)
|
||||
|
||||
def put(self, request_id: int, response: ndarray):
|
||||
def put(self, request_id: int, response: ndarray) -> None:
|
||||
with self.cond:
|
||||
self.responses[request_id] = response
|
||||
self.cond.notify_all()
|
||||
|
||||
def get(self, request_id: int, timeout=None) -> ndarray:
|
||||
def get(self, request_id: int, timeout: float | None = None) -> ndarray:
|
||||
with self.cond:
|
||||
if not self.cond.wait_for(
|
||||
lambda: request_id in self.responses, timeout=timeout
|
||||
@ -65,7 +65,7 @@ class ResponseStore:
|
||||
return self.responses.pop(request_id)
|
||||
|
||||
|
||||
def tensor_transform(desired_shape: InputTensorEnum):
|
||||
def tensor_transform(desired_shape: InputTensorEnum) -> tuple[int, int, int, int] | None:
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == InputTensorEnum.nhwc:
|
||||
return None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user