This commit is contained in:
Nicolas Mowen 2025-04-15 06:52:12 -06:00
parent 0ce2f6d411
commit a5f6c754de
2 changed files with 3 additions and 0 deletions

View File

@ -26,6 +26,7 @@ class InputTensorEnum(str, Enum):
nchw = "nchw" nchw = "nchw"
nhwc = "nhwc" nhwc = "nhwc"
hwnc = "hwnc" hwnc = "hwnc"
hwcn = "hwcn"
class InputDTypeEnum(str, Enum): class InputDTypeEnum(str, Enum):

View File

@ -73,3 +73,5 @@ def tensor_transform(desired_shape: InputTensorEnum):
return (0, 3, 1, 2) return (0, 3, 1, 2)
elif desired_shape == InputTensorEnum.hwnc: elif desired_shape == InputTensorEnum.hwnc:
return (1, 2, 0, 3) return (1, 2, 0, 3)
elif desired_shape == InputTensorEnum.hwcn:
return (1, 2, 3, 0)