Compare commits

..

2 Commits

Author SHA1 Message Date
Josh Hawkins
4dcd2968b3
consolidate attribute filtering to match non-english and url encoded values (#22002)
Some checks failed
CI / AMD64 Build (push) Has been cancelled
CI / ARM Build (push) Has been cancelled
CI / Jetson Jetpack 6 (push) Has been cancelled
CI / AMD64 Extra Build (push) Has been cancelled
CI / ARM Extra Build (push) Has been cancelled
CI / Synaptics Build (push) Has been cancelled
CI / Assemble and push default build (push) Has been cancelled
2026-02-14 08:33:17 -06:00
Nicolas Mowen
73c1e12faf
Fix saving attributes for object to DB (#22000) 2026-02-14 07:40:08 -06:00
5 changed files with 109 additions and 18 deletions

View File

@ -69,6 +69,25 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=[Tags.events]) router = APIRouter(tags=[Tags.events])
def _build_attribute_filter_clause(attributes: str):
filtered_attributes = [
attr.strip() for attr in attributes.split(",") if attr.strip()
]
attribute_clauses = []
for attr in filtered_attributes:
attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*')
escaped_attr = json.dumps(attr, ensure_ascii=True)[1:-1]
if escaped_attr != attr:
attribute_clauses.append(Event.data.cast("text") % f'*:"{escaped_attr}"*')
if not attribute_clauses:
return None
return reduce(operator.or_, attribute_clauses)
@router.get( @router.get(
"/events", "/events",
response_model=list[EventResponse], response_model=list[EventResponse],
@ -193,13 +212,8 @@ def events(
if attributes != "all": if attributes != "all":
# Custom classification results are stored as data[model_name] = result_value # Custom classification results are stored as data[model_name] = result_value
filtered_attributes = attributes.split(",") attribute_clause = _build_attribute_filter_clause(attributes)
attribute_clauses = [] if attribute_clause is not None:
for attr in filtered_attributes:
attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*')
attribute_clause = reduce(operator.or_, attribute_clauses)
clauses.append(attribute_clause) clauses.append(attribute_clause)
if recognized_license_plate != "all": if recognized_license_plate != "all":
@ -508,7 +522,7 @@ def events_search(
cameras = params.cameras cameras = params.cameras
labels = params.labels labels = params.labels
sub_labels = params.sub_labels sub_labels = params.sub_labels
attributes = params.attributes attributes = unquote(params.attributes)
zones = params.zones zones = params.zones
after = params.after after = params.after
before = params.before before = params.before
@ -607,13 +621,9 @@ def events_search(
if attributes != "all": if attributes != "all":
# Custom classification results are stored as data[model_name] = result_value # Custom classification results are stored as data[model_name] = result_value
filtered_attributes = attributes.split(",") attribute_clause = _build_attribute_filter_clause(attributes)
attribute_clauses = [] if attribute_clause is not None:
event_filters.append(attribute_clause)
for attr in filtered_attributes:
attribute_clauses.append(Event.data.cast("text") % f'*:"{attr}"*')
event_filters.append(reduce(operator.or_, attribute_clauses))
if zones != "all": if zones != "all":
zone_clauses = [] zone_clauses = []

View File

@ -6,6 +6,7 @@ from typing import Dict
from frigate.comms.events_updater import EventEndPublisher, EventUpdateSubscriber from frigate.comms.events_updater import EventEndPublisher, EventUpdateSubscriber
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.config.classification import ObjectClassificationType
from frigate.events.types import EventStateEnum, EventTypeEnum from frigate.events.types import EventStateEnum, EventTypeEnum
from frigate.models import Event from frigate.models import Event
from frigate.util.builtin import to_relative_box from frigate.util.builtin import to_relative_box
@ -247,6 +248,18 @@ class EventProcessor(threading.Thread):
"recognized_license_plate" "recognized_license_plate"
][1] ][1]
# only overwrite attribute-type custom model fields in the database if they're set
for name, model_config in self.config.classification.custom.items():
if (
model_config.object_config
and model_config.object_config.classification_type
== ObjectClassificationType.attribute
):
value = event_data.get(name)
if value is not None:
event[Event.data][name] = value[0]
event[Event.data][f"{name}_score"] = value[1]
( (
Event.insert(event) Event.insert(event)
.on_conflict( .on_conflict(

View File

@ -168,6 +168,57 @@ class TestHttpApp(BaseTestHttp):
assert events[0]["id"] == id assert events[0]["id"] == id
assert events[1]["id"] == id2 assert events[1]["id"] == id2
def test_get_event_list_match_multilingual_attribute(self):
event_id = "123456.zh"
attribute = "中文标签"
with AuthTestClient(self.app) as client:
super().insert_mock_event(event_id, data={"custom_attr": attribute})
events = client.get("/events", params={"attributes": attribute}).json()
assert len(events) == 1
assert events[0]["id"] == event_id
events = client.get(
"/events", params={"attributes": "%E4%B8%AD%E6%96%87%E6%A0%87%E7%AD%BE"}
).json()
assert len(events) == 1
assert events[0]["id"] == event_id
def test_events_search_match_multilingual_attribute(self):
event_id = "123456.zh.search"
attribute = "中文标签"
mock_embeddings = Mock()
mock_embeddings.search_thumbnail.return_value = [(event_id, 0.05)]
self.app.frigate_config.semantic_search.enabled = True
self.app.embeddings = mock_embeddings
with AuthTestClient(self.app) as client:
super().insert_mock_event(event_id, data={"custom_attr": attribute})
events = client.get(
"/events/search",
params={
"search_type": "similarity",
"event_id": event_id,
"attributes": attribute,
},
).json()
assert len(events) == 1
assert events[0]["id"] == event_id
events = client.get(
"/events/search",
params={
"search_type": "similarity",
"event_id": event_id,
"attributes": "%E4%B8%AD%E6%96%87%E6%A0%87%E7%AD%BE",
},
).json()
assert len(events) == 1
assert events[0]["id"] == event_id
def test_get_good_event(self): def test_get_good_event(self):
id = "123456.random" id = "123456.random"

View File

@ -33,6 +33,7 @@ from frigate.config.camera.updater import (
CameraConfigUpdateEnum, CameraConfigUpdateEnum,
CameraConfigUpdateSubscriber, CameraConfigUpdateSubscriber,
) )
from frigate.config.classification import ObjectClassificationType
from frigate.const import ( from frigate.const import (
FAST_QUEUE_TIMEOUT, FAST_QUEUE_TIMEOUT,
UPDATE_CAMERA_ACTIVITY, UPDATE_CAMERA_ACTIVITY,
@ -759,8 +760,16 @@ class TrackedObjectProcessor(threading.Thread):
self.update_mqtt_motion(camera, frame_time, motion_boxes) self.update_mqtt_motion(camera, frame_time, motion_boxes)
attribute_model_names = [
name
for name, model_config in self.config.classification.custom.items()
if model_config.object_config
and model_config.object_config.classification_type
== ObjectClassificationType.attribute
]
tracked_objects = [ tracked_objects = [
o.to_dict() for o in camera_state.tracked_objects.values() o.to_dict(attribute_model_names=attribute_model_names)
for o in camera_state.tracked_objects.values()
] ]
# publish info on this frame # publish info on this frame

View File

@ -376,7 +376,10 @@ class TrackedObject:
) )
return (thumb_update, significant_change, path_update, autotracker_update) return (thumb_update, significant_change, path_update, autotracker_update)
def to_dict(self) -> dict[str, Any]: def to_dict(
self,
attribute_model_names: list[str] | None = None,
) -> dict[str, Any]:
event = { event = {
"id": self.obj_data["id"], "id": self.obj_data["id"],
"camera": self.camera_config.name, "camera": self.camera_config.name,
@ -411,6 +414,11 @@ class TrackedObject:
"path_data": self.path_data.copy(), "path_data": self.path_data.copy(),
"recognized_license_plate": self.obj_data.get("recognized_license_plate"), "recognized_license_plate": self.obj_data.get("recognized_license_plate"),
} }
if attribute_model_names is not None:
for name in attribute_model_names:
value = self.obj_data.get(name)
if value is not None:
event[name] = value
return event return event