Improve handling of classes starting with digits

This commit is contained in:
Nicolas Mowen 2026-01-30 14:39:45 -07:00
parent f66c4f53e0
commit d9b420929e
2 changed files with 5 additions and 4 deletions

View File

@ -97,7 +97,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(labelmap_path, prefill=0)
self.labelmap = load_labels(labelmap_path, prefill=0, indexed=False)
self.classifications_per_second.start()
def __update_metrics(self, duration: float) -> None:
@ -398,7 +398,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(labelmap_path, prefill=0)
self.labelmap = load_labels(labelmap_path, prefill=0, indexed=False)
def __update_metrics(self, duration: float) -> None:
self.classifications_per_second.update()

View File

@ -129,7 +129,7 @@ def get_ffmpeg_arg_list(arg: Any) -> list:
return arg if isinstance(arg, list) else shlex.split(arg)
def load_labels(path: Optional[str], encoding="utf-8", prefill=91):
def load_labels(path: Optional[str], encoding="utf-8", prefill=91, indexed: bool | None = None):
"""Loads labels from file (with or without index numbers).
Args:
path: path to label file.
@ -146,11 +146,12 @@ def load_labels(path: Optional[str], encoding="utf-8", prefill=91):
if not lines:
return {}
if lines[0].split(" ", maxsplit=1)[0].isdigit():
if indexed != False and lines[0].split(" ", maxsplit=1)[0].isdigit():
pairs = [line.split(" ", maxsplit=1) for line in lines]
labels.update({int(index): label.strip() for index, label in pairs})
else:
labels.update({index: line.strip() for index, line in enumerate(lines)})
return labels