mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 14:47:40 +03:00
Cleanup mypy for custom classification
This commit is contained in:
parent
4c72a210a9
commit
51397aeb1d
@ -24,7 +24,8 @@ from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.log import suppress_stderr_during
|
||||
from frigate.types import TrackedObjectUpdateTypesEnum
|
||||
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
|
||||
from frigate.util.object import box_overlaps, calculate_region
|
||||
from frigate.util.image import calculate_region
|
||||
from frigate.util.object import box_overlaps
|
||||
|
||||
from ..types import DataProcessorMetrics
|
||||
from .api import RealTimeProcessorApi
|
||||
@ -49,12 +50,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
):
|
||||
super().__init__(config, metrics)
|
||||
self.model_config = model_config
|
||||
|
||||
if not self.model_config.name:
|
||||
raise ValueError("Custom classification model name must be set.")
|
||||
|
||||
self.requestor = requestor
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Interpreter = None
|
||||
self.tensor_input_details: dict[str, Any] | None = None
|
||||
self.tensor_output_details: dict[str, Any] | None = None
|
||||
self.interpreter: Interpreter | None = None
|
||||
self.tensor_input_details: list[dict[str, Any]] | None = None
|
||||
self.tensor_output_details: list[dict[str, Any]] | None = None
|
||||
self.labelmap: dict[int, str] = {}
|
||||
self.classifications_per_second = EventsPerSecond()
|
||||
self.state_history: dict[str, dict[str, Any]] = {}
|
||||
@ -63,7 +68,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
self.metrics
|
||||
and self.model_config.name in self.metrics.classification_speeds
|
||||
):
|
||||
self.inference_speed = InferenceSpeed(
|
||||
self.inference_speed: InferenceSpeed | None = InferenceSpeed(
|
||||
self.metrics.classification_speeds[self.model_config.name]
|
||||
)
|
||||
else:
|
||||
@ -172,12 +177,20 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
return None
|
||||
|
||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray) -> None:
|
||||
if (
|
||||
not self.model_config.name
|
||||
or not self.model_config.state_config
|
||||
or not self.tensor_input_details
|
||||
or not self.tensor_output_details
|
||||
):
|
||||
return
|
||||
|
||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||
self.metrics.classification_cps[
|
||||
self.model_config.name
|
||||
].value = self.classifications_per_second.eps()
|
||||
camera = frame_data.get("camera")
|
||||
camera = str(frame_data.get("camera"))
|
||||
|
||||
if camera not in self.model_config.state_config.cameras:
|
||||
return
|
||||
@ -283,7 +296,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
logger.debug(
|
||||
f"{self.model_config.name} Ran state classification with probabilities: {probs}"
|
||||
)
|
||||
best_id = np.argmax(probs)
|
||||
best_id = int(np.argmax(probs))
|
||||
score = round(probs[best_id], 2)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||
|
||||
@ -319,7 +332,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
verified_state,
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
def handle_request(
|
||||
self, topic: str, request_data: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.__build_detector()
|
||||
@ -335,7 +350,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
else:
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id, camera):
|
||||
def expire_object(self, object_id: str, camera: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -350,13 +365,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
):
|
||||
super().__init__(config, metrics)
|
||||
self.model_config = model_config
|
||||
|
||||
if not self.model_config.name:
|
||||
raise ValueError("Custom classification model name must be set.")
|
||||
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Interpreter = None
|
||||
self.interpreter: Interpreter | None = None
|
||||
self.sub_label_publisher = sub_label_publisher
|
||||
self.requestor = requestor
|
||||
self.tensor_input_details: dict[str, Any] | None = None
|
||||
self.tensor_output_details: dict[str, Any] | None = None
|
||||
self.tensor_input_details: list[dict[str, Any]] | None = None
|
||||
self.tensor_output_details: list[dict[str, Any]] | None = None
|
||||
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
||||
self.labelmap: dict[int, str] = {}
|
||||
self.classifications_per_second = EventsPerSecond()
|
||||
@ -365,7 +384,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.metrics
|
||||
and self.model_config.name in self.metrics.classification_speeds
|
||||
):
|
||||
self.inference_speed = InferenceSpeed(
|
||||
self.inference_speed: InferenceSpeed | None = InferenceSpeed(
|
||||
self.metrics.classification_speeds[self.model_config.name]
|
||||
)
|
||||
else:
|
||||
@ -431,8 +450,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
label_counts = {}
|
||||
label_scores = {}
|
||||
label_counts: dict[str, int] = {}
|
||||
label_scores: dict[str, list[float]] = {}
|
||||
total_attempts = len(history)
|
||||
|
||||
for label, score, timestamp in history:
|
||||
@ -443,7 +462,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
label_counts[label] += 1
|
||||
label_scores[label].append(score)
|
||||
|
||||
best_label = max(label_counts, key=label_counts.get)
|
||||
best_label = max(label_counts, key=lambda k: label_counts[k])
|
||||
best_count = label_counts[best_label]
|
||||
|
||||
consensus_threshold = total_attempts * 0.6
|
||||
@ -470,7 +489,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
return best_label, avg_score
|
||||
|
||||
def process_frame(self, obj_data, frame):
|
||||
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
|
||||
if (
|
||||
not self.model_config.name
|
||||
or not self.model_config.object_config
|
||||
or not self.tensor_input_details
|
||||
or not self.tensor_output_details
|
||||
):
|
||||
return
|
||||
|
||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||
self.metrics.classification_cps[
|
||||
self.model_config.name
|
||||
@ -555,7 +582,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
logger.debug(
|
||||
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
|
||||
)
|
||||
best_id = np.argmax(probs)
|
||||
best_id = int(np.argmax(probs))
|
||||
score = round(probs[best_id], 2)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||
|
||||
@ -650,7 +677,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
),
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
def handle_request(self, topic: str, request_data: dict) -> dict | None:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.__build_detector()
|
||||
@ -666,12 +693,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
else:
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id, camera):
|
||||
def expire_object(self, object_id: str, camera: str) -> None:
|
||||
if object_id in self.classification_history:
|
||||
self.classification_history.pop(object_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def write_classification_attempt(
|
||||
folder: str,
|
||||
frame: np.ndarray,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user