reformatted frigate/detectors/plugins/meta.py

This commit is contained in:
Sergey Krashevich 2023-04-26 02:45:50 +03:00
parent 3d9b1996f9
commit 5546c085fb
No known key found for this signature in database
GPG Key ID: 625171324E7D3856

View File

@ -17,9 +17,10 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "meta_detector" DETECTOR_KEY = "meta_detector"
DetectorConfig = Annotated[ DetectorConfig = Annotated[
Union[tuple(BaseDetectorConfig.__subclasses__())], Union[tuple(BaseDetectorConfig.__subclasses__())],
Field(discriminator="type"), Field(discriminator="type"),
] ]
class MetaDetectorConfig(BaseDetectorConfig): class MetaDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY] type: Literal[DETECTOR_KEY]
@ -68,8 +69,6 @@ class MetaDetector(DetectionApi):
meta_detector_config.detectors[key] = detector_config meta_detector_config.detectors[key] = detector_config
self.detectors.append(self.create_detector(detector_config)) self.detectors.append(self.create_detector(detector_config))
def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray: def merge_detections(self, detections_list: List[np.ndarray]) -> np.ndarray:
all_detections = np.vstack(detections_list) all_detections = np.vstack(detections_list)
sorted_indices = np.argsort(-all_detections[:, 1]) sorted_indices = np.argsort(-all_detections[:, 1])
@ -80,14 +79,17 @@ class MetaDetector(DetectionApi):
detector.detect_raw(tensor_input) for detector in self.detectors detector.detect_raw(tensor_input) for detector in self.detectors
] ]
return self.merge_detections(detections_list) return self.merge_detections(detections_list)
def create_detector(self, detector_config):
def create_detector(self, detector_config):
current_module_name = os.path.splitext(os.path.basename(__file__))[0] current_module_name = os.path.splitext(os.path.basename(__file__))[0]
modules_folder = os.path.dirname(os.path.abspath(__file__)) modules_folder = os.path.dirname(os.path.abspath(__file__))
module_prefix = __package__ + "." module_prefix = __package__ + "."
_included_modules = [module for module in pkgutil.iter_modules([modules_folder], module_prefix) if module.name != current_module_name] _included_modules = [
module
for module in pkgutil.iter_modules([modules_folder], module_prefix)
if module.name != current_module_name
]
plugin_modules = [] plugin_modules = []
@ -99,8 +101,6 @@ class MetaDetector(DetectionApi):
except ImportError as e: except ImportError as e:
logger.error(f"Error importing detector runtime: {e}") logger.error(f"Error importing detector runtime: {e}")
api_types = {det.type_key: det for det in DetectionApi.__subclasses__()} api_types = {det.type_key: det for det in DetectionApi.__subclasses__()}
api = api_types.get(detector_config.type) api = api_types.get(detector_config.type)
if not api: if not api: