Catch case where field is defined multiple times and add test

This commit is contained in:
Nick Mowen 2022-11-07 17:43:40 -07:00
parent c26623a73f
commit 6e02261ca9
3 changed files with 51 additions and 1 deletions

View File

@ -22,6 +22,7 @@ from frigate.util import (
create_mask,
deep_merge,
escape_special_characters,
load_config_with_no_duplicates,
load_labels,
)
@ -1011,7 +1012,7 @@ class FrigateConfig(FrigateBaseModel):
raw_config = f.read()
if config_file.endswith(YAML_EXT):
config = yaml.safe_load(raw_config)
config = load_config_with_no_duplicates(raw_config)
elif config_file.endswith(".json"):
config = json.loads(raw_config)

View File

@ -1,11 +1,13 @@
import unittest
import numpy as np
from pydantic import ValidationError
from frigate.config import (
BirdseyeModeEnum,
FrigateConfig,
DetectorTypeEnum,
)
from frigate.util import load_config_with_no_duplicates
class TestConfig(unittest.TestCase):
@ -1452,6 +1454,23 @@ class TestConfig(unittest.TestCase):
self.assertRaises(ValueError, lambda: frigate_config.runtime_config.cameras)
def test_fails_duplicate_keys(self):
raw_config = """
cameras:
test:
ffmpeg:
inputs:
- one
- two
inputs:
- three
- four
"""
self.assertRaises(
ValueError, lambda: load_config_with_no_duplicates(raw_config)
)
def test_object_filter_ratios_work(self):
config = {
"mqtt": {"host": "mqtt"},

View File

@ -5,7 +5,10 @@ import re
import signal
import traceback
import urllib.parse
import yaml
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Mapping
from multiprocessing import shared_memory
from typing import AnyStr
@ -44,6 +47,33 @@ def deep_merge(dct1: dict, dct2: dict, override=False, merge_lists=False) -> dic
return merged
def load_config_with_no_duplicates(raw_config) -> dict:
"""Get config ensuring duplicate keys are not allowed."""
# https://stackoverflow.com/a/71751051
class PreserveDuplicatesLoader(yaml.loader.Loader):
pass
def map_constructor(loader, node, deep=False):
keys = [loader.construct_object(node, deep=deep) for node, _ in node.value]
vals = [loader.construct_object(node, deep=deep) for _, node in node.value]
key_count = Counter(keys)
data = {}
for key, val in zip(keys, vals):
if key_count[key] > 1:
raise ValueError(
f"Config input {key} is defined multiple times for the same field, this is not allowed."
)
else:
data[key] = val
return data
PreserveDuplicatesLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, map_constructor
)
return yaml.load(raw_config, PreserveDuplicatesLoader)
def draw_timestamp(
frame,
timestamp,