mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-19 01:17:06 +03:00
154 lines
6.0 KiB
C++
154 lines
6.0 KiB
C++
/*
|
|
* The MIT License (MIT)
|
|
*
|
|
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
|
*
|
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
* of this software and associated documentation files (the "Software"), to deal
|
|
* in the Software without restriction, including without limitation the rights
|
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
* copies of the Software, and to permit persons to whom the Software is
|
|
* furnished to do so, subject to the following conditions:
|
|
*
|
|
* The above copyright notice and this permission notice shall be included in
|
|
* all copies or substantial portions of the Software.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
* THE SOFTWARE.
|
|
*
|
|
*/
|
|
#include <migraphx/rewrite_reduce.hpp>
|
|
#include <migraphx/module.hpp>
|
|
#include <migraphx/matcher.hpp>
|
|
#include <migraphx/make_op.hpp>
|
|
#include <migraphx/common.hpp>
|
|
|
|
namespace migraphx {
|
|
inline namespace MIGRAPHX_INLINE_NS {
|
|
|
|
namespace {
|
|
struct find_softmax
|
|
{
|
|
auto matcher() const { return match::name("softmax"); }
|
|
|
|
void apply(module& m, const match::matcher_result& r) const
|
|
{
|
|
auto ins = r.result;
|
|
auto op = ins->get_operator().to_value();
|
|
auto axis = op["axis"].to<std::int64_t>();
|
|
|
|
auto input = ins->inputs().front();
|
|
auto max = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {axis}}}), input);
|
|
auto maxb = m.insert_instruction(
|
|
ins, make_op("multibroadcast", {{"out_lens", input->get_shape().lens()}}), max);
|
|
auto sub = m.insert_instruction(ins, make_op("sub"), input, maxb);
|
|
auto exp = m.insert_instruction(ins, make_op("exp"), sub);
|
|
auto sum = m.insert_instruction(ins, make_op("reduce_sum", {{"axes", {axis}}}), exp);
|
|
auto sumb = m.insert_instruction(
|
|
ins, make_op("multibroadcast", {{"out_lens", input->get_shape().lens()}}), sum);
|
|
m.replace_instruction(ins, make_op("div"), exp, sumb);
|
|
}
|
|
};
|
|
|
|
struct find_reduce_mean_variance
|
|
{
|
|
auto matcher() const
|
|
{
|
|
auto reduce_mean = match::name("reduce_mean");
|
|
auto skip_broadcasts_mean = match::skip_broadcasts(reduce_mean.bind("mean"));
|
|
auto x_minus_mean = match::name("sub")(match::arg(0)(match::any().bind("x")),
|
|
match::arg(1)(skip_broadcasts_mean));
|
|
auto pow_x_minus_mean =
|
|
match::name("pow")(match::arg(0)(x_minus_mean), match::arg(1)(match::has_value(2.0f)));
|
|
auto mul_x_minus_mean =
|
|
match::name("mul")(match::arg(0)(x_minus_mean), match::arg(1)(x_minus_mean));
|
|
auto sqdiff = match::name("sqdiff")(
|
|
match::either_arg(0, 1)(match::any().bind("x"), skip_broadcasts_mean));
|
|
return reduce_mean(
|
|
match::arg(0)(match::any_of(pow_x_minus_mean, mul_x_minus_mean, sqdiff)));
|
|
}
|
|
|
|
void apply(module& m, const match::matcher_result& r) const
|
|
{
|
|
auto ins = r.result;
|
|
auto x_ins = r.instructions["x"];
|
|
auto mean = r.instructions["mean"];
|
|
|
|
if(ins->get_operator() != mean->get_operator())
|
|
return;
|
|
|
|
if(mean->inputs().front() != x_ins)
|
|
return;
|
|
|
|
auto x2 = m.insert_instruction(ins, make_op("mul"), x_ins, x_ins);
|
|
auto mean_x2 = m.insert_instruction(ins, mean->get_operator(), x2);
|
|
auto mean_x_2 = m.insert_instruction(ins, make_op("mul"), mean, mean);
|
|
m.replace_instruction(ins, make_op("sub"), mean_x2, mean_x_2);
|
|
}
|
|
};
|
|
|
|
struct find_reduce_mean
|
|
{
|
|
auto matcher() const { return match::name("reduce_mean"); }
|
|
|
|
void apply(module& m, const match::matcher_result& r) const
|
|
{
|
|
auto ins = r.result;
|
|
auto op = ins->get_operator().to_value();
|
|
auto axes = op["axes"].to_vector<std::int64_t>();
|
|
auto input = ins->inputs().front();
|
|
|
|
bool is_integral = false;
|
|
double max_n = 0;
|
|
std::size_t size = 0;
|
|
input->get_shape().visit_type([&](auto t) {
|
|
is_integral = t.is_integral();
|
|
max_n = t.max();
|
|
size = t.size();
|
|
});
|
|
|
|
auto n = input->get_shape().elements() / ins->get_shape().elements();
|
|
|
|
// Convert accumulator to float if <= 8bit type or if < 3 bytes and n >= max_n /4
|
|
if(size == 1 or (n >= max_n / 4 and size < 3))
|
|
{
|
|
shape::type_t t = is_integral ? shape::int32_type : shape::float_type;
|
|
input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input);
|
|
}
|
|
|
|
auto n_literal = m.add_literal(literal{{input->get_shape().type(), {1}}, {n}});
|
|
if(is_integral)
|
|
{
|
|
auto reduce_sum =
|
|
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), input);
|
|
auto div = insert_common_op(m, ins, make_op("div"), {reduce_sum, n_literal});
|
|
m.replace_instruction(
|
|
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), div);
|
|
}
|
|
else
|
|
{
|
|
auto new_input = insert_common_op(m, ins, make_op("div"), {input, n_literal});
|
|
auto reduce_sum =
|
|
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), new_input);
|
|
m.replace_instruction(
|
|
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), reduce_sum);
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void rewrite_reduce::apply(module& m) const
|
|
{
|
|
match::find_matches(m, find_softmax{}, find_reduce_mean_variance{});
|
|
match::find_matches(m, find_reduce_mean{});
|
|
}
|
|
|
|
} // namespace MIGRAPHX_INLINE_NS
|
|
} // namespace migraphx
|