mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 14:47:40 +03:00
Cleanup mypy for custom classification
This commit is contained in:
parent
4c72a210a9
commit
51397aeb1d
@ -24,7 +24,8 @@ from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
|||||||
from frigate.log import suppress_stderr_during
|
from frigate.log import suppress_stderr_during
|
||||||
from frigate.types import TrackedObjectUpdateTypesEnum
|
from frigate.types import TrackedObjectUpdateTypesEnum
|
||||||
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
|
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
|
||||||
from frigate.util.object import box_overlaps, calculate_region
|
from frigate.util.image import calculate_region
|
||||||
|
from frigate.util.object import box_overlaps
|
||||||
|
|
||||||
from ..types import DataProcessorMetrics
|
from ..types import DataProcessorMetrics
|
||||||
from .api import RealTimeProcessorApi
|
from .api import RealTimeProcessorApi
|
||||||
@ -49,12 +50,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
):
|
):
|
||||||
super().__init__(config, metrics)
|
super().__init__(config, metrics)
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
|
|
||||||
|
if not self.model_config.name:
|
||||||
|
raise ValueError("Custom classification model name must be set.")
|
||||||
|
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||||
self.interpreter: Interpreter = None
|
self.interpreter: Interpreter | None = None
|
||||||
self.tensor_input_details: dict[str, Any] | None = None
|
self.tensor_input_details: list[dict[str, Any]] | None = None
|
||||||
self.tensor_output_details: dict[str, Any] | None = None
|
self.tensor_output_details: list[dict[str, Any]] | None = None
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
self.state_history: dict[str, dict[str, Any]] = {}
|
self.state_history: dict[str, dict[str, Any]] = {}
|
||||||
@ -63,7 +68,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.metrics
|
self.metrics
|
||||||
and self.model_config.name in self.metrics.classification_speeds
|
and self.model_config.name in self.metrics.classification_speeds
|
||||||
):
|
):
|
||||||
self.inference_speed = InferenceSpeed(
|
self.inference_speed: InferenceSpeed | None = InferenceSpeed(
|
||||||
self.metrics.classification_speeds[self.model_config.name]
|
self.metrics.classification_speeds[self.model_config.name]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -172,12 +177,20 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray) -> None:
|
||||||
|
if (
|
||||||
|
not self.model_config.name
|
||||||
|
or not self.model_config.state_config
|
||||||
|
or not self.tensor_input_details
|
||||||
|
or not self.tensor_output_details
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||||
self.metrics.classification_cps[
|
self.metrics.classification_cps[
|
||||||
self.model_config.name
|
self.model_config.name
|
||||||
].value = self.classifications_per_second.eps()
|
].value = self.classifications_per_second.eps()
|
||||||
camera = frame_data.get("camera")
|
camera = str(frame_data.get("camera"))
|
||||||
|
|
||||||
if camera not in self.model_config.state_config.cameras:
|
if camera not in self.model_config.state_config.cameras:
|
||||||
return
|
return
|
||||||
@ -283,7 +296,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self.model_config.name} Ran state classification with probabilities: {probs}"
|
f"{self.model_config.name} Ran state classification with probabilities: {probs}"
|
||||||
)
|
)
|
||||||
best_id = np.argmax(probs)
|
best_id = int(np.argmax(probs))
|
||||||
score = round(probs[best_id], 2)
|
score = round(probs[best_id], 2)
|
||||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||||
|
|
||||||
@ -319,7 +332,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
verified_state,
|
verified_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(
|
||||||
|
self, topic: str, request_data: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||||
if request_data.get("model_name") == self.model_config.name:
|
if request_data.get("model_name") == self.model_config.name:
|
||||||
self.__build_detector()
|
self.__build_detector()
|
||||||
@ -335,7 +350,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def expire_object(self, object_id, camera):
|
def expire_object(self, object_id: str, camera: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -350,13 +365,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
):
|
):
|
||||||
super().__init__(config, metrics)
|
super().__init__(config, metrics)
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
|
|
||||||
|
if not self.model_config.name:
|
||||||
|
raise ValueError("Custom classification model name must be set.")
|
||||||
|
|
||||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||||
self.interpreter: Interpreter = None
|
self.interpreter: Interpreter | None = None
|
||||||
self.sub_label_publisher = sub_label_publisher
|
self.sub_label_publisher = sub_label_publisher
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.tensor_input_details: dict[str, Any] | None = None
|
self.tensor_input_details: list[dict[str, Any]] | None = None
|
||||||
self.tensor_output_details: dict[str, Any] | None = None
|
self.tensor_output_details: list[dict[str, Any]] | None = None
|
||||||
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
@ -365,7 +384,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.metrics
|
self.metrics
|
||||||
and self.model_config.name in self.metrics.classification_speeds
|
and self.model_config.name in self.metrics.classification_speeds
|
||||||
):
|
):
|
||||||
self.inference_speed = InferenceSpeed(
|
self.inference_speed: InferenceSpeed | None = InferenceSpeed(
|
||||||
self.metrics.classification_speeds[self.model_config.name]
|
self.metrics.classification_speeds[self.model_config.name]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -431,8 +450,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
)
|
)
|
||||||
return None, 0.0
|
return None, 0.0
|
||||||
|
|
||||||
label_counts = {}
|
label_counts: dict[str, int] = {}
|
||||||
label_scores = {}
|
label_scores: dict[str, list[float]] = {}
|
||||||
total_attempts = len(history)
|
total_attempts = len(history)
|
||||||
|
|
||||||
for label, score, timestamp in history:
|
for label, score, timestamp in history:
|
||||||
@ -443,7 +462,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
label_counts[label] += 1
|
label_counts[label] += 1
|
||||||
label_scores[label].append(score)
|
label_scores[label].append(score)
|
||||||
|
|
||||||
best_label = max(label_counts, key=label_counts.get)
|
best_label = max(label_counts, key=lambda k: label_counts[k])
|
||||||
best_count = label_counts[best_label]
|
best_count = label_counts[best_label]
|
||||||
|
|
||||||
consensus_threshold = total_attempts * 0.6
|
consensus_threshold = total_attempts * 0.6
|
||||||
@ -470,7 +489,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
)
|
)
|
||||||
return best_label, avg_score
|
return best_label, avg_score
|
||||||
|
|
||||||
def process_frame(self, obj_data, frame):
|
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
|
||||||
|
if (
|
||||||
|
not self.model_config.name
|
||||||
|
or not self.model_config.object_config
|
||||||
|
or not self.tensor_input_details
|
||||||
|
or not self.tensor_output_details
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||||
self.metrics.classification_cps[
|
self.metrics.classification_cps[
|
||||||
self.model_config.name
|
self.model_config.name
|
||||||
@ -555,7 +582,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
|
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
|
||||||
)
|
)
|
||||||
best_id = np.argmax(probs)
|
best_id = int(np.argmax(probs))
|
||||||
score = round(probs[best_id], 2)
|
score = round(probs[best_id], 2)
|
||||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||||
|
|
||||||
@ -650,7 +677,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic: str, request_data: dict) -> dict | None:
|
||||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||||
if request_data.get("model_name") == self.model_config.name:
|
if request_data.get("model_name") == self.model_config.name:
|
||||||
self.__build_detector()
|
self.__build_detector()
|
||||||
@ -666,12 +693,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def expire_object(self, object_id, camera):
|
def expire_object(self, object_id: str, camera: str) -> None:
|
||||||
if object_id in self.classification_history:
|
if object_id in self.classification_history:
|
||||||
self.classification_history.pop(object_id)
|
self.classification_history.pop(object_id)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def write_classification_attempt(
|
def write_classification_attempt(
|
||||||
folder: str,
|
folder: str,
|
||||||
frame: np.ndarray,
|
frame: np.ndarray,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user