/* * The MIT License (MIT) * * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { /** * Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input * into multibroadcast op with a static output shape attribute. * */ struct find_broadcast_with_dims_static { auto matcher() const { return match::name("broadcast_with_dims")(match::nargs(2), match::arg(0)(match::static_shape()), match::arg(1)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); // read the values of arg(1) to create input to multibroadcast std::vector sizes_vec; inputs.at(1)->eval().visit( [&](auto output) { sizes_vec.assign(output.begin(), output.end()); }); m.replace_instruction( ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0)); } }; /** * Convert a Resize op. with Nearest mode to an implementation using Gather op. * From: resize[scales={...}/sizes={...},](static, constant) * To: * 0 = literal{ ... } computed_indices * ... * 2 = reshape[dims={45}](X) 1-dimensional * 3 = gather[axis=0](2,0) * * At the time of writing, this conversion is required for GPU targets because there * is not direct a GPU implementation of the Resize operation. * This matcher depends on a split_single_dyn_dim pass being run before it, which * will convert any dynamic-batch input to static inputs and make this conversion possible. * * At time of writing, Resize allows either 1 or 2 inputs * but the 1-input case is never created by Onnx parsing. */ struct find_resize_static { auto matcher() const { return match::name("resize")(match::nargs(2), match::arg(0)(match::static_shape()), match::arg(1)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); auto resize_op = any_cast(ins->get_operator()); auto in_lens = inputs.at(0)->get_shape().lens(); std::vector sizes_vec(inputs.at(0)->get_shape().ndim()); std::vector scales_vec(inputs.at(0)->get_shape().ndim()); // populate both scales and sizes for the benefit of the algorithm. inputs.at(1)->eval().visit([&](auto input) { using type = typename decltype(input)::value_type; if constexpr(std::is_integral{}) { // read output sizes and use them to compute scales sizes_vec.assign(input.begin(), input.end()); std::transform( input.begin(), input.end(), in_lens.begin(), scales_vec.begin(), [](auto sz, size_t in_len) { return static_cast(sz) / in_len; }); } else { // read scales and use them to compute output sizes scales_vec.assign(input.begin(), input.end()); std::transform( input.begin(), input.end(), in_lens.begin(), sizes_vec.begin(), [](auto sz, size_t in_len) { return static_cast(sz * in_len); }); } }); auto in_s = inputs.at(0)->get_shape(); shape out_s{in_s.type(), sizes_vec}; std::vector ind(out_s.elements()); // map out_idx to in_idx auto nearest_op = op::resize::get_nearest_op(resize_op.nearest_mode); auto idx_op = op::resize::get_original_idx_op(resize_op.coordinate_transformation_mode); shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) { std::vector in_idx(out_idx_v.size()); for(auto ii = 0; ii < in_lens.size(); ++ii) { auto idx_val = idx_op(in_lens[ii], sizes_vec[ii], out_idx_v[ii], scales_vec[ii]); in_idx[ii] = nearest_op(in_lens[ii], idx_val); } ind[out_idx] = static_cast(in_s.index(in_idx)); }); // reshape input to one-dimension std::vector rsp_lens = {static_cast(in_s.elements())}; auto reshape_op = make_op("reshape", {{"dims", rsp_lens}}); auto rsp = m.insert_instruction(ins, reshape_op, ins->inputs().at(0)); // Add our computed indices as a literal. // ins_ind is a multi dimensional index that will restore original rank shape ind_s{shape::int32_type, sizes_vec}; auto ins_ind = m.add_literal(literal(ind_s, ind)); m.replace_instruction(ins, make_op("gather", {{"axis", 0}}), rsp, ins_ind); } }; /** * Convert 2 input static shape broadcast/multibroadcast into 1 input version. * Some compiler passes (ex. simplify_algebra) only support the 1 input versions * of the broadcasting operators. * From: * broadcast_op(argument_with_static_shape, argument_with_static_shape) * To: * broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims */ struct find_static_2in_broadcasts { auto matcher() const { return match::broadcast(match::nargs(2), match::arg(0)(match::static_shape()), match::arg(1)(match::static_shape())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto out_lens = ins->get_shape().lens(); auto broadcast_op = ins->get_operator(); if(broadcast_op.name() == "broadcast") { broadcast_op.from_value({{"out_lens", out_lens}}); } else { broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}}); } m.replace_instruction(ins, broadcast_op, ins->inputs().at(0)); } }; /** * Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant. * From: * slice(data, constant_input); two attributes set * To: * slice(data); slice.starts, slice.ends. slice.axes set */ struct find_const_2in_slice { auto matcher() const { return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); auto slice_op = any_cast(ins->get_operator()); auto set_attrs = slice_op.get_set_attributes(); std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; if(set_attrs == op::slice::ends_axes) { // slice(data, starts) inputs.at(1)->eval().visit( [&](auto output) { starts_vec.assign(output.begin(), output.end()); }); ends_vec = slice_op.ends; axes_vec = slice_op.axes; } else if(set_attrs == op::slice::starts_axes) { // slice(data, ends) inputs.at(1)->eval().visit( [&](auto output) { ends_vec.assign(output.begin(), output.end()); }); starts_vec = slice_op.starts; axes_vec = slice_op.axes; } else { // slice(data, axes) inputs.at(1)->eval().visit( [&](auto output) { axes_vec.assign(output.begin(), output.end()); }); starts_vec = slice_op.starts; ends_vec = slice_op.ends; } m.replace_instruction( ins, make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), inputs.at(0)); } }; /** * Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant. * From: * slice(data, constant_input1, constant_input2); one attribute set * To: * slice(data); slice.starts, slice.ends. slice.axes set */ struct find_const_3in_slice { auto matcher() const { return match::name("slice")(match::nargs(3), match::arg(1)(match::is_constant()), match::arg(2)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); auto slice_op = any_cast(ins->get_operator()); auto set_attrs = slice_op.get_set_attributes(); std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; if(set_attrs == op::slice::axes_only) { // slice(data, starts, ends) inputs.at(1)->eval().visit( [&](auto output) { starts_vec.assign(output.begin(), output.end()); }); inputs.at(2)->eval().visit( [&](auto output) { ends_vec.assign(output.begin(), output.end()); }); axes_vec = slice_op.axes; } else if(set_attrs == op::slice::ends_only) { // slice(data, starts, axes) inputs.at(1)->eval().visit( [&](auto output) { starts_vec.assign(output.begin(), output.end()); }); inputs.at(2)->eval().visit( [&](auto output) { axes_vec.assign(output.begin(), output.end()); }); ends_vec = slice_op.ends; } else { // slice(data, ends, axes) inputs.at(1)->eval().visit( [&](auto output) { ends_vec.assign(output.begin(), output.end()); }); inputs.at(2)->eval().visit( [&](auto output) { axes_vec.assign(output.begin(), output.end()); }); starts_vec = slice_op.starts; } m.replace_instruction( ins, make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), inputs.at(0)); } }; /** * Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant. * From: * slice(data, constant_starts, constant_ends, constant_axes) * To: * slice(data); slice.starts, slice.ends. slice.axes set */ struct find_const_4in_slice { auto matcher() const { return match::name("slice")(match::nargs(4), match::arg(1)(match::is_constant()), match::arg(2)(match::is_constant()), match::arg(3)(match::is_constant())); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto inputs = ins->inputs(); argument starts_arg = inputs.at(1)->eval(false); argument ends_arg = inputs.at(2)->eval(false); argument axes_arg = inputs.at(3)->eval(false); if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty()) { std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); }); m.replace_instruction( ins, make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), inputs.at(0)); } } }; /** * Simplify dimensions_of to a literal when the input arugment has a static shape * or the dynamic dimensions from `start` to `end` are fixed. */ struct find_static_dimensions_of { auto matcher() const { return match::name("dimensions_of")(); } void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; auto input = ins->inputs().at(0); auto dimensions_of_value = ins->get_operator().to_value(); auto start = dimensions_of_value.at("start").to(); auto end = dimensions_of_value.at("end").to(); if(input->get_shape().dynamic()) { // check if dynamic dimensions from start to end are fixed auto dds = input->get_shape().dyn_dims(); if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) { return not dd.is_fixed(); })) { return; } } std::size_t output_ndim = end - start; std::vector vec_shape(output_ndim); migraphx::shape s(migraphx::shape::int64_type, {output_ndim}); std::vector input_lens = input->get_shape().to_static(1).lens(); std::transform(input_lens.begin() + start, input_lens.begin() + end, vec_shape.begin(), [](auto i) { return int64_t(i); }); migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}}; auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape}); m.replace_instruction(ins, lit_ins); } }; /** * Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1 * argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes. * This matcher can be generalized to matching reshape(data, static_shape_output_tensor). * From: * x = allocate(constant_output_dims) -> reshape(data, x) * To: * reshape(data); reshape.dims = constant_output_dims */ struct find_const_alloc_reshapes { auto matcher() const { auto const_alloc = match::arg(1)(match::name("allocate")(match::is_constant())); return match::name("reshape")(match::nargs(2), const_alloc); } void apply(module& m, const match::matcher_result& mr) const { auto reshape_ins = mr.result; auto reshape_inputs = reshape_ins->inputs(); auto alloc_ins = reshape_inputs.at(1); argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false); std::vector output_dims_vec; output_dims_arg.visit( [&](auto output) { output_dims_vec.assign(output.begin(), output.end()); }); m.replace_instruction( reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0)); // have dead_code_elimination remove the previous allocate } }; /** * Simplify allocate into fill operator that has constant output dimensions and constant value. * The allocate into fill instructions is what is produced when parsing the ONNX * ConstantOfShape operator. This replacement could be handled with propagate_constant, but * would rather have the simplification happen earlier during compiling. * This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor). * From: * x = allocate(constant_ouptut_dims) -> fill(constant_value, x) * To: * literal */ struct find_const_alloc_fill { auto matcher() const { auto const_alloc = match::arg(1)(match::name("allocate")(match::is_constant())); return match::name("fill")(match::arg(0)(match::is_constant()), const_alloc); } void apply(module& m, const match::matcher_result& mr) const { auto fill_ins = mr.result; auto fill_arg = fill_ins->eval(false); auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data()); m.replace_instruction(fill_ins, l); } }; /** * Simplify broadcast_for_dot instructions with two static shaped arguments * From: * broadcast_for_dot(static_shape_arg, static_shape_arg) * To: * multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape */ struct find_static_broadcast_for_dot { auto matcher() const { return match::name("broadcast_for_dot")(match::arg(0)(match::static_shape()), match::arg(1)(match::static_shape())); } void apply(module& m, const match::matcher_result& mr) const { auto broadcast_for_dot_ins = mr.result; auto inputs = broadcast_for_dot_ins->inputs(); auto s0 = inputs.at(0)->get_shape(); auto s1 = inputs.at(1)->get_shape(); auto l0_it = s0.lens().end() - 2; std::vector l0_broadcasted_lens(s0.lens().begin(), l0_it); auto l1_it = s1.lens().begin() + s1.ndim() - 2; std::vector l1_broadcasted_lens(s1.lens().begin(), l1_it); auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); output_lens.insert(output_lens.end(), l0_it, s0.lens().end()); m.replace_instruction(broadcast_for_dot_ins, make_op("multibroadcast", {{"out_lens", output_lens}}), inputs.at(0)); } }; /** * Simplify onehot instructions with static shape `indices` input and * a compile-time constant `depth` attribute or input. * From: * onehot(static_shape_arg, constant_arg, values) or * onehot(static_shape_arg, values) * To: * A = literal(shape = onehot_output_shape, value = 0) * B = unsqueeze(literal(lens = indices_lens, strides = broadcasted scalar, value = 1), * axis=onehot_axis) C = scatter(A, unsqueeze(indices, axis=onehot_axis), B) diff = on_value - * off_value D = mul(diff, C); return = add(D, off_value); * * NOTE: It might be cleaner to use some form of `fill` instead of * (on_value - off_value) * mask + off_value when we have `fill` working * on the GPU. */ struct find_static_onehot { auto matcher() const { auto match_2_args = match::nargs(2)(match::arg(0)(match::static_shape()), match::arg(1)(match::static_shape())); auto match_3_args = match::nargs(3)(match::arg(0)(match::static_shape()), match::arg(1)(match::is_constant()), match::arg(2)(match::static_shape())); return match::name("onehot")(match::any_of(match_2_args, match_3_args)); } void apply(module& m, const match::matcher_result& mr) const { auto onehot_ins = mr.result; auto onehot_inputs = onehot_ins->inputs(); auto onehot_op = any_cast(onehot_ins->get_operator()); auto indices_ins = onehot_inputs[0]; shape indices_shape = indices_ins->get_shape(); std::size_t depth_val; migraphx::instruction_ref values_ins; if(onehot_op.depth.has_value()) { assert(onehot_inputs.size() == 2); depth_val = onehot_op.depth.value(); values_ins = onehot_inputs[1]; } else { assert(onehot_inputs.size() == 3); auto depth_ins = onehot_inputs[1]; depth_ins->eval().visit([&](auto d) { depth_val = d[0]; }); values_ins = onehot_inputs[2]; } shape values_shape = values_ins->get_shape(); std::vector static_output_lens = indices_shape.lens(); auto normalized_axis = (onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis; static_output_lens.insert(static_output_lens.begin() + normalized_axis, depth_val); shape output_shape{values_shape.type(), static_output_lens}; std::vector zeros(output_shape.elements(), 0); auto zeros_lit = m.add_literal(literal(output_shape, zeros)); auto unsqueeze_inds = m.insert_instruction( onehot_ins, migraphx::make_op("unsqueeze", {{"axes", {normalized_axis}}}), indices_ins); // broadcast the one scalar to the correct shape auto ones_lit = m.add_literal(literal(shape{values_shape.type(), {1}, {0}}, {1})); auto mb_ones = m.insert_instruction( onehot_ins, migraphx::make_op("multibroadcast", {{"out_lens", unsqueeze_inds->get_shape().lens()}}), ones_lit); auto mask = m.insert_instruction( onehot_ins, make_op("scatter_none", {{"axis", normalized_axis}, {"skip_out_of_bounds", true}}), zeros_lit, unsqueeze_inds, mb_ones); auto off_val = m.insert_instruction(onehot_ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), values_ins); auto on_val = m.insert_instruction(onehot_ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), values_ins); auto diff_val = m.insert_instruction(onehot_ins, make_op("sub"), on_val, off_val); auto mul_diff_mask = insert_common_op(m, onehot_ins, make_op("mul"), {diff_val, mask}); auto mb_off_val = m.insert_instruction( onehot_ins, make_op("multibroadcast", {{"out_lens", output_shape.lens()}}), off_val); m.replace_instruction(onehot_ins, make_op("add"), mb_off_val, mul_diff_mask); } }; /** * Go through `select_module` instructions and update the `output_dyn_shapes` attribute. * Checks the submodule output shapes and determines an appropriate `output_dyn_shapes` attribute. * This version ignores dynamic_dimension opt values. * Intended to be run after the other simplify_dyn_ops passes. */ struct simplify_select_module_output_shape { auto matcher() const { return match::name("select_module"); } void apply(module& m, const match::matcher_result& mr) const { auto sm_ins = mr.result; auto sm_module_inputs = sm_ins->module_inputs(); std::vector> all_output_shapes(sm_module_inputs.size()); std::transform(sm_module_inputs.begin(), sm_module_inputs.end(), all_output_shapes.begin(), [](auto submod) { return submod->get_output_shapes(); }); // check that all of the submodules have the same number of outputs and all respective // outputs have the same rank and type auto shapes_ndim = get_shapes_ndim(all_output_shapes.front()); auto shapes_types = get_shapes_types(all_output_shapes.front()); if(std::any_of( all_output_shapes.begin() + 1, all_output_shapes.end(), [&](auto out_shapes) { bool same_types = get_shapes_types(out_shapes) == shapes_types; bool same_ndim = get_shapes_ndim(out_shapes) == shapes_ndim; return not same_types or not same_ndim; })) { return; } auto num_out_shapes = shapes_ndim.size(); std::vector dyn_shapes(num_out_shapes); auto num_submod = sm_module_inputs.size(); // compare respective output shapes from each submodule to get a range for the output shape for(int i : range(num_out_shapes)) { std::vector shapes_at_index(num_submod); std::transform(all_output_shapes.begin(), all_output_shapes.end(), shapes_at_index.begin(), [&](auto output_shapes) { return output_shapes.at(i); }); dyn_shapes.at(i) = dyn_shape_from_shapes(shapes_at_index); } auto tuple_shape = shape{dyn_shapes}; m.replace_instruction( sm_ins, make_op("select_module", {{"output_dyn_shapes", to_value(tuple_shape)}}), sm_ins->inputs(), sm_module_inputs); } std::vector get_shapes_ndim(const std::vector& shapes) const { std::vector ret(shapes.size()); std::transform( shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.ndim(); }); return ret; } std::vector get_shapes_types(const std::vector& shapes) const { std::vector ret(shapes.size()); std::transform( shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.type(); }); return ret; } /** * Calculating an appropriate shape that encompasses all of the given vector of shapes. * Equivalent to creating a 2D matrix of shape lengths and do a reduce over each axis. * The shapes can be dynamic or static. * Assuming all shapes have the same ndim. */ shape dyn_shape_from_shapes(std::vector shape_vec) const { // making 2D matrices of min_lens and max_lens // specifically using uint64_t because we're going to put the values into a tensor_view // later std::vector all_min_lens; std::vector all_max_lens; for(const auto& s : shape_vec) { auto min_lens = s.min_lens(); auto max_lens = s.max_lens(); std::copy(min_lens.begin(), min_lens.end(), std::back_inserter(all_min_lens)); std::copy(max_lens.begin(), max_lens.end(), std::back_inserter(all_max_lens)); } assert(all_min_lens.size() == shape_vec.size() * shape_vec.front().ndim()); assert(all_max_lens.size() == shape_vec.size() * shape_vec.front().ndim()); auto num_rows = shape_vec.size(); auto num_cols = shape_vec.front().ndim(); shape tensor_shape{shape::uint64_type, {num_rows, num_cols}}; auto min_lens_matrix = make_view(tensor_shape, all_min_lens.data()); auto max_lens_matrix = make_view(tensor_shape, all_max_lens.data()); std::vector mins(num_cols); std::vector maxes(num_cols); // rearranging data into column vectors to reduce over // i = row, j = column for(int j : range(num_cols)) { std::vector reduce_min_vals(num_rows); std::vector reduce_max_vals(num_rows); for(int i : range(num_rows)) { reduce_min_vals.at(i) = min_lens_matrix(i, j); reduce_max_vals.at(i) = max_lens_matrix(i, j); } uint64_t max_int = std::numeric_limits::max(); uint64_t min_val = std::accumulate(reduce_min_vals.begin(), reduce_min_vals.end(), max_int, [](uint64_t x, uint64_t y) { return x < y ? x : y; }); uint64_t max_val = std::accumulate( reduce_max_vals.begin(), reduce_max_vals.end(), 0, [](uint64_t x, uint64_t y) { return x > y ? x : y; }); mins.at(j) = min_val; maxes.at(j) = max_val; } // fixed output shape case if(mins == maxes) { return shape{shape_vec.front().type(), mins}; } // dynamic output shape case return shape{shape_vec.front().type(), mins, maxes, {}}; } }; void simplify_dyn_ops::apply(module& m) const { match::find_matches(m, find_broadcast_with_dims_static{}, find_resize_static{}, find_static_dimensions_of{}, find_const_alloc_reshapes{}, find_static_2in_broadcasts{}, find_const_2in_slice{}, find_const_3in_slice{}, find_const_4in_slice{}, find_const_alloc_fill{}, find_static_broadcast_for_dot{}, find_static_onehot{}); match::find_matches(m, simplify_select_module_output_shape{}); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx