Replace functool lru cache with dict cache

This commit is contained in:
knoffelcut 2026-05-25 19:55:25 +02:00
parent 5a87e0180a
commit a2ecbc9d6b

View File

@ -1,6 +1,5 @@
"""Model Utils"""
import functools
import logging
import os
from typing import Any
@ -18,10 +17,13 @@ logger = logging.getLogger(__name__)
### Post Processing
@functools.lru_cache
def nanodet_center_priors(
def calculate_nanodet_center_priors(
input_height: int, input_width: int, strides: tuple, dtype: type
):
"""
Adapted from https://github.com/RangiLyu/nanodet/blob/be9b4a9001d7f9b6fc89c2df31ae8d428e35b4f0/nanodet/model/head/nanodet_plus_head.py
"""
def get_single_level_center_priors(featmap_size, stride, dtype):
"""Generate centers of a single stage feature map.
Args:
@ -63,6 +65,9 @@ def nanodet_center_priors(
return center_priors
nanodet_center_priors: dict[(int, int, tuple, type), np.ndarray] = {}
def post_process_dfine(
tensor_output: np.ndarray, width: int, height: int
) -> np.ndarray:
@ -332,6 +337,10 @@ def post_process_nanodet_plus(
width: int,
height: int,
):
"""
Adapted from https://github.com/RangiLyu/nanodet/blob/be9b4a9001d7f9b6fc89c2df31ae8d428e35b4f0/nanodet/model/head/nanodet_plus_head.py
"""
def distance2bbox(points, distance, max_shape=None):
"""Decode distance prediction to bounding box.
@ -357,14 +366,24 @@ def post_process_nanodet_plus(
predictions = predictions[0]
# TODO From parameters
# Below two parameters are consistent with all nanodet **plus** models
reg_max = 7
strides = (8, 16, 32, 64)
num_classes = predictions.shape[-1] - 4 * (reg_max + 1)
cls_scores, bbox_preds = predictions[:, :num_classes], predictions[:, num_classes:]
center_priors = nanodet_center_priors(height, width, strides, predictions[0].dtype)
try:
center_priors = nanodet_center_priors[
(height, width, strides, predictions[0].dtype)
]
except KeyError:
center_priors = calculate_nanodet_center_priors(
height, width, strides, predictions[0].dtype
)
nanodet_center_priors[(height, width, strides, predictions[0].dtype)] = (
center_priors
)
x = bbox_preds.reshape(bbox_preds.shape[0], 4, reg_max + 1)
x = scipy.special.softmax(x, axis=-1)