/* * The MIT License (MIT) * * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL #include #endif #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_LAYERNORM_FUSION); namespace { template struct layernorm_base { float epsilon = 1e-12f; template static auto reflect(Self& self, F f) { return pack(f(self.epsilon, "epsilon")); } shape compute_shape(std::vector inputs, std::vector mods) const { std::size_t nargs = N; if(not mods.empty()) { auto* pm = mods.front(); nargs += pm->get_parameter_names().size() - 1; } check_shapes{inputs, static_cast(*this)}.has(nargs); auto s = inputs.front(); auto t = s.type(); if(not mods.empty()) t = mods.front()->get_output_shapes().front().type(); // Scalar output if all inputs are scalar if(inputs.front().elements() == 1 and all_of(inputs, [](const auto& ss) { return ss.scalar(); })) return inputs.front(); auto l_s = shape::from_permutation( t, s.lens(), find_permutation(std::vector(inputs.begin(), inputs.begin() + N))); // just prelayernorm or preadd_layernorm if(nargs <= N) return l_s; // else, layernorm + pointwise fusion, preserve layout of fused op std::vector lp_s(inputs.begin() + N, inputs.end()); lp_s.insert(lp_s.begin(), l_s); return shape::from_permutation(t, s.lens(), find_permutation(lp_s)); } }; struct layernorm : layernorm_base { std::string name() const { return "gpu::prelayernorm"; } }; MIGRAPHX_REGISTER_OP(layernorm); struct add_layernorm : layernorm_base { std::string name() const { return "gpu::preadd_layernorm"; } }; MIGRAPHX_REGISTER_OP(add_layernorm); struct find_layernorm { auto matcher() const { return match::layernorm(); } void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; float eps = 0; if(contains(r.instructions, "eps")) eps = r.instructions["eps"]->eval().at(); m.replace_instruction(ins, layernorm{eps}, x_ins); } }; struct find_add_layernorm { auto matcher() const { return match::name("gpu::prelayernorm")( match::args(match::name("add")(match::used_once()).bind("add"))); } void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; auto add_ins = r.instructions["add"]; auto op = any_cast(ins->get_operator()); m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs()); } }; struct pre_gemm_softmax_gemm : gemm_softmax_gemm { std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } }; MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm); auto is_ck_gemm() { return match::make_basic_pred_matcher([=](instruction_ref ins) { #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL if(not enabled(MIGRAPHX_ENABLE_CK{})) return false; if(ins->name() != "dot") return false; if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type())) return false; return true; #else (void)ins; return false; #endif }); } auto is_test_gemm(bool enable_attention) { return match::make_basic_pred_matcher([=](instruction_ref ins) { if(ins->name() != "dot") return false; return enable_attention; }); } auto is_bias_supported() { return match::make_basic_pred_matcher([=](instruction_ref) { #ifdef MIGRAPHX_USE_COMPOSABLEKERNEL return not enabled(MIGRAPHX_ENABLE_CK{}); #else return true; #endif }); } struct find_gemm_softmax_gemm { bool enable_attention = false; auto matcher() const { auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")( match::any_of(is_ck_gemm(), is_test_gemm(enable_attention)).bind("gemm1"))); auto mul = match::name("mul")( match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); auto where = match::name("where")(match::arg(2)(match::is_constant().bind("select_const")), match::arg(1)(mul), match::arg(0)(match::any().bind("select_cond"))); auto add = match::name("add")(is_bias_supported(), match::nargs(2), match::either_arg(0, 1)(match::none_of(mul).bind("bias"), mul)); auto softmax = match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1, where))) .bind("softmax"); return match::name("dot")( match::any_of(is_ck_gemm(), is_test_gemm(enable_attention)).bind("gemm2"))( match::arg(0)(softmax)); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto gemm2_ins = r.instructions["gemm2"]; auto gemm1_ins = r.instructions["gemm1"]; float scale = 1.0; if(contains(r.instructions, "scale")) { auto scale_lit = r.instructions["scale"]; // CK only supports single-valued scale scale_lit->eval().visit([&](const auto s) { // CK only supports single-valued scale if(not std::all_of( s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); })) return; scale = s.front(); }); } auto inputs = gemm1_ins->inputs(); // A, B if(contains(r.instructions, "select_cond")) { inputs.push_back(r.instructions["select_cond"]); inputs.push_back(r.instructions["select_const"]); } if(contains(r.instructions, "bias")) { inputs.push_back(r.instructions["bias"]); } inputs.push_back(gemm2_ins->inputs().back()); // B1 mpm.get_module().replace_instruction( ins, pre_gemm_softmax_gemm{gemm2_ins->get_operator(), scale}, inputs); } }; struct gpu_compute_attention_probabilities : op::group_query_attention { std::string name() const { return "gpu::compute_attention_probabilities"; } shape compute_shape(std::vector inputs) const { auto query_lens = inputs.front().lens(); auto present_kv_seqlen = inputs.at(1).lens().at(2); std::vector output_lens{ query_lens.at(0), num_heads, query_lens.at(2), present_kv_seqlen}; shape output_shape{inputs.front().type(), output_lens}; return output_shape; } }; MIGRAPHX_REGISTER_OP(gpu_compute_attention_probabilities); struct gpu_compute_attention_scores : op::group_query_attention { std::string name() const { return "gpu::compute_attention_scores"; } shape compute_shape(std::vector inputs) const { auto query_lens = inputs.front().lens(); std::size_t q_hidden_size = (query_lens[1] * query_lens[3] * num_heads) / (num_heads + 2 * kv_num_heads); std::vector output_lens{query_lens.at(0), query_lens.at(2), q_hidden_size}; shape output_shape{inputs.front().type(), output_lens}; return output_shape; } }; MIGRAPHX_REGISTER_OP(gpu_compute_attention_scores); struct gpu_gqa_rotary_embedding : op::group_query_attention { std::string name() const { return "gpu::gqa_rotary_embedding"; } shape compute_shape(std::vector inputs) const { return inputs.front(); } }; MIGRAPHX_REGISTER_OP(gpu_gqa_rotary_embedding); struct gpu_gqa_softmax : op::group_query_attention { std::string name() const { return "gpu::gqa_softmax"; } shape compute_shape(std::vector inputs) const { return inputs.at(2); } }; MIGRAPHX_REGISTER_OP(gpu_gqa_softmax); struct gpu_concat_past_present : op::group_query_attention { std::string name() const { return "gpu::concat_past_present"; } shape compute_shape(std::vector inputs) const { return inputs[0]; } }; MIGRAPHX_REGISTER_OP(gpu_concat_past_present); struct find_group_query_attention { auto matcher() const { return match::name("group_query_attention"); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto inputs = ins->inputs(); auto v = ins->get_operator().to_value(); auto num_heads = v.at("num_heads").to(); auto kv_num_heads = v.at("kv_num_heads").to(); auto do_rotary = v.at("do_rotary").to(); auto local_window_size = v.at("local_window_size").to(); auto rotary_interleaved = v.at("rotary_interleaved").to(); auto scale = v.at("scale").to(); auto q_shape = inputs[0]->get_shape(); auto q_lens = q_shape.lens(); const std::size_t batch_size = q_lens[0]; const std::size_t sequence_length = q_lens[1]; std::size_t q_hidden_size = q_lens[2]; std::size_t head_size = q_hidden_size / (num_heads + 2 * kv_num_heads); std::vector bsnh{ batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size}; auto transposed_qkv = mpm.get_module().insert_instruction( ins, make_op("reshape", {{"dims", bsnh}}), inputs.at(0)); transposed_qkv = mpm.get_module().insert_instruction( ins, make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), transposed_qkv); auto rotary_qkv = transposed_qkv; if(do_rotary) { std::vector rotary_inputs{ transposed_qkv, inputs.at(5), inputs.at(7), inputs.at(8)}; rotary_qkv = mpm.get_module().insert_instruction(ins, gpu_gqa_rotary_embedding{do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, rotary_inputs); } auto pres_k = inputs.at(3); auto pres_v = inputs.at(4); std::vector concat_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5)}; auto concat = mpm.get_module().insert_instruction( ins, gpu_concat_past_present{ do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, concat_inputs); auto id = mpm.get_module().insert_instruction(ins, make_op("identity"), concat, pres_k, pres_v); std::vector attn_probs_inputs{id, pres_k, pres_v, inputs.at(5)}; auto attn_probs = mpm.get_module().insert_instruction( ins, gpu_compute_attention_probabilities{ do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, attn_probs_inputs); std::vector softmax_inputs{rotary_qkv, pres_k, attn_probs, inputs.at(5)}; auto softmax = mpm.get_module().insert_instruction( ins, gpu_gqa_softmax{ do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, softmax_inputs); std::vector new_inputs{rotary_qkv, pres_k, pres_v, inputs.at(5), softmax}; auto get_tuple_elm_0 = std::next(ins); auto get_tuple_elm_1 = std::next(get_tuple_elm_0); auto get_tuple_elm_2 = std::next(get_tuple_elm_1); mpm.get_module().replace_instruction(get_tuple_elm_2, pres_v); mpm.get_module().replace_instruction(get_tuple_elm_1, pres_k); mpm.get_module().replace_instruction( get_tuple_elm_0, gpu_compute_attention_scores{ do_rotary, kv_num_heads, local_window_size, num_heads, rotary_interleaved, scale}, new_inputs); } }; } // namespace void prefuse_ops::apply(module_pass_manager& mpm) const { if(not enabled(MIGRAPHX_DISABLE_LAYERNORM_FUSION{})) { match::find_matches(mpm.get_module(), find_layernorm{}); mpm.run_pass(dead_code_elimination{}); match::find_matches(mpm.get_module(), find_add_layernorm{}); } match::find_matches(mpm, find_gemm_softmax_gemm{enable_attention}); match::find_matches(mpm, find_group_query_attention{}); } } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx