diff --git a/frigate/util/model.py b/frigate/util/model.py index 64e1a28d62..218b3a4dd5 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -1,6 +1,5 @@ """Model Utils""" -import functools import logging import os from typing import Any @@ -18,10 +17,13 @@ logger = logging.getLogger(__name__) ### Post Processing -@functools.lru_cache -def nanodet_center_priors( +def calculate_nanodet_center_priors( input_height: int, input_width: int, strides: tuple, dtype: type ): + """ + Adapted from https://github.com/RangiLyu/nanodet/blob/be9b4a9001d7f9b6fc89c2df31ae8d428e35b4f0/nanodet/model/head/nanodet_plus_head.py + """ + def get_single_level_center_priors(featmap_size, stride, dtype): """Generate centers of a single stage feature map. Args: @@ -63,6 +65,9 @@ def nanodet_center_priors( return center_priors +nanodet_center_priors: dict[(int, int, tuple, type), np.ndarray] = {} + + def post_process_dfine( tensor_output: np.ndarray, width: int, height: int ) -> np.ndarray: @@ -332,6 +337,10 @@ def post_process_nanodet_plus( width: int, height: int, ): + """ + Adapted from https://github.com/RangiLyu/nanodet/blob/be9b4a9001d7f9b6fc89c2df31ae8d428e35b4f0/nanodet/model/head/nanodet_plus_head.py + """ + def distance2bbox(points, distance, max_shape=None): """Decode distance prediction to bounding box. @@ -357,14 +366,24 @@ def post_process_nanodet_plus( predictions = predictions[0] - # TODO From parameters + # Below two parameters are consistent with all nanodet **plus** models reg_max = 7 strides = (8, 16, 32, 64) num_classes = predictions.shape[-1] - 4 * (reg_max + 1) cls_scores, bbox_preds = predictions[:, :num_classes], predictions[:, num_classes:] - center_priors = nanodet_center_priors(height, width, strides, predictions[0].dtype) + try: + center_priors = nanodet_center_priors[ + (height, width, strides, predictions[0].dtype) + ] + except KeyError: + center_priors = calculate_nanodet_center_priors( + height, width, strides, predictions[0].dtype + ) + nanodet_center_priors[(height, width, strides, predictions[0].dtype)] = ( + center_priors + ) x = bbox_preds.reshape(bbox_preds.shape[0], 4, reg_max + 1) x = scipy.special.softmax(x, axis=-1)