diff --git a/frigate/detectors/plugins/meta.py b/frigate/detectors/plugins/meta.py index 9afd334ce..3b01d5fdf 100644 --- a/frigate/detectors/plugins/meta.py +++ b/frigate/detectors/plugins/meta.py @@ -6,6 +6,7 @@ from frigate.detectors.detector_config import BaseDetectorConfig, ModelConfig from frigate.util import deep_merge from typing import List, Tuple, Dict, Any, Literal from typing import Union +from typing import Optional from typing_extensions import Annotated from enum import Enum from pydantic import Field, parse_obj_as @@ -28,6 +29,10 @@ class MetaDetectorConfig(BaseDetectorConfig): default={"cpu": {"type": "cpu"}}, title="Detector hardware configuration.", ) + filtered_labels: Dict[str, Optional[List[str]]] = Field( + default={}, + title="Labels to filter for each detector.", + ) max_detections: int = Field( default=20, title="Maximum number of detections to return after merging results", @@ -39,6 +44,8 @@ class MetaDetector(DetectionApi): def __init__(self, meta_detector_config: MetaDetectorConfig): self.max_detections = meta_detector_config.max_detections + self.filtered_labels = meta_detector_config.filtered_labels + self.labels = meta_detector_config.model.merged_labelmap self.detectors = [] @@ -62,12 +69,13 @@ class MetaDetector(DetectionApi): "Customizing more than a detector model path is unsupported." ) merged_model = deep_merge( - detector_config.model.dict(exclude_unset=True), meta_detector_config.model.dict(exclude_unset=True), + detector_config.model.dict(exclude_unset=True), ) detector_config.model = ModelConfig.parse_obj(merged_model) meta_detector_config.detectors[key] = detector_config self.detectors.append(self.create_detector(detector_config)) + self.meta_detector_config = meta_detector_config def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray: all_detections = np.vstack(detections_list) @@ -75,11 +83,28 @@ class MetaDetector(DetectionApi): return all_detections[sorted_indices[: self.max_detections]] def detect_raw(self, tensor_input) -> np.ndarray: - detections_list = [ - detector.detect_raw(tensor_input) for detector in self.detectors - ] + detections_list = [] + for i, detector in enumerate(self.detectors): + detector_key = list(self.meta_detector_config.detectors.keys())[i] + filtered_labels = self.filtered_labels.get(detector_key) + detections = detector.detect_raw(tensor_input) + + if filtered_labels is not None: + detections = np.array( + [ + d + for d in detections + if self.get_label_name(d[0]) in filtered_labels + ] + ) + + detections_list.append(detections) + return self.merge_detections(detections_list) + def get_label_name(self, index: int) -> str: + return self.labels.get(index) + def create_detector(self, detector_config): current_module_name = os.path.splitext(os.path.basename(__file__))[0] modules_folder = os.path.dirname(os.path.abspath(__file__))