Use config to process

This commit is contained in:
Nicolas Mowen 2025-05-23 06:05:43 -06:00
parent 46e72227f9
commit e7427b47fb
2 changed files with 36 additions and 8 deletions

View File

@ -51,13 +51,18 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi):
self.tensor_output_details = self.interpreter.get_output_details() self.tensor_output_details = self.interpreter.get_output_details()
self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0) self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0)
def process_frame(self, obj_data, frame): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera")
if camera not in self.model_config.state_config.cameras:
return
camera_config = self.model_config.state_config.cameras[camera]
x, y, x2, y2 = calculate_region( x, y, x2, y2 = calculate_region(
frame.shape, frame.shape,
obj_data["box"][0], camera_config.crop[0],
obj_data["box"][1], camera_config.crop[1],
obj_data["box"][2], camera_config.crop[2],
obj_data["box"][3], camera_config.crop[3],
224, 224,
1.0, 1.0,
) )
@ -71,17 +76,20 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi):
if input.shape != (224, 224): if input.shape != (224, 224):
input = cv2.resize(input, (224, 224)) input = cv2.resize(input, (224, 224))
cv2.imwrite("/media/frigate/frames/gate.jpg", input)
input = np.expand_dims(input, axis=0) input = np.expand_dims(input, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke() self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor( res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"] self.tensor_output_details[0]["index"]
)[0] )[0]
print(f"the gate res is {res}")
probs = res / res.sum(axis=0) probs = res / res.sum(axis=0)
best_id = np.argmax(probs) best_id = np.argmax(probs)
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
print(f"got ID of {best_id} with score {score}") print(f"got {self.labelmap[best_id]} with score {score}")
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
return None return None

View File

@ -46,6 +46,10 @@ from frigate.data_processing.real_time.face import FaceRealTimeProcessor
from frigate.data_processing.real_time.license_plate import ( from frigate.data_processing.real_time.license_plate import (
LicensePlateRealTimeProcessor, LicensePlateRealTimeProcessor,
) )
from frigate.data_processing.real_time.teachable_machine import (
TeachableMachineObjectProcessor,
TeachableMachineStateProcessor,
)
from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum
from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum
from frigate.genai import get_genai_client from frigate.genai import get_genai_client
@ -143,6 +147,18 @@ class EmbeddingMaintainer(threading.Thread):
) )
) )
for model in self.config.classification.teachable_machine.values():
self.realtime_processors.append(
TeachableMachineStateProcessor(self.config, model, self.metrics)
if model.state_config != None
else TeachableMachineObjectProcessor(
self.config,
model,
self.event_metadata_publisher,
self.metrics,
)
)
# post processors # post processors
self.post_processors: list[PostProcessorApi] = [] self.post_processors: list[PostProcessorApi] = []
@ -463,11 +479,10 @@ class EmbeddingMaintainer(threading.Thread):
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
custom_classification_enabled = True
if ( if (
camera_config.type != CameraTypeEnum.lpr camera_config.type != CameraTypeEnum.lpr
or "license_plate" in camera_config.objects.track or "license_plate" in camera_config.objects.track
) and not custom_classification_enabled: ) and len(self.config.classification.teachable_machine) == 0:
# no active features that use this data # no active features that use this data
return return
@ -488,6 +503,11 @@ class EmbeddingMaintainer(threading.Thread):
if isinstance(processor, LicensePlateRealTimeProcessor): if isinstance(processor, LicensePlateRealTimeProcessor):
processor.process_frame(camera, yuv_frame, True) processor.process_frame(camera, yuv_frame, True)
if isinstance(processor, TeachableMachineObjectProcessor) or isinstance(
processor, TeachableMachineStateProcessor
):
processor.process_frame({"camera": camera}, yuv_frame)
self.frame_manager.close(frame_name) self.frame_manager.close(frame_name)
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: