mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-22 08:08:22 +03:00
Improve handling of classes starting with digits
This commit is contained in:
parent
f66c4f53e0
commit
d9b420929e
@ -97,7 +97,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
self.tensor_input_details = self.interpreter.get_input_details()
|
self.tensor_input_details = self.interpreter.get_input_details()
|
||||||
self.tensor_output_details = self.interpreter.get_output_details()
|
self.tensor_output_details = self.interpreter.get_output_details()
|
||||||
self.labelmap = load_labels(labelmap_path, prefill=0)
|
self.labelmap = load_labels(labelmap_path, prefill=0, indexed=False)
|
||||||
self.classifications_per_second.start()
|
self.classifications_per_second.start()
|
||||||
|
|
||||||
def __update_metrics(self, duration: float) -> None:
|
def __update_metrics(self, duration: float) -> None:
|
||||||
@ -398,7 +398,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
self.tensor_input_details = self.interpreter.get_input_details()
|
self.tensor_input_details = self.interpreter.get_input_details()
|
||||||
self.tensor_output_details = self.interpreter.get_output_details()
|
self.tensor_output_details = self.interpreter.get_output_details()
|
||||||
self.labelmap = load_labels(labelmap_path, prefill=0)
|
self.labelmap = load_labels(labelmap_path, prefill=0, indexed=False)
|
||||||
|
|
||||||
def __update_metrics(self, duration: float) -> None:
|
def __update_metrics(self, duration: float) -> None:
|
||||||
self.classifications_per_second.update()
|
self.classifications_per_second.update()
|
||||||
|
|||||||
@ -129,7 +129,7 @@ def get_ffmpeg_arg_list(arg: Any) -> list:
|
|||||||
return arg if isinstance(arg, list) else shlex.split(arg)
|
return arg if isinstance(arg, list) else shlex.split(arg)
|
||||||
|
|
||||||
|
|
||||||
def load_labels(path: Optional[str], encoding="utf-8", prefill=91):
|
def load_labels(path: Optional[str], encoding="utf-8", prefill=91, indexed: bool | None = None):
|
||||||
"""Loads labels from file (with or without index numbers).
|
"""Loads labels from file (with or without index numbers).
|
||||||
Args:
|
Args:
|
||||||
path: path to label file.
|
path: path to label file.
|
||||||
@ -146,11 +146,12 @@ def load_labels(path: Optional[str], encoding="utf-8", prefill=91):
|
|||||||
if not lines:
|
if not lines:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
if lines[0].split(" ", maxsplit=1)[0].isdigit():
|
if indexed != False and lines[0].split(" ", maxsplit=1)[0].isdigit():
|
||||||
pairs = [line.split(" ", maxsplit=1) for line in lines]
|
pairs = [line.split(" ", maxsplit=1) for line in lines]
|
||||||
labels.update({int(index): label.strip() for index, label in pairs})
|
labels.update({int(index): label.strip() for index, label in pairs})
|
||||||
else:
|
else:
|
||||||
labels.update({index: line.strip() for index, line in enumerate(lines)})
|
labels.update({index: line.strip() for index, line in enumerate(lines)})
|
||||||
|
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user