frigate/docker/rocm/migraphx/simplify_algebra.cpp

2036 lines
75 KiB
C++
Raw Normal View History

2025-02-03 22:53:47 +03:00
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
auto lit_broadcast() { return match::any_of(match::is_constant(), match::name("broadcast")); }
auto not_lit_broadcast() { return match::none_of(match::is_constant(), match::name("broadcast")); }
auto op_lit_broadcast(std::string op, std::string x, std::string y)
{
return match::name(std::move(op))(match::either_arg(0, 1)(
lit_broadcast().bind(std::move(x)), not_lit_broadcast().bind(std::move(y))));
}
auto conv_const_weights()
{
return match::name("convolution")(
match::used_once(),
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
}
auto from_int4()
{
return match::make_predicate_matcher([](instruction_ref start) {
return fix<bool>([&](auto self, instruction_ref ins) {
auto alias = instruction::get_output_alias(ins);
if(contains({"reshape", "dequantizelinear"}, alias->name()))
return self(alias->inputs().front());
if(alias->name() == "concat")
return all_of(alias->inputs(), self);
return alias->name() == "unpack_int4";
})(start);
});
}
auto not_from_int4() { return match::none_of(from_int4()); }
auto reduction() { return match::name_contains("reduce"); }
// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv
{
auto matcher() const
{
return match::name("mul")(
match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast", "multibroadcast").bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
const auto& a_input_lens = a_ins->inputs().front()->get_shape().lens();
std::size_t num_not_one_dims = std::count_if(
a_input_lens.cbegin(), a_input_lens.cend(), [](auto dim) { return dim != 1; });
if(num_not_one_dims > 1)
return;
// check broadcasted along channels
const auto& a_lens = a_ins->get_shape().lens();
const auto& a_strides = a_ins->get_shape().strides();
auto is_broadcasted_axis = [](auto len, auto stride) { return len == 1 or stride == 0; };
if(a_strides.at(1) != 1)
return;
if(not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;
if(not std::equal(a_lens.begin() + 2,
a_lens.end(),
a_strides.begin() + 2,
a_strides.end(),
is_broadcasted_axis))
return;
auto sq = m.insert_instruction(ins, make_op("squeeze"), a_ins->inputs().front());
auto new_a = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), sq);
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
m.replace_instruction(ins, new_conv);
}
};
struct find_mul_slice_conv
{
static auto conv()
{
return match::name("convolution")(
match::all_of[match::outputs()](match::name("slice")),
match::args(match::any(), match::is_constant().bind("w")));
}
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("slice")(match::used_once(), match::arg(0)(conv().bind("conv")))
.bind("slice"),
match::name("broadcast")(match::is_constant()).bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slice_ins = r.instructions["slice"];
auto conv_ins = r.instructions["conv"];
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
return;
auto slice_op = any_cast<op::slice>(slice_ins->get_operator());
if(slice_op.axes.size() != 1)
return;
if(slice_op.axes.front() != 1)
return;
auto slice_idx = std::distance(conv_ins, slice_ins);
if(std::any_of(conv_ins->outputs().begin(), conv_ins->outputs().end(), [&](auto i) {
if(i == slice_ins)
return false;
if(std::distance(conv_ins, i) < slice_idx)
return true;
auto sop = any_cast<op::slice>(i->get_operator());
if(sop.axes != slice_op.axes)
return true;
if(std::max(sop.starts.front(), slice_op.starts.front()) <
std::min(sop.ends.front(), slice_op.ends.front()))
return true;
return false;
}))
return;
auto w_slice_op = slice_op;
w_slice_op.axes = {0};
auto slice_w_ins = m.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0)
sliced_weights.push_back(m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins));
sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins));
auto new_weights =
m.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
assert(conv_ins->get_shape() == new_conv->get_shape());
auto slice1 = m.insert_instruction(ins, slice_op, new_conv);
assert(ins->get_shape().lens() == slice1->get_shape().lens());
m.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
auto outputs = conv_ins->outputs();
for(auto output : outputs)
if(output != slice_ins)
instruction::replace_argument(output, conv_ins, new_conv);
}
};
struct find_mul_dot
{
auto matcher() const
{
auto constant = match::is_constant(not_from_int4());
auto is_dot_const_inputs = match::name("dot")(match::any_of[match::inputs()](constant));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
/*
Moves the slice on the output of the Dot operation to slices on the inputs of the Dot operation to
avoid computing redundant values.
e.g. slice(gemm(a, b)) --> gemm(slice(a), slice(b))
*/
struct find_dot_slice
{
auto matcher() const
{
return match::name("slice")(
match::args(match::name("dot", "quant_dot")(match::used_once()).bind("dot_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto slice_ins = r.result;
auto dot_ins = r.instructions["dot_ins"];
auto slice_op = slice_ins->normalized_operator().to_value();
auto axes = slice_op["axes"].to_vector<int64_t>();
auto starts = slice_op["starts"].to_vector<int64_t>();
auto ends = slice_op["ends"].to_vector<int64_t>();
assert(starts.size() == ends.size() and starts.size() == axes.size());
auto has_neg_vals = [](auto vec) {
return std::any_of(vec.begin(), vec.end(), [](auto i) { return i < 0; });
};
if(has_neg_vals(starts) or has_neg_vals(ends) or has_neg_vals(axes))
{
MIGRAPHX_THROW("FIND_DOT_SLICE: slice is not normalized.");
}
auto dot_inputs = dot_ins->inputs();
auto num_batch_dims = dot_ins->get_shape().lens().size() - 2;
std::vector<int64_t> slice_axes_1, starts_1, ends_1; // NOLINT
std::vector<int64_t> slice_axes_2, starts_2, ends_2; // NOLINT
for(auto i : range(axes.size()))
{
if(axes[i] < num_batch_dims)
{
slice_axes_1.push_back(axes[i]);
starts_1.push_back(starts[i]);
ends_1.push_back(ends[i]);
slice_axes_2.push_back(axes[i]);
starts_2.push_back(starts[i]);
ends_2.push_back(ends[i]);
}
else if(axes[i] == num_batch_dims)
{
slice_axes_1.push_back(axes[i]);
starts_1.push_back(starts[i]);
ends_1.push_back(ends[i]);
}
else if(axes[i] == num_batch_dims + 1)
{
slice_axes_2.push_back(axes[i]);
starts_2.push_back(starts[i]);
ends_2.push_back(ends[i]);
}
else
{
MIGRAPHX_THROW("FIND_DOT_SLICE: invalid case");
}
}
auto slice_1 = dot_inputs.at(0);
if(not slice_axes_1.empty())
{
slice_1 = m.insert_instruction(
slice_ins,
migraphx::make_op("slice",
{{"axes", slice_axes_1}, {"starts", starts_1}, {"ends", ends_1}}),
dot_inputs.at(0));
}
auto slice_2 = dot_inputs.at(1);
if(not slice_axes_2.empty())
{
slice_2 = m.insert_instruction(
slice_ins,
migraphx::make_op("slice",
{{"axes", slice_axes_2}, {"starts", starts_2}, {"ends", ends_2}}),
dot_inputs.at(1));
}
m.replace_instruction(slice_ins, dot_ins->get_operator(), {slice_1, slice_2});
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(
match::either_arg(0, 1)(mul, match::is_constant(not_from_int4()).bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")),
match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
instruction_ref sumab;
if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum = m.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = m.insert_instruction(ins, op, presum);
}
else
{
sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
}
auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins);
m.replace_instruction(ins, make_op("add"), sumxy, sumab);
}
};
/// Find elementswise operators that have all broadcast inputs. It then
/// rewrites the elementwise to do the computation on the non-broadcasted
/// axes, and then broadcast that result.
struct find_inner_broadcast
{
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto get_non_broadcast_input(instruction_ref ins)
{
if(ins->inputs().size() != 1)
return ins;
auto input = ins->inputs().front();
if(contains(input->name(), "broadcast"))
return get_non_broadcast_input(input);
return input;
}
static bool is_unsqueeze_needed_for_multibroadcast(const shape& input, const shape& output)
{
if(input.elements() == 1)
return false;
auto shift = output.ndim() - input.ndim();
if(shift == 0)
return false;
if(std::equal(input.lens().begin(),
input.lens().end(),
output.lens().begin() + shift,
output.lens().end()))
{
return std::all_of(output.lens().begin(), output.lens().begin() + shift, [](auto x) {
return x == 1;
});
}
return true;
}
// Simple case
void apply_same_broadcasts(module& m, instruction_ref ins) const
{
const auto& broadcasts = ins->inputs();
// Scalars can have different ndim, so find the largest ndim input
auto max_broadcast = *std::max_element(
broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref broadcast) {
return get_non_broadcast_input(broadcast)->get_shape().ndim();
}));
auto max_ndim = max_broadcast->get_shape().ndim();
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[&](instruction_ref broadcast) {
auto input = get_non_broadcast_input(broadcast);
auto s = input->get_shape();
// If scalar doesnt match the other input dims then add a squeeze
if(s.elements() == 1 and s.ndim() > 1 and s.ndim() != max_ndim)
return m.insert_instruction(broadcast, make_op("squeeze"), input);
return input;
});
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
// Find broadcast op on a non-scalar instruction if it exists
auto first =
std::find_if(broadcasts.begin(), broadcasts.end(), [&](instruction_ref broadcast) {
return not broadcast->get_shape().scalar();
});
if(first != broadcasts.end())
{
m.replace_instruction(ins, (*first)->get_operator(), op);
}
else
{
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
}
void apply_diff_broadcasts(module& m, instruction_ref ins) const
{
const auto& broadcasts = ins->inputs();
auto ndim = ins->get_shape().ndim();
// Compute the inner dimensions and axes that the computation will
// use. Also compute the axes that will be broadcasted
std::vector<std::size_t> idims;
std::vector<std::int64_t> iaxes;
std::vector<std::int64_t> axes;
for(auto axis : range(ndim))
{
if(std::all_of(broadcasts.begin(), broadcasts.end(), [&](instruction_ref i) {
auto s = i->get_shape();
return s.lens()[axis] == 1 or s.strides()[axis] == 0;
}))
{
axes.push_back(axis);
}
else
{
iaxes.push_back(axis);
idims.push_back(ins->get_shape().lens()[axis]);
}
}
// If the inner axes are the same as the original operator then
// there is no reason to do this transformation.
if(iaxes.size() == ndim)
return;
std::vector<instruction_ref> inputs;
std::transform(
broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[&](instruction_ref broadcast) {
auto input = broadcast->inputs().front();
auto s = input->get_shape();
// If its a single element then just return that as an input
if(s.elements() == 1)
{
if(s.lens().size() > 1)
return m.insert_instruction(broadcast, make_op("squeeze"), input);
return input;
}
// Find how the axes are shifted from the broadcast
std::int64_t shift = ndim - s.ndim();
if(broadcast->name() == "broadcast")
shift = broadcast->get_operator().to_value()["axis"].to<int64_t>();
// Compute the squeeze axes to be used by taking the inner
// axes and shifting to what the axes will be on the
// input
std::vector<std::int64_t> sq_axes;
for(auto axis : axes)
{
auto iaxis = axis - shift;
if(iaxis < 0)
continue;
if(iaxis >= s.ndim())
continue;
sq_axes.push_back(iaxis);
}
instruction_ref result = input;
if(not sq_axes.empty())
result = m.insert_instruction(
broadcast, make_op("squeeze", {{"axes", sq_axes}}), result);
// If the number of dimension are still smaller than the
// number of inner axes, then we need to insert a
// broadcast to have the same dimensions for all inputs.
if(result->get_shape().ndim() < iaxes.size())
{
// We find the first inner axis that can be mapped to the input
auto start_axis = std::find_if(iaxes.begin(),
iaxes.end(),
[&](auto x) { return x >= shift; }) -
iaxes.begin();
result = m.insert_instruction(
broadcast,
make_op("broadcast", {{"axis", start_axis}, {"out_lens", idims}}),
result);
}
return result;
});
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
if(iaxes.size() == 1)
{
m.replace_instruction(
ins,
make_op("broadcast",
{{"axis", iaxes.front()}, {"out_lens", ins->get_shape().lens()}}),
op);
}
else
{
auto unsqueeze =
is_unsqueeze_needed_for_multibroadcast(op->get_shape(), ins->get_shape())
? m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), op)
: op;
m.replace_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze);
}
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
if(ins->get_operator().name() == "layout")
return;
const auto& broadcasts = ins->inputs();
if(broadcasts.empty())
return;
// Skip if different data types are used
if(any_of(broadcasts, [&](auto i) {
return i->get_shape().type() != broadcasts.front()->get_shape().type();
}))
return;
// All inputs should have less elements
if(not all_of(broadcasts, [&](instruction_ref broadcast) {
auto input = broadcast->inputs().front();
return input->get_shape().elements() < ins->get_shape().elements();
}))
return;
// Find first broadcast that is not a scalar
auto first =
std::find_if(broadcasts.begin(), broadcasts.end(), [&](instruction_ref broadcast) {
return not broadcast->get_shape().scalar();
});
// Try to see if we can do a simple case that just applies the op to
// the inputs of the broadcasts, and then just put that same
// broadcast after the op. For this case we need each of the
// broadcasts to be the same and the inputs to have the same dimesion
// (or be scalar).
const bool same_broadcasts =
std::all_of(first, broadcasts.end(), [&](instruction_ref broadcast) {
if(broadcast->get_operator() != (*first)->get_operator())
return false;
auto s1 = get_non_broadcast_input(broadcast)->get_shape();
auto s2 = get_non_broadcast_input(*first)->get_shape();
if(s1.elements() == 1)
return true;
return s1.lens() == s2.lens();
});
if(same_broadcasts)
{
apply_same_broadcasts(m, ins);
}
// Skip if any input to the broadcasted inputs is already broadcasted
// as the below algorithm may not be able to handle such case.
else if(std::none_of(broadcasts.begin(), broadcasts.end(), [](instruction_ref broadcast) {
return broadcast->inputs().front()->get_shape().broadcasted();
}))
{
apply_diff_broadcasts(m, ins);
}
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref x_ins) -> instruction_ref {
auto input = x_ins->inputs()[0];
std::vector<std::size_t> lens(x_ins->get_shape().lens().begin() + naxes,
x_ins->get_shape().lens().end());
if(input->get_shape().lens() == lens)
return input;
auto input_naxis = input->get_shape().lens().size();
auto new_bc_naxis = lens.size();
if(input_naxis > new_bc_naxis)
{
std::vector<std::size_t> axes_to_sq(input_naxis - new_bc_naxis);
std::iota(axes_to_sq.begin(), axes_to_sq.end(), 0);
input =
m.insert_instruction(ins, make_op("squeeze", {{"axes", axes_to_sq}}), input);
}
if(x_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(x_ins->name() == "broadcast")
{
auto v = x_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
}
};
struct find_concat_op
{
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(),
match::name("broadcast", "multibroadcast", "unpack_int4")),
match::used_once()));
}
template <class Iterator>
static std::vector<std::size_t> get_output_lens(Iterator start, Iterator last, std::size_t axis)
{
assert(start != last);
std::size_t dim = 0;
for(auto ins : range(start, last))
{
dim += ins->get_shape().lens().at(axis);
}
auto lens = (*start)->get_shape().lens();
lens[axis] = dim;
return lens;
}
static bool is_valid_op(const operation& op)
{
return contains({"broadcast", "multibroadcast", "unpack_int4"}, op.name()) or
op.attributes().contains("pointwise");
}
static bool is_valid_concat(std::vector<instruction_ref> ins, size_t axis)
{
auto concat_lens = ins.front()->get_shape().lens();
concat_lens.erase(concat_lens.begin() + axis);
return std::all_of(ins.begin(), ins.end(), [&](auto i) {
auto lens = i->get_shape().lens();
lens.erase(lens.begin() + axis);
return lens == concat_lens;
});
}
static bool rejected_inputs(const std::vector<instruction_ref>& inputs)
{
if(inputs.empty())
return true;
if(inputs.size() < 3)
return false;
auto nonconst = std::count_if(
inputs.begin(), inputs.end(), [](instruction_ref ins) { return not ins->can_eval(); });
return nonconst > 2;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
auto each = [&](auto start, auto last) -> std::vector<instruction_ref> {
if(std::distance(start, last) < 2)
return {start, last};
auto x = *start;
if(x->outputs().size() > 1 or rejected_inputs(x->inputs()))
return {start, last};
auto op = x->get_operator();
if(not is_valid_op(op))
return {start, last};
auto iaxis = axis;
// Adjust broadcast lens
if(op.name() == "broadcast")
{
auto b = any_cast<op::broadcast>(op);
if(b.axis != iaxis)
return {start, last};
b.broadcast_lens = get_output_lens(start, last, iaxis);
op = b;
iaxis = 0;
}
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
{
std::vector<instruction_ref> inputs;
std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
return j->inputs().at(i);
});
if(not is_valid_concat(inputs, iaxis))
return {start, last};
auto concat =
m.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat);
}
auto y = m.insert_instruction(ins, op, concats);
return {y};
};
std::vector<instruction_ref> args;
auto update_args = [&](auto start, auto last) {
auto x = each(start, last);
args.insert(args.end(), x.begin(), x.end());
};
auto pred = [](auto i, auto j) {
return i->get_operator() == j->get_operator() and
i->inputs().size() == i->inputs().size() and
i->outputs().size() == i->outputs().size();
};
group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred);
if(args.size() == 1)
m.replace_instruction(ins, args.front());
else
m.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
}
};
struct find_concat_conv
{
auto matcher() const
{
return match::name("concat")(
match::all_of[match::inputs()](match::used_once(), match::name("convolution")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto axis = ins->get_operator().to_value()["axis"].to<int>();
if(axis != 1)
return;
if(ins->inputs().empty())
return;
auto conv = ins->inputs().front()->get_operator();
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto conv_ins) {
return conv_ins->get_operator() != conv;
}))
return;
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[](auto conv_ins) { return conv_ins->inputs()[0]; });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input) {
return input->get_shape() != inputs.front()->get_shape();
}))
return;
std::vector<instruction_ref> weights;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(weights),
[](auto conv_ins) { return conv_ins->inputs()[1]; });
if(std::any_of(weights.begin(), weights.end(), [&](auto w) {
return w->get_shape() != weights.front()->get_shape();
}))
return;
auto original_group = from_value<int>(conv.to_value()["group"]);
auto x = m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), inputs);
auto w = m.insert_instruction(ins, make_op("concat", {{"axis", 0}}), weights);
conv.from_value({{"group", original_group * inputs.size()}});
m.replace_instruction(ins, conv, x, w);
}
};
void move_instructions_back(module& m, instruction_ref pos, std::vector<instruction_ref> inss)
{
auto start = range(m.begin(), pos);
for(auto ins : iterator_for(start))
{
auto it = std::find(inss.begin(), inss.end(), ins);
if(it != inss.end())
inss.erase(it);
}
for(auto ins : inss)
{
if(not m.has_instruction(ins))
continue;
move_instructions_back(m, pos, ins->inputs());
m.move_instruction(ins, pos);
}
}
/** Search for multiple "slice" instructions in an instruction's outputs
* which are contiguous slices of the same tensor.
*/
std::vector<instruction_ref> get_splits(instruction_ref ins)
{
std::vector<instruction_ref> result;
std::copy_if(ins->outputs().begin(),
ins->outputs().end(),
std::back_inserter(result),
[&](auto i) { return i->name() == "slice"; });
if(result.size() < 2)
return {};
auto get_slice = [](auto& i) -> auto& { return any_cast<op::slice>(i->get_operator()); };
auto&& axes = get_slice(result.front()).axes;
// "slice" instructions must all have the same axes
if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; }))
return {};
auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; };
auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; };
// Sort the "slice" instructions in order of starts
std::sort(
result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); });
if(std::any_of(get_start(result.front()).begin(), get_start(result.front()).end(), [&](auto i) {
return i != 0;
}))
return {};
// one slice must "start" where the last slice "end"
auto it = std::adjacent_find(
result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); });
if(it != result.end())
return {};
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
if(ins->get_shape().lens()[axis] != get_slice(result.back()).ends[i])
return {};
}
return result;
}
struct find_splits
{
auto matcher() const
{
auto pointwise_reduction = match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction());
return match::any(
match::any_of[match::outputs()](match::name("slice")(pointwise_reduction)));
}
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
{
std::unordered_set<instruction_ref> traversed;
return fix<bool>([&](auto self, auto ins) -> bool {
if(ins == ins2)
return true;
if(contains(traversed, ins))
return false;
traversed.insert(ins);
const auto& inputs = ins->inputs();
return std::any_of(inputs.begin(), inputs.end(), [&](auto in) {
return m.has_instruction(in) and self(in);
});
})(ins1);
}
static std::vector<std::vector<instruction_ref>>
get_split_groups(const module& m, const std::vector<instruction_ref>& splits)
{
std::vector<std::vector<instruction_ref>> groups;
for(auto out : splits.front()->outputs())
{
if(out->name() == "slice")
continue;
std::vector<instruction_ref> group;
for(auto split : splits)
{
auto it =
std::find_if(split->outputs().begin(), split->outputs().end(), [&](auto i) {
return i->get_operator() == out->get_operator();
});
if(it == split->outputs().end())
break;
assert((*it)->name() != "slice");
// If there is a duplicate bail
// there are should be no dependency between instructions in the group
if(std::any_of(group.begin(), group.end(), [&](auto i) {
return is_dependent(m, *it, i) or is_dependent(m, i, *it);
}))
{
return {};
}
group.push_back(*it);
}
if(group.size() != splits.size())
continue;
groups.push_back(group);
}
return groups;
}
bool is_fusable(instruction_ref start, instruction_ref split_front) const
{
auto op = start->get_operator();
if(contains(op.name(), "reduce"))
{
auto slc = any_cast<op::slice>(split_front->get_operator());
auto slc_axes = slc.axes;
auto reduce_axes = start->get_operator().to_value()["axes"].to_vector<int64_t>();
// axes of slice and reduce op cannot have overlap
if(std::any_of(slc_axes.begin(), slc_axes.end(), [&](auto axis) {
return (std::find(reduce_axes.begin(), reduce_axes.end(), axis) !=
reduce_axes.end());
}))
{
return false;
}
}
else if(not op.attributes().contains("pointwise"))
{
return false;
}
return true;
}
int get_binary_op_split_idx(std::vector<instruction_ref> group,
std::vector<instruction_ref> splits) const
{
auto first_group_inputs = group.front()->inputs();
auto arg_it =
std::find_if(first_group_inputs.begin(), first_group_inputs.end(), [&](auto i) {
return std::find(splits.begin(), splits.end(), i) != splits.end();
});
auto split_idx = arg_it - first_group_inputs.begin();
// All splits are at the same input index
if(std::all_of(group.begin() + 1, group.end(), [&](auto i) {
auto split_idx_input = i->inputs().at(split_idx);
return std::find(splits.begin(), splits.end(), split_idx_input) != splits.end();
}))
return split_idx;
return -1;
}
void align_commutative_op_args(module& m,
std::vector<instruction_ref> group,
std::vector<instruction_ref> splits,
size_t split_idx) const
{
auto group_op = group.front()->get_operator();
assert(std::all_of(
group.begin(), group.end(), [&](auto i) { return i->get_operator() == group_op; }));
for(auto i : group)
{
if(std::find(splits.begin(), splits.end(), i->inputs().at(split_idx)) == splits.end())
{
auto args = i->inputs();
assert(args.size() == 2);
std::reverse(args.begin(), args.end());
m.replace_instruction(i, i->get_operator(), args);
}
}
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(m, splits))
{
auto start = group.front();
auto split_front = splits.front();
auto op = start->get_operator();
if(not is_fusable(start, split_front))
{
continue;
}
// Make sure there is no duplicates
assert(std::none_of(
std::next(group.begin()), group.end(), [&](auto i) { return i == start; }));
auto split_idx = 0;
instruction_ref c = m.end();
if(start->inputs().size() == 1)
{
c = m.insert_instruction(std::next(ins), op, ins);
}
else if(start->inputs().size() == 2)
{
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice";
}) and "one argument must be a split");
split_idx = get_binary_op_split_idx(group, splits);
assert(split_idx < 2);
size_t data_idx;
if(split_idx < 0 and op.attributes().contains("commutative"))
{
split_idx = 0;
data_idx = 1;
align_commutative_op_args(m, group, splits, split_idx);
}
else if(split_idx < 0)
{
return;
}
else
{
data_idx = split_idx == 0 ? 1 : 0;
}
std::vector<instruction_ref> data_args;
std::transform(group.begin(),
group.end(),
std::back_inserter(data_args),
[&](auto i) { return i->inputs()[data_idx]; });
// Data arguments must be a constant
if(std::any_of(data_args.begin(), data_args.end(), [](auto i) {
return not i->can_eval();
}))
return;
move_instructions_back(m, ins, data_args);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
if(slice_op.axes.size() > 1)
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
args[split_idx] = ins;
args[data_idx] = concat;
c = m.insert_instruction(std::next(ins), op, args);
}
if(c != m.end())
{
for(auto i : group)
{
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
m.replace_instruction(i, split->get_operator(), c);
}
}
}
}
};
/**
* Matcher for a sequence of "slice" operations whose outputs are put back
* together by a "concat".
*/
struct find_split_concat
{
auto matcher() const
{
auto concat = match::all_of[match::outputs()](match::name("concat"));
return match::any(match::any_of[match::outputs()](match::name("slice")(concat)));
}
void apply(module& m, const match::matcher_result& r) const
{
// Verifies that the slices meet several conditions: they must all output to the same
// concat instruction, slice on the same (1 only) axis, and the end of one slice
// must match the start of the next.
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
// Each slice must output to only one instruction
if(std::any_of(
splits.begin(), splits.end(), [](auto i) { return i->outputs().size() != 1; }))
return;
// The single output instruction for all items in the list must be the same one
auto concat = splits.front()->outputs().front();
if(std::any_of(splits.begin(), splits.end(), [&](auto i) {
return i->outputs().front() != concat;
}))
return;
// The axis for the common output instruction must be the same as for the split ops
auto concat_op = any_cast<op::concat>(concat->get_operator());
auto split_op = any_cast<op::slice>(splits.front()->get_operator());
if(split_op.axes.size() != 1)
return;
if(split_op.axes.front() != concat_op.axis)
return;
// Find where the slices are in the concat instruction's inputs (concat can have
// any number of inputs)
auto args = concat->inputs();
auto it =
std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); });
// Verify the slices were found, and the list is long enough
if(std::distance(it, args.end()) < splits.size())
return;
// Don't do anything if the "slice" inputs to the concat op have other operations mixed in
// among them
if(std::any_of(it, it + splits.size(), [](instruction_ref x) {
return x->get_operator().name() != "slice";
}))
return;
// Check that the slices passed to concat are in order.
if(not std::is_sorted(it, it + splits.size(), [](instruction_ref x, instruction_ref y) {
auto xop = any_cast<op::slice>(x->get_operator());
auto yop = any_cast<op::slice>(y->get_operator());
return std::tie(xop.starts, xop.ends) < std::tie(yop.starts, yop.ends);
}))
return;
// Perform the substitution
*it = splits.front()->inputs().front();
args.erase(std::next(it), it + splits.size());
if(args.size() == 1)
m.replace_instruction(concat, args.front());
else
m.replace_instruction(concat, concat->get_operator(), args);
}
};
bool axis_equal(const std::vector<std::size_t>& x,
const std::vector<std::size_t>& y,
std::size_t axis)
{
return x.size() == y.size() and x.size() > axis and
std::equal(x.begin(), x.begin() + axis, y.begin()) and
std::equal(x.begin() + axis + 1, x.end(), y.begin() + axis + 1);
}
bool axis_shape_equal(const shape& x, const shape& y, std::size_t axis)
{
// TODO: Check strides
return axis_equal(x.lens(), y.lens(), axis);
}
struct find_add_convs
{
auto matcher() const
{
return match::name("add")(
match::args(conv_const_weights().bind("a"), conv_const_weights().bind("b")));
}
static bool symmetrical_strides(const op::convolution& op)
{
return op.stride[0] == op.stride[1];
}
static std::size_t compute_stride_factor(const op::convolution& x, const op::convolution& y)
{
if(not symmetrical_strides(x))
return 0;
if(not symmetrical_strides(y))
return 0;
if((x.stride[0] % y.stride[0]) != 0)
return 0;
return x.stride[0] / y.stride[0];
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_conv = r.instructions["a"];
auto a_input = a_conv->inputs().at(0);
auto a_weights = a_conv->inputs().at(1);
auto b_conv = r.instructions["b"];
auto b_input = b_conv->inputs().at(0);
auto b_weights = b_conv->inputs().at(1);
if(not axis_shape_equal(a_weights->get_shape(), b_weights->get_shape(), 1))
return;
auto a_op = any_cast<op::convolution>(a_conv->get_operator());
auto b_op = any_cast<op::convolution>(b_conv->get_operator());
auto new_op = a_op;
if(a_op != b_op)
{
if(std::tie(a_op.padding, a_op.dilation, a_op.group) ==
std::tie(b_op.padding, b_op.dilation, b_op.group) and
a_weights->get_shape().lens()[2] == 1 and a_weights->get_shape().lens()[3] == 1)
{
if(a_op.stride < b_op.stride)
{
auto n = compute_stride_factor(b_op, a_op);
if(n == 0)
return;
new_op = a_op;
b_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input);
}
else if(b_op.stride < a_op.stride)
{
auto n = compute_stride_factor(a_op, b_op);
if(n == 0)
return;
new_op = b_op;
a_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input);
}
else
return;
}
else
return;
}
auto concat_input =
m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights =
m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
m.replace_instruction(ins, new_op, concat_input, concat_weights);
}
};
MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
{
auto pred = [&](auto name) {
return [=](auto i) {
return i->name() == name and i->inputs().front() == ins and
i->inputs().at(1)->can_eval();
};
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2 or qdots >= 2);
}
struct find_conv_dot_horiz_fusion
{
auto matcher() const { return horiz_conv_dot(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true;
auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size())
return false;
// Check that non-axes match
int axis = 1;
if(i->name() == "dot" or i->name() == "quant_dot")
{
axis = x.size() - 1;
}
return axis_equal(x, y, axis);
};
auto each = [&](auto start, auto last) {
if(std::distance(start, last) < 2)
return;
auto&& name = (*start)->name();
if(not contains({"quant_dot", "dot", "convolution"}, name))
return;
auto op = (*start)->get_operator();
int group = 1;
if(name == "convolution")
group = any_cast<op::convolution>(op).group;
// Skip group convolution
if(group != 1)
return;
auto input = (*start)->inputs().front();
std::vector<instruction_ref> args;
std::transform(
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1;
int concat_axis = 0;
if(name == "dot" or name == "quant_dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
}
move_instructions_back(m, input, args);
// TODO: Check if axes match
auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
offset += len;
}
};
auto outputs = ins->outputs();
group_by(outputs.begin(), outputs.end(), each, pred);
}
};
struct find_div_const
{
auto matcher() const
{
return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
if(shape::is_integral(ins->get_shape().type()))
return;
auto recip = m.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs();
m.replace_instruction(ins, make_op("mul"), args.front(), recip);
}
};
struct find_unit_ops
{
auto matcher() const
{
auto mul_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(1.0f), match::any().bind("x")));
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 0, 0), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f, 0, 0)));
return match::any_of(mul_1, div_1, add_0, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
m.replace_instruction(ins, c_in);
}
};
struct find_neg_unit_ops
{
auto matcher() const
{
auto mul_neg_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(-1.0f), match::any().bind("x")));
auto div_neg_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f)));
auto sub_0 =
match::name("sub")(match::args(match::has_value(0.0f, 0, 0), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
auto neg = m.insert_instruction(ins, make_op("neg"), c_in);
m.replace_instruction(ins, neg);
}
};
struct eliminate_zero_point
{
auto get_qlinear_ops_names() const
{
static std::unordered_set<std::string> qdq_names = {"quantizelinear", "dequantizelinear"};
return qdq_names;
}
auto matcher() const
{
return match::name(get_qlinear_ops_names())(match::arg(0)(match::any().bind("x")),
match::arg(1)(match::any().bind("scale")),
match::arg(2)(match::has_value(0.0f, 0, 0)));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto scale = r.instructions["scale"];
auto op = ins->get_operator().to_value();
if(ins->get_operator().name() == "quantizelinear")
{
op["out_type"] = to_value(ins->get_shape().type());
}
auto qdq_ins = m.insert_instruction(ins, migraphx::make_op(ins->name(), op), {x, scale});
m.replace_instruction(ins, qdq_ins);
}
};
struct find_zero_ops
{
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f, 0, 0).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f, 0, 0).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto zero_ins = r.instructions["x"];
m.replace_instruction(ins, zero_ins);
}
};
struct find_sub_const
{
auto matcher() const
{
return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto neg = m.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs();
m.replace_instruction(ins, make_op("add"), args.front(), neg);
}
};
struct find_rsqrt
{
auto matcher() const
{
auto bind_x = match::args(match::any().bind("x"));
return match::name("recip")(match::args(match::name("sqrt")(match::used_once(), bind_x)));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
m.replace_instruction(ins, make_op("rsqrt"), x_ins);
}
};
static bool same_ops(const std::vector<instruction_ref>& vec_ins)
{
return std::all_of(vec_ins.begin(), vec_ins.end(), [&](auto i) {
return i->get_operator() == vec_ins.front()->get_operator();
});
}
struct find_split_reshape
{
auto matcher() const
{
auto slice_bind_slice = match::arg(0)(match::name("slice").bind("slice"));
return match::name("reshape")(match::arg(0)(match::name("contiguous")(slice_bind_slice)))
.bind("reshape");
}
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();
// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}
auto split_outputs = get_splits(input);
if(split_outputs.empty())
{
return;
}
// Find all the reshapes (similar to rsp) that can be simplified
std::vector<instruction_ref> conts;
std::vector<instruction_ref> vec_rsp;
// Iterate through slice and contiguous outputs to allow simplifications when
// slice is followed by multiple reshapes
for(auto& i : split_outputs)
{
std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
}
for(auto& i : conts)
{
std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
}
// No simplification needed if there is only one slice -> cont -> reshape
if(vec_rsp.size() <= 1)
{
return;
}
// ensure reshape happens after the axis dimension
auto axis = axes[0];
auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end())
{
return;
}
else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// Calculate reshape output shape
// Need to find a reshape such that data represented by instructions in vec_rsp can be
// written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());
// cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});
std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
m.replace_instruction(
vec_rsp[i],
make_op(
"slice",
{{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins);
}
}
};
struct find_split_transpose
{
auto matcher() const
{
return match::name("transpose")(match::arg(0)(match::name("slice").bind("slice")))
.bind("trans");
}
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"];
auto input = slc->inputs().front();
auto split_outputs = get_splits(input);
if(split_outputs.empty())
{
return;
}
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
return i->outputs().front();
});
// all transpose are the same
auto perm = any_cast<op::transpose>(trans->get_operator()).dims;
if(not same_ops(vec_trans))
{
return;
}
// insert an transpose instruction
auto tr = m.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
auto it = std::find(perm.begin(), perm.end(), axis);
assert(it != perm.end());
int64_t axis_new = std::distance(perm.begin(), it);
for(auto in : split_outputs)
{
auto oper = any_cast<op::slice>(in->get_operator());
auto starts = oper.starts;
auto ends = oper.ends;
auto tr_orig = in->outputs().front();
m.replace_instruction(
tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr);
}
}
};
void simplify_algebra::apply(module& m) const
{
// Run simplifications multiple times
m.repeat_while_changes(8, [&] {
match::find_matches(m,
find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
find_conv_dot_horiz_fusion{},
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
find_zero_ops{},
find_dot_add{},
find_conv_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
find_concat_conv{},
find_concat_op{},
find_split_concat{},
find_splits{},
find_split_reshape{},
find_split_transpose{});
dead_code_elimination{}.apply(m);
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx