mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-05 02:35:22 +03:00
Add support for filtering labels for each detector in MetaDetectorConfig
This commit is contained in:
parent
5546c085fb
commit
55997dbbd5
@ -6,6 +6,7 @@ from frigate.detectors.detector_config import BaseDetectorConfig, ModelConfig
|
|||||||
from frigate.util import deep_merge
|
from frigate.util import deep_merge
|
||||||
from typing import List, Tuple, Dict, Any, Literal
|
from typing import List, Tuple, Dict, Any, Literal
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from typing import Optional
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import Field, parse_obj_as
|
from pydantic import Field, parse_obj_as
|
||||||
@ -28,6 +29,10 @@ class MetaDetectorConfig(BaseDetectorConfig):
|
|||||||
default={"cpu": {"type": "cpu"}},
|
default={"cpu": {"type": "cpu"}},
|
||||||
title="Detector hardware configuration.",
|
title="Detector hardware configuration.",
|
||||||
)
|
)
|
||||||
|
filtered_labels: Dict[str, Optional[List[str]]] = Field(
|
||||||
|
default={},
|
||||||
|
title="Labels to filter for each detector.",
|
||||||
|
)
|
||||||
max_detections: int = Field(
|
max_detections: int = Field(
|
||||||
default=20,
|
default=20,
|
||||||
title="Maximum number of detections to return after merging results",
|
title="Maximum number of detections to return after merging results",
|
||||||
@ -39,6 +44,8 @@ class MetaDetector(DetectionApi):
|
|||||||
|
|
||||||
def __init__(self, meta_detector_config: MetaDetectorConfig):
|
def __init__(self, meta_detector_config: MetaDetectorConfig):
|
||||||
self.max_detections = meta_detector_config.max_detections
|
self.max_detections = meta_detector_config.max_detections
|
||||||
|
self.filtered_labels = meta_detector_config.filtered_labels
|
||||||
|
self.labels = meta_detector_config.model.merged_labelmap
|
||||||
|
|
||||||
self.detectors = []
|
self.detectors = []
|
||||||
|
|
||||||
@ -62,12 +69,13 @@ class MetaDetector(DetectionApi):
|
|||||||
"Customizing more than a detector model path is unsupported."
|
"Customizing more than a detector model path is unsupported."
|
||||||
)
|
)
|
||||||
merged_model = deep_merge(
|
merged_model = deep_merge(
|
||||||
detector_config.model.dict(exclude_unset=True),
|
|
||||||
meta_detector_config.model.dict(exclude_unset=True),
|
meta_detector_config.model.dict(exclude_unset=True),
|
||||||
|
detector_config.model.dict(exclude_unset=True),
|
||||||
)
|
)
|
||||||
detector_config.model = ModelConfig.parse_obj(merged_model)
|
detector_config.model = ModelConfig.parse_obj(merged_model)
|
||||||
meta_detector_config.detectors[key] = detector_config
|
meta_detector_config.detectors[key] = detector_config
|
||||||
self.detectors.append(self.create_detector(detector_config))
|
self.detectors.append(self.create_detector(detector_config))
|
||||||
|
self.meta_detector_config = meta_detector_config
|
||||||
|
|
||||||
def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray:
|
def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray:
|
||||||
all_detections = np.vstack(detections_list)
|
all_detections = np.vstack(detections_list)
|
||||||
@ -75,11 +83,28 @@ class MetaDetector(DetectionApi):
|
|||||||
return all_detections[sorted_indices[: self.max_detections]]
|
return all_detections[sorted_indices[: self.max_detections]]
|
||||||
|
|
||||||
def detect_raw(self, tensor_input) -> np.ndarray:
|
def detect_raw(self, tensor_input) -> np.ndarray:
|
||||||
detections_list = [
|
detections_list = []
|
||||||
detector.detect_raw(tensor_input) for detector in self.detectors
|
for i, detector in enumerate(self.detectors):
|
||||||
]
|
detector_key = list(self.meta_detector_config.detectors.keys())[i]
|
||||||
|
filtered_labels = self.filtered_labels.get(detector_key)
|
||||||
|
detections = detector.detect_raw(tensor_input)
|
||||||
|
|
||||||
|
if filtered_labels is not None:
|
||||||
|
detections = np.array(
|
||||||
|
[
|
||||||
|
d
|
||||||
|
for d in detections
|
||||||
|
if self.get_label_name(d[0]) in filtered_labels
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
detections_list.append(detections)
|
||||||
|
|
||||||
return self.merge_detections(detections_list)
|
return self.merge_detections(detections_list)
|
||||||
|
|
||||||
|
def get_label_name(self, index: int) -> str:
|
||||||
|
return self.labels.get(index)
|
||||||
|
|
||||||
def create_detector(self, detector_config):
|
def create_detector(self, detector_config):
|
||||||
current_module_name = os.path.splitext(os.path.basename(__file__))[0]
|
current_module_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||||
modules_folder = os.path.dirname(os.path.abspath(__file__))
|
modules_folder = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user