mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 09:04:28 +03:00
Cleanup classification
This commit is contained in:
parent
f87e82481d
commit
61870184df
@ -5,9 +5,10 @@ import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.const import FRIGATE_LOCALHOST, MODEL_CACHE_DIR
|
||||
from frigate.util.object import calculate_region
|
||||
|
||||
from ..types import DataProcessorMetrics
|
||||
@ -28,6 +29,7 @@ class BirdProcessor(RealTimeProcessorApi):
|
||||
self.tensor_input_details: dict[str, any] = None
|
||||
self.tensor_output_details: dict[str, any] = None
|
||||
self.detected_birds: dict[str, float] = {}
|
||||
self.labelmap: dict[int, str] = {}
|
||||
|
||||
download_path = os.path.join(MODEL_CACHE_DIR, "bird")
|
||||
self.model_files = {
|
||||
@ -73,6 +75,17 @@ class BirdProcessor(RealTimeProcessorApi):
|
||||
self.tensor_input_details = self.interpreter.get_input_details()
|
||||
self.tensor_output_details = self.interpreter.get_output_details()
|
||||
|
||||
i = 0
|
||||
|
||||
with open(os.path.join(MODEL_CACHE_DIR, "bird/birdmap.txt")) as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
start = line.find("(")
|
||||
end = line.find(")")
|
||||
self.labelmap[i] = line[start + 1 : end]
|
||||
i += 1
|
||||
line = f.readline()
|
||||
|
||||
def process_frame(self, obj_data, frame):
|
||||
if obj_data["label"] != "bird":
|
||||
return
|
||||
@ -84,7 +97,7 @@ class BirdProcessor(RealTimeProcessorApi):
|
||||
obj_data["box"][2],
|
||||
obj_data["box"][3],
|
||||
224,
|
||||
1.4,
|
||||
1.0,
|
||||
)
|
||||
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||
@ -99,11 +112,34 @@ class BirdProcessor(RealTimeProcessorApi):
|
||||
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||
self.interpreter.invoke()
|
||||
res: np.ndarray = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
|
||||
res: np.ndarray = self.interpreter.get_tensor(
|
||||
self.tensor_output_details[0]["index"]
|
||||
)[0]
|
||||
probs = res / res.sum(axis=0)
|
||||
best_id = np.argmax(probs)
|
||||
|
||||
if best_id == 964:
|
||||
logger.debug("No bird classification was detected.")
|
||||
return
|
||||
|
||||
score = round(probs[best_id], 2)
|
||||
logger.info(f"the best scoring index is {best_id} {score}%")
|
||||
previous_score = self.detected_birds.get(obj_data["id"], 0.0)
|
||||
|
||||
if score <= previous_score:
|
||||
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
||||
return
|
||||
|
||||
resp = requests.post(
|
||||
f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label",
|
||||
json={
|
||||
"camera": obj_data.get("camera"),
|
||||
"subLabel": self.labelmap[best_id],
|
||||
"subLabelScore": score,
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
self.detected_birds[obj_data["id"]] = score
|
||||
|
||||
def handle_request(self, request_data):
|
||||
return None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user