frigate/docker/rocm/migraphx/targets/gpu/fuse_mlir.cpp

1107 lines
42 KiB
C++
Raw Normal View History

2025-02-04 00:44:02 +03:00
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/common.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/param_utils.hpp>
#include <migraphx/match/softmax.hpp>
#include <migraphx/fp8_types.hpp>
#include <optional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
return not mlir_disabled;
#else
return false;
#endif
}
namespace {
struct requested
{
};
struct rejected
{
};
} // namespace
static bool is_negated_op(const std::string& s)
{
if(s.empty())
return false;
return contains({'!', '~'}, s[0]);
}
template <class Action>
static std::vector<std::string> get_usage()
{
static const auto options =
split_string(string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""), ',');
static const bool enabled = std::is_same<Action, requested>{};
std::vector<std::string> result;
auto remove_not_symbol = [&](const std::string& s) {
if(is_negated_op(s))
return s.substr(1);
return s;
};
transform_if(
options.begin(),
options.end(),
std::back_inserter(result),
[&](const std::string& option) {
if(option.empty())
return false;
if(is_negated_op(option))
return not enabled;
return enabled;
},
remove_not_symbol);
return result;
}
template <class Action>
static bool specific_op(std::string_view option, bool fallback = false)
{
static const auto options = get_usage<Action>();
if(options.empty())
return fallback;
if(contains(option, "fused") and contains(options, "fused"))
return true;
return contains(options, option);
}
bool mlir_attention_enabled(context* ctx)
{
#ifdef MIGRAPHX_MLIR
if(not mlir_enabled())
return false;
if(specific_op<rejected>("attention"))
return false;
// Enable attention by default for mi300
if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94"))
return true;
return specific_op<requested>("attention");
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR
struct mlir_op
{
std::string name() const { return "gpu::mlir_op"; }
operation op = make_op("convolution");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
// Check if the shape can be created from a transpose/broadcast/slice
static bool is_mlir_compatible(const shape& s)
{
if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1)
return true;
auto ns = reorder_shape(s, find_permutation(s));
std::vector<std::size_t> stride_ratios;
auto last = std::find(ns.strides().begin(), ns.strides().end(), 0);
if(*std::prev(last) != 1)
return false;
std::adjacent_difference(ns.strides().begin(),
last,
std::back_inserter(stride_ratios),
[](auto y, auto x) -> std::size_t {
assert(y != 0);
if((x % y) != 0)
return 0;
return x / y;
});
return std::equal(stride_ratios.begin() + 1,
stride_ratios.end(),
ns.lens().begin() + 1,
[](auto ratio, auto len) { return ratio >= len; });
}
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.has_at_least(1);
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(not std::all_of(inputs.begin(), inputs.end(), &is_mlir_compatible))
MIGRAPHX_THROW("Shape is not mlir compatible.");
auto result =
mod->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true});
if(result.size() == 1)
return result.front();
return shape{result};
}
};
MIGRAPHX_REGISTER_OP(mlir_op);
namespace {
const auto& reshaper_names()
{
// clang-format off
static const std::unordered_set<std::string> names = {
"slice",
"transpose",
"multibroadcast",
"broadcast",
"contiguous",
"reshape",
"lazy_reshape",
"squeeze",
"flatten",
"unsqueeze"
};
// clang-format on
return names;
}
std::tuple<instruction_ref, std::vector<operation>>
get_fusable_input_op_stream(instruction_ref lower_input)
{
instruction_ref upper_input = lower_input;
std::vector<operation> op_stream;
while(contains(reshaper_names(), upper_input->name()))
{
operation op = upper_input->get_operator();
op_stream.push_back(op);
upper_input = upper_input->inputs().at(0);
}
return {upper_input, op_stream};
}
void fuse_input_ops(module_ref mm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins)
{
assert(map_ins != nullptr);
size_t input_cnt = mm->get_parameters().size();
for(instruction_ref input : inputs)
{
if(contains(*map_ins, input))
continue;
auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
if(not contains(*map_ins, upper_input))
(*map_ins)[upper_input] =
mm->add_parameter(param_name(input_cnt++), upper_input->get_shape().as_standard());
instruction_ref prev_input = (*map_ins)[upper_input];
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
(*map_ins)[input] = prev_input;
}
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
const operation& gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op_inputs)
{
auto [upper_input, op_stream] = get_fusable_input_op_stream(input);
top_inputs.push_back(upper_input);
instruction_ref prev_input =
mm->add_parameter(param_name(input_cnt++, "y"), upper_input->get_shape().as_standard());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op = mm->add_instruction(gemm_based_op, imm_inputs);
return {new_gemm_based_op, top_inputs};
}
enum class mlir_mode
{
all,
fast,
int8,
none
};
auto is_mlir_dot(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
// dot operation where (FP8 * FP8 = FP8) is not available in MLIR. rocBLAS/hipBLASLt should
// have the support for it.
if(contains(fp8_types{}.get(), ins->get_shape().type()))
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 1024;
});
}
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
auto input = ins->inputs().front()->get_shape();
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4 and group > 1)
return false;
std::set<shape::type_t> supported_types = fp8_types{}.get();
supported_types.insert(shape::int8_type);
if(contains(supported_types, input.type()))
return true;
if(mode == mlir_mode::all)
return true;
// No winograd for group convolution
if(group > 1)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
instruction_ref unroll_pointwise(module& main_mod,
instruction_ref pos,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args)
{
if(op.name() == "pointwise")
{
auto* sub_pm = mod_args.front();
auto param_map_2 = create_param_map_with_literals(
&main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args));
return main_mod.insert_inline(pos, *sub_pm, inputs, &param_map_2)
.front(); // cppcheck-suppress returnDanglingLifetime;
}
return main_mod.insert_instruction(pos, op, inputs, mod_args);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::bf16_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e5m2fnuz_type,
type_t::fp8e4m3fn_type,
type_t::fp8e5m2_type,
type_t::int8_type,
type_t::uint8_type,
type_t::int32_type,
type_t::uint32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"relu",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg",
};
const std::initializer_list<std::string> fp_only_ops = {
"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"sqrt",
"rsqrt",
"sigmoid",
"softmax",
"tanh",
};
std::set<shape::type_t> float_types = {type_t::float_type,
type_t::half_type,
type_t::bf16_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e5m2fnuz_type,
type_t::fp8e4m3fn_type,
type_t::fp8e5m2_type};
bool is_float = contains(float_types, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true;
if(is_float and contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float and name == "convert")
{
if(contains(fp8_types{}.get(), result_type))
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type, type_t::bf16_type},
arg->get_shape().type());
});
}
return false;
}
bool is_reduce_op_supported_by_mlir(const instruction& i)
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::bf16_type,
type_t::fp8e4m3fnuz_type,
type_t::fp8e5m2fnuz_type,
type_t::fp8e4m3fn_type,
type_t::fp8e5m2_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> reduce_ops = {"reduce_mean", "reduce_sum"};
return contains(reduce_ops, i.name());
}
// A separate function so we can remove operators that are supported by mlir
// but not supported for an input fusion.
bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i)
{
return is_pointwise_op_supported_by_mlir(i);
}
MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins)
{
if(ins->name() != "split_fused_reduce")
return false;
auto* mod_arg = ins->module_inputs().front();
auto supported_reshapes = reshaper_names();
supported_reshapes.erase("slice");
std::unordered_set<std::string> builtins = {"@param", "@literal", "@return"};
for(const auto i : iterator_for(*mod_arg))
{
if(is_reduce(*i))
{
if(not is_reduce_op_supported_by_mlir(*i))
return false;
}
else if(i->name() == "pointwise")
{
if(not std::all_of(i->module_inputs().front()->begin(),
i->module_inputs().front()->end(),
&is_pointwise_op_supported_by_mlir))
return false;
}
else if(not contains(reshaper_names(), i->name()) and not contains(builtins, i->name()))
{
return false;
}
}
return true;
}
MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins)
{
if(ins->name() != "pointwise")
return false;
auto* pm = ins->module_inputs().front();
return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir);
}
MIGRAPHX_PRED_MATCHER(mlir_input_pointwise, instruction_ref ins)
{
if(ins->name() != "pointwise")
return false;
auto* pm = ins->module_inputs().front();
return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir_for_input);
}
std::vector<instruction_ref> mlir_contiguous(module_pass_manager& mpm,
const std::vector<instruction_ref>& inputs)
{
std::vector<instruction_ref> result;
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(result), [&](instruction_ref input) {
if(input->get_shape().packed() or input->get_shape().broadcasted())
return input;
return mpm.get_module().insert_instruction(
std::next(input), make_op("contiguous"), input);
});
return result;
}
struct find_mlir_split_reduce
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto dot_or_conv = match::name("gpu::mlir_op");
// TODO: Handle reshapes inbetween
return mlir_split_reduce()(match::any_of[match::inputs()](dot_or_conv.bind("gemm")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce_ins = r.result;
auto gemm_ins = r.instructions["gemm"];
assert(gemm_ins->get_shape().sub_shapes().empty());
auto* rm = reduce_ins->module_inputs().front();
auto names = rm->get_parameter_names();
std::sort(names.begin(), names.end());
module_ref gemm_old_mm = gemm_ins->module_inputs().front();
module_ref mm = mpm.create_module(gemm_old_mm->name() + "_" + rm->name(), *gemm_old_mm);
// remove last return instruction
if(std::prev(mm->end())->name() == "@return")
{
mm->remove_instruction(std::prev(mm->end()));
}
mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
param_map[gemm_ins] = std::prev(mm->end());
bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1;
auto return_vals = mm->fuse(*rm, reduce_ins->inputs(), &param_map, &unroll_pointwise);
if(gemm_has_multi_outs)
{
return_vals.insert(return_vals.end(), param_map[gemm_ins]);
}
mm->add_return(return_vals);
std::vector<instruction_ref> inputs;
std::copy_if(reduce_ins->inputs().begin(),
reduce_ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != gemm_ins; });
inputs.insert(inputs.end(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
if(gemm_has_multi_outs)
{
auto fused_ins = mpm.get_module().insert_instruction(
reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
auto dot_ins = mpm.get_module().insert_instruction(
reduce_ins,
migraphx::make_op("get_tuple_elem", {{"index", return_vals.size() - 1}}),
fused_ins);
mpm.get_module().replace_instruction(gemm_ins, dot_ins);
for(const auto& outs : reduce_ins->outputs())
{
assert(outs->get_operator().name() == "get_tuple_elem");
mpm.get_module().replace_instruction(outs, outs->get_operator(), fused_ins);
}
}
else
{
mpm.get_module().replace_instruction(
reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
}
}
};
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto reshapes = reshaper_names();
// slice is not supported
reshapes.erase("slice");
auto dot_or_conv = match::skip(match::name(reshapes))(
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto pw_ins = r.result;
auto gemm_based_op = r.instructions["gemm_based_op"];
auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream
auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
// only of one of the inputs to pointwise module should be dependent on conv/gemm that is
// being fused, otherwise it can create invalid graph transformation
if(std::any_of(pw_inputs.begin(), pw_inputs.end(), [&](const auto& i) {
return i != x_ins and reaches(gemm_based_op, i);
}))
return;
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(mm, pm, pw_ins->get_shape());
auto [upper_input, op_stream] = get_fusable_input_op_stream(x_ins);
assert(upper_input == gemm_based_op);
auto prev_input = anchor_op;
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
assert(prev_input->get_shape().lens() == x_ins->get_shape().lens());
param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped
// input to pointwise in new fused module
bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1;
auto reshaped_gemm = x_ins;
std::vector<instruction_ref> reshapes_vec;
while(reshaped_gemm != gemm_based_op)
{
reshapes_vec.push_back(reshaped_gemm);
gemm_has_multi_outs = gemm_has_multi_outs or reshaped_gemm->outputs().size() > 1;
reshaped_gemm = reshaped_gemm->inputs().at(0);
}
reshapes_vec.push_back(reshaped_gemm);
auto return_vals = mm->fuse(*pm, pw_ins->inputs(), &param_map);
if(gemm_has_multi_outs)
{
return_vals.insert(return_vals.begin(), anchor_op);
}
mm->add_return(return_vals);
std::vector<instruction_ref> inputs;
std::copy_if(pw_ins->inputs().begin(),
pw_ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != x_ins; });
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
if(gemm_has_multi_outs)
{
auto fused_ins = mpm.get_module().insert_instruction(
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
mpm.get_module().replace_instruction(
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused_ins);
auto dot_ins = mpm.get_module().insert_instruction(
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
// move all the reshape instructions and original GEMM instruction after the fused op to
// avoid generating invalid migraphx program
for(const auto& orig_i : reverse(reshapes_vec))
{
mpm.get_module().move_instruction(orig_i, pw_ins);
}
mpm.get_module().replace_instruction(gemm_based_op, dot_ins);
}
else
{
mpm.get_module().replace_instruction(
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
}
}
};
template <auto Matcher>
struct find_mlir_standalone_op
{
mlir_mode mode = mlir_mode::none;
std::size_t* counter = nullptr;
auto matcher() const { return Matcher(mode); }
std::string get_count() const
{
if(counter == nullptr)
MIGRAPHX_THROW("Invalid counter");
return std::to_string((*counter)++);
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto gemm_based_op = r.result;
// enable only for fp32/fp16/i8/fp8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::bf16_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type,
shape::type_t::fp8e5m2fnuz_type,
shape::type_t::fp8e4m3fn_type,
shape::type_t::fp8e5m2_type},
i->get_shape().type());
}))
return;
std::string module_name = "mlir_" + gemm_based_op->name() + get_count();
if(mpm.get_module().name() != "main")
module_name = mpm.get_module().name() + ":" + module_name;
module_ref mm = mpm.create_module(module_name);
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(gemm_based_op,
mlir_op{gemm_based_op->get_operator()},
mlir_contiguous(mpm, top_inputs),
{mm});
}
};
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
struct find_mlir_standalone_attention_op
{
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::used_once(), is_mlir_dot(dot_mode))
.bind("gemm1");
auto fused_reduce =
match::name("fused_reduce")(match::used_once(),
match::any_of[match::inputs()](
match::skip(match::name("reshape").bind("rsp"))(gemm1)))
.bind("fused_reduce");
return is_mlir_dot(dot_mode)(match::arg(0)(fused_reduce)).bind("gemm2");
}
std::unordered_map<instruction_ref, instruction_ref>
invert_map_ins(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const
{
std::unordered_map<instruction_ref, instruction_ref> inverse_map;
for(auto const& [key, value] : map_ins)
{
assert(not contains(inverse_map, value));
inverse_map[value] = key;
}
return inverse_map;
}
auto finalize_attention_module(module_ref m) const
{
eliminate_common_subexpression{}.apply(*m);
dead_code_elimination{}.apply(*m);
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto gemm2 = r.instructions["gemm2"];
auto fused_reduce = r.instructions["fused_reduce"];
auto gemm1 = r.instructions["gemm1"];
auto axes = fused_reduce->get_operator().to_value()["axes"];
if(axes.size() != 1)
return;
module m_attn;
std::unordered_map<instruction_ref, instruction_ref> map_main_to_mattn;
// Add first gemm and fuse any input shape ops
module fuse_gemm1;
auto [anchor_op, top_inputs] =
fuse_input_ops_and_gemm_based_op(&fuse_gemm1, gemm1->inputs(), gemm1->get_operator());
fuse_gemm1.add_return({anchor_op});
m_attn.add_params(top_inputs, &map_main_to_mattn);
std::unordered_map<instruction_ref, instruction_ref> map_gemm1_to_mattn(map_main_to_mattn);
auto m_gemm1 = m_attn.fuse(fuse_gemm1, top_inputs, &map_gemm1_to_mattn).front();
map_main_to_mattn[gemm1] = m_gemm1;
if(contains(r.instructions, "rsp"))
{
auto rsp = r.instructions["rsp"];
auto m_rsp = m_attn.add_instruction(rsp->get_operator(), {m_gemm1});
map_main_to_mattn[rsp] = m_rsp;
}
// Add pointwise-softmax, unroll any pointwise modules back to base ops
m_attn.add_params(fused_reduce->inputs(), &map_main_to_mattn);
std::unordered_map<instruction_ref, instruction_ref> map_mfr_to_mattn(map_main_to_mattn);
auto pw_softmax = m_attn
.fuse(*fused_reduce->module_inputs().front(),
fused_reduce->inputs(),
&map_mfr_to_mattn,
&unroll_pointwise)
.front();
// fused_reduce submodule should end with a softmax
auto result = match::match_instruction(m_attn, pw_softmax, match::softmax());
if(result.result != pw_softmax)
return;
// Insert explict softmax op - required for MLIR
auto softmax_in = result.instructions["x"];
auto softmax = m_attn.insert_instruction(
std::next(softmax_in), make_op("softmax", {{"axis", axes.front()}}), softmax_in);
map_main_to_mattn[fused_reduce] = softmax;
// all preceeding ops should be fusable ops
if(not std::all_of(m_gemm1, softmax, [](auto i) {
return (is_pointwise_op_supported_by_mlir(i) or
contains(reshaper_names(), i.name()));
}))
return;
// Add second gemm and fuse any input shape ops
module fuse_gemm2;
auto [anchor_op2, top_inputs2] =
fuse_input_ops_and_gemm_based_op(&fuse_gemm2, gemm2->inputs(), gemm2->get_operator());
fuse_gemm2.add_return({anchor_op2});
m_attn.add_params(top_inputs2, &map_main_to_mattn);
std::unordered_map<instruction_ref, instruction_ref> map_gemm2_to_mattn(map_main_to_mattn);
auto m_gemm2 = m_attn.fuse(fuse_gemm2, top_inputs2, &map_gemm2_to_mattn).front();
map_main_to_mattn[gemm2] = m_gemm2;
// Fuse any succeeding pointwise module
if(contains(r.instructions, "trailing_pm"))
{
auto trailing_pm_ins = r.instructions["trailing_pm"];
auto lit_map = create_param_map_with_literals(
&m_attn, trailing_pm_ins->module_inputs().front(), trailing_pm_ins->get_shape());
m_attn.add_params(trailing_pm_ins->inputs(), &map_main_to_mattn);
map_main_to_mattn.insert(lit_map.begin(), lit_map.end());
std::unordered_map<instruction_ref, instruction_ref> map_pm_to_mattn(map_main_to_mattn);
auto fused_pw_outs = m_attn
.fuse(*trailing_pm_ins->module_inputs().front(),
trailing_pm_ins->inputs(),
&map_pm_to_mattn)
.front();
map_main_to_mattn[trailing_pm_ins] = fused_pw_outs;
m_attn.add_return({fused_pw_outs});
}
else
{
m_attn.add_return({m_gemm2});
}
finalize_attention_module(&m_attn);
auto map_mattn_to_main = invert_map_ins(map_main_to_mattn);
auto new_inputs = m_attn.get_inputs(map_mattn_to_main);
module_ref mpm_attn = mpm.create_module(
"mlir_attn_" + fused_reduce->module_inputs().front()->name(), std::move(m_attn));
mpm_attn->set_bypass();
mpm.get_module().replace_instruction(
r.result, mlir_op{gemm1->get_operator()}, mlir_contiguous(mpm, new_inputs), {mpm_attn});
}
};
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{
auto matcher() const
{
auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return mlir_pointwise()(
match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
;
}
};
struct find_pointwise_mlir
{
auto supported_pointwise() const { return mlir_input_pointwise(match::used_once()); }
auto matcher() const
{
return match::name("gpu::mlir_op")(match::any_of[match::inputs()](supported_pointwise()));
}
static bool is_simple_op(const_module_ref pm, std::initializer_list<std::string> op_names)
{
auto last = std::prev(pm->end());
assert(last->name() == "@return");
if(last->inputs().size() != 1)
return false;
auto rins = last->inputs().front();
auto op_ins = std::find_if(pm->begin(), pm->end(), [](const instruction& x) {
return not contains({"@param", "@literal", "broadcast", "multibroadcast"}, x.name());
});
if(op_ins != rins)
return false;
return contains(op_names, op_ins->name());
}
static instruction_ref insert_pointwise(module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args)
{
// Only used in assert
(void)mod_args;
assert(mod_args.empty());
return insert_common_op(m, ins, op, inputs, {.common_type = false});
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto* mm = ins->module_inputs().front();
std::vector<instruction_ref> pws;
std::copy_if(
ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(pws),
[&](instruction_ref input) {
if(not match::instruction_matches(mpm.get_module(), input, supported_pointwise()))
return false;
auto* pm = input->module_inputs().front();
if(input->inputs().size() > 1 and not is_simple_op(pm, {"dequantizelinear"}))
{
if(not enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{}))
return false;
}
return true;
});
if(pws.empty())
return;
std::string module_name;
std::transform(
pws.begin(), pws.end(), join_back_inserter(module_name), [](instruction_ref pw) {
return pw->module_inputs().front()->name() + ":";
});
module_name += mm->name();
module_ref m = mpm.create_module(module_name);
m->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
for(auto pw : pws)
{
auto* pm = pw->module_inputs().front();
fuse_input_ops(m, pw->inputs(), &map_ins);
auto rins = m->fuse(*pm, pw->inputs(), &map_ins, &insert_pointwise).front();
map_ins[pw] = rins;
}
auto ret = m->fuse(*mm, ins->inputs(), &map_ins);
m->add_return({ret});
auto inputs = find_inputs(map_ins, &mpm.get_module(), m);
mpm.get_module().replace_instruction(
ins, ins->get_operator(), mlir_contiguous(mpm, inputs), {m});
}
};
struct find_unpack_int4_mlir_op
{
auto matcher() const
{
return match::name("gpu::mlir_op")(
match::any_of[match::inputs()](match::name("unpack_int4").bind("unpack_int4")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto* mm = ins->module_inputs().front();
module_ref nm = mpm.create_module("int4:" + mm->name());
nm->set_bypass();
std::vector<instruction_ref> x_in;
std::unordered_map<instruction_ref, instruction_ref> map_ins;
int ct = 0;
for(auto input : ins->inputs())
{
if(input->get_operator().name() == "unpack_int4")
{
auto unpack_input = input->inputs()[0];
instruction_ref t_ins =
nm->add_parameter(param_name(++ct), unpack_input->get_shape().as_standard());
map_ins[input] = nm->add_instruction(input->get_operator(), t_ins);
x_in.push_back(unpack_input);
}
else
{
map_ins[input] =
nm->add_parameter(param_name(++ct), input->get_shape().as_standard());
x_in.push_back(input);
}
}
auto ret = nm->fuse(*mm, ins->inputs(), &map_ins);
nm->add_return({ret});
mpm.get_module().replace_instruction(ins, ins->get_operator(), x_in, {nm});
}
};
} // namespace
#endif // MIGRAPHX_MLIR
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
std::size_t counter = 0;
const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
const bool is_navi = starts_with(device_name, "gfx11") or starts_with(device_name, "gfx12");
auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(specific_op<rejected>(option))
return mlir_mode::none;
if(specific_op<requested>(option))
return mlir_mode::all;
if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
// Attention offloads; default disabled
if(mlir_attention_enabled(ctx) or enable_extra)
{
match::find_matches(mpm, find_mlir_attention_fused_ops{mlir_mode::all});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm, find_mlir_standalone_attention_op{mlir_mode::all});
mpm.run_pass(dead_code_elimination{});
}
match::find_matches(
mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{.mode = get_mode("convolution", mlir_mode::fast),
.counter = &counter},
find_mlir_standalone_dot_op{.mode = get_mode("dot", mlir_mode::fast), .counter = &counter});
mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{}))
{
match::find_matches(
mpm,
find_mlir_split_reduce{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});
}
match::find_matches(mpm, find_pointwise_mlir{});
match::find_matches(mpm, find_unpack_int4_mlir_op{});
#else
(void)mpm;
#endif
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx