Cleanup mypy for custom classification

This commit is contained in:
Nicolas Mowen 2026-03-26 11:13:47 -06:00
parent 4c72a210a9
commit 51397aeb1d

View File

@ -24,7 +24,8 @@ from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
from frigate.log import suppress_stderr_during from frigate.log import suppress_stderr_during
from frigate.types import TrackedObjectUpdateTypesEnum from frigate.types import TrackedObjectUpdateTypesEnum
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
from frigate.util.object import box_overlaps, calculate_region from frigate.util.image import calculate_region
from frigate.util.object import box_overlaps
from ..types import DataProcessorMetrics from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi from .api import RealTimeProcessorApi
@ -49,12 +50,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
): ):
super().__init__(config, metrics) super().__init__(config, metrics)
self.model_config = model_config self.model_config = model_config
if not self.model_config.name:
raise ValueError("Custom classification model name must be set.")
self.requestor = requestor self.requestor = requestor
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None self.interpreter: Interpreter | None = None
self.tensor_input_details: dict[str, Any] | None = None self.tensor_input_details: list[dict[str, Any]] | None = None
self.tensor_output_details: dict[str, Any] | None = None self.tensor_output_details: list[dict[str, Any]] | None = None
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond() self.classifications_per_second = EventsPerSecond()
self.state_history: dict[str, dict[str, Any]] = {} self.state_history: dict[str, dict[str, Any]] = {}
@ -63,7 +68,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self.metrics self.metrics
and self.model_config.name in self.metrics.classification_speeds and self.model_config.name in self.metrics.classification_speeds
): ):
self.inference_speed = InferenceSpeed( self.inference_speed: InferenceSpeed | None = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name] self.metrics.classification_speeds[self.model_config.name]
) )
else: else:
@ -172,12 +177,20 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
return None return None
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray) -> None:
if (
not self.model_config.name
or not self.model_config.state_config
or not self.tensor_input_details
or not self.tensor_output_details
):
return
if self.metrics and self.model_config.name in self.metrics.classification_cps: if self.metrics and self.model_config.name in self.metrics.classification_cps:
self.metrics.classification_cps[ self.metrics.classification_cps[
self.model_config.name self.model_config.name
].value = self.classifications_per_second.eps() ].value = self.classifications_per_second.eps()
camera = frame_data.get("camera") camera = str(frame_data.get("camera"))
if camera not in self.model_config.state_config.cameras: if camera not in self.model_config.state_config.cameras:
return return
@ -283,7 +296,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
logger.debug( logger.debug(
f"{self.model_config.name} Ran state classification with probabilities: {probs}" f"{self.model_config.name} Ran state classification with probabilities: {probs}"
) )
best_id = np.argmax(probs) best_id = int(np.argmax(probs))
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now) self.__update_metrics(datetime.datetime.now().timestamp() - now)
@ -319,7 +332,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
verified_state, verified_state,
) )
def handle_request(self, topic, request_data): def handle_request(
self, topic: str, request_data: dict[str, Any]
) -> dict[str, Any] | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name: if request_data.get("model_name") == self.model_config.name:
self.__build_detector() self.__build_detector()
@ -335,7 +350,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
else: else:
return None return None
def expire_object(self, object_id, camera): def expire_object(self, object_id: str, camera: str) -> None:
pass pass
@ -350,13 +365,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
): ):
super().__init__(config, metrics) super().__init__(config, metrics)
self.model_config = model_config self.model_config = model_config
if not self.model_config.name:
raise ValueError("Custom classification model name must be set.")
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
self.interpreter: Interpreter = None self.interpreter: Interpreter | None = None
self.sub_label_publisher = sub_label_publisher self.sub_label_publisher = sub_label_publisher
self.requestor = requestor self.requestor = requestor
self.tensor_input_details: dict[str, Any] | None = None self.tensor_input_details: list[dict[str, Any]] | None = None
self.tensor_output_details: dict[str, Any] | None = None self.tensor_output_details: list[dict[str, Any]] | None = None
self.classification_history: dict[str, list[tuple[str, float, float]]] = {} self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond() self.classifications_per_second = EventsPerSecond()
@ -365,7 +384,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.metrics self.metrics
and self.model_config.name in self.metrics.classification_speeds and self.model_config.name in self.metrics.classification_speeds
): ):
self.inference_speed = InferenceSpeed( self.inference_speed: InferenceSpeed | None = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name] self.metrics.classification_speeds[self.model_config.name]
) )
else: else:
@ -431,8 +450,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
) )
return None, 0.0 return None, 0.0
label_counts = {} label_counts: dict[str, int] = {}
label_scores = {} label_scores: dict[str, list[float]] = {}
total_attempts = len(history) total_attempts = len(history)
for label, score, timestamp in history: for label, score, timestamp in history:
@ -443,7 +462,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
label_counts[label] += 1 label_counts[label] += 1
label_scores[label].append(score) label_scores[label].append(score)
best_label = max(label_counts, key=label_counts.get) best_label = max(label_counts, key=lambda k: label_counts[k])
best_count = label_counts[best_label] best_count = label_counts[best_label]
consensus_threshold = total_attempts * 0.6 consensus_threshold = total_attempts * 0.6
@ -470,7 +489,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
) )
return best_label, avg_score return best_label, avg_score
def process_frame(self, obj_data, frame): def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
if (
not self.model_config.name
or not self.model_config.object_config
or not self.tensor_input_details
or not self.tensor_output_details
):
return
if self.metrics and self.model_config.name in self.metrics.classification_cps: if self.metrics and self.model_config.name in self.metrics.classification_cps:
self.metrics.classification_cps[ self.metrics.classification_cps[
self.model_config.name self.model_config.name
@ -555,7 +582,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
logger.debug( logger.debug(
f"{self.model_config.name} Ran object classification with probabilities: {probs}" f"{self.model_config.name} Ran object classification with probabilities: {probs}"
) )
best_id = np.argmax(probs) best_id = int(np.argmax(probs))
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now) self.__update_metrics(datetime.datetime.now().timestamp() - now)
@ -650,7 +677,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
), ),
) )
def handle_request(self, topic, request_data): def handle_request(self, topic: str, request_data: dict) -> dict | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name: if request_data.get("model_name") == self.model_config.name:
self.__build_detector() self.__build_detector()
@ -666,12 +693,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
else: else:
return None return None
def expire_object(self, object_id, camera): def expire_object(self, object_id: str, camera: str) -> None:
if object_id in self.classification_history: if object_id in self.classification_history:
self.classification_history.pop(object_id) self.classification_history.pop(object_id)
@staticmethod
def write_classification_attempt( def write_classification_attempt(
folder: str, folder: str,
frame: np.ndarray, frame: np.ndarray,