fix detector model config defaults

This commit is contained in:
Dennis George 2022-12-10 04:54:46 -06:00
parent f5ff975530
commit edd6818d3c
6 changed files with 82 additions and 15 deletions

View File

@ -186,7 +186,12 @@ class FrigateApp:
self.detection_out_events[name] = mp.Event() self.detection_out_events[name] = mp.Event()
try: try:
size = max([det.model.height * det.model.width * 3 for (name, det) in self.config.detectors.items()]) size = max(
[
det.model.height * det.model.width * 3
for (name, det) in self.config.detectors.items()
]
)
shm_in = mp.shared_memory.SharedMemory( shm_in = mp.shared_memory.SharedMemory(
name=name, name=name,
create=True, create=True,

View File

@ -975,6 +975,26 @@ class FrigateConfig(FrigateBaseModel):
detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector) detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector)
if detector_config.model is None: if detector_config.model is None:
detector_config.model = config.model detector_config.model = config.model
else:
model = detector_config.model
schema = ModelConfig.schema()["properties"]
if (
model.width != schema["width"]["default"]
or model.height != schema["height"]["default"]
or model.labelmap_path is not None
or model.labelmap is not {}
or model.input_tensor != schema["input_tensor"]["default"]
or model.input_pixel_format
!= schema["input_pixel_format"]["default"]
):
logger.warning(
"Customizing more than a detector model path is unsupported."
)
merged_model = deep_merge(
detector_config.model.dict(exclude_unset=True),
config.model.dict(exclude_unset=True),
)
detector_config.model = ModelConfig.parse_obj(merged_model)
config.detectors[key] = detector_config config.detectors[key] = detector_config
return config return config

View File

@ -60,7 +60,7 @@ class ModelConfig(BaseModel):
class BaseDetectorConfig(BaseModel): class BaseDetectorConfig(BaseModel):
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type") type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
model: ModelConfig = Field( model: ModelConfig = Field(
default_factory=ModelConfig, title="Detector specific model configuration." default=None, title="Detector specific model configuration."
) )
class Config: class Config:

View File

