add id filter

This commit is contained in:
Josh Hawkins 2024-10-04 17:18:56 -05:00
parent df94a941fc
commit cdafb318fd

View File

@ -130,7 +130,7 @@ class Embeddings:
) )
def search_thumbnail( def search_thumbnail(
self, query: Union[Event, str], limit=10 self, query: Union[Event, str], event_ids: List[str] = None, limit=10
) -> List[Tuple[str, float]]: ) -> List[Tuple[str, float]]:
if query.__class__ == Event: if query.__class__ == Event:
cursor = self.db.execute_sql( cursor = self.db.execute_sql(
@ -155,35 +155,62 @@ class Embeddings:
else: else:
query_embedding = self.clip_embedding([query])[0] query_embedding = self.clip_embedding([query])[0]
results = self.db.execute_sql( sql_query = """
"""
SELECT SELECT
vec_thumbnails.id, vec_thumbnails.id,
distance distance
FROM vec_thumbnails FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ? WHERE thumbnail_embedding MATCH ?
AND k = ? AND k = ?
ORDER BY distance """
""",
(serialize(query_embedding), limit), # Add the IN clause if event_ids is provided and not empty
).fetchall() # this is the only filter supported by sqlite-vec as of 0.1.3
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
sql_query += " ORDER BY distance"
logger.info(f"thumb query: {sql_query}")
parameters = (
[serialize(query_embedding), limit] + event_ids
if event_ids
else [serialize(query_embedding), limit]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results return results
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]: def search_description(
self, query_text: str, event_ids: List[str] = None, limit=10
) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0] query_embedding = self.minilm_embedding([query_text])[0]
results = self.db.execute_sql(
""" # Prepare the base SQL query
sql_query = """
SELECT SELECT
vec_descriptions.id, vec_descriptions.id,
distance distance
FROM vec_descriptions FROM vec_descriptions
WHERE description_embedding MATCH ? WHERE description_embedding MATCH ?
AND k = ? AND k = ?
ORDER BY distance """
""",
(serialize(query_embedding), limit), # Add the IN clause if event_ids is provided and not empty
).fetchall() # this is the only filter supported by sqlite-vec as of 0.1.3
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding), limit] + event_ids
if event_ids
else [serialize(query_embedding), limit]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results return results