From e7427b47fb61a2f420108669fa8472f0a50cddf7 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Fri, 23 May 2025 06:05:43 -0600 Subject: [PATCH] Use config to process --- .../real_time/teachable_machine.py | 20 +++++++++++----- frigate/embeddings/maintainer.py | 24 +++++++++++++++++-- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/frigate/data_processing/real_time/teachable_machine.py b/frigate/data_processing/real_time/teachable_machine.py index 6a8aff330..bf5b2d30f 100644 --- a/frigate/data_processing/real_time/teachable_machine.py +++ b/frigate/data_processing/real_time/teachable_machine.py @@ -51,13 +51,18 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi): self.tensor_output_details = self.interpreter.get_output_details() self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0) - def process_frame(self, obj_data, frame): + def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): + camera = frame_data.get("camera") + if camera not in self.model_config.state_config.cameras: + return + + camera_config = self.model_config.state_config.cameras[camera] x, y, x2, y2 = calculate_region( frame.shape, - obj_data["box"][0], - obj_data["box"][1], - obj_data["box"][2], - obj_data["box"][3], + camera_config.crop[0], + camera_config.crop[1], + camera_config.crop[2], + camera_config.crop[3], 224, 1.0, ) @@ -71,17 +76,20 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi): if input.shape != (224, 224): input = cv2.resize(input, (224, 224)) + cv2.imwrite("/media/frigate/frames/gate.jpg", input) + input = np.expand_dims(input, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() res: np.ndarray = self.interpreter.get_tensor( self.tensor_output_details[0]["index"] )[0] + print(f"the gate res is {res}") probs = res / res.sum(axis=0) best_id = np.argmax(probs) score = round(probs[best_id], 2) - print(f"got ID of {best_id} with score {score}") + print(f"got {self.labelmap[best_id]} with score {score}") def handle_request(self, topic, request_data): return None diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index d52ad46fd..8a745117a 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -46,6 +46,10 @@ from frigate.data_processing.real_time.face import FaceRealTimeProcessor from frigate.data_processing.real_time.license_plate import ( LicensePlateRealTimeProcessor, ) +from frigate.data_processing.real_time.teachable_machine import ( + TeachableMachineObjectProcessor, + TeachableMachineStateProcessor, +) from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum from frigate.genai import get_genai_client @@ -143,6 +147,18 @@ class EmbeddingMaintainer(threading.Thread): ) ) + for model in self.config.classification.teachable_machine.values(): + self.realtime_processors.append( + TeachableMachineStateProcessor(self.config, model, self.metrics) + if model.state_config != None + else TeachableMachineObjectProcessor( + self.config, + model, + self.event_metadata_publisher, + self.metrics, + ) + ) + # post processors self.post_processors: list[PostProcessorApi] = [] @@ -463,11 +479,10 @@ class EmbeddingMaintainer(threading.Thread): camera_config = self.config.cameras[camera] - custom_classification_enabled = True if ( camera_config.type != CameraTypeEnum.lpr or "license_plate" in camera_config.objects.track - ) and not custom_classification_enabled: + ) and len(self.config.classification.teachable_machine) == 0: # no active features that use this data return @@ -488,6 +503,11 @@ class EmbeddingMaintainer(threading.Thread): if isinstance(processor, LicensePlateRealTimeProcessor): processor.process_frame(camera, yuv_frame, True) + if isinstance(processor, TeachableMachineObjectProcessor) or isinstance( + processor, TeachableMachineStateProcessor + ): + processor.process_frame({"camera": camera}, yuv_frame) + self.frame_manager.close(frame_name) def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: