fix normalization

This commit is contained in:
Josh Hawkins 2024-10-04 17:18:50 -05:00
parent 980a889001
commit df94a941fc

View File

@ -479,42 +479,75 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else: else:
event_ids = [] event_ids = []
thumb_results = [] thumb_ids = {}
desc_results = [] desc_ids = {}
if search_type == "similarity": if search_type == "similarity":
try: try:
search_event: Event = Event.get(Event.id == event_id) search_event: Event = Event.get(Event.id == event_id)
except DoesNotExist: except DoesNotExist:
return JSONResponse( return JSONResponse(
content=( content={
{ "success": False,
"success": False, "message": "Event not found",
"message": "Event not found", },
}
),
status_code=404, status_code=404,
) )
thumb_results = context.embeddings.search_thumbnail(search_event, limit)
# Get thumbnail results for the specific event
thumb_result = context.embeddings.search_thumbnail(
search_event, event_ids, limit
)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize([result[1] for result in thumb_result]),
)
)
else: else:
search_types = search_type.split(",") search_types = search_type.split(",")
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_results = context.embeddings.search_thumbnail(query, limit) thumb_result = context.embeddings.search_thumbnail(query, event_ids, limit)
logger.info(f"thumb results: {thumb_results}")
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize(
[result[1] for result in thumb_result]
),
)
)
if "description" in search_types: if "description" in search_types:
desc_results = context.embeddings.search_description(query, limit) desc_result = context.embeddings.search_description(query, event_ids, limit)
desc_ids = dict(
zip(
[result[0] for result in desc_result],
context.desc_stats.normalize([result[1] for result in desc_result]),
)
)
results = {} results = {}
for result in thumb_results + desc_results: for event_id in thumb_ids.keys() | desc_ids.keys():
event_id, distance = result[0], result[1] thumb_distance = thumb_ids.get(event_id)
if event_id in event_ids or not event_ids: desc_distance = desc_ids.get(event_id)
if event_id not in results or distance < results[event_id]["distance"]:
results[event_id] = { # Select the minimum distance from the available results
"distance": distance, if thumb_distance is not None and (
"source": "thumbnail" if result in thumb_results else "description", desc_distance is None or thumb_distance < desc_distance
} ):
results[event_id] = {
"distance": thumb_distance,
"source": "thumbnail",
}
elif desc_distance is not None:
results[event_id] = {
"distance": desc_distance,
"source": "description",
}
if not results: if not results:
return JSONResponse(content=[]) return JSONResponse(content=[])