mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-11 17:47:37 +03:00
Use weighted scoring for object classification
This commit is contained in:
parent
7e34097142
commit
9970e429a1
@ -230,7 +230,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.sub_label_publisher = sub_label_publisher
|
||||
self.tensor_input_details: dict[str, Any] | None = None
|
||||
self.tensor_output_details: dict[str, Any] | None = None
|
||||
self.detected_objects: dict[str, float] = {}
|
||||
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
||||
self.labelmap: dict[int, str] = {}
|
||||
self.classifications_per_second = EventsPerSecond()
|
||||
|
||||
@ -272,6 +272,56 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
if self.inference_speed:
|
||||
self.inference_speed.update(duration)
|
||||
|
||||
def get_weighted_score(
|
||||
self,
|
||||
object_id: str,
|
||||
current_label: str,
|
||||
current_score: float,
|
||||
current_time: float,
|
||||
) -> tuple[str | None, float]:
|
||||
"""
|
||||
Determine weighted score based on history to prevent false positives/negatives.
|
||||
Requires 60% of attempts to agree on a label before publishing.
|
||||
Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score.
|
||||
"""
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
|
||||
self.classification_history[object_id].append(
|
||||
(current_label, current_score, current_time)
|
||||
)
|
||||
|
||||
history = self.classification_history[object_id]
|
||||
|
||||
if len(history) < 3:
|
||||
return None, 0.0
|
||||
|
||||
label_counts = {}
|
||||
label_scores = {}
|
||||
total_attempts = len(history)
|
||||
|
||||
for label, score, timestamp in history:
|
||||
if label not in label_counts:
|
||||
label_counts[label] = 0
|
||||
label_scores[label] = []
|
||||
|
||||
label_counts[label] += 1
|
||||
label_scores[label].append(score)
|
||||
|
||||
best_label = max(label_counts, key=label_counts.get)
|
||||
best_count = label_counts[best_label]
|
||||
|
||||
consensus_threshold = total_attempts * 0.6
|
||||
if best_count < consensus_threshold:
|
||||
return None, 0.0
|
||||
|
||||
avg_score = sum(label_scores[best_label]) / len(label_scores[best_label])
|
||||
|
||||
if best_label == "none":
|
||||
return None, 0.0
|
||||
|
||||
return best_label, avg_score
|
||||
|
||||
def process_frame(self, obj_data, frame):
|
||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||
self.metrics.classification_cps[
|
||||
@ -334,7 +384,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
probs = res / res.sum(axis=0)
|
||||
best_id = np.argmax(probs)
|
||||
score = round(probs[best_id], 2)
|
||||
previous_score = self.detected_objects.get(obj_data["id"], 0.0)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||
|
||||
write_classification_attempt(
|
||||
@ -350,30 +399,34 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
logger.debug(f"Score {score} is less than threshold.")
|
||||
return
|
||||
|
||||
if score <= previous_score:
|
||||
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
||||
return
|
||||
|
||||
sub_label = self.labelmap[best_id]
|
||||
self.detected_objects[obj_data["id"]] = score
|
||||
|
||||
if (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.sub_label
|
||||
):
|
||||
if sub_label != "none":
|
||||
consensus_label, consensus_score = self.get_weighted_score(
|
||||
obj_data["id"], sub_label, score, now
|
||||
)
|
||||
|
||||
if consensus_label is not None:
|
||||
if (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.sub_label
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(obj_data["id"], sub_label, score),
|
||||
(obj_data["id"], consensus_label, consensus_score),
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
)
|
||||
elif (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.attribute
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(obj_data["id"], self.model_config.name, sub_label, score),
|
||||
EventMetadataTypeEnum.attribute.value,
|
||||
)
|
||||
elif (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.attribute
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(
|
||||
obj_data["id"],
|
||||
self.model_config.name,
|
||||
consensus_label,
|
||||
consensus_score,
|
||||
),
|
||||
EventMetadataTypeEnum.attribute.value,
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
@ -391,8 +444,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id, camera):
|
||||
if object_id in self.detected_objects:
|
||||
self.detected_objects.pop(object_id)
|
||||
if object_id in self.classification_history:
|
||||
self.classification_history.pop(object_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user