frigate/docker/rocm/migraphx/targets/ref/lowering.cpp
WhiteWolf84 931b31452a upload2
2025-02-03 22:44:02 +01:00

505 lines
17 KiB
C++

/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution_backwards.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/pad_calc.hpp>
#include <unordered_map>
#include <utility>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace ref {
template <typename T>
T zero(const T&)
{
return T(0);
}
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
type
make_signed(T x)
{
return x;
}
struct ref_lrn
{
op::lrn op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::lrn"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
int n_batch = output_shape.lens()[0];
int channels = output_shape.lens()[1];
int height = output_shape.lens()[2];
int width = output_shape.lens()[3];
float alphaoverarea = op.alpha / float(op.size);
int radius_lower = (op.size - 1) / 2;
int radius_upper = op.size / 2 + 1;
par_dfor(n_batch, height, width)([&](int b, int h, int w) {
float scale = 0;
dfor(channels)([&](int c) {
auto start = (c - radius_lower) < 0 ? 0 : (c - radius_lower);
auto end = (c + radius_upper) > channels ? channels : (c + radius_upper);
for(auto k = start; k < end; ++k)
{
scale += std::pow(input(b, k, h, w), 2);
}
scale *= alphaoverarea;
scale += op.bias;
scale = std::pow(scale, -op.beta);
output(b, c, h, w) = input(b, c, h, w) * scale;
});
});
});
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_lrn)
template <class V, class T, class... Ts>
void visit_quantize_impl(V&& v, T&& x, Ts&&... xs)
{
x.visit([&](auto y) { visit_all(xs...)([&](auto... ys) { v(y, ys...); }); });
}
template <class T, class... Ts>
auto visit_quantize(T&& x, Ts&&... xs)
{
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
visit_quantize_impl(v, x, xs...);
};
}
struct ref_im2col
{
op::im2col op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
static std::string name() { return "ref::im2col"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto input_shape = args[0].get_shape();
auto weights_shape = args[1].get_shape();
visit_all(result, args[0])([&](auto col, auto input) {
const std::size_t& height = input_shape.lens()[2];
const std::size_t& width = input_shape.lens()[3];
const std::size_t& channels = weights_shape.lens()[1];
const std::size_t& kernel_h = weights_shape.lens()[2];
const std::size_t& kernel_w = weights_shape.lens()[3];
const std::size_t& pad_h = op.padding[0];
const std::size_t& pad_w = op.padding[1];
const std::size_t& stride_h = op.stride[0];
const std::size_t& stride_w = op.stride[1];
long kdiv2_h = long(kernel_h) / 2;
long kdiv2_w = long(kernel_w) / 2;
// calculate output sizes
const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1;
const std::size_t col_width = (width - kernel_w + 2 * pad_w) / stride_w + 1;
// account for padding for the starting position of the input pixels
long iinput = kdiv2_h - long(pad_h);
// loop over output pixels (ioutput, joutput)
for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h)
{
long jinput = kdiv2_w - long(pad_w);
for(std::size_t joutput = 0; joutput < col_width; joutput++, jinput += stride_w)
{
// compute linear index for output
std::size_t ldx = ioutput * col_width + joutput;
std::size_t p = 0;
dfor(channels,
kernel_h,
kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) {
auto idx = iinput + long(koffset) - kdiv2_h;
auto jdx = jinput + long(loffset) - kdiv2_w;
col(ldx, p) =
((idx >= 0) and (idx < height) and (jdx >= 0) and (jdx < width))
? input(0, c, idx, jdx)
: 0;
p++;
});
}
}
});
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_im2col)
struct ref_op
{
operation op = op::identity{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::op"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const
{
return op.compute(output_shape, args);
}
value to_value() const
{
value v;
v["name"] = op.name();
v["operator"] = op.to_value();
return v;
}
void from_value(const value& v)
{
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
}
friend std::ostream& operator<<(std::ostream& os, const ref_op& x)
{
os << "ref::" << x.op;
return os;
}
};
MIGRAPHX_REGISTER_OP(ref_op)
struct ref_pad
{
op::pad op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::pad"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
result.visit([&](auto output) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), pad_clamp<type>(op.value));
});
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](const auto& idx) {
std::vector<std::size_t> new_idx(idx.size());
std::transform(
idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) {
return i + j;
});
output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end());
});
});
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_pad)
struct ref_gemm
{
op::dot op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_gemm)
struct ref_quant_gemm
{
op::quant_dot op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::quant_dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto cmat) {
visit_all(args.at(0), args.at(1))(
[&](auto amat, auto bmat) { return gemm(cmat, amat, bmat, 1.0f, 0.0f); });
});
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_gemm)
template <class Op>
struct ref_softmax : auto_register_op<ref_softmax<Op>>
{
ref_softmax() = default;
ref_softmax(Op pop) : op(std::move(pop)) {}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
auto batch_lens = dyn_out.computed_shape.lens();
int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name());
std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
using value_type = accumulator_type<typename decltype(input)::value_type>;
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
par_for(batch_shape.elements(), [&](auto i) {
auto idx = batch_shape.multi(i);
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[tuned_axis] = j;
batch_max[i] =
std::max<value_type>(batch_max[i], input(idx.begin(), idx.end()));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[tuned_axis] = j;
std::size_t index = dyn_out.computed_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]);
}
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[tuned_axis] = j;
batch_sum[i] += output(idx.begin(), idx.end());
}
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[tuned_axis] = j;
output(idx.begin(), idx.end()) =
op.output()(output(idx.begin(), idx.end()), batch_sum[i]);
}
});
});
return result;
}
};
struct ref_rnn_var_sl_last_output
{
op::rnn_var_sl_last_output op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::rnn_var_sl_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto out_comp_lens = args[0].get_shape().lens();
out_comp_lens[0] = 1;
shape out_comp_s{output_shape.type(), out_comp_lens};
visit_all(result, args[0])([&](auto output, auto input) {
args[1].visit([&](auto seq_lens) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = out_comp_s.multi(i);
auto b = idx[2];
if(op.direction == op::rnn_direction::reverse or idx[1] == 1)
{
idx[0] = 0;
}
else
{
idx[0] = seq_lens[b] - 1;
}
output[i] = input(idx.begin(), idx.end());
});
});
});
return result;
}
};
MIGRAPHX_REGISTER_OP(ref_rnn_var_sl_last_output)
struct ref_apply
{
module* mod;
std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};
template <class T>
auto simple_op()
{
return [this](instruction_ref ins) { apply_simple_op<T>(ins); };
}
template <class T, class Op>
auto extend_op()
{
return [this](instruction_ref ins) { apply_extend_op<T, Op>(ins); };
}
void init()
{
apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
apply_map["pad"] = extend_op<ref_pad, op::pad>();
apply_map["softmax"] = extend_op<ref_softmax<op::softmax>, op::softmax>();
apply_map["rnn_var_sl_last_output"] =
extend_op<ref_rnn_var_sl_last_output, op::rnn_var_sl_last_output>();
}
void apply()
{
init();
for(auto it : iterator_for(*mod))
{
if(apply_map.count(it->name()) > 0)
{
apply_map.at(it->name())(it);
}
else if(is_context_free(it->get_operator()))
{
apply_ref_op(it);
}
}
}
void apply_ref_op(instruction_ref ins) const
{
mod->replace_instruction(ins, ref_op{ins->get_operator()}, ins->inputs());
}
template <class T>
void apply_simple_op(instruction_ref ins)
{
mod->replace_instruction(ins, T{}, ins->inputs());
}
template <class T, class Op>
void apply_extend_op(instruction_ref ins)
{
auto&& op = any_cast<Op>(ins->get_operator());
mod->replace_instruction(ins, T{op}, ins->inputs());
}
};
void lowering::apply(module& m) const { ref_apply{&m}.apply(); }
} // namespace ref
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx