hard code limit

This commit is contained in:
Josh Hawkins 2024-10-05 05:45:06 -05:00
parent d189893459
commit 92298e6578

View File

@ -130,7 +130,7 @@ class Embeddings:
) )
def search_thumbnail( def search_thumbnail(
self, query: Union[Event, str], event_ids: List[str] = None, limit=10 self, query: Union[Event, str], event_ids: List[str] = None
) -> List[Tuple[str, float]]: ) -> List[Tuple[str, float]]:
if query.__class__ == Event: if query.__class__ == Event:
cursor = self.db.execute_sql( cursor = self.db.execute_sql(
@ -157,11 +157,11 @@ class Embeddings:
sql_query = """ sql_query = """
SELECT SELECT
vec_thumbnails.id, id,
distance distance
FROM vec_thumbnails FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ? WHERE thumbnail_embedding MATCH ?
AND k = ? AND k = 100
""" """
# Add the IN clause if event_ids is provided and not empty # Add the IN clause if event_ids is provided and not empty
@ -172,9 +172,9 @@ class Embeddings:
sql_query += " ORDER BY distance" sql_query += " ORDER BY distance"
parameters = ( parameters = (
[serialize(query_embedding), limit] + event_ids [serialize(query_embedding)] + event_ids
if event_ids if event_ids
else [serialize(query_embedding), limit] else [serialize(query_embedding)]
) )
results = self.db.execute_sql(sql_query, parameters).fetchall() results = self.db.execute_sql(sql_query, parameters).fetchall()
@ -182,18 +182,18 @@ class Embeddings:
return results return results
def search_description( def search_description(
self, query_text: str, event_ids: List[str] = None, limit=10 self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]: ) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0] query_embedding = self.minilm_embedding([query_text])[0]
# Prepare the base SQL query # Prepare the base SQL query
sql_query = """ sql_query = """
SELECT SELECT
vec_descriptions.id, id,
distance distance
FROM vec_descriptions FROM vec_descriptions
WHERE description_embedding MATCH ? WHERE description_embedding MATCH ?
AND k = ? AND k = 100
""" """
# Add the IN clause if event_ids is provided and not empty # Add the IN clause if event_ids is provided and not empty
@ -204,9 +204,9 @@ class Embeddings:
sql_query += " ORDER BY distance" sql_query += " ORDER BY distance"
parameters = ( parameters = (
[serialize(query_embedding), limit] + event_ids [serialize(query_embedding)] + event_ids
if event_ids if event_ids
else [serialize(query_embedding), limit] else [serialize(query_embedding)]
) )
results = self.db.execute_sql(sql_query, parameters).fetchall() results = self.db.execute_sql(sql_query, parameters).fetchall()