Cleanup mypy for custom classification

This commit is contained in:
Nicolas Mowen 2026-03-26 11:13:47 -06:00
parent 4c72a210a9
commit 51397aeb1d

View File

@ -24,7 +24,8 @@ from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
from frigate.log import suppress_stderr_during
from frigate.types import TrackedObjectUpdateTypesEnum
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
from frigate.util.object import box_overlaps, calculate_region
from frigate.util.image import calculate_region
from frigate.util.object import box_overlaps
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
@ -49,12 +50,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
):
super().__init__(config, metrics)
self.model_config = model_config
if not self.model_config.name:
raise ValueError("Custom classification model name must be set.")
self.requestor = requestor
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.tensor_input_details: dict[str, Any] | None = None
self.tensor_output_details: dict[str, Any] | None = None
self.interpreter: Interpreter | None = None
self.tensor_input_details: list[dict[str, Any]] | None = None
self.tensor_output_details: list[dict[str, Any]] | None = None
self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond()
self.state_history: dict[str, dict[str, Any]] = {}
@ -63,7 +68,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self.metrics
and self.model_config.name in self.metrics.classification_speeds
):
self.inference_speed = InferenceSpeed(
self.inference_speed: InferenceSpeed | None = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name]
)
else:
@ -172,12 +177,20 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
return None
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray) -> None:
if (
not self.model_config.name
or not self.model_config.state_config
or not self.tensor_input_details
or not self.tensor_output_details
):
return
if self.metrics and self.model_config.name in self.metrics.classification_cps:
self.metrics.classification_cps[
self.model_config.name
].value = self.classifications_per_second.eps()
camera = frame_data.get("camera")
camera = str(frame_data.get("camera"))
if camera not in self.model_config.state_config.cameras:
return
@ -283,7 +296,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
logger.debug(
f"{self.model_config.name} Ran state classification with probabilities: {probs}"
)
best_id = np.argmax(probs)
best_id = int(np.argmax(probs))
score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now)
@ -319,7 +332,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
verified_state,
)
def handle_request(self, topic, request_data):
def handle_request(
self, topic: str, request_data: dict[str, Any]
) -> dict[str, Any] | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name:
self.__build_detector()
@ -335,7 +350,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
else:
return None
def expire_object(self, object_id, camera):
def expire_object(self, object_id: str, camera: str) -> None:
pass
@ -350,13 +365,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
):
super().__init__(config, metrics)
self.model_config = model_config
if not self.model_config.name:
raise ValueError("Custom classification model name must be set.")
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None
self.interpreter: Interpreter | None = None
self.sub_label_publisher = sub_label_publisher
self.requestor = requestor
self.tensor_input_details: dict[str, Any] | None = None
self.tensor_output_details: dict[str, Any] | None = None
self.tensor_input_details: list[dict[str, Any]] | None = None
self.tensor_output_details: list[dict[str, Any]] | None = None
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond()
@ -365,7 +384,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.metrics
and self.model_config.name in self.metrics.classification_speeds
):
self.inference_speed = InferenceSpeed(
self.inference_speed: InferenceSpeed | None = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name]
)
else:
@ -431,8 +450,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)
return None, 0.0
label_counts = {}
label_scores = {}
label_counts: dict[str, int] = {}
label_scores: dict[str, list[float]] = {}
total_attempts = len(history)
for label, score, timestamp in history:
@ -443,7 +462,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
label_counts[label] += 1
label_scores[label].append(score)
best_label = max(label_counts, key=label_counts.get)
best_label = max(label_counts, key=lambda k: label_counts[k])
best_count = label_counts[best_label]
consensus_threshold = total_attempts * 0.6
@ -470,7 +489,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)
return best_label, avg_score
def process_frame(self, obj_data, frame):
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
if (
not self.model_config.name
or not self.model_config.object_config
or not self.tensor_input_details
or not self.tensor_output_details
):
return
if self.metrics and self.model_config.name in self.metrics.classification_cps:
self.metrics.classification_cps[
self.model_config.name
@ -555,7 +582,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
logger.debug(
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
)
best_id = np.argmax(probs)
best_id = int(np.argmax(probs))
score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now)
@ -650,7 +677,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
),
)
def handle_request(self, topic, request_data):
def handle_request(self, topic: str, request_data: dict) -> dict | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name:
self.__build_detector()
@ -666,12 +693,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
else:
return None
def expire_object(self, object_id, camera):
def expire_object(self, object_id: str, camera: str) -> None:
if object_id in self.classification_history:
self.classification_history.pop(object_id)
@staticmethod
def write_classification_attempt(
folder: str,
frame: np.ndarray,