mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
fix normalization
This commit is contained in:
parent
980a889001
commit
df94a941fc
@ -479,42 +479,75 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
else:
|
||||
event_ids = []
|
||||
|
||||
thumb_results = []
|
||||
desc_results = []
|
||||
thumb_ids = {}
|
||||
desc_ids = {}
|
||||
|
||||
if search_type == "similarity":
|
||||
try:
|
||||
search_event: Event = Event.get(Event.id == event_id)
|
||||
except DoesNotExist:
|
||||
return JSONResponse(
|
||||
content=(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Event not found",
|
||||
}
|
||||
),
|
||||
content={
|
||||
"success": False,
|
||||
"message": "Event not found",
|
||||
},
|
||||
status_code=404,
|
||||
)
|
||||
thumb_results = context.embeddings.search_thumbnail(search_event, limit)
|
||||
|
||||
# Get thumbnail results for the specific event
|
||||
thumb_result = context.embeddings.search_thumbnail(
|
||||
search_event, event_ids, limit
|
||||
)
|
||||
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in thumb_result],
|
||||
context.thumb_stats.normalize([result[1] for result in thumb_result]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
search_types = search_type.split(",")
|
||||
|
||||
if "thumbnail" in search_types:
|
||||
thumb_results = context.embeddings.search_thumbnail(query, limit)
|
||||
logger.info(f"thumb results: {thumb_results}")
|
||||
thumb_result = context.embeddings.search_thumbnail(query, event_ids, limit)
|
||||
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in thumb_result],
|
||||
context.thumb_stats.normalize(
|
||||
[result[1] for result in thumb_result]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if "description" in search_types:
|
||||
desc_results = context.embeddings.search_description(query, limit)
|
||||
desc_result = context.embeddings.search_description(query, event_ids, limit)
|
||||
|
||||
desc_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in desc_result],
|
||||
context.desc_stats.normalize([result[1] for result in desc_result]),
|
||||
)
|
||||
)
|
||||
|
||||
results = {}
|
||||
for result in thumb_results + desc_results:
|
||||
event_id, distance = result[0], result[1]
|
||||
if event_id in event_ids or not event_ids:
|
||||
if event_id not in results or distance < results[event_id]["distance"]:
|
||||
results[event_id] = {
|
||||
"distance": distance,
|
||||
"source": "thumbnail" if result in thumb_results else "description",
|
||||
}
|
||||
for event_id in thumb_ids.keys() | desc_ids.keys():
|
||||
thumb_distance = thumb_ids.get(event_id)
|
||||
desc_distance = desc_ids.get(event_id)
|
||||
|
||||
# Select the minimum distance from the available results
|
||||
if thumb_distance is not None and (
|
||||
desc_distance is None or thumb_distance < desc_distance
|
||||
):
|
||||
results[event_id] = {
|
||||
"distance": thumb_distance,
|
||||
"source": "thumbnail",
|
||||
}
|
||||
elif desc_distance is not None:
|
||||
results[event_id] = {
|
||||
"distance": desc_distance,
|
||||
"source": "description",
|
||||
}
|
||||
|
||||
if not results:
|
||||
return JSONResponse(content=[])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user