add id filter

This commit is contained in:
Josh Hawkins 2024-10-04 17:18:56 -05:00
parent df94a941fc
commit cdafb318fd

View File

@ -130,7 +130,7 @@ class Embeddings:
)
def search_thumbnail(
self, query: Union[Event, str], limit=10
self, query: Union[Event, str], event_ids: List[str] = None, limit=10
) -> List[Tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
@ -155,35 +155,62 @@ class Embeddings:
else:
query_embedding = self.clip_embedding([query])[0]
results = self.db.execute_sql(
"""
sql_query = """
SELECT
vec_thumbnails.id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
(serialize(query_embedding), limit),
).fetchall()
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
sql_query += " ORDER BY distance"
logger.info(f"thumb query: {sql_query}")
parameters = (
[serialize(query_embedding), limit] + event_ids
if event_ids
else [serialize(query_embedding), limit]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
def search_description(
self, query_text: str, event_ids: List[str] = None, limit=10
) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0]
results = self.db.execute_sql(
"""
# Prepare the base SQL query
sql_query = """
SELECT
vec_descriptions.id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
(serialize(query_embedding), limit),
).fetchall()
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding), limit] + event_ids
if event_ids
else [serialize(query_embedding), limit]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results