@ -38,16 +38,14 @@ class LocalObjectDetector(ObjectDetector):
detector_config=None, detector_config=None,
labels=None, labels=None,
): ):
model_config = detector_config.model
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
if labels is None: if labels is None:
self.labels = {} self.labels = {}
else: else:
self.labels = load_labels(labels) self.labels = load_labels(labels)
if model_config: if detector_config:
self.input_transform = tensor_transform(model_config.input_tensor) self.input_transform = tensor_transform(detector_config.model.input_tensor)
else: else:
self.input_transform = None self.input_transform = None
@ -96,9 +94,7 @@ def run_detector(
signal.signal(signal.SIGINT, receiveSignal) signal.signal(signal.SIGINT, receiveSignal)
frame_manager = SharedMemoryFrameManager() frame_manager = SharedMemoryFrameManager()
object_detector = LocalObjectDetector( object_detector = LocalObjectDetector(detector_config=detector_config)
detector_config=detector_config
)
outputs = {} outputs = {}
for name in out_events.keys(): for name in out_events.keys():
@ -112,7 +108,8 @@ def run_detector(
except queue.Empty: except queue.Empty:
continue continue
input_frame = frame_manager.get( input_frame = frame_manager.get(
connection_id, (1, detector_config.model.height, detector_config.model.width, 3) connection_id,
(1, detector_config.model.height, detector_config.model.width, 3),
) )
if input_frame is None: if input_frame is None:

View File

@ -7,7 +7,7 @@ from frigate.config import (
FrigateConfig, FrigateConfig,
) )
from frigate.detectors import DetectorTypeEnum from frigate.detectors import DetectorTypeEnum
from frigate.util import load_config_with_no_duplicates from frigate.util import deep_merge, load_config_with_no_duplicates
class TestConfig(unittest.TestCase): class TestConfig(unittest.TestCase):
@ -37,6 +37,49 @@ class TestConfig(unittest.TestCase):
runtime_config = frigate_config.runtime_config runtime_config = frigate_config.runtime_config
assert "cpu" in runtime_config.detectors.keys() assert "cpu" in runtime_config.detectors.keys()
assert runtime_config.detectors["cpu"].type == DetectorTypeEnum.cpu assert runtime_config.detectors["cpu"].type == DetectorTypeEnum.cpu
assert runtime_config.detectors["cpu"].model.width == 320
def test_detector_custom_model_path(self):
config = {
"detectors": {
"cpu": {
"type": "cpu",
"num_threads": 5,
"model": {"path": "/cpu_model.tflite"},
},
"edgetpu": {
"type": "edgetpu",
"device": "usb",
"model": {"path": "/edgetpu_model.tflite", "width": 160},
},
"openvino": {
"type": "openvino",
"device": "usb",
},
},
"model": {"path": "/default.tflite", "width": 512},
}
frigate_config = FrigateConfig(**(deep_merge(config, self.minimal)))
runtime_config = frigate_config.runtime_config
assert "cpu" in runtime_config.detectors.keys()
assert "edgetpu" in runtime_config.detectors.keys()
assert "openvino" in runtime_config.detectors.keys()
assert runtime_config.detectors["cpu"].type == DetectorTypeEnum.cpu
assert runtime_config.detectors["edgetpu"].type == DetectorTypeEnum.edgetpu
assert runtime_config.detectors["openvino"].type == DetectorTypeEnum.openvino
assert runtime_config.model.path == "/default.tflite"
assert runtime_config.detectors["cpu"].model.path == "/cpu_model.tflite"
assert runtime_config.detectors["edgetpu"].model.path == "/edgetpu_model.tflite"
assert runtime_config.detectors["openvino"].model.path == "/default.tflite"
assert runtime_config.model.width == 512
assert runtime_config.detectors["cpu"].model.width == 512
assert runtime_config.detectors["edgetpu"].model.width == 160
assert runtime_config.detectors["openvino"].model.width == 512
def test_invalid_mqtt_config(self): def test_invalid_mqtt_config(self):
config = { config = {

View File

@ -24,7 +24,9 @@ class TestLocalObjectDetector(unittest.TestCase):
"frigate.detectors.api_types", "frigate.detectors.api_types",
{det_type: Mock() for det_type in DetectorTypeEnum}, {det_type: Mock() for det_type in DetectorTypeEnum},
): ):
test_cfg = parse_obj_as(DetectorConfig, {"type": det_type}) test_cfg = parse_obj_as(
DetectorConfig, {"type": det_type, "model": {}}
)
test_cfg.model.path = "/test/modelpath" test_cfg.model.path = "/test/modelpath"
test_obj = frigate.object_detection.LocalObjectDetector( test_obj = frigate.object_detection.LocalObjectDetector(
detector_config=test_cfg detector_config=test_cfg
@ -47,7 +49,7 @@ class TestLocalObjectDetector(unittest.TestCase):
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32]) TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.LocalObjectDetector(
detector_config=CpuDetectorConfig() detector_config=CpuDetectorConfig(model=ModelConfig())
) )
mock_det_api = mock_cputfl.return_value mock_det_api = mock_cputfl.return_value
@ -70,7 +72,7 @@ class TestLocalObjectDetector(unittest.TestCase):
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8) TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32]) TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_cfg = CpuDetectorConfig() test_cfg = CpuDetectorConfig(model=ModelConfig())
test_cfg.model.input_tensor = InputTensorEnum.nchw test_cfg.model.input_tensor = InputTensorEnum.nchw
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.LocalObjectDetector(
@ -119,7 +121,7 @@ class TestLocalObjectDetector(unittest.TestCase):
"label-5", "label-5",
] ]
test_cfg = CpuDetectorConfig() test_cfg = CpuDetectorConfig(model=ModelConfig())
test_cfg.model = ModelConfig() test_cfg.model = ModelConfig()
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.LocalObjectDetector(
detector_config=test_cfg, detector_config=test_cfg,