mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
fix normalization
This commit is contained in:
parent
980a889001
commit
df94a941fc
@ -479,42 +479,75 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
|||||||
else:
|
else:
|
||||||
event_ids = []
|
event_ids = []
|
||||||
|
|
||||||
thumb_results = []
|
thumb_ids = {}
|
||||||
desc_results = []
|
desc_ids = {}
|
||||||
|
|
||||||
if search_type == "similarity":
|
if search_type == "similarity":
|
||||||
try:
|
try:
|
||||||
search_event: Event = Event.get(Event.id == event_id)
|
search_event: Event = Event.get(Event.id == event_id)
|
||||||
except DoesNotExist:
|
except DoesNotExist:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=(
|
content={
|
||||||
{
|
"success": False,
|
||||||
"success": False,
|
"message": "Event not found",
|
||||||
"message": "Event not found",
|
},
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=404,
|
status_code=404,
|
||||||
)
|
)
|
||||||
thumb_results = context.embeddings.search_thumbnail(search_event, limit)
|
|
||||||
|
# Get thumbnail results for the specific event
|
||||||
|
thumb_result = context.embeddings.search_thumbnail(
|
||||||
|
search_event, event_ids, limit
|
||||||
|
)
|
||||||
|
|
||||||
|
thumb_ids = dict(
|
||||||
|
zip(
|
||||||
|
[result[0] for result in thumb_result],
|
||||||
|
context.thumb_stats.normalize([result[1] for result in thumb_result]),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
search_types = search_type.split(",")
|
search_types = search_type.split(",")
|
||||||
|
|
||||||
if "thumbnail" in search_types:
|
if "thumbnail" in search_types:
|
||||||
thumb_results = context.embeddings.search_thumbnail(query, limit)
|
thumb_result = context.embeddings.search_thumbnail(query, event_ids, limit)
|
||||||
logger.info(f"thumb results: {thumb_results}")
|
|
||||||
|
thumb_ids = dict(
|
||||||
|
zip(
|
||||||
|
[result[0] for result in thumb_result],
|
||||||
|
context.thumb_stats.normalize(
|
||||||
|
[result[1] for result in thumb_result]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if "description" in search_types:
|
if "description" in search_types:
|
||||||
desc_results = context.embeddings.search_description(query, limit)
|
desc_result = context.embeddings.search_description(query, event_ids, limit)
|
||||||
|
|
||||||
|
desc_ids = dict(
|
||||||
|
zip(
|
||||||
|
[result[0] for result in desc_result],
|
||||||
|
context.desc_stats.normalize([result[1] for result in desc_result]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
for result in thumb_results + desc_results:
|
for event_id in thumb_ids.keys() | desc_ids.keys():
|
||||||
event_id, distance = result[0], result[1]
|
thumb_distance = thumb_ids.get(event_id)
|
||||||
if event_id in event_ids or not event_ids:
|
desc_distance = desc_ids.get(event_id)
|
||||||
if event_id not in results or distance < results[event_id]["distance"]:
|
|
||||||
results[event_id] = {
|
# Select the minimum distance from the available results
|
||||||
"distance": distance,
|
if thumb_distance is not None and (
|
||||||
"source": "thumbnail" if result in thumb_results else "description",
|
desc_distance is None or thumb_distance < desc_distance
|
||||||
}
|
):
|
||||||
|
results[event_id] = {
|
||||||
|
"distance": thumb_distance,
|
||||||
|
"source": "thumbnail",
|
||||||
|
}
|
||||||
|
elif desc_distance is not None:
|
||||||
|
results[event_id] = {
|
||||||
|
"distance": desc_distance,
|
||||||
|
"source": "description",
|
||||||
|
}
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return JSONResponse(content=[])
|
return JSONResponse(content=[])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user