Add support for filtering labels for each detector in MetaDetectorConfig

This commit is contained in:
Sergey Krashevich 2023-04-26 03:33:25 +03:00
parent 5546c085fb
commit 55997dbbd5
No known key found for this signature in database
GPG Key ID: 625171324E7D3856

View File

@ -6,6 +6,7 @@ from frigate.detectors.detector_config import BaseDetectorConfig, ModelConfig
from frigate.util import deep_merge from frigate.util import deep_merge
from typing import List, Tuple, Dict, Any, Literal from typing import List, Tuple, Dict, Any, Literal
from typing import Union from typing import Union
from typing import Optional
from typing_extensions import Annotated from typing_extensions import Annotated
from enum import Enum from enum import Enum
from pydantic import Field, parse_obj_as from pydantic import Field, parse_obj_as
@ -28,6 +29,10 @@ class MetaDetectorConfig(BaseDetectorConfig):
default={"cpu": {"type": "cpu"}}, default={"cpu": {"type": "cpu"}},
title="Detector hardware configuration.", title="Detector hardware configuration.",
) )
filtered_labels: Dict[str, Optional[List[str]]] = Field(
default={},
title="Labels to filter for each detector.",
)
max_detections: int = Field( max_detections: int = Field(
default=20, default=20,
title="Maximum number of detections to return after merging results", title="Maximum number of detections to return after merging results",
@ -39,6 +44,8 @@ class MetaDetector(DetectionApi):
def __init__(self, meta_detector_config: MetaDetectorConfig): def __init__(self, meta_detector_config: MetaDetectorConfig):
self.max_detections = meta_detector_config.max_detections self.max_detections = meta_detector_config.max_detections
self.filtered_labels = meta_detector_config.filtered_labels
self.labels = meta_detector_config.model.merged_labelmap
self.detectors = [] self.detectors = []
@ -62,12 +69,13 @@ class MetaDetector(DetectionApi):
"Customizing more than a detector model path is unsupported." "Customizing more than a detector model path is unsupported."
) )
merged_model = deep_merge( merged_model = deep_merge(
detector_config.model.dict(exclude_unset=True),
meta_detector_config.model.dict(exclude_unset=True), meta_detector_config.model.dict(exclude_unset=True),
detector_config.model.dict(exclude_unset=True),
) )
detector_config.model = ModelConfig.parse_obj(merged_model) detector_config.model = ModelConfig.parse_obj(merged_model)
meta_detector_config.detectors[key] = detector_config meta_detector_config.detectors[key] = detector_config
self.detectors.append(self.create_detector(detector_config)) self.detectors.append(self.create_detector(detector_config))
self.meta_detector_config = meta_detector_config
def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray: def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray:
all_detections = np.vstack(detections_list) all_detections = np.vstack(detections_list)
@ -75,11 +83,28 @@ class MetaDetector(DetectionApi):
return all_detections[sorted_indices[: self.max_detections]] return all_detections[sorted_indices[: self.max_detections]]
def detect_raw(self, tensor_input) -> np.ndarray: def detect_raw(self, tensor_input) -> np.ndarray:
detections_list = [ detections_list = []
detector.detect_raw(tensor_input) for detector in self.detectors for i, detector in enumerate(self.detectors):
] detector_key = list(self.meta_detector_config.detectors.keys())[i]
filtered_labels = self.filtered_labels.get(detector_key)
detections = detector.detect_raw(tensor_input)
if filtered_labels is not None:
detections = np.array(
[
d
for d in detections
if self.get_label_name(d[0]) in filtered_labels
]
)
detections_list.append(detections)
return self.merge_detections(detections_list) return self.merge_detections(detections_list)
def get_label_name(self, index: int) -> str:
return self.labels.get(index)
def create_detector(self, detector_config): def create_detector(self, detector_config):
current_module_name = os.path.splitext(os.path.basename(__file__))[0] current_module_name = os.path.splitext(os.path.basename(__file__))[0]
modules_folder = os.path.dirname(os.path.abspath(__file__)) modules_folder = os.path.dirname(os.path.abspath(__file__))