/* * The MIT License (MIT) * * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { struct grid_sampler { std::string m_padding; bool m_align_corners; instruction_ref m_input; instruction_ref m_grid; size_t m_batch{1}; size_t m_channel{1}; size_t m_in_height{1}; size_t m_in_width{1}; size_t m_out_height{1}; size_t m_out_width{1}; migraphx::shape m_nc_shape; instruction_ref m_one_l; instruction_ref m_two_l; instruction_ref m_zero_l; instruction_ref m_minus_half_l; instruction_ref m_width_l; instruction_ref m_width_max_l; instruction_ref m_height_l; instruction_ref m_height_max_l; instruction_ref m_unnorm_x; instruction_ref m_unnorm_y; grid_sampler(const instruction_ref& input, const instruction_ref& grid, bool align, std::string&& padding, const onnx_parser::node_info& info) : m_padding(std::move(padding)), m_align_corners(align), m_input(input), m_grid(grid) { auto i_lens = input->get_shape().lens(); m_batch = i_lens.at(0); m_channel = i_lens.at(1); m_in_height = i_lens.at(2); m_in_width = i_lens.at(3); auto g_lens = grid->get_shape().lens(); m_out_height = g_lens.at(1); m_out_width = g_lens.at(2); auto type = m_grid->get_shape().type(); m_nc_shape = migraphx::shape{type, {1, 2}}; m_zero_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {0.0f}}); m_one_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {1.0f}}); m_two_l = info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type}, {2}}); m_minus_half_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {-0.5f}}); m_width_max_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {m_in_width - 1}}); m_width_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {m_in_width}}); m_height_max_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {m_in_height - 1}}); m_height_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {m_in_height}}); auto x_coords = info.add_instruction( make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {1}}}), m_grid); auto y_coords = info.add_instruction( make_op("slice", {{"axes", {3}}, {"starts", {1}}, {"ends", {2}}}), m_grid); x_coords = info.add_instruction(make_op("squeeze", {{"axes", {3}}}), x_coords); y_coords = info.add_instruction(make_op("squeeze", {{"axes", {3}}}), y_coords); m_unnorm_x = unnormalize(info, x_coords, m_in_width); m_unnorm_y = unnormalize(info, y_coords, m_in_height); if(m_padding == "reflection") { auto corner_start = m_align_corners ? m_zero_l : m_minus_half_l; m_unnorm_x = reflect_coordinates( info, m_unnorm_x, m_align_corners ? m_width_max_l : m_width_l, corner_start); m_unnorm_y = reflect_coordinates( info, m_unnorm_y, m_align_corners ? m_height_max_l : m_height_l, corner_start); m_unnorm_x = info.add_common_op("clip", m_unnorm_x, m_zero_l, m_width_max_l); m_unnorm_y = info.add_common_op("clip", m_unnorm_y, m_zero_l, m_height_max_l); } if(m_padding == "border") { m_unnorm_x = info.add_common_op("clip", m_unnorm_x, m_zero_l, m_width_max_l); m_unnorm_y = info.add_common_op("clip", m_unnorm_y, m_zero_l, m_height_max_l); } } instruction_ref reflect_coordinates(const onnx_parser::node_info& info, instruction_ref coords, instruction_ref size, instruction_ref corner_start) const { auto index_align_corner = info.add_common_op("sub", corner_start, coords); index_align_corner = info.add_common_op("abs", index_align_corner); auto size_times = info.add_common_op("floor", index_align_corner); size_times = info.add_common_op("div", size_times, size); size_times = info.add_common_op("floor", size_times); auto cond = info.add_common_op("mod", size_times, m_two_l); cond = info.add_common_op("equal", cond, m_zero_l); auto extra = info.add_common_op("mul", size_times, size); extra = info.add_common_op("sub", index_align_corner, extra); auto cond_true = info.add_common_op("add", extra, corner_start); auto cond_false = info.add_common_op("sub", size, extra); cond_false = info.add_common_op("add", cond_false, corner_start); return info.add_common_op("where", cond, cond_true, cond_false); } instruction_ref unnormalize(const onnx_parser::node_info& info, const instruction_ref& coords_t, float size) const { auto unnorm = info.add_common_op("add", coords_t, m_one_l); if(m_align_corners) { // unnorm_x = (x + 1) * (size - 1) / 2 auto mul_const = info.add_literal( migraphx::literal{migraphx::shape{coords_t->get_shape().type()}, {(size - 1) / 2}}); unnorm = info.add_common_op("mul", unnorm, mul_const); } else { // unnorm_x = -0.5 + (x + 1) * size / 2 auto mul_const = info.add_literal( migraphx::literal{migraphx::shape{coords_t->get_shape().type()}, {size / 2}}); unnorm = info.add_common_op("mul", unnorm, mul_const); unnorm = info.add_common_op("add", unnorm, m_minus_half_l); } return unnorm; } static instruction_ref concat_on_first_dim(const onnx_parser::node_info& info, std::vector instructions) { return std::accumulate( std::next(instructions.begin()), instructions.end(), instructions.front(), [&info](auto& ret, auto& ins) { return info.add_instruction(make_op("concat", {{"axis", 0}}), ret, ins); }); } static instruction_ref concat_on_dim(const onnx_parser::node_info& info, std::array instructions, int64_t dim) { return std::accumulate( std::next(instructions.begin()), instructions.end(), instructions.front(), [&info, &dim](auto& ret, auto& ins) { return info.add_instruction(make_op("concat", {{"axis", dim}}), ret, ins); }); } inline bool has_border_padding() const { return m_padding == "border"; } }; struct nearest_sampler : grid_sampler { instruction_ref m_round_x; instruction_ref m_round_y; nearest_sampler(const instruction_ref& input, const instruction_ref& grid, bool align, std::string&& padding, const onnx_parser::node_info& info) : grid_sampler(input, grid, align, std::move(padding), info), m_round_x(info.add_common_op("nearbyint", m_unnorm_x)), m_round_y(info.add_common_op("nearbyint", m_unnorm_y)) { } instruction_ref sample(const onnx_parser::node_info& info) { std::vector hw_indices; std::vector nc_values; const static auto nhw_shape = migraphx::shape{migraphx::shape::int64_type, {1, 3}}; bool validate = not has_border_padding(); dfor(m_batch, m_out_height, m_out_width)([&](auto n, auto h, auto w) { auto nhw = info.add_literal(migraphx::literal{nhw_shape, {n, h, w}}); for(size_t c = 0; c < m_channel; c++) { hw_indices.push_back(nhw); nc_values.push_back(info.add_literal(migraphx::literal{m_nc_shape, {n, c}})); } }); auto hw_indices_t = concat_on_first_dim(info, hw_indices); auto h_samples = info.add_instruction(make_op("gathernd"), m_round_y, hw_indices_t); auto w_samples = info.add_instruction(make_op("gathernd"), m_round_x, hw_indices_t); instruction_ref validation; if(validate) { auto h_clip = info.add_common_op("clip", h_samples, m_zero_l, m_height_max_l); auto w_clip = info.add_common_op("clip", w_samples, m_zero_l, m_width_max_l); auto h_valid = info.add_common_op("equal", h_samples, h_clip); auto w_valid = info.add_common_op("equal", w_samples, w_clip); validation = info.add_common_op("logical_and", h_valid, w_valid); h_samples = h_clip; w_samples = w_clip; } auto nc = concat_on_first_dim(info, nc_values); h_samples = info.add_instruction( make_op("reshape", {{"dims", {h_samples->get_shape().elements(), 1}}}), h_samples); w_samples = info.add_instruction( make_op("reshape", {{"dims", {w_samples->get_shape().elements(), 1}}}), w_samples); auto indices_t = info.add_instruction(make_op("concat", {{"axis", 1}}), h_samples, w_samples); indices_t = info.add_instruction(make_op("concat", {{"axis", 1}}), nc, indices_t); auto samples = info.add_instruction(make_op("gathernd"), m_input, indices_t); if(validate) { samples = info.add_common_op("where", validation, samples, m_zero_l); } samples = info.add_instruction( make_op("reshape", {{"dims", {m_batch, m_out_height, m_out_width, m_channel}}}), samples); samples = info.add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), samples); samples = info.add_instruction( make_op("convert", {{"target_type", m_input->get_shape().type()}}), samples); return samples; } }; struct linear_sampler : grid_sampler { instruction_ref m_floor_x; instruction_ref m_floor_y; instruction_ref m_ceil_x; instruction_ref m_ceil_y; std::array m_corner_weights; linear_sampler(const instruction_ref& input, const instruction_ref& grid, bool align, std::string&& padding, const onnx_parser::node_info& info) : grid_sampler(input, grid, align, std::move(padding), info), m_floor_x(info.add_common_op("floor", m_unnorm_x)), m_floor_y(info.add_common_op("floor", m_unnorm_y)), m_ceil_x(info.add_common_op("add", m_floor_x, m_one_l)), m_ceil_y(info.add_common_op("add", m_floor_y, m_one_l)) { auto fract_x = info.add_common_op("sub", m_unnorm_x, m_floor_x); auto fract_y = info.add_common_op("sub", m_unnorm_y, m_floor_y); auto one_minus_fract_x = info.add_common_op("sub", m_one_l, fract_x); auto one_minus_fract_y = info.add_common_op("sub", m_one_l, fract_y); m_corner_weights[0] = info.add_common_op("mul", one_minus_fract_y, one_minus_fract_x); m_corner_weights[1] = info.add_common_op("mul", one_minus_fract_y, fract_x); m_corner_weights[2] = info.add_common_op("mul", fract_y, one_minus_fract_x); m_corner_weights[3] = info.add_common_op("mul", fract_y, fract_x); } instruction_ref sample(const onnx_parser::node_info& info) { std::vector weight_indices; std::vector xy_indices; std::vector nc_values; const static auto nhw_shape = migraphx::shape{migraphx::shape::int64_type, {1, 3}}; dfor(m_batch, m_out_height, m_out_width)([&](auto n, auto h, auto w) { auto nhw = info.add_literal(migraphx::literal{nhw_shape, {n, h, w}}); weight_indices.push_back(nhw); for(size_t c = 0; c < m_channel; c++) { xy_indices.push_back(nhw); nc_values.push_back(info.add_literal(migraphx::literal{m_nc_shape, {n, c}})); } }); auto xy_indices_t = concat_on_first_dim(info, xy_indices); auto y0_samples = info.add_instruction(make_op("gathernd"), m_floor_y, xy_indices_t); auto x0_samples = info.add_instruction(make_op("gathernd"), m_floor_x, xy_indices_t); auto y1_samples = info.add_instruction(make_op("gathernd"), m_ceil_y, xy_indices_t); auto x1_samples = info.add_instruction(make_op("gathernd"), m_ceil_x, xy_indices_t); auto validate_samples = [&](auto& samples, auto& max) { auto clip = info.add_common_op("clip", samples, m_zero_l, max); auto validation = info.add_common_op("equal", samples, clip); samples = clip; return validation; }; auto y0_validation = validate_samples(y0_samples, m_height_max_l); auto x0_validation = validate_samples(x0_samples, m_width_max_l); auto y1_validation = validate_samples(y1_samples, m_height_max_l); auto x1_validation = validate_samples(x1_samples, m_width_max_l); y0_samples = info.add_instruction( make_op("reshape", {{"dims", {y0_samples->get_shape().elements(), 1}}}), y0_samples); x0_samples = info.add_instruction( make_op("reshape", {{"dims", {x0_samples->get_shape().elements(), 1}}}), x0_samples); y1_samples = info.add_instruction( make_op("reshape", {{"dims", {y1_samples->get_shape().elements(), 1}}}), y1_samples); x1_samples = info.add_instruction( make_op("reshape", {{"dims", {x1_samples->get_shape().elements(), 1}}}), x1_samples); auto nc = concat_on_first_dim(info, nc_values); auto make_corner_indices = [&](auto& x, auto& y) { auto hw = info.add_instruction(make_op("concat", {{"axis", 1}}), y, x); return info.add_instruction(make_op("concat", {{"axis", 1}}), nc, hw); }; std::array corner_indices{make_corner_indices(x0_samples, y0_samples), make_corner_indices(x1_samples, y0_samples), make_corner_indices(x0_samples, y1_samples), make_corner_indices(x1_samples, y1_samples)}; std::array corner_validations{ info.add_common_op("logical_and", x0_validation, y0_validation), info.add_common_op("logical_and", x1_validation, y0_validation), info.add_common_op("logical_and", x0_validation, y1_validation), info.add_common_op("logical_and", x1_validation, y1_validation)}; std::array corner_samples; auto weight_index_t = concat_on_first_dim(info, weight_indices); weight_index_t = info.add_instruction( make_op("reshape", {{"dims", {weight_indices.size(), 3}}}), weight_index_t); std::transform(corner_indices.begin(), corner_indices.end(), corner_validations.begin(), corner_samples.begin(), [&](const auto& indices, const auto& validations) { auto samples = info.add_instruction(make_op("gathernd"), m_input, indices); return info.add_common_op("where", validations, samples, m_zero_l); }); std::transform(corner_samples.begin(), corner_samples.end(), m_corner_weights.begin(), corner_samples.begin(), [&](const auto& samples, const auto& weights) { auto weights_t = info.add_instruction(make_op("gathernd"), weights, weight_index_t); return info.add_instruction(make_op("mul"), samples, weights_t); }); auto samples = std::accumulate( std::next(corner_samples.begin()), corner_samples.end(), corner_samples.front(), [&](auto acc, auto s) { return info.add_instruction(make_op("add"), acc, s); }); samples = info.add_instruction( make_op("reshape", {{"dims", {m_batch, m_out_height, m_out_width, m_channel}}}), samples); samples = info.add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), samples); samples = info.add_instruction( make_op("convert", {{"target_type", m_input->get_shape().type()}}), samples); return samples; } }; struct bicubic_sampler : grid_sampler { instruction_ref m_a_l; instruction_ref m_aplus2_l; instruction_ref m_aplus3_l; instruction_ref m_4a_l; instruction_ref m_5a_l; instruction_ref m_8a_l; std::array m_x_weights; std::array m_y_weights; std::array m_x_corners; std::array m_y_corners; bicubic_sampler(const instruction_ref& input, const instruction_ref& grid, bool align, std::string&& padding, const onnx_parser::node_info& info) : grid_sampler(input, grid, align, std::move(padding), info) { auto type = m_grid->get_shape().type(); m_a_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {-0.75}}); m_aplus2_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {1.25}}); m_aplus3_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {2.25}}); m_4a_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {-3.0}}); m_5a_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {-3.75}}); m_8a_l = info.add_literal(migraphx::literal{migraphx::shape{type}, {-6.0}}); auto floor_x = info.add_common_op("floor", m_unnorm_x); auto floor_y = info.add_common_op("floor", m_unnorm_y); auto fract_x = info.add_common_op("sub", m_unnorm_x, floor_x); auto fract_y = info.add_common_op("sub", m_unnorm_y, floor_y); m_x_weights[0] = cubic_weight_2(info, info.add_common_op("add", fract_x, m_one_l)); m_x_weights[1] = cubic_weight_1(info, fract_x); m_x_weights[2] = cubic_weight_1(info, info.add_common_op("sub", m_one_l, fract_x)); m_x_weights[3] = cubic_weight_2(info, info.add_common_op("sub", m_two_l, fract_x)); m_y_weights[0] = cubic_weight_2(info, info.add_common_op("add", fract_y, m_one_l)); m_y_weights[1] = cubic_weight_1(info, fract_y); m_y_weights[2] = cubic_weight_1(info, info.add_common_op("sub", m_one_l, fract_y)); m_y_weights[3] = cubic_weight_2(info, info.add_common_op("sub", m_two_l, fract_y)); m_x_corners[0] = info.add_common_op("sub", floor_x, m_one_l); m_x_corners[1] = floor_x; m_x_corners[2] = info.add_common_op("add", floor_x, m_one_l); m_x_corners[3] = info.add_common_op("add", floor_x, m_two_l); m_y_corners[0] = info.add_common_op("sub", floor_y, m_one_l); m_y_corners[1] = floor_y; m_y_corners[2] = info.add_common_op("add", floor_y, m_one_l); m_y_corners[3] = info.add_common_op("add", floor_y, m_two_l); if(m_padding == "reflection") { auto corner_start = m_align_corners ? m_zero_l : m_minus_half_l; std::transform( m_x_corners.begin(), m_x_corners.end(), m_x_corners.begin(), [&](const auto& corner) { auto tmp = reflect_coordinates( info, corner, m_align_corners ? m_width_max_l : m_width_l, corner_start); return info.add_common_op("clip", tmp, m_zero_l, m_width_max_l); }); std::transform( m_y_corners.begin(), m_y_corners.end(), m_y_corners.begin(), [&](const auto& corner) { auto tmp = reflect_coordinates( info, corner, m_align_corners ? m_height_max_l : m_height_l, corner_start); return info.add_common_op("clip", tmp, m_zero_l, m_height_max_l); }); } if(m_padding == "border") { std::transform( m_x_corners.begin(), m_x_corners.end(), m_x_corners.begin(), [&](auto& corner) { return info.add_common_op("clip", corner, m_zero_l, m_width_max_l); }); std::transform( m_y_corners.begin(), m_y_corners.end(), m_y_corners.begin(), [&](auto& corner) { return info.add_common_op("clip", corner, m_zero_l, m_height_max_l); }); } } instruction_ref cubic_weight_1(const onnx_parser::node_info& info, const instruction_ref& ins) const { //((A + 2) * fraction - (A + 3)) * fraction * fraction + 1 auto mul_1 = info.add_common_op("mul", m_aplus2_l, ins); auto sub = info.add_common_op("sub", mul_1, m_aplus3_l); auto mul_2 = info.add_common_op("mul", sub, ins); auto mul_3 = info.add_common_op("mul", mul_2, ins); return info.add_common_op("add", mul_3, m_one_l); } instruction_ref cubic_weight_2(const onnx_parser::node_info& info, const instruction_ref& ins) const { // ((A * fraction - 5 * A) * fraction + 8 * A) * fraction - (4 * A) auto mul_1 = info.add_common_op("mul", m_a_l, ins); auto sub_1 = info.add_common_op("sub", mul_1, m_5a_l); auto mul_2 = info.add_common_op("mul", sub_1, ins); auto add = info.add_common_op("add", mul_2, m_8a_l); auto mul_3 = info.add_common_op("mul", add, ins); return info.add_common_op("sub", mul_3, m_4a_l); } static instruction_ref compute_weights(const onnx_parser::node_info& info, const std::vector& weight_indices, const std::array& weights, const std::vector& out_lens, size_t gather_dim) { auto weight_indices_t = concat_on_first_dim(info, weight_indices); weight_indices_t = info.add_instruction( make_op( "reshape", {{"dims", {weight_indices_t->get_shape().elements() / gather_dim, gather_dim}}}), weight_indices_t); std::array corner_weights; std::transform(weights.cbegin(), weights.cend(), corner_weights.begin(), [&](auto& corner) { auto corner_weight = info.add_instruction(make_op("gathernd"), corner, weight_indices_t); return info.add_instruction( make_op("reshape", {{"dims", {corner_weight->get_shape().elements(), 1}}}), corner_weight); }); auto weights_t = std::accumulate( std::next(corner_weights.begin()), corner_weights.end(), corner_weights.front(), [&info](auto& acc, auto& ins) { return info.add_instruction(make_op("concat", {{"axis", 1}}), acc, ins); }); return info.add_instruction(make_op("reshape", {{"dims", out_lens}}), weights_t); } instruction_ref sample(const onnx_parser::node_info& info) { std::vector x_weight_indices; std::vector y_weight_indices; std::vector inner_x_indices; std::vector nc_values; std::vector inner_indices; const static auto nhw_shape = migraphx::shape{migraphx::shape::int64_type, {3}}; dfor(m_batch, m_out_height, m_out_width)([&](auto n, auto h, auto w) { auto nhw = info.add_literal(migraphx::literal{nhw_shape, {n, h, w}}); x_weight_indices.insert(x_weight_indices.end(), {nhw, nhw, nhw, nhw}); y_weight_indices.push_back(nhw); dfor(m_channel, m_y_corners.size())([&](auto c, auto) { inner_indices.push_back(nhw); auto nc = info.add_literal(migraphx::literal{m_nc_shape, {n, c}}); nc_values.insert(nc_values.end(), {nc, nc, nc, nc}); }); }); auto inner_indices_t = concat_on_first_dim(info, inner_indices); inner_indices_t = info.add_instruction( make_op("reshape", {{"dims", {inner_indices_t->get_shape().elements() / nhw_shape.elements(), nhw_shape.elements()}}}), inner_indices_t); std::array inner_y_samples; std::transform( m_y_corners.begin(), m_y_corners.end(), inner_y_samples.begin(), [&](auto corner) { auto sample = info.add_instruction(make_op("gathernd"), corner, inner_indices_t); return info.add_instruction( make_op("reshape", {{"dims", {sample->get_shape().elements(), 1}}}), sample); }); auto inner_y_t = concat_on_dim(info, inner_y_samples, 1); auto elements = inner_y_t->get_shape().elements(); inner_y_t = info.add_instruction(make_op("reshape", {{"dims", {elements / 16, 4, 4}}}), inner_y_t); inner_y_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1}}}), inner_y_t); inner_y_t = info.add_instruction(make_op("reshape", {{"dims", {elements}}}), inner_y_t); std::array inner_x_samples; std::transform( m_x_corners.begin(), m_x_corners.end(), inner_x_samples.begin(), [&](auto corner) { auto sample = info.add_instruction(make_op("gathernd"), corner, inner_indices_t); return info.add_instruction( make_op("reshape", {{"dims", {sample->get_shape().elements(), 1}}}), sample); }); auto inner_x_t = concat_on_dim(info, inner_x_samples, 1); inner_x_t = info.add_instruction( make_op("reshape", {{"dims", {inner_x_t->get_shape().elements()}}}), inner_x_t); auto validate_index = [&](auto& index, auto& max) { auto clip = info.add_common_op("clip", index, m_zero_l, max); auto validation = info.add_common_op("equal", index, clip); index = clip; return validation; }; auto y_validation = validate_index(inner_y_t, m_height_max_l); auto x_validation = validate_index(inner_x_t, m_width_max_l); inner_y_t = info.add_instruction( make_op("reshape", {{"dims", {inner_y_t->get_shape().elements(), 1}}}), inner_y_t); inner_x_t = info.add_instruction( make_op("reshape", {{"dims", {inner_x_t->get_shape().elements(), 1}}}), inner_x_t); auto nc_t = concat_on_first_dim(info, nc_values); auto indices_t = info.add_instruction(make_op("concat", {{"axis", 1}}), inner_y_t, inner_x_t); indices_t = info.add_instruction(make_op("concat", {{"axis", 1}}), nc_t, indices_t); auto samples = info.add_instruction(make_op("gathernd"), m_input, indices_t); auto validation_t = info.add_common_op("logical_and", y_validation, x_validation); samples = info.add_common_op("where", validation_t, samples, m_zero_l); auto x_weights_t = compute_weights( info, x_weight_indices, m_x_weights, samples->get_shape().lens(), nhw_shape.elements()); auto weighted_samples = info.add_common_op("mul", samples, x_weights_t); weighted_samples = info.add_instruction( make_op("reshape", {{"dims", {weighted_samples->get_shape().elements() / 4, 4}}}), weighted_samples); auto coefficients = info.add_instruction(make_op("reduce_sum", {{"axes", {1}}}), weighted_samples); coefficients = info.add_instruction(make_op("squeeze", {{"axes", {1}}}), coefficients); auto y_weights_t = compute_weights(info, y_weight_indices, m_y_weights, coefficients->get_shape().lens(), nhw_shape.elements()); auto weighted_coefficients = info.add_common_op("mul", coefficients, y_weights_t); weighted_coefficients = info.add_instruction( make_op("reshape", {{"dims", {weighted_coefficients->get_shape().elements() / 4, 4}}}), weighted_coefficients); auto res = info.add_instruction(make_op("reduce_sum", {{"axes", {1}}}), weighted_coefficients); auto expected_shape = migraphx::shape{migraphx::shape::int64_type, {m_batch, m_out_height, m_out_width, m_channel}}; res = info.add_instruction( make_op("reshape", {{"dims", {m_batch, m_out_height, m_out_width, m_channel}}}), res); res = info.add_instruction(make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), res); res = info.add_instruction( make_op("convert", {{"target_type", m_input->get_shape().type()}}), res); return res; } }; struct parse_gridsample : op_parser { std::vector operators() const { return {{"GridSample"}}; } instruction_ref parse(const op_desc& /*opd*/, const onnx_parser& parser, const onnx_parser::node_info& info, std::vector args) const { bool align_corners = false; // Note: defult mode can be linear or bilinear depending on the onnx version std::string mode = "linear"; std::string padding_mode = "zeros"; if(contains(info.attributes, "align_corners")) { align_corners = parser.parse_value(info.attributes.at("align_corners")).at(); } if(contains(info.attributes, "mode")) { mode = info.attributes.at("mode").s(); } if(contains(info.attributes, "padding_mode")) { padding_mode = info.attributes.at("padding_mode").s(); } const auto& grid = args.at(1); const auto& grid_shape = grid->get_shape(); if(not is_type_float(grid_shape.type())) { MIGRAPHX_THROW("PARSE_GRID_SAMPLE: grid input must have floating type"); } const auto& x = args.at(0); const auto& x_dims = x->get_shape().lens().size(); if(grid_shape.lens().size() != x_dims) { MIGRAPHX_THROW( "PARSE_GRID_SAMPLE: x and grid inputs must have same number of dimensions"); } if(x_dims != 4) { MIGRAPHX_THROW("PARSE_GRID_SAMPLE: only 4-D inputs are supported"); } return contains(mode, "nearest") ? nearest_sampler(x, grid, align_corners, std::move(padding_mode), info) .sample(info) : (contains(mode, "linear") ? linear_sampler(x, grid, align_corners, std::move(padding_mode), info) .sample(info) : bicubic_sampler(x, grid, align_corners, std::move(padding_mode), info) .sample(info)); } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx