mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 17:14:26 +03:00
441 lines
16 KiB
C++
441 lines
16 KiB
C++
/*
|
|
* The MIT License (MIT)
|
|
*
|
|
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
* of this software and associated documentation files (the "Software"), to deal
|
|
* in the Software without restriction, including without limitation the rights
|
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
* copies of the Software, and to permit persons to whom the Software is
|
|
* furnished to do so, subject to the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be included in
|
|
* all copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
* THE SOFTWARE.
|
|
*/
|
|
#include <migraphx/fuse_reduce.hpp>
|
|
#include <migraphx/check_shapes.hpp>
|
|
#include <migraphx/dead_code_elimination.hpp>
|
|
#include <migraphx/eliminate_common_subexpression.hpp>
|
|
#include <migraphx/instruction.hpp>
|
|
#include <migraphx/iterator_for.hpp>
|
|
#include <migraphx/make_op.hpp>
|
|
#include <migraphx/matcher.hpp>
|
|
#include <migraphx/pass_manager.hpp>
|
|
#include <migraphx/program.hpp>
|
|
#include <migraphx/ranges.hpp>
|
|
#include <migraphx/register_op.hpp>
|
|
#include <migraphx/rewrite_reshapes.hpp>
|
|
#include <migraphx/param_utils.hpp>
|
|
#include <iterator>
|
|
#include <map>
|
|
|
|
namespace migraphx {
|
|
inline namespace MIGRAPHX_INLINE_NS {
|
|
|
|
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
|
|
|
|
struct fused_reduce
|
|
{
|
|
std::vector<std::int64_t> axes{};
|
|
|
|
template <class Self, class F>
|
|
static auto reflect(Self& self, F f)
|
|
{
|
|
return pack(f(self.axes, "axes"));
|
|
}
|
|
|
|
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
|
|
{
|
|
if(mods.size() != 1)
|
|
MIGRAPHX_THROW("should have one submodule.");
|
|
const auto* sm = mods.front();
|
|
if(sm->get_output_shapes().size() != 1)
|
|
MIGRAPHX_THROW("Only one output supported");
|
|
if(not sm->bypass())
|
|
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
|
|
auto names = sm->get_parameter_names();
|
|
check_shapes{inputs, *this}.has(names.size()).same_ndims();
|
|
std::sort(names.begin(), names.end());
|
|
auto shapes = sm->get_parameter_shapes();
|
|
// Check dimension matches for each input
|
|
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
|
|
return shapes.at(name).lens() == input.lens();
|
|
}))
|
|
MIGRAPHX_THROW("Input dimension does not match the submodule.");
|
|
|
|
return shape::from_permutation(sm->get_output_shapes().front().type(),
|
|
sm->get_output_shapes().front().lens(),
|
|
find_permutation(inputs));
|
|
}
|
|
|
|
std::string name() const { return "fused_reduce"; }
|
|
};
|
|
MIGRAPHX_REGISTER_OP(fused_reduce);
|
|
|
|
/*
|
|
* Predicate matcher checks that input and output shapes have the same rank. This is assumed
|
|
* for broadcast instructions for these fusions.
|
|
*/
|
|
MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins)
|
|
{
|
|
auto input_shape = ins->inputs().front()->get_shape();
|
|
auto output_shape = ins->get_shape();
|
|
return input_shape.ndim() == output_shape.ndim();
|
|
}
|
|
|
|
static auto
|
|
insert_module_in_submodule(module_ref sm,
|
|
instruction_ref ins,
|
|
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
|
|
module::inserter insert = nullptr)
|
|
{
|
|
assert(ins->module_inputs().size() == 1);
|
|
return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert));
|
|
}
|
|
|
|
static void create_reduce_modules(module_pass_manager& mpm)
|
|
{
|
|
std::size_t n = 0;
|
|
for(auto ins : iterator_for(mpm.get_module()))
|
|
{
|
|
if(not ins->get_operator().attributes().get("reduce", false))
|
|
continue;
|
|
if(ins->inputs().size() != 1)
|
|
continue;
|
|
|
|
auto* rm =
|
|
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
|
|
rm->set_bypass();
|
|
|
|
rm->add_return(rm->fuse({ins}));
|
|
auto v = ins->get_operator().to_value();
|
|
mpm.get_module().replace_instruction(
|
|
ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
instruction_ref get_broadcast_output(instruction_ref broadcast)
|
|
{
|
|
if(broadcast->outputs().size() != 1)
|
|
return broadcast;
|
|
auto output = broadcast->outputs().front();
|
|
if(output->name() == "contiguous")
|
|
return get_broadcast_output(output);
|
|
return output;
|
|
}
|
|
|
|
MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins)
|
|
{
|
|
if(ins->outputs().size() == 1)
|
|
return true;
|
|
if(ins->outputs().size() == 2)
|
|
{
|
|
auto is_broadcast = [](instruction_ref output) {
|
|
return contains(output->name(), "broadcast");
|
|
};
|
|
auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
|
|
if(broadcast == ins->outputs().end())
|
|
return false;
|
|
auto non_broadcast =
|
|
std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
|
|
if(non_broadcast == ins->outputs().end())
|
|
return false;
|
|
auto output = get_broadcast_output(*broadcast);
|
|
return output == *non_broadcast;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
} // namespace
|
|
template <class... Ms>
|
|
static auto match_broadcast(Ms... ms)
|
|
{
|
|
return match::skip(match::name("contiguous"))(
|
|
match::name("multibroadcast")(
|
|
match::arg(0)(ms...), match::used_once(), input_output_ndim_match())
|
|
.bind("broadcast"))
|
|
.bind("final_broadcast");
|
|
}
|
|
|
|
template <class... Ms>
|
|
static auto any_input(Ms... ms)
|
|
{
|
|
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
|
|
}
|
|
|
|
bool is_valid_broadcast(const instruction_ref b, const std::vector<size_t>& reduce_axes)
|
|
{
|
|
std::vector<size_t> broadcast_axes;
|
|
auto bstrides = b->get_shape().strides();
|
|
|
|
for(size_t i = 0; i < bstrides.size(); ++i)
|
|
{
|
|
if(bstrides.at(i) == 0)
|
|
broadcast_axes.push_back(i);
|
|
}
|
|
|
|
return broadcast_axes == reduce_axes;
|
|
}
|
|
|
|
template <class M>
|
|
static auto match_broadcast_axes(M m)
|
|
{
|
|
return match::make_basic_fun_matcher(
|
|
[=](match::matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
|
|
optional<instruction_ref> result = m.match(ctx, ins);
|
|
if(contains(ctx.instructions, "broadcast"))
|
|
{
|
|
instruction_ref reduce;
|
|
if(ins->get_operator().name() == "fused_reduce")
|
|
{
|
|
reduce = ins;
|
|
}
|
|
else
|
|
{
|
|
assert(contains(ctx.instructions, "reduce"));
|
|
reduce = ctx.instructions["reduce"];
|
|
}
|
|
auto axes = reduce->get_operator().to_value().at("axes").to_vector<size_t>();
|
|
auto broadcast = ctx.instructions["broadcast"];
|
|
if(not is_valid_broadcast(broadcast, axes))
|
|
return nullopt;
|
|
}
|
|
return result;
|
|
});
|
|
}
|
|
|
|
static auto match_broadcastable_input(const std::string& op, const std::string& name)
|
|
{
|
|
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
|
|
auto match_op_input = any_input(match_op, match::used_once());
|
|
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
|
|
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
|
|
}
|
|
|
|
static void finalize_reduce_module(module_ref m)
|
|
{
|
|
eliminate_common_subexpression{}.apply(*m);
|
|
dead_code_elimination{}.apply(*m);
|
|
}
|
|
|
|
namespace {
|
|
struct find_pointwise_reduce
|
|
{
|
|
auto matcher() const
|
|
{
|
|
// fused_reduce instruction with pointwise inputs.
|
|
return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise"));
|
|
}
|
|
|
|
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
|
{
|
|
auto reduce = r.result;
|
|
auto input = r.instructions["pointwise"];
|
|
const auto* pm = input->module_inputs().front();
|
|
const auto* old_rm = reduce->module_inputs().front();
|
|
|
|
auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name());
|
|
rm->set_bypass();
|
|
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
|
// Insert pointwise
|
|
auto rins = rm->fuse({input}, &map_ins).front();
|
|
map_ins[input] = rins;
|
|
|
|
if(contains(r.instructions, "broadcast"))
|
|
{
|
|
auto broadcast = r.instructions["broadcast"];
|
|
auto fbroadcast = r.instructions["final_broadcast"];
|
|
map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front();
|
|
if(fbroadcast != broadcast)
|
|
map_ins[fbroadcast] = map_ins[broadcast];
|
|
}
|
|
|
|
// Insert fused_reduce
|
|
rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins));
|
|
finalize_reduce_module(rm);
|
|
|
|
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
|
|
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
|
|
}
|
|
};
|
|
|
|
struct find_reduce_pointwise
|
|
{
|
|
|
|
auto matcher() const
|
|
{
|
|
return match::name("pointwise")(match_broadcastable_input("fused_reduce", "reduce"));
|
|
}
|
|
|
|
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
|
{
|
|
auto pw = r.result;
|
|
auto reduce = r.instructions["reduce"];
|
|
auto input = r.instructions["input"];
|
|
|
|
const auto* pm = pw->module_inputs().front();
|
|
const auto* old_rm = reduce->module_inputs().front();
|
|
auto* rm = mpm.create_module(old_rm->name() + ":" + pm->name());
|
|
rm->set_bypass();
|
|
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
|
// Copy module instructions
|
|
insert_module_in_submodule(rm, reduce, &map_ins);
|
|
if(contains(r.instructions, "broadcast"))
|
|
{
|
|
auto broadcast = r.instructions["broadcast"];
|
|
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
|
|
auto bout = rm->fuse({broadcast}, &map_ins);
|
|
map_ins[input] = bout.front();
|
|
}
|
|
else
|
|
{
|
|
map_ins[input] = rm->get_returns().front();
|
|
}
|
|
|
|
auto out = rm->fuse({pw}, &map_ins);
|
|
rm->replace_return(out);
|
|
finalize_reduce_module(rm);
|
|
|
|
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
|
|
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
|
|
}
|
|
};
|
|
|
|
struct find_reduce_reduce
|
|
{
|
|
auto matcher() const
|
|
{
|
|
return match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce"));
|
|
}
|
|
|
|
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
|
{
|
|
auto reduce1 = r.result;
|
|
auto reduce2 = r.instructions["reduce"];
|
|
auto input = r.instructions["input"];
|
|
|
|
if(reduce1->get_operator() != reduce2->get_operator())
|
|
return;
|
|
|
|
const auto* rm1 = reduce1->module_inputs().front();
|
|
const auto* rm2 = reduce2->module_inputs().front();
|
|
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
|
|
rm->set_bypass();
|
|
|
|
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
|
// Copy reduce1 instructions
|
|
insert_module_in_submodule(rm, reduce2, &map_ins);
|
|
if(contains(r.instructions, "broadcast"))
|
|
{
|
|
auto broadcast = r.instructions["broadcast"];
|
|
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
|
|
auto bout = rm->fuse({broadcast}, &map_ins);
|
|
map_ins[input] = bout.front();
|
|
}
|
|
else
|
|
{
|
|
map_ins[input] = rm->get_returns().front();
|
|
}
|
|
|
|
auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
|
|
rm->replace_return(out);
|
|
finalize_reduce_module(rm);
|
|
|
|
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
|
|
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
|
|
}
|
|
};
|
|
|
|
struct reduce_reshape : rewrite_reshapes_base
|
|
{
|
|
static std::string name() { return "fused_reduce"; }
|
|
|
|
template <class Transform>
|
|
static auto transform_op(Transform t)
|
|
{
|
|
return [=](module& m,
|
|
instruction_ref ins,
|
|
const operation& op,
|
|
const std::vector<instruction_ref>& inputs,
|
|
const std::vector<module_ref>& mod_args) {
|
|
auto new_op = t(op);
|
|
return m.insert_instruction(ins, new_op, inputs, mod_args);
|
|
};
|
|
}
|
|
|
|
template <class AxesMap>
|
|
static instruction_ref insert(module_pass_manager& mpm,
|
|
instruction_ref ins,
|
|
const std::vector<instruction_ref>& inputs,
|
|
const AxesMap& am)
|
|
{
|
|
auto op = any_cast<fused_reduce>(ins->get_operator());
|
|
std::vector<int64_t> axes;
|
|
for(auto axis : op.axes)
|
|
{
|
|
auto new_axes = am.at(axis);
|
|
axes.insert(axes.end(), new_axes.begin(), new_axes.end());
|
|
}
|
|
std::sort(axes.begin(), axes.end());
|
|
auto dims = base_dims(inputs);
|
|
auto* oldm = ins->module_inputs().front();
|
|
auto* sm = mpm.create_module(oldm->name() + "_reshape");
|
|
sm->set_bypass();
|
|
auto outs = sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {
|
|
if(contains(sop.name(), "reduce"))
|
|
return make_op(sop.name(), {{"axes", axes}});
|
|
if(sop.name() == "multibroadcast")
|
|
return make_op("multibroadcast", {{"out_lens", dims}});
|
|
assert(sop.name() == "pointwise");
|
|
return sop;
|
|
}));
|
|
sm->add_return(outs);
|
|
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
|
|
}
|
|
|
|
static std::vector<std::size_t> base_dims(const std::vector<instruction_ref>& inputs)
|
|
{
|
|
auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) {
|
|
return i->get_shape().elements();
|
|
}));
|
|
return (*input)->get_shape().lens();
|
|
}
|
|
|
|
static std::vector<std::size_t> base_dims(instruction_ref ins)
|
|
{
|
|
return base_dims(ins->inputs());
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void fuse_reduce::apply(module_pass_manager& mpm) const
|
|
{
|
|
if(enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}))
|
|
return;
|
|
create_reduce_modules(mpm);
|
|
mpm.run_pass(dead_code_elimination{});
|
|
for(int i = 0; i < 4; i++)
|
|
{
|
|
if(enable_rewrite_reshapes)
|
|
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
|
|
match::find_matches(
|
|
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
|
|
mpm.run_pass(dead_code_elimination{});
|
|
}
|
|
}
|
|
|
|
} // namespace MIGRAPHX_INLINE_NS
|
|
} // namespace migraphx
|