Handle case when no classification model exists (#20257)
Some checks failed
CI / AMD64 Build (push) Has been cancelled
CI / ARM Build (push) Has been cancelled
CI / Jetson Jetpack 6 (push) Has been cancelled
CI / AMD64 Extra Build (push) Has been cancelled
CI / ARM Extra Build (push) Has been cancelled
CI / Synaptics Build (push) Has been cancelled
CI / Assemble and push default build (push) Has been cancelled

This commit is contained in:
Nicolas Mowen 2025-09-28 15:03:44 -06:00 committed by GitHub
parent 12f8c3feac
commit 9fdce80729
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -48,9 +48,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self.requestor = requestor
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
self.interpreter: Interpreter | None = None
self.tensor_input_details: dict[str, Any] | None = None
self.tensor_output_details: dict[str, Any] | None = None
self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond()
self.inference_speed = InferenceSpeed(
@ -61,17 +61,24 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
@redirect_output_to_logger(logger, logging.DEBUG)
def __build_detector(self) -> None:
model_path = os.path.join(self.model_dir, "model.tflite")
labelmap_path = os.path.join(self.model_dir, "labelmap.txt")
if not os.path.exists(model_path) or not os.path.exists(labelmap_path):
self.interpreter = None
self.tensor_input_details = None
self.tensor_output_details = None
self.labelmap = {}
return
self.interpreter = Interpreter(
model_path=os.path.join(self.model_dir, "model.tflite"),
model_path=model_path,
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(
os.path.join(self.model_dir, "labelmap.txt"),
prefill=0,
)
self.labelmap = load_labels(labelmap_path, prefill=0)
self.classifications_per_second.start()
def __update_metrics(self, duration: float) -> None:
@ -140,6 +147,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
logger.warning("Failed to resize image for state classification")
return
if self.interpreter is None:
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
"unknown",
0.0,
)
return
input = np.expand_dims(frame, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
@ -197,10 +214,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.model_config = model_config
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.interpreter: Interpreter | None = None
self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None
self.tensor_output_details: dict[str, Any] = None
self.tensor_input_details: dict[str, Any] | None = None
self.tensor_output_details: dict[str, Any] | None = None
self.detected_objects: dict[str, float] = {}
self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond()
@ -211,17 +228,24 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
@redirect_output_to_logger(logger, logging.DEBUG)
def __build_detector(self) -> None:
model_path = os.path.join(self.model_dir, "model.tflite")
labelmap_path = os.path.join(self.model_dir, "labelmap.txt")
if not os.path.exists(model_path) or not os.path.exists(labelmap_path):
self.interpreter = None
self.tensor_input_details = None
self.tensor_output_details = None
self.labelmap = {}
return
self.interpreter = Interpreter(
model_path=os.path.join(self.model_dir, "model.tflite"),
model_path=model_path,
num_threads=2,
)
self.interpreter.allocate_tensors()
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(
os.path.join(self.model_dir, "labelmap.txt"),
prefill=0,
)
self.labelmap = load_labels(labelmap_path, prefill=0)
def __update_metrics(self, duration: float) -> None:
self.classifications_per_second.update()
@ -265,6 +289,16 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
logger.warning("Failed to resize image for state classification")
return
if self.interpreter is None:
write_classification_attempt(
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
now,
"unknown",
0.0,
)
return
input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()