Only apply normalization to multi modal searches

This commit is contained in:
Josh Hawkins 2024-10-11 12:03:18 -05:00
parent cfc6205537
commit a8b5aebacb

View File

@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
) )
thumb_result = context.search_thumbnail(search_event) thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict( thumb_ids = {result[0]: result[1] for result in thumb_result}
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize([result[1] for result in thumb_result]),
)
)
search_results = { search_results = {
event_id: {"distance": distance, "source": "thumbnail"} event_id: {"distance": distance, "source": "thumbnail"}
for event_id, distance in thumb_ids.items() for event_id, distance in thumb_ids.items()
@ -486,15 +481,23 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else: else:
search_types = search_type.split(",") search_types = search_type.split(",")
# only normalize multi-modal searches
apply_normalization = (
"thumbnail" in search_types and "description" in search_types
)
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_result = context.search_thumbnail(query) thumb_result = context.search_thumbnail(query)
thumb_ids = dict(
zip( if apply_normalization:
[result[0] for result in thumb_result], thumb_distances = context.thumb_stats.normalize(
context.thumb_stats.normalize(
[result[1] for result in thumb_result] [result[1] for result in thumb_result]
),
) )
else:
thumb_distances = [result[1] for result in thumb_result]
thumb_ids = dict(
zip([result[0] for result in thumb_result], thumb_distances)
) )
search_results.update( search_results.update(
{ {
@ -505,12 +508,17 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
if "description" in search_types: if "description" in search_types:
desc_result = context.search_description(query) desc_result = context.search_description(query)
desc_ids = dict( if apply_normalization:
zip( desc_distances = context.desc_stats.normalize(
[result[0] for result in desc_result], [result[1] for result in desc_result]
context.desc_stats.normalize([result[1] for result in desc_result]),
)
) )
else:
desc_distances = [
result[1] for result in desc_result
] # Use raw distances
desc_ids = dict(zip([result[0] for result in desc_result], desc_distances))
for event_id, distance in desc_ids.items(): for event_id, distance in desc_ids.items():
if ( if (
event_id not in search_results event_id not in search_results