fix normalization

This commit is contained in:
Josh Hawkins 2024-10-04 17:18:50 -05:00
parent 980a889001
commit df94a941fc

View File

@ -479,42 +479,75 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else:
event_ids = []
thumb_results = []
desc_results = []
thumb_ids = {}
desc_ids = {}
if search_type == "similarity":
try:
search_event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
return JSONResponse(
content=(
{
"success": False,
"message": "Event not found",
}
),
content={
"success": False,
"message": "Event not found",
},
status_code=404,
)
thumb_results = context.embeddings.search_thumbnail(search_event, limit)
# Get thumbnail results for the specific event
thumb_result = context.embeddings.search_thumbnail(
search_event, event_ids, limit
)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize([result[1] for result in thumb_result]),
)
)
else:
search_types = search_type.split(",")
if "thumbnail" in search_types:
thumb_results = context.embeddings.search_thumbnail(query, limit)
logger.info(f"thumb results: {thumb_results}")
thumb_result = context.embeddings.search_thumbnail(query, event_ids, limit)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize(
[result[1] for result in thumb_result]
),
)
)
if "description" in search_types:
desc_results = context.embeddings.search_description(query, limit)
desc_result = context.embeddings.search_description(query, event_ids, limit)
desc_ids = dict(
zip(
[result[0] for result in desc_result],
context.desc_stats.normalize([result[1] for result in desc_result]),
)
)
results = {}
for result in thumb_results + desc_results:
event_id, distance = result[0], result[1]
if event_id in event_ids or not event_ids:
if event_id not in results or distance < results[event_id]["distance"]:
results[event_id] = {
"distance": distance,
"source": "thumbnail" if result in thumb_results else "description",
}
for event_id in thumb_ids.keys() | desc_ids.keys():
thumb_distance = thumb_ids.get(event_id)
desc_distance = desc_ids.get(event_id)
# Select the minimum distance from the available results
if thumb_distance is not None and (
desc_distance is None or thumb_distance < desc_distance
):
results[event_id] = {
"distance": thumb_distance,
"source": "thumbnail",
}
elif desc_distance is not None:
results[event_id] = {
"distance": desc_distance,
"source": "description",
}
if not results:
return JSONResponse(content=[])