/* * The MIT License (MIT) * * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { struct parse_scan : op_parser { std::vector operators() const { return {{"Scan"}}; } std::vector parse(const op_desc& /*opd*/, onnx_parser& parser, onnx_parser::node_info info, std::vector args) const { if(parser.opset_version == 8) MIGRAPHX_THROW("Scan: Opset 8 version not supported"); check_for_required_attributes(info, {"body", "num_scan_inputs"}); const auto& body_graph = info.attributes["body"].g(); auto* body = parser.prog.create_module(info.name + "_scan"); parser.parse_graph(body, body_graph); // Scan has: // N + M inputs (N state variables, M scan inputs) // N + K outputs (N state variables, K scan outputs) // Same input and output counts apply for body auto body_outs = body->get_returns(); const auto m = info.attributes["num_scan_inputs"].i(); const auto n = args.size() - m; const auto k = body_outs.size() - n; std::vector body_params; transform(body->get_parameter_names(), std::back_inserter(body_params), [&](const auto& name) { return body->get_parameter(name); }); if(auto num_body_params = body_params.size(); num_body_params != n + m) MIGRAPHX_THROW("Scan: Number of inputs to body {" + std::to_string(num_body_params) + "} does not match number of inputs to Scan {" + std::to_string(n + m) + "}"); const auto scan_input_axes = parse_axes(info, "scan_input_axes", m, args.begin() + n, 0); const auto scan_input_directions = parse_dirs(info, "scan_input_directions", m); const auto scan_output_axes = parse_axes(info, "scan_output_axes", k, body_outs.begin() + n, 1); const auto scan_output_directions = parse_dirs(info, "scan_output_directions", k); // Check that scan axes lens are the same across all scan inputs size_t num_iters = args[n]->get_shape().lens()[scan_input_axes[0]]; for(auto i = 1; i < m; ++i) if(args[n + i]->get_shape().lens()[scan_input_axes[i]] != num_iters) MIGRAPHX_THROW( "Scan: Lengths of scan_input_axes do not match across all scan inputs.\n" "Scan input shapes: " + to_string_range( to_shapes(std::vector(args.begin() + n, args.end()))) + "\nScan input axes: " + to_string_range(scan_input_axes)); if(num_iters > parser.max_loop_iterations) MIGRAPHX_THROW("Scan: Number of required iterations {" + std::to_string(num_iters) + "} would exceed the maximum iteration limit {" + std::to_string(parser.max_loop_iterations) + "}"); // Check that state variable shapes match between the Scan node and its body attribute for(auto i = 0; i < n; ++i) if(args[i]->get_shape() != body_params[i]->get_shape()) MIGRAPHX_THROW("Scan: State input " + std::to_string(i) + " shape {" + to_string(args[i]->get_shape()) + "} does not match corresponding body input shape {" + to_string(body_params[i]->get_shape()) + "}"); // Check that the shapes of scan inputs sliced across scan input axes match the shapes of // the body attribute scan inputs for(auto i = 0; i < m; ++i) { auto node_shape = args[i + n]->get_shape(); auto node_lens = node_shape.lens(); node_lens.erase(node_lens.begin() + scan_input_axes[i]); auto slice_sh = shape(node_shape.type(), std::move(node_lens)); if(body_params[i + n]->get_shape() != slice_sh) MIGRAPHX_THROW("Slice: Sliced scan input " + std::to_string(i) + " shape {" + to_string(slice_sh) + "} does not match corresponding body input shape {" + to_string(body_params[i + n]->get_shape()) + "}"); } modify_body(body, args, n, m, scan_input_axes, scan_input_directions); auto max_iter_lit = info.add_literal(literal{shape{shape::int64_type}, {num_iters}}); auto cond_lit = info.add_literal(literal{shape{shape::bool_type}, {true}}); std::vector loop_args{max_iter_lit, cond_lit}; loop_args.insert(loop_args.end(), args.begin(), args.begin() + n); auto loop = info.add_instruction(make_op("loop", {{"max_iterations", num_iters}, {"scan_output_directions", scan_output_directions}}), loop_args, {body}); std::vector ret; ret.reserve(n + k); for(auto i = 0; i < n; ++i) ret.push_back(info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), loop)); for(auto i = 0; i < k; ++i) { auto o = info.add_instruction(make_op("get_tuple_elem", {{"index", i + n}}), loop); // Loop concatenates scan axes along axis 0 which is inserted/unsqueezed, e.g. a body // scan output(from a single iteration) of shape {2, 2} is first expanded to {1, 2, 2}, // and then concatenated with body scan outputs from previous iterations. For n // iterations of the loop, this will end up producing a scan output of shape {n, 2, 2}. // // The scan_output_axes attribute of Scan can define an axis other than zero as the // concatenation axis. Using the previous scenario, for a body scan output of // shape {2,2}, with the scan output axis being 1, it is unsqueezed to {2, 1, 2}. The // final concatenation is then of shape {2, n, 2}. // // Since Loop only concatenates along the unsqueezed axis 0, a transpose is necessary to // place axis 0 in the appropriate scan_output_axis position auto perm = make_perm_for_scan_out(o->get_shape().ndim(), scan_output_axes[i]); ret.push_back(info.add_instruction(make_op("transpose", {{"permutation", perm}}), o)); } return ret; } void check_for_required_attributes(onnx_parser::node_info& info, const std::vector& attribute_names) const { auto it = std::find_if( attribute_names.cbegin(), attribute_names.cend(), [&](const std::string& name) { return not contains(info.attributes, name); }); if(it != attribute_names.cend()) MIGRAPHX_THROW("Scan: " + *it + " attribute required"); } std::vector parse_vector_attribute(onnx_parser::node_info& info, const std::string& attr_name, size_t expected_size) const { if(not contains(info.attributes, attr_name)) return {}; std::vector res; auto&& attr = info.attributes[attr_name].ints(); if(attr.size() != expected_size) MIGRAPHX_THROW("Scan: " + attr_name + " size is " + to_string(attr.size()) + ", should be " + to_string(expected_size)); res.assign(attr.begin(), attr.end()); return res; } std::vector parse_dirs(onnx_parser::node_info& info, const std::string& name, size_t expected_size) const { auto dirs = parse_vector_attribute(info, name, expected_size); if(dirs.empty()) return std::vector(expected_size, 0); // NOLINT if(any_of(dirs, [](auto i) { return i != 0 and i != 1; })) MIGRAPHX_THROW("Scan: " + name + " may contain only 1s and 0s, actual values: " + to_string_range(dirs)); return dirs; } int64_t normalize_axis(int64_t axis, int64_t rank, const std::string& attr_name) const { if(axis < -rank or axis >= rank) MIGRAPHX_THROW("Scan: " + attr_name + " axis value {" + to_string(axis) + "} out of range [" + to_string(-rank) + ", " + to_string(rank) + ")"); return axis < 0 ? rank + axis : axis; } std::vector parse_axes(onnx_parser::node_info& info, const std::string& name, long expected_size, std::vector::iterator ins_begin, size_t rank_offset) const { auto axes = parse_vector_attribute(info, name, expected_size); if(axes.empty()) return std::vector(expected_size, 0); // NOLINT std::transform(axes.begin(), axes.end(), ins_begin, axes.begin(), [&](int64_t axis, instruction_ref arg) { return normalize_axis(axis, arg->get_shape().ndim() + rank_offset, name); }); return axes; } // Alter the Scan body to match a body that Loop would expect. // // Loop body inputs: iteration_num, condition, loop_state_variables // Scan body inputs: loop_state_variables, scan_input_slices // iteration_num and condition parameters are prepended to the Scan body parameter list, while // scan_input_slices are removed from parameters. // Instead, scan_inputs are used directly in Scan body(as values from enclosing scope), and // together with iteration_num passed to the scan_slice operator which produces slices that are // used instead of the scan_inputs_slices. // // Loop body outputs: condition, loop_state_variables, scan_output_slices // Scan body outputs: loop_state_variables, scan_output_slices // The inserted Scan body condition parameter is prepended to the Scan body returns void modify_body(module_ref mod, const std::vector& args, int64_t n, int64_t m, const std::vector& scan_input_axes, const std::vector& scan_input_directions) const { std::vector params; params.reserve(n + m); transform(mod->get_parameter_names(), std::back_inserter(params), [&](const std::string& name) { return mod->get_parameter(name); }); // iteration_num, condition, and duplicate loop_state_variables are appended to parameters. // References to the original loop_state_variables in other instructions are then replaced // with references to the duplicate ones, after which the originals are removed. // // References to the scan_input_slices are replaced with references to inserted // scan_slice->squeeze instructions, after which the scan_input_slices parameters are // removed. auto iter_param = mod->add_parameter("iter", shape{shape::int64_type}); auto cond_param = mod->add_parameter("cond", shape{shape::bool_type}); std::vector new_params; new_params.reserve(n); for(auto i = 0; i < n; ++i) new_params.push_back( mod->add_parameter("state_var" + std::to_string(i), params[i]->get_shape())); for(auto i = 0; i < params.size(); ++i) { if(i < n) { mod->replace_instruction(params[i], new_params[i]); } else { auto scan_axis = scan_input_axes[i - n]; auto scan_dir = scan_input_directions[i - n]; auto new_ins = mod->insert_instruction( params[i], make_op("scan_slice", {{"axis", scan_axis}, {"direction", scan_dir}}), args[i], iter_param); new_ins = mod->insert_instruction( params[i], make_op("squeeze", {{"axes", {scan_axis}}}), new_ins); mod->replace_instruction(params[i], new_ins); } mod->remove_instruction(params[i]); } auto returns = mod->get_returns(); returns.insert(returns.begin(), cond_param); mod->replace_return(returns); } // Creates permutation so that axis 0 will be permuted to position axis, while maintaining the // relative ordering of all the other axes. // e.g. for rank = 4, axis = 2, the created perm is: [1, 2, 0, 3] std::vector make_perm_for_scan_out(int64_t rank, int64_t axis) const { std::vector perm(rank); std::iota(perm.begin(), perm.end(), 0); std::copy(perm.begin() + 1, perm.begin() + 1 + axis, perm.begin()); perm[axis] = 0; return perm; } }; } // namespace onnx } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx