diff --git a/docs/docs/integrations/mqtt.md b/docs/docs/integrations/mqtt.md index 535e1bb4b..05b0ccecd 100644 --- a/docs/docs/integrations/mqtt.md +++ b/docs/docs/integrations/mqtt.md @@ -173,7 +173,8 @@ Message published when [object classification](/configuration/custom_classificat "timestamp": 1607123958.748393, "model": "person_classifier", "sub_label": "delivery_person", - "score": 0.87 + "score": 0.87, + "zones": ["front_yard", "driveway"] } ``` @@ -187,10 +188,15 @@ Message published when [object classification](/configuration/custom_classificat "timestamp": 1607123958.748393, "model": "helmet_detector", "attribute": "yes", - "score": 0.92 + "score": 0.92, + "zones": ["front_yard"] } ``` +:::note +The `zones` field is only included if the tracked object is currently in one or more zones. +::: + ### `frigate/reviews` Message published for each changed review item. The first message is published when the `detection` or `alert` is initiated. diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 229383d9f..a3650e523 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -613,19 +613,20 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): (object_id, consensus_label, consensus_score), EventMetadataTypeEnum.sub_label, ) + classification_data = { + "type": TrackedObjectUpdateTypesEnum.classification, + "id": object_id, + "camera": camera, + "timestamp": now, + "model": self.model_config.name, + "sub_label": consensus_label, + "score": consensus_score, + } + if obj_data.get("current_zones"): + classification_data["zones"] = obj_data["current_zones"] self.requestor.send_data( "tracked_object_update", - json.dumps( - { - "type": TrackedObjectUpdateTypesEnum.classification, - "id": object_id, - "camera": camera, - "timestamp": now, - "model": self.model_config.name, - "sub_label": consensus_label, - "score": consensus_score, - } - ), + json.dumps(classification_data), ) elif ( self.model_config.object_config.classification_type @@ -640,19 +641,20 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ), EventMetadataTypeEnum.attribute.value, ) + classification_data = { + "type": TrackedObjectUpdateTypesEnum.classification, + "id": object_id, + "camera": camera, + "timestamp": now, + "model": self.model_config.name, + "attribute": consensus_label, + "score": consensus_score, + } + if obj_data.get("current_zones"): + classification_data["zones"] = obj_data["current_zones"] self.requestor.send_data( "tracked_object_update", - json.dumps( - { - "type": TrackedObjectUpdateTypesEnum.classification, - "id": object_id, - "camera": camera, - "timestamp": now, - "model": self.model_config.name, - "attribute": consensus_label, - "score": consensus_score, - } - ), + json.dumps(classification_data), ) def handle_request(self, topic, request_data): diff --git a/frigate/test/test_custom_classification.py b/frigate/test/test_custom_classification.py new file mode 100644 index 000000000..116d10694 --- /dev/null +++ b/frigate/test/test_custom_classification.py @@ -0,0 +1,166 @@ +import json +import unittest +from unittest.mock import MagicMock + + +class TestCustomObjectClassificationZones(unittest.TestCase): + """Test that zone information is correctly added to custom classification MQTT messages""" + + def test_sub_label_message_includes_zones_when_present(self): + """Test that zones are included in sub_label classification messages when object is in zones""" + # Create a simple mock requestor + requestor = MagicMock() + + # Create mock obj_data with zones + obj_data = { + "id": "test_object_123", + "camera": "front_door", + "current_zones": ["driveway", "front_yard"], + } + + # Simulate what the processor does when publishing sub_label classification + classification_data = { + "type": "classification", + "id": obj_data["id"], + "camera": obj_data["camera"], + "timestamp": 1234567890.0, + "model": "test_classifier", + "sub_label": "person_walking", + "score": 0.89, + } + if obj_data.get("current_zones"): + classification_data["zones"] = obj_data["current_zones"] + + requestor.send_data("tracked_object_update", json.dumps(classification_data)) + + # Verify that send_data was called + requestor.send_data.assert_called_once() + + # Get the actual call arguments + call_args = requestor.send_data.call_args + topic = call_args[0][0] + data_json = call_args[0][1] + + # Verify the topic + self.assertEqual(topic, "tracked_object_update") + + # Parse and verify the data + data = json.loads(data_json) + self.assertEqual(data["type"], "classification") + self.assertEqual(data["id"], "test_object_123") + self.assertEqual(data["camera"], "front_door") + self.assertEqual(data["model"], "test_classifier") + self.assertEqual(data["sub_label"], "person_walking") + self.assertIn("zones", data) + self.assertEqual(data["zones"], ["driveway", "front_yard"]) + + def test_sub_label_message_excludes_zones_when_empty(self): + """Test that zones are not included when object is not in any zones""" + requestor = MagicMock() + + # Create mock obj_data without zones + obj_data = { + "id": "test_object_456", + "camera": "back_door", + "current_zones": [], + } + + # Simulate what the processor does when publishing sub_label classification + classification_data = { + "type": "classification", + "id": obj_data["id"], + "camera": obj_data["camera"], + "timestamp": 1234567890.0, + "model": "test_classifier", + "sub_label": "person_running", + "score": 0.87, + } + if obj_data.get("current_zones"): + classification_data["zones"] = obj_data["current_zones"] + + requestor.send_data("tracked_object_update", json.dumps(classification_data)) + + # Get the actual call arguments + call_args = requestor.send_data.call_args + data_json = call_args[0][1] + + # Parse and verify the data + data = json.loads(data_json) + self.assertNotIn("zones", data) + + def test_attribute_message_includes_zones_when_present(self): + """Test that zones are included in attribute classification messages when object is in zones""" + requestor = MagicMock() + + # Create mock obj_data with zones + obj_data = { + "id": "test_object_789", + "camera": "construction_site", + "current_zones": ["site_entrance"], + } + + # Simulate what the processor does when publishing attribute classification + classification_data = { + "type": "classification", + "id": obj_data["id"], + "camera": obj_data["camera"], + "timestamp": 1234567890.0, + "model": "helmet_detector", + "attribute": "wearing_helmet", + "score": 0.92, + } + if obj_data.get("current_zones"): + classification_data["zones"] = obj_data["current_zones"] + + requestor.send_data("tracked_object_update", json.dumps(classification_data)) + + # Get the actual call arguments + call_args = requestor.send_data.call_args + data_json = call_args[0][1] + + # Parse and verify the data + data = json.loads(data_json) + self.assertEqual(data["type"], "classification") + self.assertEqual(data["id"], "test_object_789") + self.assertEqual(data["camera"], "construction_site") + self.assertEqual(data["model"], "helmet_detector") + self.assertEqual(data["attribute"], "wearing_helmet") + self.assertIn("zones", data) + self.assertEqual(data["zones"], ["site_entrance"]) + + def test_attribute_message_excludes_zones_when_missing(self): + """Test that zones are not included when current_zones key is missing""" + requestor = MagicMock() + + # Create mock obj_data without current_zones key + obj_data = { + "id": "test_object_999", + "camera": "parking_lot", + } + + # Simulate what the processor does when publishing attribute classification + classification_data = { + "type": "classification", + "id": obj_data["id"], + "camera": obj_data["camera"], + "timestamp": 1234567890.0, + "model": "vehicle_type", + "attribute": "sedan", + "score": 0.95, + } + if obj_data.get("current_zones"): + classification_data["zones"] = obj_data["current_zones"] + + requestor.send_data("tracked_object_update", json.dumps(classification_data)) + + # Get the actual call arguments + call_args = requestor.send_data.call_args + data_json = call_args[0][1] + + # Parse and verify the data + data = json.loads(data_json) + self.assertNotIn("zones", data) + + +if __name__ == "__main__": + unittest.main()