tensor input dump and UI to classify

This commit is contained in:
Michael Wei 2020-11-28 09:27:19 +00:00
parent 893e6b40a7
commit 410be5adaa
4 changed files with 133 additions and 8 deletions

View File

@ -238,6 +238,15 @@ objects:
# Optional: minimum decimal percentage for tracked object's computed score to be considered a true positive (default: shown below) # Optional: minimum decimal percentage for tracked object's computed score to be considered a true positive (default: shown below)
threshold: 0.85 threshold: 0.85
# Configuraiton for saving data for training.
# Set to TRUE to save inputs to tensorflow. Note that every input is saved, so this will generate a lot of data.
saveTensorInputs: false
# Path to save tensor inputs to.
saveTensorPath: /clips/tensorInputs
# Path the categorizer UI will save categorized inputs to.
# The categorizer UI can be accessed at http://<frigate url>/tensorInputUi
saveTensorCategorizedPath: /clips/tensorCategorized
# Required: configuration section for cameras # Required: configuration section for cameras
cameras: cameras:
# Required: name of the camera # Required: name of the camera

View File

@ -15,8 +15,9 @@ import multiprocessing as mp
import subprocess as sp import subprocess as sp
import numpy as np import numpy as np
import logging import logging
from flask import Flask, Response, make_response, jsonify, request from flask import Flask, Response, make_response, jsonify, request, send_from_directory
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import shutil
from frigate.video import capture_camera, track_camera, get_ffmpeg_input, get_frame_shape, CameraCapture, start_or_restart_ffmpeg from frigate.video import capture_camera, track_camera, get_ffmpeg_input, get_frame_shape, CameraCapture, start_or_restart_ffmpeg
from frigate.object_processing import TrackedObjectProcessor from frigate.object_processing import TrackedObjectProcessor
@ -121,8 +122,9 @@ class FrigateWatchdog(threading.Thread):
camera_process['process_fps'].value = 0.0 camera_process['process_fps'].value = 0.0
camera_process['detection_fps'].value = 0.0 camera_process['detection_fps'].value = 0.0
camera_process['read_start'].value = 0.0 camera_process['read_start'].value = 0.0
process = mp.Process(target=track_camera, args=(name, self.config, process = mp.Process(target=track_camera, args=(name, self.config[name],
self.detection_queue, self.out_events[name], self.tracked_objects_queue, camera_process, self.stop_event)) self.detection_queue, self.out_events[name], self.tracked_objects_queue, camera_process,
self.stop_event, CONFIG))
process.daemon = True process.daemon = True
camera_process['process'] = process camera_process['process'] = process
process.start() process.start()
@ -265,7 +267,7 @@ def main():
camera_process_info[name]['capture_process'] = capture_process camera_process_info[name]['capture_process'] = capture_process
camera_process = mp.Process(target=track_camera, args=(name, config, camera_process = mp.Process(target=track_camera, args=(name, config,
detection_queue, out_events[name], tracked_objects_queue, camera_process_info[name], stop_event)) detection_queue, out_events[name], tracked_objects_queue, camera_process_info[name], stop_event, CONFIG))
camera_process.daemon = True camera_process.daemon = True
camera_process_info[name]['process'] = camera_process camera_process_info[name]['process'] = camera_process
@ -417,7 +419,7 @@ def main():
return response return response
else: else:
return "Camera named {} not found".format(camera_name), 404 return "Camera named {} not found".format(camera_name), 404
def imagestream(camera_name, fps, height): def imagestream(camera_name, fps, height):
while True: while True:
# max out at specified FPS # max out at specified FPS
@ -433,6 +435,110 @@ def main():
yield (b'--frame\r\n' yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n') b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
@app.route('/tensorInput/<path:path>')
def send_tensor_image(path):
return send_from_directory(CONFIG["saveTensorPath"], path)
@app.route('/tensorInputClassify/<path:path>/<category>')
def classify_tensor_image(path, category):
if category == "delete":
os.unlink(os.path.join(CONFIG["saveTensorPath"], path))
return "deleted"
else:
category_path = os.path.join(CONFIG["saveTensorCategorizedPath"], category)
os.makedirs(category_path, exist_ok=True)
shutil.move(os.path.join(CONFIG["saveTensorPath"], path), os.path.join(category_path, path))
return "ok"
@app.route('/tensorInputNext')
def get_next_tensor_image():
file = None
for root, dirs, files in os.walk(CONFIG["saveTensorPath"]):
for name in files:
file = name
break
return "No more inputs" if file is None else file
@app.route('/tensorInputUi')
def get_tensor_ui():
return """<html>
<head>
<title>Simple classification UI</title>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/water.css@2/out/water.css">
</head>
<script type="text/javascript">
var current = ""
let categories = {}
const getNext = async () => {
let response = await fetch('./tensorInputNext')
current = await response.text()
if (current === "No more inputs") {
// probably make this nicer in the future
alert("All done, no more inputs");
}
document.getElementById("image").src = `./tensorInput/${current}`
}
const addCategory = (name) => {
let row = document.getElementById("categories").insertRow(-1);
let i = -1;
let letter = ""
while (letter in categories || i == -1) {
i++;
if (i < name.length) {
letter = name.toLowerCase().charAt(i);
} else {
//out of letters, pick some random key until it works
const alphabet = "abcdefghijklmnopqrstuvwxyz1234567890,./;'[]\-=`'"
letter = alphabet[Math.floor(Math.random() * alphabet.length)]
}
}
categories[letter] = name
row.insertCell().appendChild(document.createTextNode(letter));
row.insertCell().appendChild(document.createTextNode(name));
}
const classify = async (key) => {
await fetch(`./tensorInputClassify/${current}/${categories[key].toLowerCase()}`)
await getNext()
}
window.onload=()=> {
getNext();
// add default delete category
addCategory("Delete");
document.getElementById("add_category").onclick = () => {
addCategory(document.getElementById("category").value)
document.getElementById("category").value = ""
return false;
}
}
document.addEventListener('keypress', event => {
if (event.key in categories && document.getElementById("category") != document.activeElement) {
classify(event.key);
}
});
</script>
<img style="height:300px;width:300px;" id="image"/>
<hr>
<form>
<label for="category">New Category</label>
<input type="text" id="category">
<button id="add_category">Add Category</button>
</form>
<table id="categories">
<thead>
<tr> <th>Key</th> <th>Category</th></thead><tbody>
</tbody> </table>
</html>
"""
app.run(host='0.0.0.0', port=WEB_PORT, debug=False) app.run(host='0.0.0.0', port=WEB_PORT, debug=False)
object_processor.join() object_processor.join()

View File

@ -10,6 +10,7 @@ import numpy as np
import tflite_runtime.interpreter as tflite import tflite_runtime.interpreter as tflite
from tflite_runtime.interpreter import load_delegate from tflite_runtime.interpreter import load_delegate
from frigate.util import EventsPerSecond, listen, SharedMemoryFrameManager from frigate.util import EventsPerSecond, listen, SharedMemoryFrameManager
import cv2
def load_labels(path, encoding='utf-8'): def load_labels(path, encoding='utf-8'):
"""Loads labels from file (with or without index numbers). """Loads labels from file (with or without index numbers).
@ -159,7 +160,7 @@ class EdgeTPUProcess():
self.detect_process.start() self.detect_process.start()
class RemoteObjectDetector(): class RemoteObjectDetector():
def __init__(self, name, labels, detection_queue, event): def __init__(self, name, labels, detection_queue, event, config):
self.labels = load_labels(labels) self.labels = load_labels(labels)
self.name = name self.name = name
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
@ -169,6 +170,7 @@ class RemoteObjectDetector():
self.np_shm = np.ndarray((1,300,300,3), dtype=np.uint8, buffer=self.shm.buf) self.np_shm = np.ndarray((1,300,300,3), dtype=np.uint8, buffer=self.shm.buf)
self.out_shm = mp.shared_memory.SharedMemory(name=f"out-{self.name}", create=False) self.out_shm = mp.shared_memory.SharedMemory(name=f"out-{self.name}", create=False)
self.out_np_shm = np.ndarray((20,6), dtype=np.float32, buffer=self.out_shm.buf) self.out_np_shm = np.ndarray((20,6), dtype=np.float32, buffer=self.out_shm.buf)
self.config = config
def detect(self, tensor_input, threshold=.4): def detect(self, tensor_input, threshold=.4):
detections = [] detections = []
@ -177,6 +179,14 @@ class RemoteObjectDetector():
self.np_shm[:] = tensor_input[:] self.np_shm[:] = tensor_input[:]
self.event.clear() self.event.clear()
self.detection_queue.put(self.name) self.detection_queue.put(self.name)
if self.config["saveTensorInputs"]:
root_path = self.config['saveTensorPath']
os.makedirs(root_path, exist_ok=True)
file_path = os.path.join(root_path, f"{self.name}.{datetime.datetime.now().timestamp()}.jpg")
cv2.imwrite(file_path, tensor_input[0])
result = self.event.wait(timeout=10.0) result = self.event.wait(timeout=10.0)
# if it timed out # if it timed out

View File

@ -241,7 +241,7 @@ def capture_camera(name, config, process_info, stop_event):
camera_watchdog.start() camera_watchdog.start()
camera_watchdog.join() camera_watchdog.join()
def track_camera(name, config, detection_queue, result_connection, detected_objects_queue, process_info, stop_event): def track_camera(name, config, detection_queue, result_connection, detected_objects_queue, process_info, stop_event, global_config):
listen() listen()
frame_queue = process_info['frame_queue'] frame_queue = process_info['frame_queue']
@ -275,7 +275,7 @@ def track_camera(name, config, detection_queue, result_connection, detected_obje
mask[:] = 255 mask[:] = 255
motion_detector = MotionDetector(frame_shape, mask, resize_factor=6) motion_detector = MotionDetector(frame_shape, mask, resize_factor=6)
object_detector = RemoteObjectDetector(name, '/labelmap.txt', detection_queue, result_connection) object_detector = RemoteObjectDetector(name, '/labelmap.txt', detection_queue, result_connection, global_config)
object_tracker = ObjectTracker(10) object_tracker = ObjectTracker(10)