mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
add id filter
This commit is contained in:
parent
df94a941fc
commit
cdafb318fd
@ -130,7 +130,7 @@ class Embeddings:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def search_thumbnail(
|
def search_thumbnail(
|
||||||
self, query: Union[Event, str], limit=10
|
self, query: Union[Event, str], event_ids: List[str] = None, limit=10
|
||||||
) -> List[Tuple[str, float]]:
|
) -> List[Tuple[str, float]]:
|
||||||
if query.__class__ == Event:
|
if query.__class__ == Event:
|
||||||
cursor = self.db.execute_sql(
|
cursor = self.db.execute_sql(
|
||||||
@ -155,35 +155,62 @@ class Embeddings:
|
|||||||
else:
|
else:
|
||||||
query_embedding = self.clip_embedding([query])[0]
|
query_embedding = self.clip_embedding([query])[0]
|
||||||
|
|
||||||
results = self.db.execute_sql(
|
sql_query = """
|
||||||
"""
|
|
||||||
SELECT
|
SELECT
|
||||||
vec_thumbnails.id,
|
vec_thumbnails.id,
|
||||||
distance
|
distance
|
||||||
FROM vec_thumbnails
|
FROM vec_thumbnails
|
||||||
WHERE thumbnail_embedding MATCH ?
|
WHERE thumbnail_embedding MATCH ?
|
||||||
AND k = ?
|
AND k = ?
|
||||||
ORDER BY distance
|
"""
|
||||||
""",
|
|
||||||
(serialize(query_embedding), limit),
|
# Add the IN clause if event_ids is provided and not empty
|
||||||
).fetchall()
|
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||||
|
if event_ids:
|
||||||
|
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||||
|
|
||||||
|
sql_query += " ORDER BY distance"
|
||||||
|
logger.info(f"thumb query: {sql_query}")
|
||||||
|
|
||||||
|
parameters = (
|
||||||
|
[serialize(query_embedding), limit] + event_ids
|
||||||
|
if event_ids
|
||||||
|
else [serialize(query_embedding), limit]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
|
def search_description(
|
||||||
|
self, query_text: str, event_ids: List[str] = None, limit=10
|
||||||
|
) -> List[Tuple[str, float]]:
|
||||||
query_embedding = self.minilm_embedding([query_text])[0]
|
query_embedding = self.minilm_embedding([query_text])[0]
|
||||||
results = self.db.execute_sql(
|
|
||||||
"""
|
# Prepare the base SQL query
|
||||||
|
sql_query = """
|
||||||
SELECT
|
SELECT
|
||||||
vec_descriptions.id,
|
vec_descriptions.id,
|
||||||
distance
|
distance
|
||||||
FROM vec_descriptions
|
FROM vec_descriptions
|
||||||
WHERE description_embedding MATCH ?
|
WHERE description_embedding MATCH ?
|
||||||
AND k = ?
|
AND k = ?
|
||||||
ORDER BY distance
|
"""
|
||||||
""",
|
|
||||||
(serialize(query_embedding), limit),
|
# Add the IN clause if event_ids is provided and not empty
|
||||||
).fetchall()
|
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||||
|
if event_ids:
|
||||||
|
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||||
|
|
||||||
|
sql_query += " ORDER BY distance"
|
||||||
|
|
||||||
|
parameters = (
|
||||||
|
[serialize(query_embedding), limit] + event_ids
|
||||||
|
if event_ids
|
||||||
|
else [serialize(query_embedding), limit]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user