Catch case where field is defined multiple times and add test

This commit is contained in:
Nick Mowen 2022-11-07 17:43:40 -07:00
parent c26623a73f
commit 6e02261ca9
3 changed files with 51 additions and 1 deletions

View File

@ -22,6 +22,7 @@ from frigate.util import (
create_mask, create_mask,
deep_merge, deep_merge,
escape_special_characters, escape_special_characters,
load_config_with_no_duplicates,
load_labels, load_labels,
) )
@ -1011,7 +1012,7 @@ class FrigateConfig(FrigateBaseModel):
raw_config = f.read() raw_config = f.read()
if config_file.endswith(YAML_EXT): if config_file.endswith(YAML_EXT):
config = yaml.safe_load(raw_config) config = load_config_with_no_duplicates(raw_config)
elif config_file.endswith(".json"): elif config_file.endswith(".json"):
config = json.loads(raw_config) config = json.loads(raw_config)

View File

@ -1,11 +1,13 @@
import unittest import unittest
import numpy as np import numpy as np
from pydantic import ValidationError from pydantic import ValidationError
from frigate.config import ( from frigate.config import (
BirdseyeModeEnum, BirdseyeModeEnum,
FrigateConfig, FrigateConfig,
DetectorTypeEnum, DetectorTypeEnum,
) )
from frigate.util import load_config_with_no_duplicates
class TestConfig(unittest.TestCase): class TestConfig(unittest.TestCase):
@ -1452,6 +1454,23 @@ class TestConfig(unittest.TestCase):
self.assertRaises(ValueError, lambda: frigate_config.runtime_config.cameras) self.assertRaises(ValueError, lambda: frigate_config.runtime_config.cameras)
def test_fails_duplicate_keys(self):
raw_config = """
cameras:
test:
ffmpeg:
inputs:
- one
- two
inputs:
- three
- four
"""
self.assertRaises(
ValueError, lambda: load_config_with_no_duplicates(raw_config)
)
def test_object_filter_ratios_work(self): def test_object_filter_ratios_work(self):
config = { config = {
"mqtt": {"host": "mqtt"}, "mqtt": {"host": "mqtt"},

View File

@ -5,7 +5,10 @@ import re
import signal import signal
import traceback import traceback
import urllib.parse import urllib.parse
import yaml
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Mapping from collections.abc import Mapping
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import AnyStr from typing import AnyStr
@ -44,6 +47,33 @@ def deep_merge(dct1: dict, dct2: dict, override=False, merge_lists=False) -> dic
return merged return merged
def load_config_with_no_duplicates(raw_config) -> dict:
"""Get config ensuring duplicate keys are not allowed."""
# https://stackoverflow.com/a/71751051
class PreserveDuplicatesLoader(yaml.loader.Loader):
pass
def map_constructor(loader, node, deep=False):
keys = [loader.construct_object(node, deep=deep) for node, _ in node.value]
vals = [loader.construct_object(node, deep=deep) for _, node in node.value]
key_count = Counter(keys)
data = {}
for key, val in zip(keys, vals):
if key_count[key] > 1:
raise ValueError(
f"Config input {key} is defined multiple times for the same field, this is not allowed."
)
else:
data[key] = val
return data
PreserveDuplicatesLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, map_constructor
)
return yaml.load(raw_config, PreserveDuplicatesLoader)
def draw_timestamp( def draw_timestamp(
frame, frame,
timestamp, timestamp,