/* * The MIT License (MIT) * * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) struct fused_reduce { std::vector axes{}; template static auto reflect(Self& self, F f) { return pack(f(self.axes, "axes")); } shape compute_shape(const std::vector& inputs, std::vector mods) const { if(mods.size() != 1) MIGRAPHX_THROW("should have one submodule."); const auto* sm = mods.front(); if(sm->get_output_shapes().size() != 1) MIGRAPHX_THROW("Only one output supported"); if(not sm->bypass()) MIGRAPHX_THROW("fused_reduce: bypass flag is not set"); auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); std::sort(names.begin(), names.end()); auto shapes = sm->get_parameter_shapes(); // Check dimension matches for each input if(not equal(names, inputs, [&](const auto& name, const auto& input) { return shapes.at(name).lens() == input.lens(); })) MIGRAPHX_THROW("Input dimension does not match the submodule."); return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), find_permutation(inputs)); } std::string name() const { return "fused_reduce"; } }; MIGRAPHX_REGISTER_OP(fused_reduce); /* * Predicate matcher checks that input and output shapes have the same rank. This is assumed * for broadcast instructions for these fusions. */ MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins) { auto input_shape = ins->inputs().front()->get_shape(); auto output_shape = ins->get_shape(); return input_shape.ndim() == output_shape.ndim(); } static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map* map_ins = nullptr, module::inserter insert = nullptr) { assert(ins->module_inputs().size() == 1); return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); } static void create_reduce_modules(module_pass_manager& mpm) { std::size_t n = 0; for(auto ins : iterator_for(mpm.get_module())) { if(not ins->get_operator().attributes().get("reduce", false)) continue; if(ins->inputs().size() != 1) continue; auto* rm = mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++)); rm->set_bypass(); rm->add_return(rm->fuse({ins})); auto v = ins->get_operator().to_value(); mpm.get_module().replace_instruction( ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm}); } } namespace { instruction_ref get_broadcast_output(instruction_ref broadcast) { if(broadcast->outputs().size() != 1) return broadcast; auto output = broadcast->outputs().front(); if(output->name() == "contiguous") return get_broadcast_output(output); return output; } MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins) { if(ins->outputs().size() == 1) return true; if(ins->outputs().size() == 2) { auto is_broadcast = [](instruction_ref output) { return contains(output->name(), "broadcast"); }; auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast); if(broadcast == ins->outputs().end()) return false; auto non_broadcast = std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast); if(non_broadcast == ins->outputs().end()) return false; auto output = get_broadcast_output(*broadcast); return output == *non_broadcast; } return false; } } // namespace template static auto match_broadcast(Ms... ms) { return match::skip(match::name("contiguous"))( match::name("multibroadcast")( match::arg(0)(ms...), match::used_once(), input_output_ndim_match()) .bind("broadcast")) .bind("final_broadcast"); } template static auto any_input(Ms... ms) { return match::any_of[match::inputs()](match::any(ms...).bind("input")); } bool is_valid_broadcast(const instruction_ref b, const std::vector& reduce_axes) { std::vector broadcast_axes; auto bstrides = b->get_shape().strides(); for(size_t i = 0; i < bstrides.size(); ++i) { if(bstrides.at(i) == 0) broadcast_axes.push_back(i); } return broadcast_axes == reduce_axes; } template static auto match_broadcast_axes(M m) { return match::make_basic_fun_matcher( [=](match::matcher_context& ctx, instruction_ref ins) -> optional { optional result = m.match(ctx, ins); if(contains(ctx.instructions, "broadcast")) { instruction_ref reduce; if(ins->get_operator().name() == "fused_reduce") { reduce = ins; } else { assert(contains(ctx.instructions, "reduce")); reduce = ctx.instructions["reduce"]; } auto axes = reduce->get_operator().to_value().at("axes").to_vector(); auto broadcast = ctx.instructions["broadcast"]; if(not is_valid_broadcast(broadcast, axes)) return nullopt; } return result; }); } static auto match_broadcastable_input(const std::string& op, const std::string& name) { auto match_op = match::name(op)(used_once_except_broadcast()).bind(name); auto match_op_input = any_input(match_op, match::used_once()); auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once()); return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input)); } static void finalize_reduce_module(module_ref m) { eliminate_common_subexpression{}.apply(*m); dead_code_elimination{}.apply(*m); } namespace { struct find_pointwise_reduce { auto matcher() const { // fused_reduce instruction with pointwise inputs. return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise")); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto reduce = r.result; auto input = r.instructions["pointwise"]; const auto* pm = input->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front(); auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name()); rm->set_bypass(); std::unordered_map map_ins; // Insert pointwise auto rins = rm->fuse({input}, &map_ins).front(); map_ins[input] = rins; if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; auto fbroadcast = r.instructions["final_broadcast"]; map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front(); if(fbroadcast != broadcast) map_ins[fbroadcast] = map_ins[broadcast]; } // Insert fused_reduce rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins)); finalize_reduce_module(rm); auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); } }; struct find_reduce_pointwise { auto matcher() const { return match::name("pointwise")(match_broadcastable_input("fused_reduce", "reduce")); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto pw = r.result; auto reduce = r.instructions["reduce"]; auto input = r.instructions["input"]; const auto* pm = pw->module_inputs().front(); const auto* old_rm = reduce->module_inputs().front(); auto* rm = mpm.create_module(old_rm->name() + ":" + pm->name()); rm->set_bypass(); std::unordered_map map_ins; // Copy module instructions insert_module_in_submodule(rm, reduce, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else { map_ins[input] = rm->get_returns().front(); } auto out = rm->fuse({pw}, &map_ins); rm->replace_return(out); finalize_reduce_module(rm); auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); } }; struct find_reduce_reduce { auto matcher() const { return match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce")); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto reduce1 = r.result; auto reduce2 = r.instructions["reduce"]; auto input = r.instructions["input"]; if(reduce1->get_operator() != reduce2->get_operator()) return; const auto* rm1 = reduce1->module_inputs().front(); const auto* rm2 = reduce2->module_inputs().front(); auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name()); rm->set_bypass(); std::unordered_map map_ins; // Copy reduce1 instructions insert_module_in_submodule(rm, reduce2, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else { map_ins[input] = rm->get_returns().front(); } auto out = insert_module_in_submodule(rm, reduce1, &map_ins); rm->replace_return(out); finalize_reduce_module(rm); auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm}); } }; struct reduce_reshape : rewrite_reshapes_base { static std::string name() { return "fused_reduce"; } template static auto transform_op(Transform t) { return [=](module& m, instruction_ref ins, const operation& op, const std::vector& inputs, const std::vector& mod_args) { auto new_op = t(op); return m.insert_instruction(ins, new_op, inputs, mod_args); }; } template static instruction_ref insert(module_pass_manager& mpm, instruction_ref ins, const std::vector& inputs, const AxesMap& am) { auto op = any_cast(ins->get_operator()); std::vector axes; for(auto axis : op.axes) { auto new_axes = am.at(axis); axes.insert(axes.end(), new_axes.begin(), new_axes.end()); } std::sort(axes.begin(), axes.end()); auto dims = base_dims(inputs); auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); sm->set_bypass(); auto outs = sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) { if(contains(sop.name(), "reduce")) return make_op(sop.name(), {{"axes", axes}}); if(sop.name() == "multibroadcast") return make_op("multibroadcast", {{"out_lens", dims}}); assert(sop.name() == "pointwise"); return sop; })); sm->add_return(outs); return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); } static std::vector base_dims(const std::vector& inputs) { auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) { return i->get_shape().elements(); })); return (*input)->get_shape().lens(); } static std::vector base_dims(instruction_ref ins) { return base_dims(ins->inputs()); } }; } // namespace void fuse_reduce::apply(module_pass_manager& mpm) const { if(enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{})) return; create_reduce_modules(mpm); mpm.run_pass(dead_code_elimination{}); for(int i = 0; i < 4; i++) { if(enable_rewrite_reshapes) mpm.run_pass(rewrite_reshapes{}); match::find_matches( mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{}); mpm.run_pass(dead_code_elimination{}); } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx