Change the model input tensor config to use an enumeration

This commit is contained in:
Nate Meyer 2022-10-26 00:25:54 -04:00
parent 46df2a6734
commit fc676e7107
5 changed files with 16 additions and 14 deletions

View File

@ -61,7 +61,7 @@ Custom models may also require different input tensor formats. The colorspace co
```yaml
model:
input_tensor: ["B", "H", "W", "C"]
input_tensor: "nhwc"
```
The labelmap can be customized to your needs. A common reason to do this is to combine multiple object types that are easily confused when you don't need to be as granular such as car/truck. By default, truck is renamed to car because they are often confused. You cannot add new object types, but you can change the names of existing objects in the model.
@ -85,6 +85,7 @@ Note that if you rename objects in the labelmap, you will also need to update yo
Included with Frigate is a build of ffmpeg that works for the vast majority of users. However, there exists some hardware setups which have incompatibilities with the included build. In this case, a docker volume mapping can be used to overwrite the included ffmpeg build with an ffmpeg build that works for your specific hardware setup.
To do this:
1. Download your ffmpeg build and uncompress to a folder on the host (let's use `/home/appdata/frigate/custom-ffmpeg` for this example).
2. Update your docker-compose or docker CLI to include `'/home/appdata/frigate/custom-ffmpeg':'/usr/lib/btbn-ffmpeg':'ro'` in the volume mappings.
3. Restart frigate and the custom version will be used if the mapping was done correctly.

View File

@ -101,7 +101,7 @@ model:
# Valid values are rgb, bgr, or yuv. (default: shown below)
input_pixel_format: rgb
# Optional: Object detection model input tensor format (default: shown below)
input_tensor: ["B", "H", "W", "C"]
input_tensor: "nhwc"
# Optional: Label name modifications. These are merged into the standard labelmap.
labelmap:
2: vehicle

View File

@ -693,6 +693,11 @@ class PixelFormatEnum(str, Enum):
yuv = "yuv"
class InputTensorEnum(str, Enum):
nchw = "nchw"
nhwc = "nhwc"
class ModelConfig(FrigateBaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
@ -701,8 +706,8 @@ class ModelConfig(FrigateBaseModel):
labelmap: Dict[int, str] = Field(
default_factory=dict, title="Labelmap customization."
)
input_tensor: List[str] = Field(
default=["B", "H", "W", "C"], title="Model Input Tensor Shape"
input_tensor: InputTensorEnum = Field(
default=InputTensorEnum.nhwc, title="Model Input Tensor Shape"
)
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"

View File

@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
import numpy as np
from setproctitle import setproctitle
from frigate.config import DetectorTypeEnum
from frigate.config import DetectorTypeEnum, InputTensorEnum
from frigate.detectors.edgetpu_tfl import EdgeTpuTfl
from frigate.detectors.cpu_tfl import CpuTfl
@ -27,14 +27,10 @@ class ObjectDetector(ABC):
def tensor_transform(desired_shape):
# Currently this function only supports BHWC permutations
if desired_shape == ["B", "H", "W", "C"]:
if desired_shape == InputTensorEnum.nhwc:
return None
else:
transform = [0] * 4
transform[desired_shape.index("H")] = 1
transform[desired_shape.index("W")] = 2
transform[desired_shape.index("C")] = 3
return tuple(transform)
elif desired_shape == InputTensorEnum.nchw:
return (0, 3, 1, 2)
class LocalObjectDetector(ObjectDetector):

View File

@ -2,7 +2,7 @@ import unittest
from unittest.mock import patch
import numpy as np
from frigate.config import DetectorTypeEnum, ModelConfig
from frigate.config import DetectorTypeEnum, InputTensorEnum, ModelConfig
import frigate.object_detection
@ -66,7 +66,7 @@ class TestLocalObjectDetector(unittest.TestCase):
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_cfg = ModelConfig()
test_cfg.input_tensor = ["B", "C", "H", "W"]
test_cfg.input_tensor = InputTensorEnum.nchw
test_obj_detect = frigate.object_detection.LocalObjectDetector(
det_device=DetectorTypeEnum.cpu, model_config=test_cfg