simplify config

This commit is contained in:
Dennis George 2022-12-10 07:50:49 -06:00
parent 61c3d2adc0
commit b7fd53ecc7
4 changed files with 15 additions and 46 deletions

View File

@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import yaml import yaml
from pydantic import BaseModel, Extra, Field, validator, parse_obj_as from pydantic import BaseModel, Extra, Field, validator
from pydantic.fields import PrivateAttr from pydantic.fields import PrivateAttr
from frigate.const import ( from frigate.const import (
@ -972,7 +972,7 @@ class FrigateConfig(FrigateBaseModel):
config.cameras[name] = camera_config config.cameras[name] = camera_config
for key, detector in config.detectors.items(): for key, detector in config.detectors.items():
detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector) detector_config: DetectorConfig = DetectorConfig.parse_obj(detector)
if detector_config.model is None: if detector_config.model is None:
detector_config.model = config.model detector_config.model = config.model
else: else:

View File

@ -57,41 +57,17 @@ class ModelConfig(BaseModel):
extra = Extra.forbid extra = Extra.forbid
class BaseDetectorConfig(BaseModel): class DetectorConfig(BaseModel):
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type") type: str = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
model: ModelConfig = Field( model: ModelConfig = Field(
default=None, title="Detector specific model configuration." default=None, title="Detector specific model configuration."
) )
num_threads: Optional[int] = Field(default=3, title="Number of detection threads")
device: Optional[str] = Field(default="usb", title="Device Type")
class Config: class Config:
extra = Extra.forbid extra = Extra.allow
arbitrary_types_allowed = True arbitrary_types_allowed = True
class CpuDetectorConfig(BaseDetectorConfig):
type: Literal[DetectorTypeEnum.cpu] = Field(
default=DetectorTypeEnum.cpu, title="Detector Type"
)
num_threads: int = Field(default=3, title="Number of detection threads")
class EdgeTpuDetectorConfig(BaseDetectorConfig):
type: Literal[DetectorTypeEnum.edgetpu] = Field(
default=DetectorTypeEnum.edgetpu, title="Detector Type"
)
device: str = Field(default="usb", title="Device Type")
class OpenVinoDetectorConfig(BaseDetectorConfig):
type: Literal[DetectorTypeEnum.openvino] = Field(
default=DetectorTypeEnum.openvino, title="Detector Type"
)
device: str = Field(default="usb", title="Device Type")
DetectorConfig = Annotated[
Union[CpuDetectorConfig, EdgeTpuDetectorConfig, OpenVinoDetectorConfig],
Field(discriminator="type"),
]
DEFAULT_DETECTORS = parse_obj_as(Dict[str, DetectorConfig], {"cpu": {"type": "cpu"}}) DEFAULT_DETECTORS = parse_obj_as(Dict[str, DetectorConfig], {"cpu": {"type": "cpu"}})

View File

@ -44,12 +44,10 @@ class TestConfig(unittest.TestCase):
"detectors": { "detectors": {
"cpu": { "cpu": {
"type": "cpu", "type": "cpu",
"num_threads": 5,
"model": {"path": "/cpu_model.tflite"}, "model": {"path": "/cpu_model.tflite"},
}, },
"edgetpu": { "edgetpu": {
"type": "edgetpu", "type": "edgetpu",
"device": "usb",
"model": {"path": "/edgetpu_model.tflite", "width": 160}, "model": {"path": "/edgetpu_model.tflite", "width": 160},
}, },
"openvino": { "openvino": {
@ -71,6 +69,9 @@ class TestConfig(unittest.TestCase):
assert runtime_config.detectors["edgetpu"].type == DetectorTypeEnum.edgetpu assert runtime_config.detectors["edgetpu"].type == DetectorTypeEnum.edgetpu
assert runtime_config.detectors["openvino"].type == DetectorTypeEnum.openvino assert runtime_config.detectors["openvino"].type == DetectorTypeEnum.openvino
assert runtime_config.detectors["cpu"].num_threads == 3
assert runtime_config.detectors["edgetpu"].device == "usb"
assert runtime_config.model.path == "/default.tflite" assert runtime_config.model.path == "/default.tflite"
assert runtime_config.detectors["cpu"].model.path == "/cpu_model.tflite" assert runtime_config.detectors["cpu"].model.path == "/cpu_model.tflite"
assert runtime_config.detectors["edgetpu"].model.path == "/edgetpu_model.tflite" assert runtime_config.detectors["edgetpu"].model.path == "/edgetpu_model.tflite"

View File

@ -2,16 +2,10 @@ import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import numpy as np import numpy as np
from pydantic import parse_obj_as
from frigate.enums import InputTensorEnum from frigate.enums import InputTensorEnum
from frigate.detectors import DetectorTypeEnum from frigate.detectors import DetectorTypeEnum
from frigate.detectors.config import ( from frigate.detectors.config import DetectorConfig, ModelConfig
CpuDetectorConfig,
EdgeTpuDetectorConfig,
DetectorConfig,
ModelConfig,
)
import frigate.detectors as detectors import frigate.detectors as detectors
import frigate.object_detection import frigate.object_detection
@ -24,9 +18,7 @@ class TestLocalObjectDetector(unittest.TestCase):
"frigate.detectors.api_types", "frigate.detectors.api_types",
{det_type: Mock() for det_type in DetectorTypeEnum}, {det_type: Mock() for det_type in DetectorTypeEnum},
): ):
test_cfg = parse_obj_as( test_cfg = DetectorConfig.parse_obj({"type": det_type, "model": {}})
DetectorConfig, {"type": det_type, "model": {}}
)
test_cfg.model.path = "/test/modelpath" test_cfg.model.path = "/test/modelpath"
test_obj = frigate.object_detection.LocalObjectDetector( test_obj = frigate.object_detection.LocalObjectDetector(
detector_config=test_cfg detector_config=test_cfg
@ -49,7 +41,7 @@ class TestLocalObjectDetector(unittest.TestCase):
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32]) TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.LocalObjectDetector(
detector_config=CpuDetectorConfig(model=ModelConfig()) detector_config=DetectorConfig(type="cpu", model=ModelConfig())
) )
mock_det_api = mock_cputfl.return_value mock_det_api = mock_cputfl.return_value
@ -72,7 +64,7 @@ class TestLocalObjectDetector(unittest.TestCase):
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8) TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32]) TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_cfg = CpuDetectorConfig(model=ModelConfig()) test_cfg = DetectorConfig(type="cpu", model=ModelConfig())
test_cfg.model.input_tensor = InputTensorEnum.nchw test_cfg.model.input_tensor = InputTensorEnum.nchw
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.LocalObjectDetector(
@ -121,7 +113,7 @@ class TestLocalObjectDetector(unittest.TestCase):
"label-5", "label-5",
] ]
test_cfg = CpuDetectorConfig(model=ModelConfig()) test_cfg = DetectorConfig(type="cpu", model=ModelConfig())
test_cfg.model = ModelConfig() test_cfg.model = ModelConfig()
test_obj_detect = frigate.object_detection.LocalObjectDetector( test_obj_detect = frigate.object_detection.LocalObjectDetector(
detector_config=test_cfg, detector_config=test_cfg,