From 61870184df821d43c1299e9b3b42e74e623cd938 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 13 Jan 2025 07:24:59 -0700 Subject: [PATCH] Cleanup classification --- .../real_time/bird_processor.py | 44 +++++++++++++++++-- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index cf2c5bfea..aa9b11984 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -5,9 +5,10 @@ import os import cv2 import numpy as np +import requests from frigate.config import FrigateConfig -from frigate.const import MODEL_CACHE_DIR +from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR from frigate.util.object import calculate_region from ..types import DataProcessorMetrics @@ -28,6 +29,7 @@ class BirdProcessor(RealTimeProcessorApi): self.tensor_input_details: dict[str, any] = None self.tensor_output_details: dict[str, any] = None self.detected_birds: dict[str, float] = {} + self.labelmap: dict[int, str] = {} download_path = os.path.join(MODEL_CACHE_DIR, "bird") self.model_files = { @@ -73,6 +75,17 @@ class BirdProcessor(RealTimeProcessorApi): self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() + i = 0 + + with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f: + line = f.readline() + while line: + start = line.find("(") + end = line.find(")") + self.labelmap[i] = line[start + 1 : end] + i += 1 + line = f.readline() + def process_frame(self, obj_data, frame): if obj_data["label"] != "bird": return @@ -84,7 +97,7 @@ class BirdProcessor(RealTimeProcessorApi): obj_data["box"][2], obj_data["box"][3], 224, - 1.4, + 1.0, ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) @@ -99,11 +112,34 @@ class BirdProcessor(RealTimeProcessorApi): self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() - res: np.ndarray = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] + res: np.ndarray = self.interpreter.get_tensor( + self.tensor_output_details[0]["index"] + )[0] probs = res / res.sum(axis=0) best_id = np.argmax(probs) + + if best_id == 964: + logger.debug("No bird classification was detected.") + return + score = round(probs[best_id], 2) - logger.info(f"the best scoring index is {best_id} {score}%") + previous_score = self.detected_birds.get(obj_data["id"], 0.0) + + if score <= previous_score: + logger.debug(f"Score {score} is worse than previous score {previous_score}") + return + + resp = requests.post( + f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label", + json={ + "camera": obj_data.get("camera"), + "subLabel": self.labelmap[best_id], + "subLabelScore": score, + }, + ) + + if resp.status_code == 200: + self.detected_birds[obj_data["id"]] = score def handle_request(self, request_data): return None