Only apply normalization to multi modal searches

This commit is contained in:
Josh Hawkins 2024-10-11 12:03:18 -05:00
parent cfc6205537
commit a8b5aebacb

View File

@ -473,12 +473,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
)
thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize([result[1] for result in thumb_result]),
)
)
thumb_ids = {result[0]: result[1] for result in thumb_result}
search_results = {
event_id: {"distance": distance, "source": "thumbnail"}
for event_id, distance in thumb_ids.items()
@ -486,15 +481,23 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
else:
search_types = search_type.split(",")
# only normalize multi-modal searches
apply_normalization = (
"thumbnail" in search_types and "description" in search_types
)
if "thumbnail" in search_types:
thumb_result = context.search_thumbnail(query)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
context.thumb_stats.normalize(
if apply_normalization:
thumb_distances = context.thumb_stats.normalize(
[result[1] for result in thumb_result]
),
)
else:
thumb_distances = [result[1] for result in thumb_result]
thumb_ids = dict(
zip([result[0] for result in thumb_result], thumb_distances)
)
search_results.update(
{
@ -505,12 +508,17 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
if "description" in search_types:
desc_result = context.search_description(query)
desc_ids = dict(
zip(
[result[0] for result in desc_result],
context.desc_stats.normalize([result[1] for result in desc_result]),
)
if apply_normalization:
desc_distances = context.desc_stats.normalize(
[result[1] for result in desc_result]
)
else:
desc_distances = [
result[1] for result in desc_result
] # Use raw distances
desc_ids = dict(zip([result[0] for result in desc_result], desc_distances))
for event_id, distance in desc_ids.items():
if (
event_id not in search_results