Use pydantic migration tool

This commit is contained in:
Nick Mowen 2023-11-16 06:56:56 -07:00
parent 7b520e8a9d
commit 1fc3660050
2 changed files with 74 additions and 47 deletions

View File

@ -10,7 +10,14 @@ from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from pydantic import BaseModel, Extra, Field, parse_obj_as, validator from pydantic import (
field_validator,
ConfigDict,
BaseModel,
Field,
parse_obj_as,
validator,
)
from pydantic.fields import PrivateAttr from pydantic.fields import PrivateAttr
from frigate.const import ( from frigate.const import (
@ -65,8 +72,7 @@ DEFAULT_TIME_LAPSE_FFMPEG_ARGS = "-vf setpts=0.04*PTS -r 30"
class FrigateBaseModel(BaseModel): class FrigateBaseModel(BaseModel):
class Config: model_config = ConfigDict(extra="forbid")
extra = Extra.forbid
class LiveModeEnum(str, Enum): class LiveModeEnum(str, Enum):
@ -92,7 +98,7 @@ class UIConfig(FrigateBaseModel):
live_mode: LiveModeEnum = Field( live_mode: LiveModeEnum = Field(
default=LiveModeEnum.mse, title="Default Live Mode." default=LiveModeEnum.mse, title="Default Live Mode."
) )
timezone: Optional[str] = Field(title="Override UI timezone.") timezone: Optional[str] = Field(default=None, title="Override UI timezone.")
use_experimental: bool = Field(default=False, title="Experimental UI") use_experimental: bool = Field(default=False, title="Experimental UI")
time_format: TimeFormatEnum = Field( time_format: TimeFormatEnum = Field(
default=TimeFormatEnum.browser, title="Override UI time format." default=TimeFormatEnum.browser, title="Override UI time format."
@ -134,13 +140,17 @@ class MqttConfig(FrigateBaseModel):
topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix") topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix")
client_id: str = Field(default="frigate", title="MQTT Client ID") client_id: str = Field(default="frigate", title="MQTT Client ID")
stats_interval: int = Field(default=60, title="MQTT Camera Stats Interval") stats_interval: int = Field(default=60, title="MQTT Camera Stats Interval")
user: Optional[str] = Field(title="MQTT Username") user: Optional[str] = Field(default=None, title="MQTT Username")
password: Optional[str] = Field(title="MQTT Password") password: Optional[str] = Field(default=None, title="MQTT Password")
tls_ca_certs: Optional[str] = Field(title="MQTT TLS CA Certificates") tls_ca_certs: Optional[str] = Field(default=None, title="MQTT TLS CA Certificates")
tls_client_cert: Optional[str] = Field(title="MQTT TLS Client Certificate") tls_client_cert: Optional[str] = Field(
tls_client_key: Optional[str] = Field(title="MQTT TLS Client Key") default=None, title="MQTT TLS Client Certificate"
tls_insecure: Optional[bool] = Field(title="MQTT TLS Insecure") )
tls_client_key: Optional[str] = Field(default=None, title="MQTT TLS Client Key")
tls_insecure: Optional[bool] = Field(default=None, title="MQTT TLS Insecure")
# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("password", pre=True, always=True) @validator("password", pre=True, always=True)
def validate_password(cls, v, values): def validate_password(cls, v, values):
if (v is None) != (values["user"] is None): if (v is None) != (values["user"] is None):
@ -188,7 +198,8 @@ class PtzAutotrackConfig(FrigateBaseModel):
title="Keep track of original state of autotracking." title="Keep track of original state of autotracking."
) )
@validator("movement_weights", pre=True) @field_validator("movement_weights", mode="before")
@classmethod
def validate_weights(cls, v): def validate_weights(cls, v):
if v is None: if v is None:
return None return None
@ -209,8 +220,8 @@ class PtzAutotrackConfig(FrigateBaseModel):
class OnvifConfig(FrigateBaseModel): class OnvifConfig(FrigateBaseModel):
host: str = Field(default="", title="Onvif Host") host: str = Field(default="", title="Onvif Host")
port: int = Field(default=8000, title="Onvif Port") port: int = Field(default=8000, title="Onvif Port")
user: Optional[str] = Field(title="Onvif Username") user: Optional[str] = Field(default=None, title="Onvif Username")
password: Optional[str] = Field(title="Onvif Password") password: Optional[str] = Field(default=None, title="Onvif Password")
autotracking: PtzAutotrackConfig = Field( autotracking: PtzAutotrackConfig = Field(
default_factory=PtzAutotrackConfig, default_factory=PtzAutotrackConfig,
title="PTZ auto tracking config.", title="PTZ auto tracking config.",
@ -239,6 +250,7 @@ class EventsConfig(FrigateBaseModel):
title="List of required zones to be entered in order to save the event.", title="List of required zones to be entered in order to save the event.",
) )
objects: Optional[List[str]] = Field( objects: Optional[List[str]] = Field(
default=None,
title="List of objects to be detected in order to save the event.", title="List of objects to be detected in order to save the event.",
) )
retain: RetainConfig = Field( retain: RetainConfig = Field(
@ -276,7 +288,7 @@ class RecordConfig(FrigateBaseModel):
default_factory=RecordExportConfig, title="Recording Export Config" default_factory=RecordExportConfig, title="Recording Export Config"
) )
enabled_in_config: Optional[bool] = Field( enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of recording." default=None, title="Keep track of original state of recording."
) )
@ -330,13 +342,11 @@ class RuntimeMotionConfig(MotionConfig):
ret.pop("raw_mask") ret.pop("raw_mask")
return ret return ret
class Config: model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
arbitrary_types_allowed = True
extra = Extra.ignore
class StationaryMaxFramesConfig(FrigateBaseModel): class StationaryMaxFramesConfig(FrigateBaseModel):
default: Optional[int] = Field(title="Default max frames.", ge=1) default: Optional[int] = Field(default=None, title="Default max frames.", ge=1)
objects: Dict[str, int] = Field( objects: Dict[str, int] = Field(
default_factory=dict, title="Object specific max frames." default_factory=dict, title="Object specific max frames."
) )
@ -344,10 +354,12 @@ class StationaryMaxFramesConfig(FrigateBaseModel):
class StationaryConfig(FrigateBaseModel): class StationaryConfig(FrigateBaseModel):
interval: Optional[int] = Field( interval: Optional[int] = Field(
default=None,
title="Frame interval for checking stationary objects.", title="Frame interval for checking stationary objects.",
gt=0, gt=0,
) )
threshold: Optional[int] = Field( threshold: Optional[int] = Field(
default=None,
title="Number of frames without a position change for an object to be considered stationary", title="Number of frames without a position change for an object to be considered stationary",
ge=1, ge=1,
) )
@ -358,17 +370,23 @@ class StationaryConfig(FrigateBaseModel):
class DetectConfig(FrigateBaseModel): class DetectConfig(FrigateBaseModel):
height: Optional[int] = Field(title="Height of the stream for the detect role.") height: Optional[int] = Field(
width: Optional[int] = Field(title="Width of the stream for the detect role.") default=None, title="Height of the stream for the detect role."
)
width: Optional[int] = Field(
default=None, title="Width of the stream for the detect role."
)
fps: int = Field( fps: int = Field(
default=5, title="Number of frames per second to process through detection." default=5, title="Number of frames per second to process through detection."
) )
enabled: bool = Field(default=True, title="Detection Enabled.") enabled: bool = Field(default=True, title="Detection Enabled.")
min_initialized: Optional[int] = Field( min_initialized: Optional[int] = Field(
title="Minimum number of consecutive hits for an object to be initialized by the tracker." default=None,
title="Minimum number of consecutive hits for an object to be initialized by the tracker.",
) )
max_disappeared: Optional[int] = Field( max_disappeared: Optional[int] = Field(
title="Maximum number of frames the object can dissapear before detection ends." default=None,
title="Maximum number of frames the object can disappear before detection ends.",
) )
stationary: StationaryConfig = Field( stationary: StationaryConfig = Field(
default_factory=StationaryConfig, default_factory=StationaryConfig,
@ -402,6 +420,7 @@ class FilterConfig(FrigateBaseModel):
default=0.5, title="Minimum detection confidence for object to be counted." default=0.5, title="Minimum detection confidence for object to be counted."
) )
mask: Optional[Union[str, List[str]]] = Field( mask: Optional[Union[str, List[str]]] = Field(
default=None,
title="Detection area polygon mask for this filter configuration.", title="Detection area polygon mask for this filter configuration.",
) )
@ -416,8 +435,8 @@ class AudioFilterConfig(FrigateBaseModel):
class RuntimeFilterConfig(FilterConfig): class RuntimeFilterConfig(FilterConfig):
mask: Optional[np.ndarray] mask: Optional[np.ndarray] = None
raw_mask: Optional[Union[str, List[str]]] raw_mask: Optional[Union[str, List[str]]] = None
def __init__(self, **config): def __init__(self, **config):
mask = config.get("mask") mask = config.get("mask")
@ -435,9 +454,7 @@ class RuntimeFilterConfig(FilterConfig):
ret.pop("raw_mask") ret.pop("raw_mask")
return ret return ret
class Config: model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
arbitrary_types_allowed = True
extra = Extra.ignore
# this uses the base model because the color is an extra attribute # this uses the base model because the color is an extra attribute
@ -504,9 +521,11 @@ class AudioConfig(FrigateBaseModel):
listen: List[str] = Field( listen: List[str] = Field(
default=DEFAULT_LISTEN_AUDIO, title="Audio to listen for." default=DEFAULT_LISTEN_AUDIO, title="Audio to listen for."
) )
filters: Optional[Dict[str, AudioFilterConfig]] = Field(title="Audio filters.") filters: Optional[Dict[str, AudioFilterConfig]] = Field(
default=None, title="Audio filters."
)
enabled_in_config: Optional[bool] = Field( enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of audio detection." default=None, title="Keep track of original state of audio detection."
) )
num_threads: int = Field(default=2, title="Number of detection threads", ge=1) num_threads: int = Field(default=2, title="Number of detection threads", ge=1)
@ -625,7 +644,8 @@ class CameraInput(FrigateBaseModel):
class CameraFfmpegConfig(FfmpegConfig): class CameraFfmpegConfig(FfmpegConfig):
inputs: List[CameraInput] = Field(title="Camera inputs.") inputs: List[CameraInput] = Field(title="Camera inputs.")
@validator("inputs") @field_validator("inputs")
@classmethod
def validate_roles(cls, v): def validate_roles(cls, v):
roles = [role for i in v for role in i.roles] roles = [role for i in v for role in i.roles]
roles_set = set(roles) roles_set = set(roles)
@ -655,7 +675,7 @@ class SnapshotsConfig(FrigateBaseModel):
default_factory=list, default_factory=list,
title="List of required zones to be entered in order to save a snapshot.", title="List of required zones to be entered in order to save a snapshot.",
) )
height: Optional[int] = Field(title="Snapshot image height.") height: Optional[int] = Field(default=None, title="Snapshot image height.")
retain: RetainConfig = Field( retain: RetainConfig = Field(
default_factory=RetainConfig, title="Snapshot retention." default_factory=RetainConfig, title="Snapshot retention."
) )
@ -692,7 +712,9 @@ class TimestampStyleConfig(FrigateBaseModel):
format: str = Field(default=DEFAULT_TIME_FORMAT, title="Timestamp format.") format: str = Field(default=DEFAULT_TIME_FORMAT, title="Timestamp format.")
color: ColorConfig = Field(default_factory=ColorConfig, title="Timestamp color.") color: ColorConfig = Field(default_factory=ColorConfig, title="Timestamp color.")
thickness: int = Field(default=2, title="Timestamp thickness.") thickness: int = Field(default=2, title="Timestamp thickness.")
effect: Optional[TimestampEffectEnum] = Field(title="Timestamp effect.") effect: Optional[TimestampEffectEnum] = Field(
default=None, title="Timestamp effect."
)
class CameraMqttConfig(FrigateBaseModel): class CameraMqttConfig(FrigateBaseModel):
@ -724,8 +746,7 @@ class CameraLiveConfig(FrigateBaseModel):
class RestreamConfig(BaseModel): class RestreamConfig(BaseModel):
class Config: model_config = ConfigDict(extra="allow")
extra = Extra.allow
class CameraUiConfig(FrigateBaseModel): class CameraUiConfig(FrigateBaseModel):
@ -736,7 +757,9 @@ class CameraUiConfig(FrigateBaseModel):
class CameraConfig(FrigateBaseModel): class CameraConfig(FrigateBaseModel):
name: Optional[str] = Field(title="Camera name.", regex=REGEX_CAMERA_NAME) name: Optional[str] = Field(
default=None, title="Camera name.", pattern=REGEX_CAMERA_NAME
)
enabled: bool = Field(default=True, title="Enable camera.") enabled: bool = Field(default=True, title="Enable camera.")
ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.") ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.")
best_image_timeout: int = Field( best_image_timeout: int = Field(
@ -744,6 +767,7 @@ class CameraConfig(FrigateBaseModel):
title="How long to wait for the image with the highest confidence score.", title="How long to wait for the image with the highest confidence score.",
) )
webui_url: Optional[str] = Field( webui_url: Optional[str] = Field(
default=None,
title="URL to visit the camera directly from system page", title="URL to visit the camera directly from system page",
) )
zones: Dict[str, ZoneConfig] = Field( zones: Dict[str, ZoneConfig] = Field(
@ -770,7 +794,9 @@ class CameraConfig(FrigateBaseModel):
audio: AudioConfig = Field( audio: AudioConfig = Field(
default_factory=AudioConfig, title="Audio events configuration." default_factory=AudioConfig, title="Audio events configuration."
) )
motion: Optional[MotionConfig] = Field(title="Motion detection configuration.") motion: Optional[MotionConfig] = Field(
default=None, title="Motion detection configuration."
)
detect: DetectConfig = Field( detect: DetectConfig = Field(
default_factory=DetectConfig, title="Object detection configuration." default_factory=DetectConfig, title="Object detection configuration."
) )
@ -1084,7 +1110,7 @@ class FrigateConfig(FrigateBaseModel):
default_factory=AudioConfig, title="Global Audio events configuration." default_factory=AudioConfig, title="Global Audio events configuration."
) )
motion: Optional[MotionConfig] = Field( motion: Optional[MotionConfig] = Field(
title="Global motion detection configuration." default=None, title="Global motion detection configuration."
) )
detect: DetectConfig = Field( detect: DetectConfig = Field(
default_factory=DetectConfig, title="Global object tracking configuration." default_factory=DetectConfig, title="Global object tracking configuration."
@ -1309,7 +1335,8 @@ class FrigateConfig(FrigateBaseModel):
return config return config
@validator("cameras") @field_validator("cameras")
@classmethod
def ensure_zones_and_cameras_have_different_names(cls, v: Dict[str, CameraConfig]): def ensure_zones_and_cameras_have_different_names(cls, v: Dict[str, CameraConfig]):
zones = [zone for camera in v.values() for zone in camera.zones.keys()] zones = [zone for camera in v.values() for zone in camera.zones.keys()]
for zone in zones: for zone in zones:

View File

@ -7,7 +7,7 @@ from typing import Dict, Optional, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import requests import requests
from pydantic import BaseModel, Extra, Field from pydantic import ConfigDict, BaseModel, Field
from pydantic.fields import PrivateAttr from pydantic.fields import PrivateAttr
from frigate.plus import PlusApi from frigate.plus import PlusApi
@ -35,8 +35,12 @@ class ModelTypeEnum(str, Enum):
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.") path: Optional[str] = Field(
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") default=None, title="Custom Object detection model path."
)
labelmap_path: Optional[str] = Field(
default=None, title="Label map for custom object detector."
)
width: int = Field(default=320, title="Object detection model input width.") width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.") height: int = Field(default=320, title="Object detection model input height.")
labelmap: Dict[int, str] = Field( labelmap: Dict[int, str] = Field(
@ -132,8 +136,7 @@ class ModelConfig(BaseModel):
for key, val in enumerate(enabled_labels): for key, val in enumerate(enabled_labels):
self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3]) self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
class Config: model_config = ConfigDict(extra="forbid")
extra = Extra.forbid
class BaseDetectorConfig(BaseModel): class BaseDetectorConfig(BaseModel):
@ -142,7 +145,4 @@ class BaseDetectorConfig(BaseModel):
model: ModelConfig = Field( model: ModelConfig = Field(
default=None, title="Detector specific model configuration." default=None, title="Detector specific model configuration."
) )
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
class Config:
extra = Extra.allow
arbitrary_types_allowed = True