From 1fc3660050dbd10a1a3dc771d6eaecab2d51ca09 Mon Sep 17 00:00:00 2001 From: Nick Mowen Date: Thu, 16 Nov 2023 06:56:56 -0700 Subject: [PATCH] Use pydantic migration tool --- frigate/config.py | 103 +++++++++++++++++---------- frigate/detectors/detector_config.py | 18 ++--- 2 files changed, 74 insertions(+), 47 deletions(-) diff --git a/frigate/config.py b/frigate/config.py index a7ceaeecb..613361835 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -10,7 +10,14 @@ from typing import Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np -from pydantic import BaseModel, Extra, Field, parse_obj_as, validator +from pydantic import ( + field_validator, + ConfigDict, + BaseModel, + Field, + parse_obj_as, + validator, +) from pydantic.fields import PrivateAttr from frigate.const import ( @@ -65,8 +72,7 @@ DEFAULT_TIME_LAPSE_FFMPEG_ARGS = "-vf setpts=0.04*PTS -r 30" class FrigateBaseModel(BaseModel): - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class LiveModeEnum(str, Enum): @@ -92,7 +98,7 @@ class UIConfig(FrigateBaseModel): live_mode: LiveModeEnum = Field( default=LiveModeEnum.mse, title="Default Live Mode." ) - timezone: Optional[str] = Field(title="Override UI timezone.") + timezone: Optional[str] = Field(default=None, title="Override UI timezone.") use_experimental: bool = Field(default=False, title="Experimental UI") time_format: TimeFormatEnum = Field( default=TimeFormatEnum.browser, title="Override UI time format." @@ -134,13 +140,17 @@ class MqttConfig(FrigateBaseModel): topic_prefix: str = Field(default="frigate", title="MQTT Topic Prefix") client_id: str = Field(default="frigate", title="MQTT Client ID") stats_interval: int = Field(default=60, title="MQTT Camera Stats Interval") - user: Optional[str] = Field(title="MQTT Username") - password: Optional[str] = Field(title="MQTT Password") - tls_ca_certs: Optional[str] = Field(title="MQTT TLS CA Certificates") - tls_client_cert: Optional[str] = Field(title="MQTT TLS Client Certificate") - tls_client_key: Optional[str] = Field(title="MQTT TLS Client Key") - tls_insecure: Optional[bool] = Field(title="MQTT TLS Insecure") + user: Optional[str] = Field(default=None, title="MQTT Username") + password: Optional[str] = Field(default=None, title="MQTT Password") + tls_ca_certs: Optional[str] = Field(default=None, title="MQTT TLS CA Certificates") + tls_client_cert: Optional[str] = Field( + default=None, title="MQTT TLS Client Certificate" + ) + tls_client_key: Optional[str] = Field(default=None, title="MQTT TLS Client Key") + tls_insecure: Optional[bool] = Field(default=None, title="MQTT TLS Insecure") + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @validator("password", pre=True, always=True) def validate_password(cls, v, values): if (v is None) != (values["user"] is None): @@ -188,7 +198,8 @@ class PtzAutotrackConfig(FrigateBaseModel): title="Keep track of original state of autotracking." ) - @validator("movement_weights", pre=True) + @field_validator("movement_weights", mode="before") + @classmethod def validate_weights(cls, v): if v is None: return None @@ -209,8 +220,8 @@ class PtzAutotrackConfig(FrigateBaseModel): class OnvifConfig(FrigateBaseModel): host: str = Field(default="", title="Onvif Host") port: int = Field(default=8000, title="Onvif Port") - user: Optional[str] = Field(title="Onvif Username") - password: Optional[str] = Field(title="Onvif Password") + user: Optional[str] = Field(default=None, title="Onvif Username") + password: Optional[str] = Field(default=None, title="Onvif Password") autotracking: PtzAutotrackConfig = Field( default_factory=PtzAutotrackConfig, title="PTZ auto tracking config.", @@ -239,6 +250,7 @@ class EventsConfig(FrigateBaseModel): title="List of required zones to be entered in order to save the event.", ) objects: Optional[List[str]] = Field( + default=None, title="List of objects to be detected in order to save the event.", ) retain: RetainConfig = Field( @@ -276,7 +288,7 @@ class RecordConfig(FrigateBaseModel): default_factory=RecordExportConfig, title="Recording Export Config" ) enabled_in_config: Optional[bool] = Field( - title="Keep track of original state of recording." + default=None, title="Keep track of original state of recording." ) @@ -330,13 +342,11 @@ class RuntimeMotionConfig(MotionConfig): ret.pop("raw_mask") return ret - class Config: - arbitrary_types_allowed = True - extra = Extra.ignore + model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore") class StationaryMaxFramesConfig(FrigateBaseModel): - default: Optional[int] = Field(title="Default max frames.", ge=1) + default: Optional[int] = Field(default=None, title="Default max frames.", ge=1) objects: Dict[str, int] = Field( default_factory=dict, title="Object specific max frames." ) @@ -344,10 +354,12 @@ class StationaryMaxFramesConfig(FrigateBaseModel): class StationaryConfig(FrigateBaseModel): interval: Optional[int] = Field( + default=None, title="Frame interval for checking stationary objects.", gt=0, ) threshold: Optional[int] = Field( + default=None, title="Number of frames without a position change for an object to be considered stationary", ge=1, ) @@ -358,17 +370,23 @@ class StationaryConfig(FrigateBaseModel): class DetectConfig(FrigateBaseModel): - height: Optional[int] = Field(title="Height of the stream for the detect role.") - width: Optional[int] = Field(title="Width of the stream for the detect role.") + height: Optional[int] = Field( + default=None, title="Height of the stream for the detect role." + ) + width: Optional[int] = Field( + default=None, title="Width of the stream for the detect role." + ) fps: int = Field( default=5, title="Number of frames per second to process through detection." ) enabled: bool = Field(default=True, title="Detection Enabled.") min_initialized: Optional[int] = Field( - title="Minimum number of consecutive hits for an object to be initialized by the tracker." + default=None, + title="Minimum number of consecutive hits for an object to be initialized by the tracker.", ) max_disappeared: Optional[int] = Field( - title="Maximum number of frames the object can dissapear before detection ends." + default=None, + title="Maximum number of frames the object can disappear before detection ends.", ) stationary: StationaryConfig = Field( default_factory=StationaryConfig, @@ -402,6 +420,7 @@ class FilterConfig(FrigateBaseModel): default=0.5, title="Minimum detection confidence for object to be counted." ) mask: Optional[Union[str, List[str]]] = Field( + default=None, title="Detection area polygon mask for this filter configuration.", ) @@ -416,8 +435,8 @@ class AudioFilterConfig(FrigateBaseModel): class RuntimeFilterConfig(FilterConfig): - mask: Optional[np.ndarray] - raw_mask: Optional[Union[str, List[str]]] + mask: Optional[np.ndarray] = None + raw_mask: Optional[Union[str, List[str]]] = None def __init__(self, **config): mask = config.get("mask") @@ -435,9 +454,7 @@ class RuntimeFilterConfig(FilterConfig): ret.pop("raw_mask") return ret - class Config: - arbitrary_types_allowed = True - extra = Extra.ignore + model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore") # this uses the base model because the color is an extra attribute @@ -504,9 +521,11 @@ class AudioConfig(FrigateBaseModel): listen: List[str] = Field( default=DEFAULT_LISTEN_AUDIO, title="Audio to listen for." ) - filters: Optional[Dict[str, AudioFilterConfig]] = Field(title="Audio filters.") + filters: Optional[Dict[str, AudioFilterConfig]] = Field( + default=None, title="Audio filters." + ) enabled_in_config: Optional[bool] = Field( - title="Keep track of original state of audio detection." + default=None, title="Keep track of original state of audio detection." ) num_threads: int = Field(default=2, title="Number of detection threads", ge=1) @@ -625,7 +644,8 @@ class CameraInput(FrigateBaseModel): class CameraFfmpegConfig(FfmpegConfig): inputs: List[CameraInput] = Field(title="Camera inputs.") - @validator("inputs") + @field_validator("inputs") + @classmethod def validate_roles(cls, v): roles = [role for i in v for role in i.roles] roles_set = set(roles) @@ -655,7 +675,7 @@ class SnapshotsConfig(FrigateBaseModel): default_factory=list, title="List of required zones to be entered in order to save a snapshot.", ) - height: Optional[int] = Field(title="Snapshot image height.") + height: Optional[int] = Field(default=None, title="Snapshot image height.") retain: RetainConfig = Field( default_factory=RetainConfig, title="Snapshot retention." ) @@ -692,7 +712,9 @@ class TimestampStyleConfig(FrigateBaseModel): format: str = Field(default=DEFAULT_TIME_FORMAT, title="Timestamp format.") color: ColorConfig = Field(default_factory=ColorConfig, title="Timestamp color.") thickness: int = Field(default=2, title="Timestamp thickness.") - effect: Optional[TimestampEffectEnum] = Field(title="Timestamp effect.") + effect: Optional[TimestampEffectEnum] = Field( + default=None, title="Timestamp effect." + ) class CameraMqttConfig(FrigateBaseModel): @@ -724,8 +746,7 @@ class CameraLiveConfig(FrigateBaseModel): class RestreamConfig(BaseModel): - class Config: - extra = Extra.allow + model_config = ConfigDict(extra="allow") class CameraUiConfig(FrigateBaseModel): @@ -736,7 +757,9 @@ class CameraUiConfig(FrigateBaseModel): class CameraConfig(FrigateBaseModel): - name: Optional[str] = Field(title="Camera name.", regex=REGEX_CAMERA_NAME) + name: Optional[str] = Field( + default=None, title="Camera name.", pattern=REGEX_CAMERA_NAME + ) enabled: bool = Field(default=True, title="Enable camera.") ffmpeg: CameraFfmpegConfig = Field(title="FFmpeg configuration for the camera.") best_image_timeout: int = Field( @@ -744,6 +767,7 @@ class CameraConfig(FrigateBaseModel): title="How long to wait for the image with the highest confidence score.", ) webui_url: Optional[str] = Field( + default=None, title="URL to visit the camera directly from system page", ) zones: Dict[str, ZoneConfig] = Field( @@ -770,7 +794,9 @@ class CameraConfig(FrigateBaseModel): audio: AudioConfig = Field( default_factory=AudioConfig, title="Audio events configuration." ) - motion: Optional[MotionConfig] = Field(title="Motion detection configuration.") + motion: Optional[MotionConfig] = Field( + default=None, title="Motion detection configuration." + ) detect: DetectConfig = Field( default_factory=DetectConfig, title="Object detection configuration." ) @@ -1084,7 +1110,7 @@ class FrigateConfig(FrigateBaseModel): default_factory=AudioConfig, title="Global Audio events configuration." ) motion: Optional[MotionConfig] = Field( - title="Global motion detection configuration." + default=None, title="Global motion detection configuration." ) detect: DetectConfig = Field( default_factory=DetectConfig, title="Global object tracking configuration." @@ -1309,7 +1335,8 @@ class FrigateConfig(FrigateBaseModel): return config - @validator("cameras") + @field_validator("cameras") + @classmethod def ensure_zones_and_cameras_have_different_names(cls, v: Dict[str, CameraConfig]): zones = [zone for camera in v.values() for zone in camera.zones.keys()] for zone in zones: diff --git a/frigate/detectors/detector_config.py b/frigate/detectors/detector_config.py index ca1915449..861e34b2f 100644 --- a/frigate/detectors/detector_config.py +++ b/frigate/detectors/detector_config.py @@ -7,7 +7,7 @@ from typing import Dict, Optional, Tuple import matplotlib.pyplot as plt import requests -from pydantic import BaseModel, Extra, Field +from pydantic import ConfigDict, BaseModel, Field from pydantic.fields import PrivateAttr from frigate.plus import PlusApi @@ -35,8 +35,12 @@ class ModelTypeEnum(str, Enum): class ModelConfig(BaseModel): - path: Optional[str] = Field(title="Custom Object detection model path.") - labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") + path: Optional[str] = Field( + default=None, title="Custom Object detection model path." + ) + labelmap_path: Optional[str] = Field( + default=None, title="Label map for custom object detector." + ) width: int = Field(default=320, title="Object detection model input width.") height: int = Field(default=320, title="Object detection model input height.") labelmap: Dict[int, str] = Field( @@ -132,8 +136,7 @@ class ModelConfig(BaseModel): for key, val in enumerate(enabled_labels): self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3]) - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class BaseDetectorConfig(BaseModel): @@ -142,7 +145,4 @@ class BaseDetectorConfig(BaseModel): model: ModelConfig = Field( default=None, title="Detector specific model configuration." ) - - class Config: - extra = Extra.allow - arbitrary_types_allowed = True + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)