mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 22:57:40 +03:00
Cleanup object detection mypy
This commit is contained in:
parent
706ba0c9c7
commit
f1c8c69448
@ -317,7 +317,7 @@ class MemryXDetector(DetectionApi):
|
|||||||
f"Failed to remove downloaded zip {zip_path}: {e}"
|
f"Failed to remove downloaded zip {zip_path}: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_input(self, connection_id, tensor_input: np.ndarray):
|
def send_input(self, connection_id, tensor_input: np.ndarray) -> None:
|
||||||
"""Pre-process (if needed) and send frame to MemryX input queue"""
|
"""Pre-process (if needed) and send frame to MemryX input queue"""
|
||||||
if tensor_input is None:
|
if tensor_input is None:
|
||||||
raise ValueError("[send_input] No image data provided for inference")
|
raise ValueError("[send_input] No image data provided for inference")
|
||||||
|
|||||||
@ -50,7 +50,7 @@ ignore_errors = false
|
|||||||
[mypy-frigate.motion.*]
|
[mypy-frigate.motion.*]
|
||||||
ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
[mypy-frigate.object_detection]
|
[mypy-frigate.object_detection.*]
|
||||||
ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
[mypy-frigate.output]
|
[mypy-frigate.output]
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from multiprocessing import Queue, Value
|
from multiprocessing import Queue, Value
|
||||||
from multiprocessing.synchronize import Event as MpEvent
|
from multiprocessing.synchronize import Event as MpEvent
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@ -34,26 +35,25 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ObjectDetector(ABC):
|
class ObjectDetector(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def detect(self, tensor_input, threshold: float = 0.4):
|
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseLocalDetector(ObjectDetector):
|
class BaseLocalDetector(ObjectDetector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector_config: BaseDetectorConfig = None,
|
detector_config: Optional[BaseDetectorConfig] = None,
|
||||||
labels: str = None,
|
labels: Optional[str] = None,
|
||||||
stop_event: MpEvent = None,
|
stop_event: Optional[MpEvent] = None,
|
||||||
):
|
) -> None:
|
||||||
self.fps = EventsPerSecond()
|
self.fps = EventsPerSecond()
|
||||||
if labels is None:
|
if labels is None:
|
||||||
self.labels = {}
|
self.labels: dict[int, str] = {}
|
||||||
else:
|
else:
|
||||||
self.labels = load_labels(labels)
|
self.labels = load_labels(labels)
|
||||||
|
|
||||||
if detector_config:
|
if detector_config and detector_config.model:
|
||||||
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
||||||
|
|
||||||
self.dtype = detector_config.model.input_dtype
|
self.dtype = detector_config.model.input_dtype
|
||||||
else:
|
else:
|
||||||
self.input_transform = None
|
self.input_transform = None
|
||||||
@ -77,10 +77,10 @@ class BaseLocalDetector(ObjectDetector):
|
|||||||
|
|
||||||
return tensor_input
|
return tensor_input
|
||||||
|
|
||||||
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
raw_detections = self.detect_raw(tensor_input)
|
raw_detections = self.detect_raw(tensor_input) # type: ignore[attr-defined]
|
||||||
|
|
||||||
for d in raw_detections:
|
for d in raw_detections:
|
||||||
if int(d[0]) < 0 or int(d[0]) >= len(self.labels):
|
if int(d[0]) < 0 or int(d[0]) >= len(self.labels):
|
||||||
@ -96,28 +96,28 @@ class BaseLocalDetector(ObjectDetector):
|
|||||||
|
|
||||||
|
|
||||||
class LocalObjectDetector(BaseLocalDetector):
|
class LocalObjectDetector(BaseLocalDetector):
|
||||||
def detect_raw(self, tensor_input: np.ndarray):
|
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||||
tensor_input = self._transform_input(tensor_input)
|
tensor_input = self._transform_input(tensor_input)
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input) # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
class AsyncLocalObjectDetector(BaseLocalDetector):
|
class AsyncLocalObjectDetector(BaseLocalDetector):
|
||||||
def async_send_input(self, tensor_input: np.ndarray, connection_id: str):
|
def async_send_input(self, tensor_input: np.ndarray, connection_id: str) -> None:
|
||||||
tensor_input = self._transform_input(tensor_input)
|
tensor_input = self._transform_input(tensor_input)
|
||||||
return self.detect_api.send_input(connection_id, tensor_input)
|
self.detect_api.send_input(connection_id, tensor_input)
|
||||||
|
|
||||||
def async_receive_output(self):
|
def async_receive_output(self) -> Any:
|
||||||
return self.detect_api.receive_output()
|
return self.detect_api.receive_output()
|
||||||
|
|
||||||
|
|
||||||
class DetectorRunner(FrigateProcess):
|
class DetectorRunner(FrigateProcess):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name: str,
|
||||||
detection_queue: Queue,
|
detection_queue: Queue,
|
||||||
cameras: list[str],
|
cameras: list[str],
|
||||||
avg_speed: Value,
|
avg_speed: Any,
|
||||||
start_time: Value,
|
start_time: Any,
|
||||||
config: FrigateConfig,
|
config: FrigateConfig,
|
||||||
detector_config: BaseDetectorConfig,
|
detector_config: BaseDetectorConfig,
|
||||||
stop_event: MpEvent,
|
stop_event: MpEvent,
|
||||||
@ -129,11 +129,11 @@ class DetectorRunner(FrigateProcess):
|
|||||||
self.start_time = start_time
|
self.start_time = start_time
|
||||||
self.config = config
|
self.config = config
|
||||||
self.detector_config = detector_config
|
self.detector_config = detector_config
|
||||||
self.outputs: dict = {}
|
self.outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
def create_output_shm(self, name: str):
|
def create_output_shm(self, name: str) -> None:
|
||||||
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
||||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||||
self.outputs[name] = {"shm": out_shm, "np": out_np}
|
self.outputs[name] = {"shm": out_shm, "np": out_np}
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
@ -155,8 +155,8 @@ class DetectorRunner(FrigateProcess):
|
|||||||
connection_id,
|
connection_id,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.detector_config.model.height,
|
self.detector_config.model.height, # type: ignore[union-attr]
|
||||||
self.detector_config.model.width,
|
self.detector_config.model.width, # type: ignore[union-attr]
|
||||||
3,
|
3,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -187,11 +187,11 @@ class DetectorRunner(FrigateProcess):
|
|||||||
class AsyncDetectorRunner(FrigateProcess):
|
class AsyncDetectorRunner(FrigateProcess):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name: str,
|
||||||
detection_queue: Queue,
|
detection_queue: Queue,
|
||||||
cameras: list[str],
|
cameras: list[str],
|
||||||
avg_speed: Value,
|
avg_speed: Any,
|
||||||
start_time: Value,
|
start_time: Any,
|
||||||
config: FrigateConfig,
|
config: FrigateConfig,
|
||||||
detector_config: BaseDetectorConfig,
|
detector_config: BaseDetectorConfig,
|
||||||
stop_event: MpEvent,
|
stop_event: MpEvent,
|
||||||
@ -203,15 +203,15 @@ class AsyncDetectorRunner(FrigateProcess):
|
|||||||
self.start_time = start_time
|
self.start_time = start_time
|
||||||
self.config = config
|
self.config = config
|
||||||
self.detector_config = detector_config
|
self.detector_config = detector_config
|
||||||
self.outputs: dict = {}
|
self.outputs: dict[str, Any] = {}
|
||||||
self._frame_manager: SharedMemoryFrameManager | None = None
|
self._frame_manager: SharedMemoryFrameManager | None = None
|
||||||
self._publisher: ObjectDetectorPublisher | None = None
|
self._publisher: ObjectDetectorPublisher | None = None
|
||||||
self._detector: AsyncLocalObjectDetector | None = None
|
self._detector: AsyncLocalObjectDetector | None = None
|
||||||
self.send_times = deque()
|
self.send_times: deque[float] = deque()
|
||||||
|
|
||||||
def create_output_shm(self, name: str):
|
def create_output_shm(self, name: str) -> None:
|
||||||
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False)
|
||||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
out_np: np.ndarray = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||||
self.outputs[name] = {"shm": out_shm, "np": out_np}
|
self.outputs[name] = {"shm": out_shm, "np": out_np}
|
||||||
|
|
||||||
def _detect_worker(self) -> None:
|
def _detect_worker(self) -> None:
|
||||||
@ -222,12 +222,13 @@ class AsyncDetectorRunner(FrigateProcess):
|
|||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
assert self._frame_manager is not None
|
||||||
input_frame = self._frame_manager.get(
|
input_frame = self._frame_manager.get(
|
||||||
connection_id,
|
connection_id,
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
self.detector_config.model.height,
|
self.detector_config.model.height, # type: ignore[union-attr]
|
||||||
self.detector_config.model.width,
|
self.detector_config.model.width, # type: ignore[union-attr]
|
||||||
3,
|
3,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -238,11 +239,13 @@ class AsyncDetectorRunner(FrigateProcess):
|
|||||||
|
|
||||||
# mark start time and send to accelerator
|
# mark start time and send to accelerator
|
||||||
self.send_times.append(time.perf_counter())
|
self.send_times.append(time.perf_counter())
|
||||||
|
assert self._detector is not None
|
||||||
self._detector.async_send_input(input_frame, connection_id)
|
self._detector.async_send_input(input_frame, connection_id)
|
||||||
|
|
||||||
def _result_worker(self) -> None:
|
def _result_worker(self) -> None:
|
||||||
logger.info("Starting Result Worker Thread")
|
logger.info("Starting Result Worker Thread")
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
|
assert self._detector is not None
|
||||||
connection_id, detections = self._detector.async_receive_output()
|
connection_id, detections = self._detector.async_receive_output()
|
||||||
|
|
||||||
# Handle timeout case (queue.Empty) - just continue
|
# Handle timeout case (queue.Empty) - just continue
|
||||||
@ -256,6 +259,7 @@ class AsyncDetectorRunner(FrigateProcess):
|
|||||||
duration = time.perf_counter() - ts
|
duration = time.perf_counter() - ts
|
||||||
|
|
||||||
# release input buffer
|
# release input buffer
|
||||||
|
assert self._frame_manager is not None
|
||||||
self._frame_manager.close(connection_id)
|
self._frame_manager.close(connection_id)
|
||||||
|
|
||||||
if connection_id not in self.outputs:
|
if connection_id not in self.outputs:
|
||||||
@ -264,6 +268,7 @@ class AsyncDetectorRunner(FrigateProcess):
|
|||||||
# write results and publish
|
# write results and publish
|
||||||
if detections is not None:
|
if detections is not None:
|
||||||
self.outputs[connection_id]["np"][:] = detections[:]
|
self.outputs[connection_id]["np"][:] = detections[:]
|
||||||
|
assert self._publisher is not None
|
||||||
self._publisher.publish(connection_id)
|
self._publisher.publish(connection_id)
|
||||||
|
|
||||||
# update timers
|
# update timers
|
||||||
@ -330,11 +335,14 @@ class ObjectDetectProcess:
|
|||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
self.start_or_restart()
|
self.start_or_restart()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
# if the process has already exited on its own, just return
|
# if the process has already exited on its own, just return
|
||||||
if self.detect_process and self.detect_process.exitcode:
|
if self.detect_process and self.detect_process.exitcode:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.detect_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
logging.info("Waiting for detection process to exit gracefully...")
|
logging.info("Waiting for detection process to exit gracefully...")
|
||||||
self.detect_process.join(timeout=30)
|
self.detect_process.join(timeout=30)
|
||||||
if self.detect_process.exitcode is None:
|
if self.detect_process.exitcode is None:
|
||||||
@ -343,8 +351,8 @@ class ObjectDetectProcess:
|
|||||||
self.detect_process.join()
|
self.detect_process.join()
|
||||||
logging.info("Detection process has exited...")
|
logging.info("Detection process has exited...")
|
||||||
|
|
||||||
def start_or_restart(self):
|
def start_or_restart(self) -> None:
|
||||||
self.detection_start.value = 0.0
|
self.detection_start.value = 0.0 # type: ignore[attr-defined]
|
||||||
if (self.detect_process is not None) and self.detect_process.is_alive():
|
if (self.detect_process is not None) and self.detect_process.is_alive():
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
@ -389,17 +397,19 @@ class RemoteObjectDetector:
|
|||||||
self.detection_queue = detection_queue
|
self.detection_queue = detection_queue
|
||||||
self.stop_event = stop_event
|
self.stop_event = stop_event
|
||||||
self.shm = UntrackedSharedMemory(name=self.name, create=False)
|
self.shm = UntrackedSharedMemory(name=self.name, create=False)
|
||||||
self.np_shm = np.ndarray(
|
self.np_shm: np.ndarray = np.ndarray(
|
||||||
(1, model_config.height, model_config.width, 3),
|
(1, model_config.height, model_config.width, 3),
|
||||||
dtype=np.uint8,
|
dtype=np.uint8,
|
||||||
buffer=self.shm.buf,
|
buffer=self.shm.buf,
|
||||||
)
|
)
|
||||||
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
|
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
|
||||||
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
|
self.out_np_shm: np.ndarray = np.ndarray(
|
||||||
|
(20, 6), dtype=np.float32, buffer=self.out_shm.buf
|
||||||
|
)
|
||||||
self.detector_subscriber = ObjectDetectorSubscriber(name)
|
self.detector_subscriber = ObjectDetectorSubscriber(name)
|
||||||
|
|
||||||
def detect(self, tensor_input, threshold=0.4):
|
def detect(self, tensor_input: np.ndarray, threshold: float = 0.4) -> list:
|
||||||
detections = []
|
detections: list = []
|
||||||
|
|
||||||
if self.stop_event.is_set():
|
if self.stop_event.is_set():
|
||||||
return detections
|
return detections
|
||||||
@ -431,7 +441,7 @@ class RemoteObjectDetector:
|
|||||||
self.fps.update()
|
self.fps.update()
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self) -> None:
|
||||||
self.detector_subscriber.stop()
|
self.detector_subscriber.stop()
|
||||||
self.shm.unlink()
|
self.shm.unlink()
|
||||||
self.out_shm.unlink()
|
self.out_shm.unlink()
|
||||||
|
|||||||
@ -13,10 +13,10 @@ class RequestStore:
|
|||||||
A thread-safe hash-based response store that handles creating requests.
|
A thread-safe hash-based response store that handles creating requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.request_counter = 0
|
self.request_counter = 0
|
||||||
self.request_counter_lock = threading.Lock()
|
self.request_counter_lock = threading.Lock()
|
||||||
self.input_queue = queue.Queue()
|
self.input_queue: queue.Queue[tuple[int, ndarray]] = queue.Queue()
|
||||||
|
|
||||||
def __get_request_id(self) -> int:
|
def __get_request_id(self) -> int:
|
||||||
with self.request_counter_lock:
|
with self.request_counter_lock:
|
||||||
@ -45,17 +45,17 @@ class ResponseStore:
|
|||||||
their request's result appears.
|
their request's result appears.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.responses = {} # Maps request_id -> (original_input, infer_results)
|
self.responses: dict[int, ndarray] = {} # Maps request_id -> (original_input, infer_results)
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.cond = threading.Condition(self.lock)
|
self.cond = threading.Condition(self.lock)
|
||||||
|
|
||||||
def put(self, request_id: int, response: ndarray):
|
def put(self, request_id: int, response: ndarray) -> None:
|
||||||
with self.cond:
|
with self.cond:
|
||||||
self.responses[request_id] = response
|
self.responses[request_id] = response
|
||||||
self.cond.notify_all()
|
self.cond.notify_all()
|
||||||
|
|
||||||
def get(self, request_id: int, timeout=None) -> ndarray:
|
def get(self, request_id: int, timeout: float | None = None) -> ndarray:
|
||||||
with self.cond:
|
with self.cond:
|
||||||
if not self.cond.wait_for(
|
if not self.cond.wait_for(
|
||||||
lambda: request_id in self.responses, timeout=timeout
|
lambda: request_id in self.responses, timeout=timeout
|
||||||
@ -65,7 +65,7 @@ class ResponseStore:
|
|||||||
return self.responses.pop(request_id)
|
return self.responses.pop(request_id)
|
||||||
|
|
||||||
|
|
||||||
def tensor_transform(desired_shape: InputTensorEnum):
|
def tensor_transform(desired_shape: InputTensorEnum) -> tuple[int, int, int, int] | None:
|
||||||
# Currently this function only supports BHWC permutations
|
# Currently this function only supports BHWC permutations
|
||||||
if desired_shape == InputTensorEnum.nhwc:
|
if desired_shape == InputTensorEnum.nhwc:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user