From cdafb318fd11de33a6a049cf344a2dd02bcf16f4 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Fri, 4 Oct 2024 17:18:56 -0500 Subject: [PATCH] add id filter --- frigate/embeddings/embeddings.py | 55 ++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 90e97efa7..55ac6447b 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -130,7 +130,7 @@ class Embeddings: ) def search_thumbnail( - self, query: Union[Event, str], limit=10 + self, query: Union[Event, str], event_ids: List[str] = None, limit=10 ) -> List[Tuple[str, float]]: if query.__class__ == Event: cursor = self.db.execute_sql( @@ -155,35 +155,62 @@ class Embeddings: else: query_embedding = self.clip_embedding([query])[0] - results = self.db.execute_sql( - """ + sql_query = """ SELECT vec_thumbnails.id, distance FROM vec_thumbnails WHERE thumbnail_embedding MATCH ? AND k = ? - ORDER BY distance - """, - (serialize(query_embedding), limit), - ).fetchall() + """ + + # Add the IN clause if event_ids is provided and not empty + # this is the only filter supported by sqlite-vec as of 0.1.3 + if event_ids: + sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids))) + + sql_query += " ORDER BY distance" + logger.info(f"thumb query: {sql_query}") + + parameters = ( + [serialize(query_embedding), limit] + event_ids + if event_ids + else [serialize(query_embedding), limit] + ) + + results = self.db.execute_sql(sql_query, parameters).fetchall() return results - def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]: + def search_description( + self, query_text: str, event_ids: List[str] = None, limit=10 + ) -> List[Tuple[str, float]]: query_embedding = self.minilm_embedding([query_text])[0] - results = self.db.execute_sql( - """ + + # Prepare the base SQL query + sql_query = """ SELECT vec_descriptions.id, distance FROM vec_descriptions WHERE description_embedding MATCH ? AND k = ? - ORDER BY distance - """, - (serialize(query_embedding), limit), - ).fetchall() + """ + + # Add the IN clause if event_ids is provided and not empty + # this is the only filter supported by sqlite-vec as of 0.1.3 + if event_ids: + sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids))) + + sql_query += " ORDER BY distance" + + parameters = ( + [serialize(query_embedding), limit] + event_ids + if event_ids + else [serialize(query_embedding), limit] + ) + + results = self.db.execute_sql(sql_query, parameters).fetchall() return results