Add files via upload

This commit is contained in:
WhiteWolf84 2025-02-03 20:53:47 +01:00 committed by GitHub
parent 8e5fe2703a
commit 58e9831aef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
90 changed files with 22996 additions and 0 deletions

View File

@ -0,0 +1,70 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void adjust_allocation::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
// skip instruction with no input
if(ins->inputs().empty())
continue;
// Skip target-independent operators
if(ins->get_operator().is_context_free())
continue;
auto alias_ins = instruction::get_output_alias(ins, true);
if(alias_ins->name() != model.name() and alias_ins->name() != "@param")
continue;
// shape allocated is different from actual shape
// of the instruction, reallocate and replace the previous one
if(alias_ins->get_shape() == ins->get_shape())
continue;
auto alloc_ins = m.insert_instruction(ins, model.allocate(ins->get_shape()));
m.replace_instruction(alias_ins, alloc_ins);
// If the memory is an output parameter then copy the memory to the parameter
if(alias_ins->name() == "@param")
{
auto copy = m.insert_instruction(std::next(ins), make_op(model.copy()), ins, alias_ins);
auto tail = range(std::next(copy), m.end());
for(auto i : iterator_for(tail))
{
if(contains(i->inputs(), ins))
instruction::replace_argument(i, ins, copy);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,114 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/analyze_streams.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/errors.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::size_t>& e2)
{
return std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::less_equal<>{}) and
not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{});
}
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm)
{
using vector_clock = std::vector<std::size_t>;
std::vector<stream_race> races;
auto nstream = strmm.get_nstream();
std::vector<vector_clock> vclock(nstream, vector_clock(nstream));
std::unordered_map<instruction_ref, vector_clock> timestamp;
std::unordered_map<std::size_t, vector_clock> events;
for(auto ins : iterator_for(m))
{
if(not strmm.has_stream(ins))
continue;
std::size_t s = strmm.get_stream(ins);
assert(s < nstream);
assert(vclock.size() == nstream);
assert(vclock[s].size() == nstream);
if(strmm.is_record(ins))
{
vclock[s][s]++;
auto event = strmm.get_event_id(ins);
events[event] = vclock[s];
}
else if(strmm.is_wait(ins))
{
auto event = strmm.get_event_id(ins);
if(not contains(events, event))
MIGRAPHX_THROW("Event is waited on before being recorded: " +
std::to_string(event));
auto payload = events.at(event);
assert(vclock[s].size() == payload.size());
std::transform(vclock[s].begin(),
vclock[s].end(),
payload.begin(),
vclock[s].begin(),
[&](auto x, auto y) { return std::max(x, y); });
vclock[s][s]++;
}
else
{
vclock[s][s]++;
}
timestamp[ins] = vclock[s];
}
for(auto ins : iterator_for(m))
{
if(not strmm.has_stream(ins))
continue;
if(ins->inputs().empty())
continue;
std::size_t s = strmm.get_stream(ins);
// Find inputs from different streams
std::vector<instruction_ref> inputs;
fix([&](auto self, auto start) {
for(auto input : start->inputs())
{
if(not strmm.has_stream(input))
self(input);
else if(strmm.get_stream(input) != s)
inputs.push_back(input);
}
})(ins);
auto it = std::find_if(inputs.begin(), inputs.end(), [&](auto input) {
return not happens_before(timestamp.at(input), timestamp.at(ins));
});
if(it != inputs.end())
{
races.push_back({ins, *it});
}
}
return races;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,79 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/apply_alpha_beta.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_apply_alpha_beta(module& m,
instruction_ref pos,
const std::vector<instruction_ref>& args,
const operation& op,
const literal& alpha,
const literal& beta)
{
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(not float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
if(a->get_shape().type() != input_type)
{
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
}
}
auto op_res = m.insert_instruction(pos, op, a, b);
if(args.size() == 3)
{
if(not float_equal(beta.at<float>(0), 0.0) and args[2]->get_shape().elements() > 0)
{
auto out_lens = op_res->get_shape().lens();
auto c = args[2];
auto c_lens = c->get_shape().lens();
input_type = c->get_shape().type();
if(out_lens != c_lens)
{
c = m.insert_instruction(
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
auto beta_literal = m.add_literal(beta);
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
if(beta_c->get_shape().type() != input_type)
{
beta_c = m.insert_instruction(
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
}
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
}
}
return op_res;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,210 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
argument::argument(const shape& s) : m_shape(s)
{
auto buffer = make_shared_array<char>(s.bytes());
assign_buffer({[=]() mutable { return buffer.get(); }});
}
argument::argument(shape s, std::nullptr_t)
: m_shape(std::move(s)), m_data({[] { return nullptr; }})
{
}
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
void argument::assign_buffer(std::function<char*()> d)
{
const shape& s = m_shape;
if(s.type() != shape::tuple_type)
{
m_data = {std::move(d)};
return;
}
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
std::size_t i = 0;
fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
{
shapes[i] = ss;
i++;
}
else
{
for(auto&& child : ss.sub_shapes())
self(child);
}
})(s);
}
// Sort by type size
std::vector<std::size_t> order(shapes.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), by(std::greater<>{}, [&](auto i) {
return shapes[i].type_size();
}));
// Compute offsets
std::unordered_map<std::size_t, std::size_t> offsets;
std::size_t offset = 0;
for(auto i : order)
{
offsets[i] = offset;
offset += shapes[i].bytes();
}
assert(offset == s.bytes());
std::size_t i = 0;
m_data = fix<data_t>([&](auto self, auto ss) {
data_t result;
if(ss.sub_shapes().empty())
{
auto n = offsets[i];
result = {[d, n]() mutable { return d() + n; }};
i++;
return result;
}
std::vector<data_t> subs;
std::transform(ss.sub_shapes().begin(),
ss.sub_shapes().end(),
std::back_inserter(subs),
[&](auto child) { return self(child); });
result.sub = subs;
return result;
})(s);
}
std::vector<argument> flatten(const std::vector<argument>& args)
{
std::vector<argument> result;
for(const auto& arg : args)
{
if(arg.get_shape().type() == shape::tuple_type)
{
auto subs = flatten(arg.get_sub_objects());
result.insert(result.end(), subs.begin(), subs.end());
}
else
{
result.push_back(arg);
}
}
return result;
}
std::vector<shape> to_shapes(const std::vector<argument>& args)
{
std::vector<shape> shapes;
std::transform(args.begin(), args.end(), std::back_inserter(shapes), [](auto&& arg) {
return arg.get_shape();
});
return shapes;
}
argument::argument(const std::vector<argument>& args)
: m_shape(to_shapes(args)), m_data(data_t::from_args(args))
{
}
char* argument::data() const
{
assert(m_shape.type() != shape::tuple_type);
assert(not this->empty());
return m_data.get();
}
bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const
{
assert(s.element_space() <= this->get_shape().element_space());
return {s, this->m_data};
}
argument::data_t argument::data_t::share() const
{
data_t result;
if(this->get)
{
auto self = std::make_shared<data_t>(*this);
result.get = [self]() mutable { return self->get(); };
}
std::transform(sub.begin(), sub.end(), std::back_inserter(result.sub), [](const auto& d) {
return d.share();
});
return result;
}
argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
{
data_t result;
std::transform(args.begin(), args.end(), std::back_inserter(result.sub), [](auto&& arg) {
return arg.m_data;
});
return result;
}
argument argument::copy() const
{
argument result{this->get_shape()};
auto* src = this->data();
std::copy(src, src + this->get_shape().bytes(), result.data());
return result;
}
argument argument::share() const { return {m_shape, m_data.share()}; }
std::vector<argument> argument::get_sub_objects() const
{
std::vector<argument> result;
assert(m_shape.sub_shapes().size() == m_data.sub.size());
std::transform(m_shape.sub_shapes().begin(),
m_shape.sub_shapes().end(),
m_data.sub.begin(),
std::back_inserter(result),
[](auto&& s, auto&& d) {
return argument{s, d};
});
return result;
}
argument argument::element(std::size_t i) const
{
assert(this->get_shape().sub_shapes().empty());
auto idx = this->get_shape().index(i);
auto offset = this->get_shape().type_size() * idx;
return argument{shape{this->get_shape().type()}, this->data() + offset};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,80 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& m) const
{
std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(m))
{
auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false)))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return m.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(contains({"layout", "@return"}, ins->name()))
continue;
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
// If s is not standard layout or has out of sequence strides, insert "contiguous" op
// to make a standard shape
if(not s.dynamic() and (not s.standard() or s.normalize_standard() != s) and
s.elements() > 1)
{
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,82 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/autocast_fp8.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/fp8_types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void autocast_fp8_pass::apply(module& m) const
{
std::vector<instruction_ref> remove_parameters;
for(auto ins : iterator_for(m))
{
const auto& ins_name = ins->name();
if(ins_name == "@param" and contains(fp8_types{}.get(), ins->get_shape().type()))
{
shape::type_t fp8_type = ins->get_shape().type();
migraphx::shape new_shape = ins->get_shape().with_type(target_type);
std::string param_name = ins->get_operator().to_value()["parameter"].to<std::string>();
m.rename_parameter(ins, param_name + "_old");
auto new_param = m.add_parameter(param_name, new_shape);
auto new_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(fp8_type)}}),
new_param);
m.replace_instruction(ins, new_ins);
remove_parameters.push_back(ins);
}
if(ins_name == "@return")
{
std::vector<instruction_ref> inputs = ins->inputs();
std::vector<instruction_ref> new_inputs;
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [&](auto i) {
if(contains(fp8_types{}.get(), i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
return i;
});
m.replace_return({new_inputs});
}
}
// Remove unused parameters with fp8 type
for(const auto& i : remove_parameters)
m.remove_instruction(i);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,81 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/base64.hpp>
#include <vector>
#include <array>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
using byte = unsigned char;
std::array<char, 64> constexpr b64_chars{
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
/// base64 encoder snippet altered from https://stackoverflow.com/a/37109258
std::string encode(const std::vector<byte>& buf)
{
std::size_t len = buf.size();
std::vector<byte> res_vec((len + 2) / 3 * 4, '=');
std::size_t j = 0;
std::size_t remaining = len % 3;
const size_t last = len - remaining;
for(size_t i = 0; i < last; i += 3)
{
std::size_t n = static_cast<std::size_t>(buf.at(i)) << 16u |
static_cast<std::size_t>(buf.at(i + 1)) << 8u |
static_cast<std::size_t>(buf.at(i + 2));
res_vec.at(j++) = b64_chars.at(n >> 18u);
res_vec.at(j++) = b64_chars.at(n >> 12u & 0x3Fu);
res_vec.at(j++) = b64_chars.at(n >> 6u & 0x3Fu);
res_vec.at(j++) = b64_chars.at(n & 0x3Fu);
}
// Set padding
if(remaining != 0)
{
std::size_t n = --remaining == 0 ? static_cast<std::size_t>(buf.at(last))
: static_cast<std::size_t>(buf.at(last)) << 8u |
static_cast<std::size_t>(buf.at(last + 1));
res_vec.at(j++) = b64_chars.at(remaining == 0 ? n >> 2u : n >> 10u & 0x3Fu);
res_vec.at(j++) = b64_chars.at(remaining == 0 ? n << 4u & 0x3Fu : n >> 4u & 0x03Fu);
res_vec.at(j++) = remaining == 0 ? '=' : b64_chars.at(n << 2u & 0x3Fu);
}
return {res_vec.begin(), res_vec.end()};
}
} // namespace
std::string base64_encode(const std::string& str)
{
return encode(std::vector<byte>(str.begin(), str.end()));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,254 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
if(s0 == s1)
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
"} and {" + migraphx::to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
std::vector<shape::dynamic_dimension>
compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
std::vector<shape::dynamic_dimension> dds1)
{
if(dds0.size() > dds1.size())
{
std::swap(dds0, dds1);
}
auto offset = dds1.size() - dds0.size();
std::vector<shape::dynamic_dimension> out_dims(dds1);
std::transform(dds0.cbegin(),
dds0.cend(),
dds1.cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b or b == 1)
{
return a;
}
else if(a == 1)
{
return b;
}
else
{
auto intersect = a.intersection(b);
if(intersect.has_value())
{
return intersect.value();
}
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(dds0) + "} and {" +
migraphx::to_string_range(dds1) + "} mismatch!");
}
});
return out_dims;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
return compute_broadcasted_dyn_dims(s0.dyn_dims(), s1.dyn_dims());
}
std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
{
auto ret_shape = shapes.at(0);
std::for_each(shapes.cbegin() + 1, shapes.cend(), [&](auto s) {
ret_shape = shape{ret_shape.type(), compute_broadcasted_dyn_dims(ret_shape, s)};
});
return ret_shape.dyn_dims();
}
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
assert(
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
return transform_accumulate(shapes.begin() + 1,
shapes.end(),
shapes.front().lens(),
&compute_broadcasted_lens,
[](auto s) { return s.lens(); });
}
shape::type_t compute_common_type(shape::type_t t1, shape::type_t t2)
{
if(t1 == t2)
return t1;
shape::type_t result;
shape::visit(t1, [&](auto x) {
shape::visit(t2, [&](auto y) {
// Workaround broken warning on gcc 5
(void)x;
(void)y;
using type = std::common_type_t<decltype(x()), decltype(y())>;
result = shape::get_type<type>{};
});
});
return result;
}
shape::type_t compute_common_types(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
return transform_accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().type(), &compute_common_type, [&](auto s) {
return s.type();
});
}
shape common_shape(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
return {compute_common_types(shapes), compute_common_lens(shapes)};
}
std::vector<instruction_ref> insert_common_args(module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
common_options options)
{
if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{
auto input_shapes = to_shapes(inputs);
if(options.common_lens)
{
auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
auto s0 = inputs[0]->get_shape();
// always add both multibroadcast instructions for dynamic shapes
inputs[0] = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
auto s = input->get_shape();
return m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
input,
inputs[0]);
});
}
if(options.common_type)
{
auto c_type = compute_common_types(input_shapes);
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
}
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(options.common_lens and input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(options.common_type and input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
}
return inputs;
}
std::vector<instruction_ref>
add_common_args(module& m, std::vector<instruction_ref> inputs, common_options options)
{
return insert_common_args(m, m.end(), std::move(inputs), options);
}
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs,
common_options options)
{
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs), options));
}
instruction_ref add_common_op(module& m,
const operation& op,
std::vector<instruction_ref> inputs,
common_options options)
{
return insert_common_op(m, m.end(), op, std::move(inputs), options);
}
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
{
assert(not input_shape.dynamic());
auto offset = bcast_lens.size() - input_shape.ndim();
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i : reverse(range(input_shape.ndim())))
{
if(bcast_lens.at(i + offset) == input_shape.lens()[i])
{
bcast_strides.at(i + offset) = input_shape.strides()[i];
}
}
return shape{input_shape.type(), bcast_lens, bcast_strides};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,203 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x > dim;
});
if(x < dim)
return start;
return it;
}
struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims,
std::vector<std::vector<std::size_t>>& paxes_map)
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
{
}
const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1;
std::size_t get() const { return *it / rem; }
bool is_end() const { return it == dims->end(); }
void next(std::size_t i = 1) { it += i; }
auto dims_for(std::size_t d) const
{
auto dim_end = compute_end_dim(it, dims->end(), d);
return range(it, dim_end);
}
void add_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
axes_map->push_back(std::move(axes));
}
void add_multi_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
std::transform(axes.begin(),
axes.end(),
std::back_inserter(*axes_map),
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
}
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{
if(rem != 1)
{
assert(start > 0);
naxes++;
start--;
}
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), start);
return axes;
}
};
static bool compute_common_dim(std::vector<std::size_t>& cd_dims,
common_dim_state& state1,
common_dim_state& state2)
{
assert(state1.get() < state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);
if(naxes == 0)
return false;
// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return false;
auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size());
state1.rem = rem;
state2.rem = 1;
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state2.next();
return true;
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) > 0);
assert(elements(dims1) == elements(dims2));
common_dims cd;
common_dim_state state1{dims1, cd.axes_map1};
common_dim_state state2{dims2, cd.axes_map2};
while(not state1.is_end() and not state2.is_end())
{
auto d1 = state1.get();
auto d2 = state2.get();
if(d1 == d2)
{
state1.add_axes(1, cd.dims.size());
state2.add_axes(1, cd.dims.size());
state1.rem = 1;
state2.rem = 1;
cd.dims.push_back(d1);
state1.next();
state2.next();
}
else if(d1 < d2)
{
if(not compute_common_dim(cd.dims, state1, state2))
return {};
}
else // if(d1 > d2)
{
if(not compute_common_dim(cd.dims, state2, state1))
return {};
}
}
assert(elements(dims1) == elements(cd.dims));
return cd;
}
const std::vector<std::vector<std::size_t>>* common_dims::get_axes_map(std::size_t n) const
{
if(axes_map1.size() == n)
return &axes_map1;
if(axes_map2.size() == n)
return &axes_map2;
return nullptr;
}
std::vector<std::size_t>
common_dims::get_dimensions_for(const std::vector<std::size_t>& idims) const
{
if(dims.size() == idims.size())
return idims;
if(elements(dims) == elements(idims))
return dims;
// Bail for now since its ambiguous which axes map can be used
// TODO: Check for similiarity
if(axes_map1.size() == axes_map2.size())
return {};
const auto* axes_map = get_axes_map(idims.size());
if(axes_map == nullptr)
return {};
auto xdims = dims;
for(auto i : range(axes_map->size()))
{
auto dim = idims[i];
const auto& axes = (*axes_map)[i];
if(axes.size() == 1)
{
xdims[axes.front()] = dim;
}
else if(dim == 1)
{
for(auto axis : axes)
xdims[axis] = 1;
}
}
if(elements(xdims) == elements(idims))
return xdims;
return {};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,76 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/compile_src.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/fileutils.hpp>
#include <vector>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
assert(not srcs.empty());
tmp_dir td{"compile"};
std::vector<std::string> params{flags};
params.emplace_back("-I.");
auto out = output;
for(const auto& src : srcs)
{
fs::path full_path = td.path / src.path;
fs::path parent_path = full_path.parent_path();
fs::create_directories(parent_path);
write_buffer(full_path, src.content.data(), src.content.size());
if(src.path.extension().string() == ".cpp")
{
params.emplace_back(src.path.filename().string());
if(out.empty())
out = src.path.stem().string() + out_ext;
}
}
params.emplace_back("-o " + out);
std::vector<std::string> args;
if(not launcher.empty())
args.push_back(compiler.string());
args.insert(args.end(), params.begin(), params.end());
td.execute(launcher.empty() ? compiler : launcher, args);
auto out_path = td.path / out;
if(not fs::exists(out_path))
MIGRAPHX_THROW("Output file missing: " + out);
return read_buffer(out_path);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,105 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <string>
#include <sstream>
#include <migraphx/errors.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/lexing.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<std::string_view> json_tokenize(const std::string& s)
{
std::vector<lexer> lexers;
// Quote
lexers.push_back([](const char* start, const char* end) {
if(*start != '\"')
return start;
++start;
while((start != end) and (*start != '\"'))
{
if(*start == '\\')
start++;
start++;
}
return ++start;
});
// Line comments
lexers.push_back([](const char* start, const char* end) {
if(*start == '#')
start++;
else if((start + 1) < end and start[0] == '/' and start[1] == '/')
start += 2;
else
return start;
return std::find_if(start, end, [&](char c) { return c == '\n'; });
});
// Whitespace
lexers.push_back(lex_while(&isspace));
// Punctation
lexers.push_back(lex_if(&ispunct));
// Identifier/number
lexers.push_back(lex_while([](char c) {
return (isalnum(c) != 0 or contains({'_', '.', '+'}, c));
}));
return tokenize(s.data(), s.data() + s.length(), lexers);
}
std::string convert_to_json(const std::string& str)
{
auto tokens = json_tokenize(str);
std::stringstream ss;
for(auto& token : tokens)
{
std::string s(token);
if(starts_with(s, "#") or starts_with(s, "//"))
continue;
if(std::isalpha(s.front()) != 0 and
not contains({"null", "nan", "true", "false", "inf"}, s))
{
ss << "\"" << s << "\"";
}
else
{
ss << s;
}
}
return ss.str();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,284 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/iterator_for.hpp>
#include <map>
#include <sstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
cpp_generator::function&
cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g)
{
const std::string prefix = "zz";
std::unordered_map<migraphx::instruction_ref, std::string> names;
std::stringstream ss;
auto return_ins = std::prev(m.end());
for(auto ins : iterator_for(m))
{
ss << "// " << ins->get_operator() << " -> " << ins->get_shape() << "\n";
if(ins->name() == "@param")
{
names[ins] = to_c_id(
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter);
}
else if(ins->name() == "@return")
{
names[ins] = prefix + "return";
ss << "auto " << names[ins] << " = " << g(ins, names) << ";\n";
return_ins = ins;
}
else
{
std::string n = prefix + std::to_string(names.size());
names[ins] = n;
ss << "auto " << n << " = " << g(ins, names) << ";\n";
}
}
ss << "return " << names.at(return_ins) << ";\n";
body = ss.str();
return *this;
}
cpp_generator::function& cpp_generator::function::set_types(const module& m)
{
return cpp_generator::function::set_types(m, [](auto s) { return shape::cpp_type(s.type()); });
}
cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, parse(p.second)};
});
auto output_shapes = m.get_output_shapes();
assert(not output_shapes.empty());
this->return_type = parse(output_shapes.front());
return *this;
}
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + to_c_id(p.first)};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + to_c_id(p.first); });
this->return_type = "auto";
return *this;
}
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
{
body.insert(0, "(void)" + pname + ";\n");
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{
params.push_back({pname, "T" + pname});
tparams.push_back("class T" + pname);
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
bool always_return_tuple = false;
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
cpp_generator::cpp_generator(cpp_generator&&) noexcept = default;
cpp_generator& cpp_generator::operator=(cpp_generator rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::always_return_tuple(bool b) { impl->always_return_tuple = b; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args)
{
auto v = op.to_value();
std::string code;
if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes();
if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
code = attributes["point_op"].to<std::string>();
}
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last});
if(key.empty())
MIGRAPHX_THROW("Empty parameter");
std::string fselector = "function:";
if(starts_with(key, fselector))
{
auto fname = key.substr(fselector.size());
if(impl->fmap == nullptr)
return fname;
else
return impl->fmap(fname);
}
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
if(i >= args.size())
MIGRAPHX_THROW("Invalid argument index: " + key);
return args.at(i);
}
else if(v.contains(key))
{
return v[key].template to<std::string>();
}
else
{
return key;
}
});
}
std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g)
{
function f;
f.set_name(to_c_id(m.name()))
.set_types(m)
.set_body(m, [&](instruction_ref ins, const auto& names) -> std::string {
if(ins->name() == "@literal")
{
std::string string_literal;
ins->get_literal().visit([&](auto v) {
assert(v.size() == 1);
auto x = v.front();
if(std::isinf(static_cast<double>(x)))
{
string_literal = "__builtin_huge_val()";
if(x < 0)
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(static_cast<double>(x)))
string_literal = "__builtin_nan(\"0\")";
else
string_literal = ins->get_literal().to_string();
});
return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")";
}
if(ins->name() == "@return")
{
// TODO: Customize the make_tuple call
if(impl->always_return_tuple or ins->inputs().size() != 1)
return "make_tuple(" + join_strings(to_args(ins->inputs(), names), ", ") + ")";
return names.at(ins->inputs().front());
}
auto s = g(ins, names);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
});
return f;
}
std::vector<std::string>
cpp_generator::to_args(const std::vector<instruction_ref>& inputs,
const std::unordered_map<instruction_ref, std::string>& names)
{
std::vector<std::string> args;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) {
return names.at(i);
});
return args;
}
cpp_generator::function cpp_generator::generate_module(const module& m)
{
return this->generate_module(m, [&](auto ins, const auto& names) {
return this->generate_point_op(ins->get_operator(), to_args(ins->inputs(), names));
});
}
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
if(f.params.empty())
impl->fs << delim;
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << to_c_id(p.name);
delim = ',';
}
impl->fs << ") {\n" << f.body << "\n}\n";
return name;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,86 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const
{
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == m.begin())
continue;
const auto i = std::prev(ins);
// Skip the last instruction
if(i == last)
break;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate or tuple_type]
if((not i->get_shape().dynamic() and
(i->get_shape().elements() == 0 and
i->get_shape().type() != migraphx::shape::tuple_type)) and
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined())
continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) {
if(not m.has_instruction(leaf))
return;
if(leaf->outputs().empty())
{
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end());
leaf->clear_arguments();
assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins);
if(leaf->name() != "@param")
m.move_instruction(leaf, m.end());
for(auto arg : args)
self(arg);
}
})(i);
}
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,100 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) const
{
if(ins1 == ins2)
return false;
auto iter = ins2idom.find(ins2);
while(iter != ins2idom.end())
{
if(ins1 == iter->second)
return true;
assert(iter != ins2idom.find(iter->second));
iter = ins2idom.find(iter->second);
}
return false;
}
struct module_visitor
{
const module* mm;
const module& get_nodes() const { return *mm; }
const std::vector<instruction_ref>& get_children(instruction_ref ins) { return ins->inputs(); }
};
template <class Visitor>
dominator_info compute_dominator_generic(Visitor v)
{
dominator_info info;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> instr2_doms;
for(instruction_ref ins : iterator_for(v.get_nodes()))
{
const std::vector<instruction_ref>& children = v.get_children(ins);
if(children.size() == 1)
{
info.ins2idom[ins] = children.front();
instr2_doms[ins] = instr2_doms[children.front()];
}
else if(children.size() > 1)
{
auto&& doms = instr2_doms[ins];
doms = instr2_doms[children.front()];
std::for_each(children.begin() + 1, children.end(), [&](instruction_ref child) {
auto&& child_doms = instr2_doms[child];
erase_if(doms, [&](auto x) { return not contains(child_doms, x); });
});
auto iter = std::find_if(doms.begin(), doms.end(), [&](auto dom1) {
return std::none_of(doms.begin(), doms.end(), [&](auto dom2) {
if(dom1 == dom2)
return false;
return info.strictly_dominate(dom1, dom2);
});
});
if(iter != doms.end())
info.ins2idom[ins] = *iter;
}
instr2_doms[ins].insert(ins);
}
return info;
}
dominator_info compute_dominator(const module& m)
{
return compute_dominator_generic(module_visitor{&m});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,203 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/manage_ptr.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
#include <utility>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <dlfcn.h>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifndef _WIN32
void check_load_error(bool flush = false)
{
char* error_msg = dlerror();
if(not flush and error_msg != nullptr)
MIGRAPHX_THROW("Dynamic loading or symbol lookup failed with " + std::string(error_msg));
}
struct dynamic_loader_impl
{
dynamic_loader_impl() = default;
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t))
{
check_load_error();
}
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{
auto t = std::make_shared<tmp_dir>("dloader");
auto f = t->path / "libtmp.so";
write_buffer(f, image, size);
return std::make_shared<dynamic_loader_impl>(f, t);
}
std::shared_ptr<void> handle = nullptr;
std::shared_ptr<tmp_dir> temp = nullptr;
};
fs::path dynamic_loader::path(void* address)
{
fs::path p;
Dl_info info;
// Find the location of .so
if(dladdr(address, &info) != 0)
p = info.dli_fname;
return p;
}
#else
struct dynamic_loader_impl
{
dynamic_loader_impl() = default;
dynamic_loader_impl(const fs::path& p, tmp_dir t = {})
: handle{LoadLibrary(p.string().c_str())}, temp{std::move(t)}
{
if(handle == nullptr)
{
MIGRAPHX_THROW("Error loading DLL: " + p.string() + " (" +
std::to_string(GetLastError()) + ")");
}
}
dynamic_loader_impl(const dynamic_loader_impl&) = delete;
dynamic_loader_impl& operator=(const dynamic_loader_impl&) = delete;
dynamic_loader_impl(dynamic_loader_impl&&) = default;
~dynamic_loader_impl()
{
if(handle != nullptr)
{
FreeLibrary(handle);
}
}
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{
auto t = tmp_dir{"migx-dynload"};
auto f = t.path / "tmp.dll";
write_buffer(f, image, size);
return std::make_shared<dynamic_loader_impl>(f, std::move(t));
}
HMODULE handle = nullptr;
tmp_dir temp;
};
fs::path dynamic_loader::path(void* address)
{
HMODULE module = nullptr;
if(GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
static_cast<LPCSTR>(address),
&module) == 0)
{
auto err = GetLastError();
MIGRAPHX_THROW("Unable to obtain module handle, error = " + std::to_string(err));
}
TCHAR buffer[MAX_PATH];
if(GetModuleFileName(module, buffer, sizeof(buffer)) == 0)
{
auto err = GetLastError();
MIGRAPHX_THROW("Unable to read module file path, error = " + std::to_string(err));
}
if(GetLastError() == ERROR_INSUFFICIENT_BUFFER)
{
MIGRAPHX_THROW("Buffer too small (" + std::to_string(MAX_PATH) + ") to hold the path");
}
return {buffer};
}
#endif
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{
try
{
return dynamic_loader{p};
}
catch(const std::exception&)
{
return nullopt;
}
}
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
{
}
dynamic_loader::dynamic_loader(const char* image, std::size_t size)
: impl(dynamic_loader_impl::from_buffer(image, size))
{
}
dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
: impl(dynamic_loader_impl::from_buffer(buffer.data(), buffer.size()))
{
}
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{
#ifndef _WIN32
// flush any previous error messages
check_load_error(true);
void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr)
check_load_error();
return {impl, symbol};
#else
FARPROC addr = GetProcAddress(impl->handle, name.c_str());
if(addr == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name + " (" + std::to_string(GetLastError()) + ")");
return {impl, reinterpret_cast<void*>(addr)};
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,66 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(module& m) const
{
assert(alignment > 0);
std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes();
std::size_t padding = (alignment - (size % alignment)) % alignment;
n += size + padding;
}
if(n > 0)
{
auto mem = m.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs)
{
auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,76 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/functional.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Range>
void cse_range(module& m, Range&& r)
{
std::unordered_multimap<std::string, instruction_ref> instructions;
std::unordered_set<instruction_ref> processed_ins;
for(auto ins : r)
{
// Skip dead instructions
if(ins->outputs().empty())
continue;
// Find instruction with the same name
auto found_instructions = range(instructions.equal_range(ins->name()));
for(const auto& pp : found_instructions)
{
auto eq = pp.second;
if(contains(processed_ins, eq))
continue;
if(*eq != *ins)
continue;
m.replace_instruction(ins, eq);
processed_ins.emplace(ins);
std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return m.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y);
});
cse_range(m, outputs);
}
instructions.emplace(ins->name(), ins);
}
}
void eliminate_common_subexpression::apply(module& m) const { cse_range(m, iterator_for(m)); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,110 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
auto concat_op = concat_opt.get_concat(ins->get_operator());
// Look for the concat operator
if(not concat_op.has_value())
continue;
// If any inputs are builtin or context free then abort
// If any inputs are used more than once, then abort since there could
// be errors due to aliasing
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto arg) {
return arg->name().front() == '@' or
(arg->get_operator().is_context_free() and
not contains({"concat", "identity"}, arg->name())) or
arg->outputs().size() > 1;
}))
continue;
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
std::size_t axis_index = tune_axis(lens.size(), concat_op->axis, concat_op->name());
if(axis_index == 0 or
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
auto last = ins->inputs().back();
if(last->name() != concat_opt.allocate())
continue;
// Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations;
std::transform(
ins->inputs().begin(),
std::prev(ins->inputs().end()),
std::back_inserter(allocations),
[&](instruction_ref x) { return instruction::get_output_alias(x, true); });
if(std::any_of(allocations.begin(), allocations.end(), [&](auto x) {
return x->name() != concat_opt.allocate();
}))
continue;
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
auto sorted_allocations = allocations;
std::sort(sorted_allocations.begin(),
sorted_allocations.end(),
[&](instruction_ref x, instruction_ref y) {
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
});
// Move "super" allocation to the front
auto first = sorted_allocations.front();
auto super = m.move_instruction(last, first);
// Replace each allocation with a load
std::size_t offset = 0;
for(auto alloc : allocations)
{
op::load op{alloc->get_shape(), offset};
m.replace_instruction(alloc, op, {super});
offset += alloc->get_shape().bytes();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
m.replace_instruction(ins, migraphx::make_op("identity"), args);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,195 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS)
static bool try_compute_shape(instruction_ref ins,
const std::vector<shape>& inputs,
const std::vector<module_ref>& mods)
{
try
{
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
return true;
}
// if no changes for the shape, the contiguous can also be removed
if(new_shape == ins->get_shape())
{
return true;
}
auto outputs = ins->outputs();
// If the current instruction has no output, it means it is the last
// instruction and generates a non-standard output shape, and the last
// output shape is different from the case with the contiguous operator
if(outputs.empty())
{
return false;
}
for(auto output : outputs)
{
auto args = output->inputs();
std::vector<shape> input_shapes(args.size());
std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
return (arg == ins) ? new_shape : arg->get_shape();
});
if(not try_compute_shape(output, input_shapes, output->module_inputs()))
{
return false;
}
}
}
catch(const std::exception& e)
{
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "Exception: " << e.what() << std::endl;
}
return false;
}
catch(...)
{
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "Unknown exception" << std::endl;
}
return false;
}
return true;
}
static bool try_compute_shape(instruction_ref ins,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{
auto inputs = to_shapes(args);
return try_compute_shape(ins, inputs, mods);
}
template <class F>
static void remove_contiguous(const std::string& op_name, module& m, F f)
{
auto last = std::prev(m.end());
std::vector<instruction_ref> const_instructions;
for(auto ins : iterator_for(m))
{
// return instruction should have inputs with standard shape
if(ins->name() == "@return")
continue;
if(ins != last and ins->outputs().empty())
continue;
if(not f(ins))
continue;
auto args = ins->inputs();
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs())
{
if(arg->name() != op_name)
continue;
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
{
std::cout << "eliminate_contiguous: ";
m.debug_print(ins);
}
auto prev = arg->inputs().front();
// create copy of args each time as they are modified inside the loop
auto new_args = ins->inputs();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
const_instructions.push_back(arg);
}
}
}
// Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
// compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
});
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instructions[i], l);
}
}
void eliminate_contiguous::apply(module& m) const
{
// Skip contiguous from splits first
remove_contiguous(op_name, m, [](auto ins) {
if(ins->name() != "slice")
return true;
return (ins->inputs().front()->outputs().size() == 1);
});
remove_contiguous(op_name, m, [](auto) { return true; });
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,81 @@
/*
* The MIT License (MIT)
,*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_convert.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Matches with some sequence of sequential convert instructions.
* If input to the sequence of converts has the same shape as the last convert,
* replace last convert with the input.
* If input to the sequence is not the same shape as the last convert,
* replace last convert with convert from the input to the last shape.
*/
struct find_nested_convert
{
auto matcher() const { return match::name("convert")(match::arg(0)(match::name("convert"))); }
void apply(module& m, const match::matcher_result& mr) const
{
auto matched_ins = mr.result;
auto prev_convert = matched_ins->inputs().front();
auto input = prev_convert->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(matched_ins->get_shape() == input->get_shape())
{
m.replace_instruction(matched_ins, input);
}
else
{
m.replace_instruction(matched_ins, matched_ins->get_operator(), input);
}
}
};
struct find_nop_converts
{
auto matcher() const { return match::name("convert")(match::same_shape(match::arg(0))); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
m.replace_instruction(ins, ins->inputs().front());
}
};
void eliminate_convert::apply(module& m) const
{
match::find_matches(m, find_nested_convert{}, find_nop_converts{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,127 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void insert_convert_to_supported_type(module& m,
instruction_ref ins,
migraphx::shape::type_t target_type,
std::set<migraphx::shape::type_t> unsupported_types)
{
migraphx::shape::type_t orig_type = ins->get_shape().type();
std::vector<instruction_ref> inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
if(contains(unsupported_types, i->get_shape().type()))
{
return m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
i);
}
else
{
return i;
}
});
// if no change
if(inputs == ins->inputs())
return;
auto op = ins->get_operator();
auto attributes = op.attributes();
if(attributes.contains("general_data_type"))
{
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
}
auto new_ins = m.insert_instruction(ins, op, inputs);
if(orig_type == shape::tuple_type)
{
auto orig_outs = ins->outputs();
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
return out_ins->name() == "get_tuple_elem";
}))
MIGRAPHX_THROW(
"eliminate_data_type: Instruction with tuple output doesn't have all its "
"usages as get_tuple_elem instruction");
std::transform(
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
auto orig_out_type = out_ins->get_shape().type();
if(contains(unsupported_types, orig_out_type))
{
auto gte_convert = m.insert_instruction(
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
return m.replace_instruction(out_ins, gte_convert);
}
else
{
return m.replace_instruction(out_ins, gte_ins);
}
});
}
else
{
auto convert_back_ins = m.insert_instruction(
ins,
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
new_ins);
m.replace_instruction(ins, convert_back_ins);
}
}
void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {"convert",
"get_tuple_elem",
"if",
"loop",
"roialign",
"nonmaxsuppression",
"scatternd_add",
"scatternd_mul",
"scatternd_none",
"select_module"};
if(unsupported_types.empty())
return;
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
continue;
if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,70 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(module& m) const
{
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == m.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
m.replace_instruction(i, i->inputs().front());
m.move_instruction(i, m.end());
}
if(ins == last)
{
if(ins->name() == "identity")
{
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
m.move_instruction(identity_input, last);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
}
}
break;
}
}
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,114 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto pad_op = any_cast<op::pad>(input->get_operator());
auto kdims = input->get_shape().lens().size() - 2;
auto kdims_it = pad_op.pads.begin() + 2;
std::vector<size_t> pads_l(kdims_it, kdims_it + kdims);
std::vector<size_t> pads_r(kdims_it + kdims + 2, pad_op.pads.end());
auto op = ins->get_operator();
std::vector<size_t> padding(kdims * 2, 0);
std::transform(
pads_l.begin(), pads_l.end(), padding.begin(), padding.begin(), std::plus<size_t>());
std::transform(pads_r.begin(),
pads_r.end(),
padding.begin() + kdims,
padding.begin() + kdims,
std::plus<size_t>());
op.from_value({{"padding", padding}});
std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front();
m.replace_instruction(ins, op, new_inputs);
}
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == op::pooling_mode::average)
{
return;
}
auto pad_op = any_cast<op::pad>(input->get_operator());
auto kdims = input->get_shape().lens().size() - 2;
auto kdims_it = pad_op.pads.begin() + 2;
std::vector<size_t> pads_l(kdims_it, kdims_it + kdims);
std::vector<size_t> pads_r(kdims_it + kdims + 2, pad_op.pads.end());
std::transform(
pads_l.begin(), pads_l.end(), op.padding.begin(), op.padding.begin(), std::plus<size_t>());
std::transform(pads_r.begin(),
pads_r.end(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
std::plus<size_t>());
std::vector<instruction_ref> new_inputs{ins->inputs()};
new_inputs.front() = input->inputs().front();
m.replace_instruction(ins, op, new_inputs);
}
void eliminate_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
continue;
auto input = ins->inputs().front();
if(input->name() != "pad")
continue;
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,73 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <cstdlib>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool enabled(const char* name)
{
auto e = env(name);
if(e.empty())
return false;
return contains({"1", "enable", "enabled", "yes", "true"}, e.front());
}
bool disabled(const char* name)
{
auto e = env(name);
if(e.empty())
return false;
return contains({"0", "disable", "disabled", "no", "false"}, e.front());
}
std::size_t value_of(const char* name, std::size_t fallback)
{
auto e = env(name);
if(e.empty())
return fallback;
return std::stoul(e.front());
}
std::string string_value_of(const char* name, std::string fallback)
{
auto e = env(name);
if(e.empty())
return fallback;
return e.front();
}
std::vector<std::string> env(const char* name)
{
auto* p = std::getenv(name);
if(p == nullptr)
return {};
else
return {{p}};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,85 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/file_buffer.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/fileutils.hpp>
#include <fstream>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
T generic_read_file(const fs::path& filename, size_t offset = 0, size_t nbytes = 0)
{
std::ifstream is(filename, std::ios::binary | std::ios::ate);
if(not is.is_open())
MIGRAPHX_THROW("Failure opening file: " + filename);
if(nbytes == 0)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(offset, std::ios::beg);
T buffer(nbytes, 0);
if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename);
return buffer;
}
std::vector<char> read_buffer(const fs::path& filename, size_t offset, size_t nbytes)
{
return generic_read_file<std::vector<char>>(filename, offset, nbytes);
}
std::string read_string(const fs::path& filename)
{
return generic_read_file<std::string>(filename);
}
void write_string(const fs::path& filename, const std::string& buffer)
{
write_buffer(filename, buffer.data(), buffer.size());
}
void write_buffer(const fs::path& filename, const char* buffer, std::size_t size)
{
std::ofstream os(filename, std::ios::out | std::ios::binary);
os.write(buffer, size);
}
void write_buffer(const fs::path& filename, const std::vector<char>& buffer)
{
write_buffer(filename, buffer.data(), buffer.size());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,70 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fileutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef _WIN32
constexpr std::string_view executable_postfix{".exe"};
constexpr std::string_view library_prefix{""};
constexpr std::string_view library_postfix{".dll"};
constexpr std::string_view static_library_postfix{".lib"};
constexpr std::string_view object_file_postfix{".obj"};
#else
constexpr std::string_view executable_postfix{""};
constexpr std::string_view library_prefix{"lib"};
constexpr std::string_view library_postfix{".so"};
constexpr std::string_view static_library_postfix{".a"};
constexpr std::string_view object_file_postfix{".o"};
#endif
fs::path make_executable_filename(std::string_view name)
{
return std::string{name}.append(executable_postfix);
}
fs::path make_shared_object_filename(std::string_view name)
{
return std::string{library_prefix}.append(name).append(library_postfix);
}
fs::path make_object_file_filename(std::string_view name)
{
return std::string{name}.append(object_file_postfix);
}
fs::path make_static_library_filename(std::string_view name)
{
return std::string{library_prefix}.append(name).append(static_library_postfix);
}
fs::path append_extension(const fs::path& path, std::string_view ext)
{
return fs::path{path}.replace_extension(path.extension().string().append(ext));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,178 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fp8_ocp_to_fnuz.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/match/dq_helpers.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
using fp8::fp8e4m3fnuz;
std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}
struct match_fp8ocp_convert_to_fp8fnuz
{
auto matcher() const
{
auto dq1 = match::arg(0)(
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
auto dq2 = match::arg(1)(
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
return match::name(get_quantizable_op_names())(dq1, dq2);
}
static auto bit_cast_and_handle_specials(module& m,
const instruction_ref dq,
const instruction_ref x,
const instruction_ref bits_0x80_lit,
const instruction_ref bits_0x7f_lit,
const instruction_ref bits_0xff_lit,
const instruction_ref bits_0x00_lit)
{
auto x_lens = x->get_shape().lens();
auto cast_input = m.insert_instruction(
dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x);
auto mb_bits_0x80_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit);
auto mb_bits_0x7f_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit);
auto mb_bits_0xff_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit);
auto mb_zero_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit);
// negative zero in fp8e4m3fn to zero in fp8e4m3fnuz
// a == 0x80 ? 0x0 : a
auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit);
auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input);
// positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz
// (a == 0x7f or a == 0xff) ? 0x80 : a
auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit);
auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit);
auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff);
ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret);
return ret;
}
// Add the same broadcast instructions after adjusted scales or
// adjusted zero points from after the originals. Similar to
// propagate_quantized_ins in simplify_qdq.
static auto propagate_broadcasts(module& m,
const instruction_ref adj,
const instruction_ref ori,
const instruction_ref start,
const instruction_ref insert_pt)
{
auto prev_ins = start;
std::vector<instruction_ref> ins_between;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != ori)
{
ins_between.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto ret = adj;
for(auto ins : reverse_iterator_for(ins_between))
{
ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret});
}
return ret;
}
static auto cast_to_fnuz(module& m,
const instruction_ref dq,
const instruction_ref input,
const instruction_ref dq_scale,
const instruction_ref dq_zp)
{
auto x = input;
std::vector<fp8e4m3fnuz> bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())};
auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80);
std::vector<fp8e4m3fnuz> bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())};
auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f);
std::vector<fp8e4m3fnuz> bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())};
auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff);
std::vector<fp8e4m3fnuz> bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())};
auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00);
x = bit_cast_and_handle_specials(
m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);
auto adj_dq_zp = bit_cast_and_handle_specials(
m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);
// adj_scale = 2 * scale
auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}});
two_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit);
auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit);
adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq);
adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq);
m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp);
}
auto apply(module& m, const match::matcher_result& r) const
{
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fn_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1);
cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2);
}
};
} // namespace
void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,40 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fp_to_double.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_convert.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void fp_to_double::apply(module_pass_manager& mpm) const
{
mpm.run_pass(eliminate_data_type{convert_fp_types, shape::type_t::double_type});
mpm.run_pass(eliminate_convert{});
mpm.run_pass(migraphx::dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,242 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_concat.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
unsigned int get_noop_counter()
{
static unsigned int counter = 0;
return counter++;
}
struct fused_concat
{
int64_t axis = 0;
std::string name() const { return "fused_concat"; }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
// original concat can have multiple inputs. Let's say it has `n` input args.
// Each of those `n` input args are converted into pointwise modules that take atleast 1
// input parameter. Fused concat will have `n+1` module arguments. `n+1`th module is the
// post pointwise module which can take 0 or more input arguments.
if((inputs.size() + 1) < mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules inputs parameters");
auto input_iter = inputs.begin();
std::vector<shape> concat_inputs;
for(module_ref mod : range(mods.begin(), mods.end() - 1))
{
concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size();
}
module_ref post_mod = mods.back();
// post_mod has one input argument that is result of concat and will get generated from
// pre-mods internally. Therefore deduct 1 from post_mod params while asserting.
assert(input_iter + post_mod->get_parameter_names().size() - 1 == inputs.end());
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
auto mismatch_it =
std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(),
lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
});
if(mismatch_it != concat_inputs.end())
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " +
std::to_string(axis) + ": {" + to_string_range(first_shape_lens) +
"} != {" + to_string_range(mismatch_it->lens()) + "}");
std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
};
MIGRAPHX_REGISTER_OP(fused_concat);
namespace {
struct find_concat_pointwise
{
auto matcher() const
{
auto pointwise_used_once = match::name("pointwise")(match::used_once());
return match::name("concat")(match::used_once(),
match::any_of[match::inputs()](pointwise_used_once));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto concat_ins = r.result;
std::vector<instruction_ref> inputs;
size_t num_noops = 0;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
}
else
{
num_noops++;
inputs.push_back(input);
}
}
if(num_noops > std::max(size_t{1}, concat_ins->inputs().size() / 4))
{
return;
}
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
pm->add_return({x});
return pm;
});
auto* post_pm = mpm.create_module("noop:concat" + std::to_string(get_noop_counter()));
auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()});
post_pm->add_return({x});
module_inputs.push_back(post_pm);
mpm.get_module().replace_instruction(
concat_ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};
struct find_pointwise_concat_pointwise
{
auto matcher() const
{
auto pointwise = match::name("pointwise")(match::used_once());
auto concat =
match::name("concat")(match::used_once(), match::any_of[match::inputs()](pointwise));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat_ins = r.instructions["concat"];
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
ins->inputs().begin();
std::vector<instruction_ref> inputs;
for(auto input : concat_ins->inputs())
{
if(input->name() == "pointwise" and input->outputs().size() == 1)
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
else
inputs.push_back(input);
}
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise" and input->outputs().size() == 1)
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat:noop" +
std::to_string(get_noop_counter()));
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
pm->add_return({x});
return pm;
});
auto* post_pm = ins->module_inputs().front();
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
std::vector<std::string> names = rm->get_parameter_names();
std::sort(names.begin(), names.end());
auto concat_param_name = names[concat_arg];
auto concat_param = rm->get_parameter(concat_param_name);
auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
rm->replace_instruction(concat_param, param);
rm->remove_instruction(concat_param);
module_inputs.push_back(rm);
mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};
} // namespace
void fuse_concat::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_pointwise_concat_pointwise{});
mpm.run_pass(migraphx::dead_code_elimination{});
match::find_matches(mpm, find_concat_pointwise{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,262 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/param_utils.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static literal get_scalar(instruction_ref ins)
{
if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name()))
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(s.elements() != 1 and not(s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if(e.get_shape().type() == shape::bool_type)
{
r = literal{e.at<bool>()};
}
else
{
e.visit_at([&](auto x) { r = literal{x}; });
}
return r;
}
static void create_pointwise_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
if(ins->get_operator().name() == "layout")
continue;
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
pointwise_inputs.push_back(input);
param_map[input] =
pm->add_parameter(param_name(i), shape{input->get_shape().type()});
i++;
}
else
{
param_map[input] = pm->add_literal(scalar);
}
}
// Don't create pointwise module if no inputs are detected
if(pointwise_inputs.empty())
continue;
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return param_map[input]; });
auto r = pm->add_instruction(ins->get_operator(), inputs);
pm->add_return({r});
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
}
}
static module::with_inputs append_pointwise_module(instruction_ref ins, instruction_ref output)
{
assert(contains(output->inputs(), ins));
module pm = *ins->module_inputs().at(0);
module_ref xm = output->module_inputs().at(0);
auto last = std::prev(pm.end());
assert(last->name() == "@return");
assert(last->inputs().size() == 1);
assert(pm.get_parameter_names().size() == ins->inputs().size());
assert(xm->get_parameter_names().size() == output->inputs().size());
std::vector<instruction_ref> inputs = ins->inputs();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
std::unordered_map<instruction_ref, instruction_ref> input_map;
// Copy inputs to input_map
for(auto i : range(inputs.size()))
{
auto input = inputs[i];
auto param = pm.get_parameter(param_name(i));
assert(param != pm.end());
input_map[input] = param;
}
// Add the new parameter and additional inputs
for(auto i : range(output->inputs().size()))
{
auto input = output->inputs()[i];
auto param = xm->get_parameter(param_name(i));
assert(param != xm->end());
if(input == ins)
{
map_ins[param] = last->inputs().front();
input_map[input] = map_ins[param];
}
// Avoid duplicate paramter inputs
else if(contains(input_map, input))
{
map_ins[param] = input_map[input];
}
else
{
map_ins[param] =
pm.add_parameter(param_name(inputs.size()), {input->get_shape().type()});
inputs.push_back(input);
input_map[input] = map_ins[param];
}
}
pm.replace_return(pm.insert_instructions(last, xm, &map_ins));
return {std::move(pm), inputs};
}
static bool find_pointwise_modules(module_pass_manager& mpm)
{
bool changed = false;
auto last = std::prev(mpm.get_module().end());
for(auto ins : iterator_for(mpm.get_module()))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty() and ins != last)
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto input = *it;
auto fused = append_pointwise_module(input, ins);
auto name = fused.mod.name();
mpm.rename_module(name, name + ":" + ins->module_inputs().front()->name() + "-deleted");
auto* new_pm = mpm.create_module(name, std::move(fused.mod));
mpm.get_module().replace_instruction(ins, input->get_operator(), fused.inputs, {new_pm});
changed = true;
}
return changed;
}
namespace {
struct pointwise_reshape : rewrite_reshapes_base
{
static std::string name() { return "pointwise"; }
};
struct pointwise_broadcast_pointwise
{
auto matcher() const
{
auto broadcast_pointwise =
match::name("multibroadcast")(
match::used_once(),
match::args(match::name("pointwise")(match::used_once()).bind("x")))
.bind("broadcast");
return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise));
}
void apply(module& m, const match::matcher_result& r) const
{
auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"];
auto broadcast = broadcast_ins->get_operator();
auto x_inputs = x_ins->inputs();
std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), [&](auto input) {
return m.insert_instruction(broadcast_ins, broadcast, input);
});
m.replace_instruction(
broadcast_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs());
}
};
} // namespace
static void rewrite_broadcasts(module_pass_manager& mpm)
{
match::find_matches(mpm.get_module(), pointwise_broadcast_pointwise{});
mpm.run_pass(dead_code_elimination{});
}
void fuse_pointwise::apply(module_pass_manager& mpm) const
{
mpm.run_pass(eliminate_identity{});
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
{
return;
}
for(int i = 0; i < 8; i++)
{
if(enable_rewrite_reshapes)
mpm.run_pass(rewrite_reshapes<pointwise_reshape>{});
if(enable_rewrite_broadcasts)
rewrite_broadcasts(mpm);
if(not find_pointwise_modules(mpm))
break;
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,55 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#include <migraphx/fuse_pointwise_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/split_reduce.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SPLIT_REDUCE_SIZE);
static std::size_t get_split_size(std::size_t default_split)
{
std::string value = string_value_of(MIGRAPHX_SPLIT_REDUCE_SIZE{});
if(value.empty())
return default_split;
return std::stoul(value);
}
void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const
{
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = false});
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false});
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true});
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true});
mpm.run_pass(split_reduce{.split_size = get_split_size(split_size)});
mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,440 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <migraphx/param_utils.hpp>
#include <iterator>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
struct fused_reduce
{
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
if(not sm->bypass())
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
auto names = sm->get_parameter_names();
check_shapes{inputs, *this}.has(names.size()).same_ndims();
std::sort(names.begin(), names.end());
auto shapes = sm->get_parameter_shapes();
// Check dimension matches for each input
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
return shapes.at(name).lens() == input.lens();
}))
MIGRAPHX_THROW("Input dimension does not match the submodule.");
return shape::from_permutation(sm->get_output_shapes().front().type(),
sm->get_output_shapes().front().lens(),
find_permutation(inputs));
}
std::string name() const { return "fused_reduce"; }
};
MIGRAPHX_REGISTER_OP(fused_reduce);
/*
* Predicate matcher checks that input and output shapes have the same rank. This is assumed
* for broadcast instructions for these fusions.
*/
MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins)
{
auto input_shape = ins->inputs().front()->get_shape();
auto output_shape = ins->get_shape();
return input_shape.ndim() == output_shape.ndim();
}
static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
module::inserter insert = nullptr)
{
assert(ins->module_inputs().size() == 1);
return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert));
}
static void create_reduce_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("reduce", false))
continue;
if(ins->inputs().size() != 1)
continue;
auto* rm =
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass();
rm->add_return(rm->fuse({ins}));
auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction(
ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
}
}
namespace {
instruction_ref get_broadcast_output(instruction_ref broadcast)
{
if(broadcast->outputs().size() != 1)
return broadcast;
auto output = broadcast->outputs().front();
if(output->name() == "contiguous")
return get_broadcast_output(output);
return output;
}
MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return true;
if(ins->outputs().size() == 2)
{
auto is_broadcast = [](instruction_ref output) {
return contains(output->name(), "broadcast");
};
auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
if(broadcast == ins->outputs().end())
return false;
auto non_broadcast =
std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
if(non_broadcast == ins->outputs().end())
return false;
auto output = get_broadcast_output(*broadcast);
return output == *non_broadcast;
}
return false;
}
} // namespace
template <class... Ms>
static auto match_broadcast(Ms... ms)
{
return match::skip(match::name("contiguous"))(
match::name("multibroadcast")(
match::arg(0)(ms...), match::used_once(), input_output_ndim_match())
.bind("broadcast"))
.bind("final_broadcast");
}
template <class... Ms>
static auto any_input(Ms... ms)
{
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
}
bool is_valid_broadcast(const instruction_ref b, const std::vector<size_t>& reduce_axes)
{
std::vector<size_t> broadcast_axes;
auto bstrides = b->get_shape().strides();
for(size_t i = 0; i < bstrides.size(); ++i)
{
if(bstrides.at(i) == 0)
broadcast_axes.push_back(i);
}
return broadcast_axes == reduce_axes;
}
template <class M>
static auto match_broadcast_axes(M m)
{
return match::make_basic_fun_matcher(
[=](match::matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
optional<instruction_ref> result = m.match(ctx, ins);
if(contains(ctx.instructions, "broadcast"))
{
instruction_ref reduce;
if(ins->get_operator().name() == "fused_reduce")
{
reduce = ins;
}
else
{
assert(contains(ctx.instructions, "reduce"));
reduce = ctx.instructions["reduce"];
}
auto axes = reduce->get_operator().to_value().at("axes").to_vector<size_t>();
auto broadcast = ctx.instructions["broadcast"];
if(not is_valid_broadcast(broadcast, axes))
return nullopt;
}
return result;
});
}
static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
}
static void finalize_reduce_module(module_ref m)
{
eliminate_common_subexpression{}.apply(*m);
dead_code_elimination{}.apply(*m);
}
namespace {
struct find_pointwise_reduce
{
auto matcher() const
{
// fused_reduce instruction with pointwise inputs.
return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce = r.result;
auto input = r.instructions["pointwise"];
const auto* pm = input->module_inputs().front();
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = rm->fuse({input}, &map_ins).front();
map_ins[input] = rins;
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
auto fbroadcast = r.instructions["final_broadcast"];
map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front();
if(fbroadcast != broadcast)
map_ins[fbroadcast] = map_ins[broadcast];
}
// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins));
finalize_reduce_module(rm);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
}
};
struct find_reduce_pointwise
{
auto matcher() const
{
return match::name("pointwise")(match_broadcastable_input("fused_reduce", "reduce"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto pw = r.result;
auto reduce = r.instructions["reduce"];
auto input = r.instructions["input"];
const auto* pm = pw->module_inputs().front();
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":" + pm->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions
insert_module_in_submodule(rm, reduce, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}
auto out = rm->fuse({pw}, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
struct find_reduce_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"];
auto input = r.instructions["input"];
if(reduce1->get_operator() != reduce2->get_operator())
return;
const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}
auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
struct reduce_reshape : rewrite_reshapes_base
{
static std::string name() { return "fused_reduce"; }
template <class Transform>
static auto transform_op(Transform t)
{
return [=](module& m,
instruction_ref ins,
const operation& op,
const std::vector<instruction_ref>& inputs,
const std::vector<module_ref>& mod_args) {
auto new_op = t(op);
return m.insert_instruction(ins, new_op, inputs, mod_args);
};
}
template <class AxesMap>
static instruction_ref insert(module_pass_manager& mpm,
instruction_ref ins,
const std::vector<instruction_ref>& inputs,
const AxesMap& am)
{
auto op = any_cast<fused_reduce>(ins->get_operator());
std::vector<int64_t> axes;
for(auto axis : op.axes)
{
auto new_axes = am.at(axis);
axes.insert(axes.end(), new_axes.begin(), new_axes.end());
}
std::sort(axes.begin(), axes.end());
auto dims = base_dims(inputs);
auto* oldm = ins->module_inputs().front();
auto* sm = mpm.create_module(oldm->name() + "_reshape");
sm->set_bypass();
auto outs = sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
sm->add_return(outs);
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
}
static std::vector<std::size_t> base_dims(const std::vector<instruction_ref>& inputs)
{
auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) {
return i->get_shape().elements();
}));
return (*input)->get_shape().lens();
}
static std::vector<std::size_t> base_dims(instruction_ref ins)
{
return base_dims(ins->inputs());
}
};
} // namespace
void fuse_reduce::apply(module_pass_manager& mpm) const
{
if(enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}))
return;
create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 4; i++)
{
if(enable_rewrite_reshapes)
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,106 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, double value)
{
argument result;
if(s.type() == shape::tuple_type)
{
std::vector<argument> sub_args;
const auto& sub_ss = s.sub_shapes();
std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
return fill_argument(ss, value);
});
result = argument(sub_args);
}
else
{
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = fill_tensor_data<type>(s, value);
result = {s, v};
});
}
return result;
}
argument generate_argument(shape s, unsigned long seed, random_mode m)
{
argument result;
if(s.type() == shape::tuple_type)
{
const auto& sub_ss = s.sub_shapes();
std::vector<argument> sub_args;
std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
return generate_argument(ss, seed, m);
});
result = argument(sub_args);
}
else
{
s.visit_type([&](auto as) {
// we use char type to store bool type internally, so bool_type
// needs special processing to generate data
if(s.type() == shape::bool_type)
{
auto v = generate_tensor_data<bool>(s, seed, m);
result = {s, v};
}
else
{
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed, m);
result = {s, v};
}
});
}
return result;
}
literal generate_literal(shape s, unsigned long seed)
{
literal result;
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
auto v = generate_tensor_data<type>(s, seed);
result = {s, reinterpret_cast<char*>(v.get())};
});
return result;
}
// TODO: Move to literal.cpp
literal abs(literal l)
{
return transform(std::move(l), [](auto x) { return std::fabs(x); });
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,68 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/inline_module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
for(const auto& out : ins_outputs)
{
auto val = out->get_operator().to_value();
assert(val.contains("index"));
auto index = val.at("index").to<std::size_t>();
m.replace_instruction(out, mod_outputs.at(index));
}
}
void inline_module::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "if")
continue;
auto arg_cond = ins->inputs().front()->eval();
if(not arg_cond.empty())
{
bool cond = arg_cond.at<bool>();
inline_submodule(m, ins, cond);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,124 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/insert_pad.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = ins->get_operator();
auto val = op.to_value();
auto op_padding = val.at("padding").to_vector<size_t>();
// skip if shape is dynamic
if(input->get_shape().dynamic())
{
return;
}
auto kdims = input->get_shape().lens().size() - 2;
if(std::equal(op_padding.begin(),
op_padding.begin() + kdims,
op_padding.begin() + kdims,
op_padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
std::vector<size_t> pads_l(op_padding.begin(), op_padding.begin() + kdims);
std::vector<size_t> pads_r(op_padding.begin() + kdims, op_padding.end());
op_padding = std::vector<size_t>(kdims * 2, 0);
op.from_value({{"padding", op_padding}});
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
auto pad_op = m.insert_instruction(ins, op::pad{padding}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{
auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == op::pooling_mode::average)
{
return;
}
auto kdims = input->get_shape().ndim() - 2;
if(std::equal(op.padding.begin(),
op.padding.begin() + kdims,
op.padding.begin() + kdims,
op.padding.end()))
return;
std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0);
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
float pad_val = 0.0f; // for the lpnorm
if(op.mode == op::pooling_mode::max)
{
// maxpool uses lowest value for padding
pad_val = std::numeric_limits<float>::lowest();
}
auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input);
auto new_inputs = ins->inputs();
new_inputs.front() = pad_op;
m.replace_instruction(ins, op, new_inputs);
}
void insert_pad::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
const std::string& op_name = ins->name();
if(not contains(ops, op_name))
continue;
auto input = ins->inputs().front();
if(op_name == "convolution" or op_name == "im2col")
update_op(input, ins, m);
else if(op_name == "pooling")
update_pooling(input, ins, m);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,564 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/output_iterator.hpp>
#include <queue>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
auto equal_to(const T& x)
{
return [&](const T& y) { return std::equal_to<T>{}(x, y); };
}
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
instruction::instruction(operation o,
shape r,
std::vector<instruction_ref> args,
std::vector<module_ref> modules)
: op(std::move(o)),
result(std::move(r)),
arguments(std::move(args)),
module_args(std::move(modules))
{
}
instruction::instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
{
}
struct replace_shape_order
{
instruction_ref start;
std::size_t location(instruction_ref x) const { return std::distance(start, x); }
bool operator()(instruction_ref x, instruction_ref y) const
{
return location(x) > location(y);
}
};
void instruction::replace(const shape& r)
{
if(r != result)
{
result = r;
if(output.empty())
{
return;
}
auto start = std::find_if(output.front()->inputs().begin(),
output.front()->inputs().end(),
[&](instruction_ref x) { return this == as_address(x); });
assert(as_address(*start) == this);
std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(
output.begin(), output.end(), replace_shape_order{*start});
while(not q.empty())
{
instruction_ref ins = q.top();
q.pop();
assert(ins->name() == "@return" or ins->name().front() != '@');
shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
if(new_r != ins->result)
{
ins->result = new_r;
std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q));
}
}
}
}
void instruction::replace(operation o)
{
normalized = false;
op = std::move(o);
recompute_shape();
}
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
void instruction::clear_arguments()
{
for(auto&& arg : arguments)
{
arg->remove_output(*this);
}
arguments.clear();
module_args.clear();
}
bool operator==(const instruction& i, instruction_ref ref)
{
return std::addressof(i) == std::addressof(*ref);
}
bool instruction::valid(instruction_ref start, bool check_order) const
{
return valid() and std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
bool ret = self != i->outputs().end();
if(check_order)
{
// check arguments for this instruction before this instruction
ret = ret and (std::distance(start, i) < std::distance(start, *self));
}
return ret;
});
}
bool instruction::valid() const
{
shape computed;
if(op.name() == "@literal")
{
computed = lit.get_shape();
}
else if(op.name() == "@param")
{
computed = result;
}
else
{
try
{
computed = compute_shape(op, arguments, module_args);
}
catch(migraphx::exception&)
{
return false;
}
}
return (result == computed) and
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
});
}
shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
const operation& instruction::get_operator() const { return op; }
std::string instruction::name() const { return op.name(); }
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(const instruction& x, const instruction& y)
{
if(not std::equal(x.arguments.begin(),
x.arguments.end(),
y.arguments.begin(),
y.arguments.end(),
std::equal_to<instruction_ref>{}))
return false;
if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
return true;
}
bool operator!=(const instruction& x, const instruction& y) { return not(x == y); }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return not(i == ref); }
bool operator!=(instruction_ref ref, const instruction& i) { return not(i == ref); }
void instruction::add_output(instruction_ref ins)
{
if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
output.push_back(ins);
}
void instruction::backreference(instruction_ref ref)
{
for(auto&& arg : ref->inputs())
arg->add_output(ref);
}
void instruction::replace_argument(instruction_ref ins,
instruction_ref old,
instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
{
ins->replace_mod_argument(old, new_mod);
backreference(ins);
ins->recompute_shape();
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args)
{
ins->replace(std::move(o), r, std::move(args));
backreference(ins);
}
void instruction::replace(instruction_ref ins,
operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
ins->replace(std::move(o), r, std::move(args), std::move(module_args));
backreference(ins);
}
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
normalized = false;
op = std::move(o);
replace(r);
replace(std::move(args));
}
void instruction::replace(operation o,
const shape& r,
std::vector<instruction_ref> args,
std::vector<module_ref> mdl_args)
{
op = std::move(o);
replace(r);
replace(std::move(args), std::move(mdl_args));
}
void instruction::replace_refs(
instruction_ref ins,
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
const std::unordered_map<module_ref, module_ref>& map_mods)
{
const auto& args = ins->inputs();
for(const auto& arg : args)
{
if(contains(map_insts, arg))
{
instruction::replace_argument(ins, arg, map_insts.at(arg));
}
}
const auto& module_args = ins->module_inputs();
if(module_args.empty())
return;
for(const auto& mod : module_args)
{
if(contains(map_mods, mod))
{
instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
}
}
}
void instruction::replace(std::vector<instruction_ref> args)
{
clear_arguments();
arguments = std::move(args);
}
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
{
clear_arguments();
arguments = std::move(args);
module_args = std::move(mdl_args);
}
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{
assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
old->remove_output(*this);
}
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
{
assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
std::replace(module_args.begin(), module_args.end(), old, new_mod);
}
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const
{
if(op.name() == "@literal")
{
return true;
}
else if(is_context_free(op))
{
return std::all_of(
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
}
else
{
return false;
}
}
argument instruction::eval(bool check_eval) const
{
if(op.name() == "@literal")
{
return this->get_literal().get_argument();
}
if(is_context_free(op))
{
if(check_eval and not this->can_eval())
return {};
std::vector<argument> args;
std::transform(this->inputs().begin(),
this->inputs().end(),
std::back_inserter(args),
[](auto arg) { return arg->eval(false); });
return normalized_operator().compute(result, args);
}
return {};
}
void instruction::finalize(context& ctx)
{
if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
void instruction::print(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(not ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
os << delim << arg_name;
delim = ',';
}
os << ")";
}
// print module inputs
if(not ins->module_inputs().empty())
{
std::string delim = ", [";
for(const const_module_ref& mod_arg : ins->module_inputs())
{
os << delim << mod_arg->name();
delim = ", ";
}
os << "]";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
// print tid
if(ins->target_id != 0)
os << ", target_id=" << ins->target_id;
}
static void debug_name(std::ostream& os, const instruction& ins)
{
if(ins.name() == "@literal")
{
os << "@literal";
if(ins.get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins.get_literal() << "}";
}
else
{
os << ins.get_operator();
}
}
void instruction::debug_print() const
{
debug_name(std::cout, *this);
std::string delim = "(";
for(auto arg : this->inputs())
{
std::cout << delim;
debug_name(std::cout, *arg);
delim = ", ";
}
if(not this->inputs().empty())
std::cout << ")";
std::cout << " -> " << this->get_shape() << std::endl;
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0)
return ins;
if(shallow)
return ins->inputs().at(i);
return get_output_alias(ins->inputs().at(i));
}
void instruction::set_normalized(bool value) { normalized = value; }
bool instruction::is_normalized() const { return normalized; }
bool instruction::need_normalization() const
{
return this->get_operator().need_normalization() and not normalized;
}
operation instruction::normalized_operator() const
{
operation o = this->get_operator();
if(this->need_normalization())
{
auto s = this->inputs().front()->get_shape();
if(not normalize_attributes(o, s))
return this->get_operator();
}
return o;
}
std::size_t instruction::get_target_id() const { return target_id; }
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
return op.compute_shape(to_shapes(args));
}
shape compute_shape(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods)
{
if(mods.empty())
{
return op.compute_shape(to_shapes(args));
}
else
{
return op.compute_shape(to_shapes(args), mods);
}
}
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs)
{
shape new_shape;
try
{
new_shape = op.compute_shape(inputs);
}
catch(...)
{
return {};
}
return {new_shape};
}
migraphx::instruction* as_address(const instruction_ref& ins) noexcept
{
return std::addressof(*ins);
}
bool reaches(instruction_ref start, instruction_ref end)
{
std::unordered_set<instruction_ref> visited;
return fix<bool>([&](auto self, auto ins) -> bool {
if(ins == start)
return true;
if(not visited.insert(ins).second)
return false;
return std::any_of(ins->inputs().begin(), ins->inputs().end(), self);
})(end);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,177 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
#include <nlohmann/json.hpp>
#include <migraphx/json.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using json = nlohmann::json;
void value_to_json(const value& val, json& j);
migraphx::value value_from_json(const json& j);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace nlohmann {
template <>
struct adl_serializer<migraphx::value>
{
static void to_json(json& j, const migraphx::value& val) { migraphx::value_to_json(val, j); }
static void from_json(const json& j, migraphx::value& val)
{
val = migraphx::value_from_json(j);
}
};
} // namespace nlohmann
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using json = nlohmann::json;
template <class T>
void value_to_json(const T& x, json& j)
{
j = x;
}
void value_to_json(const value::binary& x, json& j)
{
j = json::object();
j["bytes"] = std::vector<int>(x.begin(), x.end());
}
void value_to_json(const std::vector<value>& x, json& j)
{
for(const auto& v : x)
{
if(v.get_key().empty())
{
j.push_back(v);
}
else
{
j[v.get_key()] = v.without_key();
}
}
}
void value_to_json(std::nullptr_t&, json& j) { j = {}; }
void value_to_json(const value& val, json& j)
{
if(val.is_array())
{
j = json::array();
}
if(val.is_object())
{
j = json::object();
}
val.visit([&](auto v) { value_to_json(v, j); });
}
migraphx::value value_from_json(const json& j)
{
migraphx::value val;
json::value_t type = j.type();
switch(type)
{
case json::value_t::null: val = nullptr; break;
case json::value_t::boolean: val = j.get<bool>(); break;
case json::value_t::number_float: val = j.get<double>(); break;
case json::value_t::number_integer: val = j.get<int64_t>(); break;
case json::value_t::number_unsigned: val = j.get<uint64_t>(); break;
case json::value_t::string: val = j.get<std::string>(); break;
case json::value_t::array:
val = migraphx::value::array{};
std::transform(j.begin(), j.end(), std::back_inserter(val), [&](const json& jj) {
return jj.get<value>();
});
break;
case json::value_t::object:
if(j.contains("bytes") and j.size() == 1)
{
val = migraphx::value::binary{j["bytes"].get<std::vector<std::uint8_t>>()};
}
else
{
val = migraphx::value::object{};
for(const auto& item : j.items())
{
const auto& key = item.key();
const json& jv = item.value();
val[key] = jv.get<value>();
}
}
break;
case json::value_t::binary: MIGRAPHX_THROW("Convert JSON to Value: binary type not supported!");
case json::value_t::discarded:
MIGRAPHX_THROW("Convert JSON to Value: discarded type not supported!");
}
return val;
}
std::string to_json_string(const value& val)
{
json j = val;
return j.dump();
}
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size)
{
json j = json::parse(str, str + size);
return j.get<value>();
}
migraphx::value from_json_string(const std::string& str)
{
json j = json::parse(str);
return j.get<value>();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,132 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_convolution.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
{
if(lc.channels_last)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(perm.begin() + 1, perm.end() - 1, 2);
perm.back() = 1;
return perm;
}
return find_permutation(ins->inputs().front()->get_shape());
}
bool skip_layout(const shape& s)
{
return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type;
}
void preserve_output_layout(module& m)
{
auto last = std::prev(m.end());
if(last->name() == "@return")
{
std::vector<instruction_ref> outputs;
std::transform(last->inputs().begin(),
last->inputs().end(),
std::back_inserter(outputs),
[&](instruction_ref ins) {
if(skip_layout(ins->get_shape()))
return ins;
auto permutation = find_permutation(ins->get_shape());
return m.insert_instruction(
last, make_op("layout", {{"permutation", permutation}}), ins);
});
m.replace_return(outputs);
}
else if(not skip_layout(last->get_shape()))
{
auto permutation = find_permutation(last->get_shape());
m.add_instruction(make_op("layout", {{"permutation", permutation}}), last);
}
}
void transform_convolutions(module& m, const layout_convolution& lc)
{
for(auto ins : iterator_for(m))
{
if(not contains({"convolution", "quant_convolution"}, ins->name()))
continue;
if(ins->get_shape().dynamic())
continue;
if(ins->get_shape().lens().size() != 4)
continue;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
auto perm = get_permutation(ins, lc);
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
void remove_layout(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
if(ins->get_shape() != ins->inputs().front()->get_shape())
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
} // namespace
void layout_convolution::apply(module_pass_manager& mpm) const
{
preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module(), *this);
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,71 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/lexing.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::function<const char*(const char*, const char*)> lex_equal(const std::string& s)
{
return [=](const char* start, const char* end) {
auto n = end - start;
if(n < s.size())
return start;
if(std::equal(start, start + s.size(), s.data()))
return start + s.size();
return start;
};
}
std::vector<std::string_view>
tokenize(const char* start, const char* end, const std::vector<lexer>& lexers)
{
std::vector<std::string_view> result;
while(start != end)
{
bool error = true;
for(const auto& l : lexers)
{
const auto* next = l(start, end);
if(next != start)
{
result.emplace_back(start, next - start);
start = next;
error = false;
break;
}
}
if(error)
{
MIGRAPHX_THROW("TOKENIZE: no token found!");
}
}
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,105 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp>
#include <fstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
program load(const std::string& filename, const file_options& options)
{
return load_buffer(read_buffer(filename), options);
}
program load_buffer(const std::vector<char>& buffer, const file_options& options)
{
return load_buffer(buffer.data(), buffer.size(), options);
}
program load_buffer(const char* buffer, std::size_t size, const file_options& options)
{
program p;
if(options.format == "msgpack")
{
p.from_value(from_msgpack(buffer, size));
}
else if(options.format == "json")
{
p.from_value(from_json_string(buffer, size));
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
}
return p;
}
void save(const program& p, const std::string& filename, const file_options& options)
{
write_buffer(filename, save_buffer(p, options));
}
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
void print_miopen_warning(const program& p)
{
auto mods = p.get_modules();
if(std::any_of(mods.begin(), mods.end(), [](const auto* m) {
return std::any_of(m->begin(), m->end(), [](const instruction& i) {
return i.name() == "gpu::miopen_fusion";
});
}))
{
std::cout << "[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
"are not stored inside serialized MIGraphX program. Consider serializing with "
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
<< std::endl;
;
}
}
std::vector<char> save_buffer(const program& p, const file_options& options)
{
value v = p.to_value();
print_miopen_warning(p);
std::vector<char> buffer;
if(options.format == "msgpack")
{
buffer = to_msgpack(v);
}
else if(options.format == "json")
{
std::string s = to_json_string(v);
buffer = std::vector<char>(s.begin(), s.end());
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
}
return buffer;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,73 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{
auto op = load_op(name);
// Merge values
value w = op.to_value();
for_each([&](const auto& key, const auto& x) {
if(not w.contains(key))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w);
return op;
}
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
operation make_json_op(const std::string& name, const std::string& s)
{
return make_op(name, from_json_string(convert_to_json(s)));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,372 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/liveness.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING);
using instruction_set = std::unordered_set<instruction_ref>;
using instruction_set_map = std::unordered_map<instruction_ref, instruction_set>;
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map build_conflict_table(const module& m, std::string allocation_op)
{
instruction_set_map conflict_table;
liveness(m, [&](auto ins, auto live_set) {
// Skip variables that aren't allocations
if(ins->name() != allocation_op)
return;
// Skip zero allocations
if(ins->get_shape().bytes() == 0)
return;
conflict_table[ins];
for(auto i : live_set)
{
if(i == ins)
continue;
// Skip variables that aren't allocations
if(i->name() != allocation_op)
continue;
// Skip zero allocations
if(i->get_shape().bytes() == 0)
continue;
conflict_table[i].insert(ins);
conflict_table[ins].insert(i);
}
});
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [](auto&& pp) {
return pp.second.count(pp.first) == 0;
}));
return conflict_table;
}
// Check if intervals overlap
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
{
return std::max(x.first, y.first) < std::min(x.second, y.second);
}
struct allocation_segment
{
using segment = std::pair<std::size_t, std::size_t>;
std::unordered_map<instruction_ref, segment> ins2segment;
const segment* add_segment(instruction_ref ins, segment s) { return &(ins2segment[ins] = s); }
const segment* get_segment(instruction_ref ins) const
{
auto it = ins2segment.find(ins);
if(it == ins2segment.end())
return nullptr;
return &it->second;
}
// Remove segment for an instruction
void remove(instruction_ref ins)
{
auto it = ins2segment.find(ins);
if(it != ins2segment.end())
{
ins2segment.erase(it);
}
}
std::size_t max()
{
std::size_t n = 0;
for(auto&& pp : ins2segment)
{
auto seg = pp.second;
n = std::max(n, seg.second);
}
return n;
}
template <class Iterator>
static bool overlaps(Iterator first, Iterator last, const segment& s)
{
return std::any_of(first, last, [&](auto&& t) { return is_overlap(s, t); });
}
static bool overlaps(const std::set<segment>& segments, const segment& s)
{
return overlaps(segments.begin(), segments.end(), s);
}
static auto find_gap(const std::set<segment>& segments, std::size_t n)
{
std::size_t max_end = 0;
return std::adjacent_find(segments.begin(), segments.end(), [&](segment x, segment y) {
if(x.second < max_end)
return false;
max_end = x.second;
if(is_overlap(x, y))
return false;
assert(y.first >= x.second);
auto k = y.first - x.second;
return (k >= n);
});
}
static std::size_t max_type_size(const shape& s)
{
return std::accumulate(
s.sub_shapes().begin(),
s.sub_shapes().end(),
s.type_size(),
[](auto size, const auto& sub) { return std::max(size, max_type_size(sub)); });
}
static std::size_t compute_alignment(instruction_ref ins)
{
auto alignment = max_type_size(ins->get_shape());
// A rough estimate for the total number of elements
auto n = ins->get_shape().bytes() / alignment;
// Check for vectorized alignment
if(n > 4)
{
auto d = n % 4;
if(d == 0)
alignment *= 4;
if(d == 2)
alignment *= 2;
}
return alignment;
}
static segment
next_segment(std::set<segment>& segments, instruction_ref ins, std::size_t alignment)
{
assert(ins->get_shape().bytes() > 0);
// Compute alignment
std::size_t n = 1 + (ins->get_shape().bytes() - 1) / alignment;
assert(n > 0);
std::size_t start = 0;
// Insert at end if it cant fit at the begining
if(segments.empty() or segments.begin()->first <= n)
{
auto it = find_gap(segments, n);
if(it == segments.end())
it = std::max_element(segments.begin(), segments.end(), [&](segment x, segment y) {
return x.second < y.second;
});
if(it != segments.end())
start = it->second;
}
auto s = segment{start, start + n};
assert(not overlaps(segments, s));
segments.insert(s);
return s;
}
static std::unordered_map<instruction_ref, int>
create_allocation_index(const module& m, const instruction_set_map& conflict_table)
{
std::unordered_map<instruction_ref, int> result;
int i = 0;
for(auto ins : iterator_for(m))
{
if(not contains(conflict_table, ins))
continue;
result[ins] = i++;
}
return result;
}
// Build the allocation_color class from the conflict_table
static allocation_segment
build(const module& m, const instruction_set_map& conflict_table, std::size_t alignment)
{
allocation_segment as{};
std::vector<instruction_ref> conflict_queue;
// Add all allocations to the conflict_queue
std::transform(conflict_table.begin(),
conflict_table.end(),
std::back_inserter(conflict_queue),
[](auto&& pp) { return pp.first; });
auto alloc_index = create_allocation_index(m, conflict_table);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std::sort(conflict_queue.begin(), conflict_queue.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(
conflict_table.at(x).size(), x->get_shape().bytes(), alloc_index.at(x));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for(auto parent : conflict_queue)
{
// Sort children by size
std::vector<instruction_ref> children(conflict_table.at(parent).begin(),
conflict_table.at(parent).end());
std::sort(children.begin(), children.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(x->get_shape().bytes(), alloc_index.at(x));
}));
assert(not contains(children, parent));
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
assert(as.get_segment(parent) == nullptr);
as.add_segment(parent, next_segment(segments, parent, alignment));
}
// Reduce the number of segments
for(std::size_t n = 0; n < 3; n++)
{
for(auto parent : conflict_queue)
{
auto children = conflict_table.at(parent);
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
// Get the segment for the parent
const auto* parent_segment = as.get_segment(parent);
assert(parent_segment != nullptr);
auto s = next_segment(segments, parent, alignment);
if(s != *parent_segment and s.second <= as.max())
{
as.add_segment(parent, s);
}
}
}
return as;
}
};
static std::size_t find_max_alignment(const module& m, const std::string& allocation_op)
{
std::size_t alignment = 1;
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
alignment = std::max(allocation_segment::compute_alignment(ins), alignment);
}
return alignment;
}
void memory_coloring::apply(module& m) const
{
const std::size_t alignment = find_max_alignment(m, allocation_op);
auto conflict_table = build_conflict_table(m, allocation_op);
auto as = allocation_segment::build(m, conflict_table, alignment);
// All allocations should have a segment
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
return as.get_segment(pp.first);
}));
// Adjacent allocations should not have overlapping segments
assert(std::none_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
auto* x = as.get_segment(pp.first);
return std::any_of(pp.second.begin(), pp.second.end(), [&](auto ins) {
auto* y = as.get_segment(ins);
assert(x and y);
return is_overlap(*x, *y);
});
}));
// Print out segments
if(enabled(MIGRAPHX_DEBUG_MEMORY_COLORING{}))
{
for(auto&& pp : conflict_table)
{
std::cout << "------- conflict -------" << std::endl;
auto s1 = as.ins2segment.at(pp.first);
std::cout << s1.first << ", " << s1.second << ": ";
m.debug_print(pp.first);
for(auto ins : pp.second)
{
auto s2 = as.ins2segment.at(ins);
std::cout << s2.first << ", " << s2.second << ": ";
m.debug_print(ins);
}
}
}
// Total memory
std::size_t n = as.max() * alignment;
// Replace allocations
auto mem = m.add_parameter("scratch", shape{shape::int8_type, {n}});
for(auto&& [ins, seg] : as.ins2segment)
{
assert(ins->name() == allocation_op);
auto s = ins->get_shape();
std::size_t offset = seg.first * alignment;
assert(offset < n);
m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
}
// Replace zero allocation
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
assert(ins->get_shape().bytes() == 0);
m.replace_instruction(
ins, make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", 0}}), mem);
}
// Remove scratch parameter if its not used
if(mem->outputs().empty())
{
m.remove_instruction(mem);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,256 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/msgpack.hpp>
#include <migraphx/serialize.hpp>
#include <msgpack.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Leave an extra byte for error checking
constexpr std::size_t msgpack_size_limit = std::numeric_limits<uint32_t>::max() - 1;
template <class Range>
std::size_t msgpack_chunk_size(const Range& r)
{
return 1 + (r.size() - 1) / msgpack_size_limit;
}
template <class Iterator, class F>
void msgpack_chunk_for_each(Iterator start, Iterator last, F f)
{
while(std::distance(start, last) > msgpack_size_limit)
{
auto next = std::next(start, msgpack_size_limit);
f(start, next);
start = next;
}
f(start, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace msgpack {
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
namespace adaptor {
template <>
struct convert<migraphx::value>
{
const msgpack::object& operator()(const msgpack::object& o, migraphx::value& v) const
{
switch(o.type)
{
case msgpack::type::NIL: {
v = nullptr;
break;
}
case msgpack::type::BOOLEAN: {
v = o.as<bool>();
break;
}
case msgpack::type::POSITIVE_INTEGER: {
v = o.as<std::uint64_t>();
break;
}
case msgpack::type::NEGATIVE_INTEGER: {
v = o.as<std::int64_t>();
break;
}
case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64: {
v = o.as<double>();
break;
}
case msgpack::type::STR: {
v = o.as<std::string>();
break;
}
case msgpack::type::BIN: {
// For backwards compatibility
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY: {
if(o.via.array.size != 0 and o.via.array.ptr->type == msgpack::type::BIN)
{
auto bin = migraphx::value::binary{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) {
bin.insert(bin.end(), so.via.bin.ptr, so.via.bin.ptr + so.via.bin.size);
});
v = bin;
}
else
{
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r;
}
break;
}
case msgpack::type::MAP: {
migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size,
[&](const msgpack::object_kv& p) {
r[p.key.as<std::string>()] = p.val.as<migraphx::value>();
});
v = r;
break;
}
case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
}
}
return o;
}
};
template <>
struct pack<migraphx::value::binary>
{
template <class Stream>
packer<Stream>& operator()(msgpack::packer<Stream>& o,
const migraphx::value::binary& x) const
{
const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size();
o.pack_array(migraphx::msgpack_chunk_size(x));
migraphx::msgpack_chunk_for_each(
data, data + size, [&](const char* start, const char* last) {
o.pack_bin(last - start);
o.pack_bin_body(start, last - start);
});
return o;
}
};
template <>
struct pack<migraphx::value>
{
template <class Stream>
void write(msgpack::packer<Stream>& o, const std::nullptr_t&) const
{
o.pack_nil();
}
template <class Stream, class T>
void write(msgpack::packer<Stream>& o, const T& x) const
{
o.pack(x);
}
template <class Stream>
void write(msgpack::packer<Stream>& o, const std::vector<migraphx::value>& v) const
{
if(v.empty())
{
o.pack_array(0);
return;
}
if(v.size() > migraphx::msgpack_size_limit)
MIGRAPHX_THROW("Size is too large for msgpack");
if(not v.front().get_key().empty())
{
o.pack_map(v.size());
for(auto&& x : v)
{
o.pack(x.get_key());
o.pack(x.without_key());
}
}
else
{
o.pack_array(v.size());
for(auto&& x : v)
{
o.pack(x);
}
}
}
template <class Stream>
packer<Stream>& operator()(msgpack::packer<Stream>& o, const migraphx::value& v) const
{
v.visit_value([&](auto&& x) { this->write(o, x); });
return o;
}
};
} // namespace adaptor
} // MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
} // namespace msgpack
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct vector_stream
{
std::vector<char> buffer{};
vector_stream& write(const char* b, std::size_t n)
{
buffer.insert(buffer.end(), b, b + n);
return *this;
}
};
struct writer_stream
{
std::function<void(const char*, std::size_t)> writer;
writer_stream& write(const char* b, std::size_t n)
{
writer(b, n);
return *this;
}
};
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer)
{
writer_stream ws{std::move(writer)};
msgpack::pack(ws, v);
}
std::vector<char> to_msgpack(const value& v)
{
vector_stream vs;
msgpack::pack(vs, v);
return vs.buffer;
}
value from_msgpack(const char* buffer, std::size_t size)
{
msgpack::object_handle oh = msgpack::unpack(buffer, size);
return oh.get().as<value>();
}
value from_msgpack(const std::vector<char>& buffer)
{
return from_msgpack(buffer.data(), buffer.size());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,278 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/netron_output.hpp>
#include <migraphx/json.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/base64.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
// from https://onnx.ai/onnx/intro/concepts.html
int get_onnx_type(shape::type_t s_type)
{
switch(s_type)
{
case shape::float_type: return 1;
case shape::uint8_type: return 2;
case shape::int8_type: return 3;
case shape::uint16_type: return 4;
case shape::int16_type: return 5;
case shape::int32_type: return 6;
case shape::int64_type: return 7;
case shape::bool_type: return 9;
case shape::half_type: return 10;
case shape::double_type: return 11;
case shape::uint32_type: return 12;
case shape::uint64_type: return 13;
case shape::bf16_type: return 16;
case shape::fp8e4m3fn_type: return 17;
case shape::fp8e4m3fnuz_type: return 18;
case shape::fp8e5m2_type: return 19;
case shape::fp8e5m2fnuz_type: return 20;
case shape::tuple_type: return 0;
}
MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported");
}
auto make_attribute(const migraphx::value& val)
{
value attribute = value(std::unordered_map<std::string, value>());
attribute["name"] = val.get_key();
auto val_string = val.to<std::string>();
std::string sub_str = val.get_key() + ":";
auto find_key = val_string.find(sub_str);
if(find_key != std::string::npos)
{
val_string = val_string.substr(find_key + sub_str.length() + 1);
}
// TODO: doesn't work for some reason with Netron now
// attribute["s"] = base64_encode(val_string);
// attribute["type"] = "STRING";
attribute["docString"] = val_string;
return attribute;
}
/// Returns a value with the JSON structure needed for a node
auto make_onnx_json_node(instruction_ref ins,
std::unordered_map<instruction_ref, std::string> ins_uids)
{
value node;
// TODO add support for module inputs
value input_arr = value({});
for(instruction_ref input_ins : ins->inputs())
{
auto name = input_ins->name();
if(name == "@literal" or name == "@param")
{
input_arr.push_back(ins_uids.at(input_ins));
}
// TODO make a better process for handling nodes to ignore
else if(name.find("hip::hip_allocate_memory") != std::string::npos)
{
continue;
}
else
{
input_arr.push_back(ins_uids.at(input_ins) + "->" + ins_uids.at(ins));
}
}
value output_arr = value({});
for(instruction_ref output_ins : ins->outputs())
{
if(output_ins->name() == "@return")
{
output_arr.push_back(ins_uids.at(output_ins));
}
else
{
output_arr.push_back(ins_uids.at(ins) + "->" + ins_uids.at(output_ins));
}
}
node["input"] = input_arr;
node["output"] = output_arr;
node["name"] = ins_uids.at(ins);
node["opType"] = ins->name();
value op_attribute_arr = value({});
auto op_value = ins->get_operator().to_value();
std::for_each(op_value.begin(), op_value.end(), [&](auto v) {
const std::string& attr_key = v.get_key();
if(v.is_binary() or attr_key == "code_object")
{
return;
}
else if(attr_key == "symbol_name" or attr_key == "name")
{
node["opType"] = migraphx::from_value<std::string>(v);
}
else
{
op_attribute_arr.push_back(make_attribute(v));
}
});
node["attribute"] = op_attribute_arr;
return node;
}
// ONNX graph constant data called "initializer"
auto make_onnx_json_literal(instruction_ref ins,
std::unordered_map<instruction_ref, std::string> ins_uids)
{
value lit;
lit["dims"] = ins->get_shape().lens();
lit["dataType"] = get_onnx_type(ins->get_shape().type());
lit["name"] = ins_uids.at(ins);
// ignoring literal data, setting to "NULL" in base64
lit["rawData"] = "TlVMTA==";
return lit;
}
// TODO handle dynamic shapes
// TODO handle subshapes
auto make_onnx_json_shape(const shape& s)
{
value ret;
value dim = value({});
for(std::size_t len : s.lens())
{
// cppcheck-suppress useStlAlgorithm
dim.push_back({{"dimValue", len}});
}
ret["dim"] = dim;
return ret;
}
// ONNX graph edges called "valueType"
auto make_onnx_json_edge(instruction_ref ins,
instruction_ref out_ins,
std::unordered_map<instruction_ref, std::string> ins_uids)
{
value ret;
shape ins_shape = ins->get_shape();
ret["name"] = ins_uids.at(ins) + "->" + ins_uids.at(out_ins);
value type = {{"tensorType",
{{"elemType", get_onnx_type(ins_shape.type())},
{"shape", make_onnx_json_shape(ins_shape)}}}};
ret["type"] = type;
return ret;
}
auto make_onnx_json_in_out(instruction_ref ins,
std::unordered_map<instruction_ref, std::string> ins_uids)
{
value ret;
shape ins_shape = ins->get_shape();
ret["name"] = ins_uids.at(ins);
value type = {{"tensorType",
{{"elemType", get_onnx_type(ins_shape.type())},
{"shape", make_onnx_json_shape(ins_shape)}}}};
ret["type"] = type;
return ret;
}
std::unordered_map<instruction_ref, std::string> make_ins_uids(const module& mod)
{
std::unordered_map<instruction_ref, std::string> ret;
int count = 0;
for(auto ins : iterator_for(mod))
{
std::string var_name;
var_name = mod.name() + ":";
var_name.append(ins->name() + ":");
var_name.append("@" + std::to_string(count));
count++;
ret.emplace(ins, var_name);
}
return ret;
}
value make_graph(const module* mod)
{
value graph = {{"node", value({})},
{"initializer", value({})},
{"input", value({})},
{"output", value({})},
{"valueInfo", value({})}};
auto ins_uids = make_ins_uids(*mod);
for(auto ins = mod->begin(); ins != mod->end(); ++ins)
{
const auto& name = ins->name();
if(name == "@literal")
{
graph["initializer"].push_back(make_onnx_json_literal(ins, ins_uids));
}
else if(name == "@param")
{
graph["input"].push_back(make_onnx_json_in_out(ins, ins_uids));
}
else if(name == "@return")
{
graph["output"].push_back(make_onnx_json_in_out(ins, ins_uids));
}
else if(name.find("hip::hip_allocate_memory") != std::string::npos)
{
continue;
}
else
{
graph["node"].push_back(make_onnx_json_node(ins, ins_uids));
const auto& outputs = ins->outputs();
for(auto out_ins : outputs)
{
if(out_ins->name() != "@return")
{
graph["valueInfo"].push_back(make_onnx_json_edge(ins, out_ins, ins_uids));
}
}
}
}
return graph;
}
} // namespace
std::string make_netron_output(const program& prog)
{
value output;
auto prog_value = prog.to_value();
// ONNX IR version 6
// TODO: investigate sure how this affects things
output["irVersion"] = 6;
output["producerName"] = "AMDMIGraphX";
output["producerVersion"] = prog_value.at("migraphx_version").to<std::string>();
for(auto& mod : prog.get_modules())
{
auto graph = make_graph(mod);
output["graph"] = graph;
}
return to_pretty_json_string(output, 4);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,283 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Parameters:
* vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min};
* input_shape: input shape passed when calling
* normalize_attributes(op&, input_shape)
*
* See normalize_attribute.hpp for explaining the options.
*/
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const shape& input_shape,
Message m)
{
std::vector<int64_t> result(vec);
if(result.empty())
{
return result;
};
int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output))
{
n_rank = n_rank + vec.size();
}
std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len))
{
if(input_shape.dynamic())
{
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
return not input_shape.dyn_dims().at(ax).is_fixed();
}))
{
return vec;
}
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
return input_shape.dyn_dims().at(i).max;
});
}
else
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
return input_shape.lens().at(i);
});
}
}
if(contains(vec_attrs, op::normalize_attribute::clip_max))
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v > mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v >= mv ? mv - 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW(m() + "value out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW(m() + "value out of range!");
}
}
}
std::vector<int64_t> min_vals = max_vals;
std::transform(min_vals.begin(), min_vals.end(), min_vals.begin(), [](auto v) { return -v; });
if(contains(vec_attrs, op::normalize_attribute::clip_min))
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv + 1 ? mv + 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
}
std::transform(
result.begin(), result.end(), max_vals.begin(), result.begin(), [](auto v, auto mv) {
return v < 0 ? v + mv : v;
});
return result;
}
auto tune_pad_attribute(const value& val)
{
std::vector<size_t> vec_attrs = val.to_vector<size_t>();
std::vector<size_t> result(vec_attrs.begin(), vec_attrs.end());
std::copy(vec_attrs.begin(), vec_attrs.end(), std::back_inserter(result));
return result;
}
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the shape of the first input.
*/
bool normalize_attributes(operation& op, const shape& input_shape)
{
bool tuned = false;
auto attrs = op.attributes();
auto val = op.to_value();
if(attrs.contains("normalize_padding"))
{
bool use_auto_padding =
(val.contains("padding_mode") and
(val.at("padding_mode").to<int>() != migraphx::op::padding_mode_t::default_));
if(not use_auto_padding)
{
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
auto padding_start = 2;
if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true;
else if(padding_size != (input_shape.ndim() - padding_start))
{
MIGRAPHX_THROW("normalize_attributes: inconsistent padding vector size ");
}
else
{
auto result = tune_pad_attribute(padding);
val["padding"] = result;
op.from_value(val);
tuned = true;
}
}
}
if(not attrs.contains("normalize_axes"))
{
return tuned;
}
auto attr_v = attrs.at("normalize_axes").without_key();
for(const auto& rv : attr_v)
{
const auto& key = rv.get_key();
if(val.contains(key))
{
auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array())
{
std::vector<int64_t> axes;
if(val.contains("axes"))
{
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), input_shape, message);
val[key] = result;
op.from_value(val);
val = op.to_value();
tuned = true;
}
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), input_shape, message);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
tuned = true;
}
}
else
{
MIGRAPHX_THROW("NORMALIZE_ATTR : op " + op.name() + " attribute \"" + key +
"\" not exist!");
}
}
return tuned;
}
std::vector<int64_t> normalize_axes(const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix)
{
return tune_attribute(axes, {}, attr_val, input_shape, [&] { return prefix; });
}
std::vector<int64_t> normalize_indices(const std::vector<int64_t>& indices,
const std::vector<int64_t>& axes,
const shape& input_shape,
const value& attr_val,
const std::string& prefix)
{
return tune_attribute(indices, axes, attr_val, input_shape, [&] { return prefix; });
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,57 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <unordered_set>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/value.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void normalize_ops::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
auto inputs = ins->inputs();
if(inputs.empty())
continue;
auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, s))
{
m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized();
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,55 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max", "lpnorm"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,41 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const operation& op)
{
v["name"] = op.name();
v["operator"] = op.to_value();
}
void migraphx_from_value(const value& v, operation& op)
{
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,55 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_convert.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void optimize_module::apply(module_pass_manager& mpm) const
{
mpm.get_module().repeat_while_changes(2, [&] {
// loop to further optimize after initial transformations
mpm.get_module().repeat_while_changes(4, [&] {
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(eliminate_convert{});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(simplify_algebra{});
});
mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{propagate_constant_skip_ops});
mpm.run_pass(dead_code_elimination{});
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,154 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pad_calc.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void calculate_padding(int64_t idx,
std::vector<int64_t>& pads,
int64_t input_dim,
int64_t stride,
int64_t dilation,
int64_t weight_dim,
bool is_same_upper)
{
int64_t output_dim = (input_dim + stride - 1) / stride; // round up result
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
int64_t pad =
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
auto pad_ndims = pads.size() / 2;
if(is_same_upper)
{
pads[idx] = pad / 2;
pads[idx + pad_ndims] = pad - pad / 2;
}
else
{
pads[idx + pad_ndims] = pad / 2;
pads[idx] = pad - pad / 2;
}
}
/**
* Given the input array dimensions; kernel (wei_lens); strides; and dilations,
* calculate the padding value in each dimension.
*
*/
std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& wei_lens,
const std::vector<std::size_t>& strides,
const std::vector<std::size_t>& dilations,
bool use_upper)
{
std::vector<std::size_t> padding;
assert(input_lens.size() >= 3);
assert(input_lens.size() == wei_lens.size());
std::size_t num_spatial_dims = input_lens.size() - 2;
padding.resize(2 * num_spatial_dims);
for(std::size_t i = 0; i < num_spatial_dims; i++)
{
std::ptrdiff_t input_dim = input_lens[i + 2];
std::ptrdiff_t stride = strides[i];
std::ptrdiff_t weight_dim = wei_lens[i + 2];
std::ptrdiff_t dilation = dilations[i];
std::ptrdiff_t output_dim = (input_dim + stride - 1) / stride; // round up result
std::ptrdiff_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
std::size_t pad = std::max(static_cast<std::ptrdiff_t>(0),
(output_dim - 1) * stride + new_weight_dim - input_dim);
auto pad_ndims = padding.size() / 2;
if(use_upper)
{
padding[i] = pad / 2;
padding[i + pad_ndims] = pad - pad / 2;
}
else
{
padding[i + pad_ndims] = pad / 2;
padding[i] = pad - pad / 2;
}
}
return padding;
}
/**
* Calculate the correct output shape for a convolution with
* a given input size and other parameters.
*
*/
shape compute_padded_shape(const shape& input,
const shape& weights,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation)
{
const size_t num_spatial_dims = input.lens().size() - 2;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; ++i)
{
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
stride[i] +
1)));
}
return input.with_lens(output_lens);
}
/**
* Calculate the correct output shape for a pooling with
* a given input size and other parameters. This uses
* the same formula for pooling that compute_padded_shape() uses
* for convolutions, but takes slightly different inputs.
*
*/
shape compute_padded_pool_shape(const shape& input,
const shape& kernel,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation)
{
const size_t num_spatial_dims = input.lens().size() - 2;
std::vector<size_t> output_lens{input.lens()[0], input.lens()[1]};
// calculate the output shape of the pooling: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; ++i)
{
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (kernel.lens()[i] - 1)) + padding_factor) /
stride[i] +
1)));
}
return input.with_lens(output_lens);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,82 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#include <migraphx/param_utils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <map>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::string param_name(std::size_t i, const std::string& prefix)
{
if(i < 10)
return prefix + std::to_string(i);
const std::size_t max_digits = 5;
if(i >= std::pow(10, max_digits))
MIGRAPHX_THROW("Too many parameters.");
std::size_t n = log10(i) + 1;
return prefix + ":" + std::string(max_digits - n, '0') + std::to_string(i);
}
void sort_params(std::vector<instruction_ref>& params)
{
std::sort(params.begin(), params.end(), by(std::less<>{}, [](instruction_ref ins) {
const auto& param = any_cast<const builtin::param&>(ins->get_operator());
return param.parameter;
}));
}
std::vector<instruction_ref>
find_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins,
const_module_ref parent,
const_module_ref sub)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param] : map_ins)
{
if(sub != nullptr and not sub->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
if(parent != nullptr and not parent->has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(not sub or result.size() == sub->get_parameter_shapes().size());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,45 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pass.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/// Dummy pass for default return
struct id_pass
{
std::string name() const { return "id"; }
void apply(const module&) const {}
};
pass enable_pass(bool enabled, pass p)
{
if(enabled)
return p;
return id_pass{};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,209 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_PASSES);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PASSES);
static bool is_pass_disabled(const std::string& name)
{
static const auto passes = split_string(string_value_of(MIGRAPHX_DISABLE_PASSES{}, ""), ',');
return contains(passes, name);
}
void validate_pass(module& mod, const pass& p, tracer trace)
{
(void)mod;
(void)p;
(void)trace;
#ifndef NDEBUG
trace("Validate ...");
auto invalid = mod.validate();
if(invalid != mod.end())
{
auto index = std::distance(mod.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
}
struct module_pm : module_pass_manager
{
module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr;
module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual module* create_module(const std::string& name, module m) override
{
assert(prog);
return prog->create_module(name, std::move(m));
}
virtual void rename_module(const std::string& old_name, const std::string& new_name) override
{
assert(prog);
assert(mod);
assert(
any_of(mod->get_sub_modules(), [&](module_ref sm) { return sm->name() == old_name; }));
prog->rename_module(old_name, new_name);
}
virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override
{
if(root_mod != nullptr)
return root_mod;
assert(prog);
return prog->get_main_module();
}
virtual void run_pass(const pass& p) override
{
if(is_pass_disabled(p.name()))
return;
trace("Pass: ", p.name());
assert(mod);
assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{}))
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto ms = time<milliseconds>([&] { p.apply(*this); });
std::cout << p.name() << ": " << ms << "ms\n";
}
else
{
p.apply(*this);
}
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
std::unordered_set<module_ref> visited;
for(const auto& p : passes)
{
auto tree = prog.get_module_tree();
std::vector<module_ref> sub_mods = root_mod->get_sub_modules();
sub_mods.insert(sub_mods.begin(), root_mod);
visited.clear();
for(const auto& mod : reverse(sub_mods))
{
if(mod->bypass())
continue;
if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
if(nparents == 0)
mpm.common_parent = nullptr;
else if(nparents == 1)
mpm.common_parent = parents.begin()->second;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm.common_parent = prog.get_main_module();
mpm.run_pass(p);
}
run_pass(prog, p, trace);
}
}
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
run_passes(prog, prog.get_main_module(), passes, trace);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,88 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <map>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
std::vector<std::int64_t> result(s.lens().size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(s.strides()[x], s.lens()[x]);
}));
return result;
}
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
std::map<std::vector<int64_t>, std::size_t> count;
for(auto&& s : shapes)
{
if(s.broadcasted())
continue;
count[find_permutation(s)]++;
}
if(count.empty())
{
std::vector<int64_t> r(shapes.front().lens().size());
std::iota(r.begin(), r.end(), 0);
return r;
}
auto it = std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }));
assert(it != count.end());
return it->first;
}
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,52 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/preallocate_param.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void preallocate_param::apply(module& m) const
{
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "@param")
continue;
if(param != any_cast<builtin::param>(ins->get_operator()).parameter)
continue;
std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
}
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,461 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/env.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/process.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tmp_dir.hpp>
#include <migraphx/fileutils.hpp>
#include <algorithm>
#include <numeric>
#include <functional>
#include <iostream>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <cstring>
#include <sstream>
#include <optional>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE)
#ifndef _WIN32
std::function<void(const char*)> redirect_to(std::ostream& os)
{
return [&](const char* x) { os << x; };
}
template <class F>
int exec(const std::string& cmd, const char* type, F f)
{
int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
auto closer = [&](FILE* stream) {
auto status = pclose(stream);
ec = WIFEXITED(status) ? WEXITSTATUS(status) : 0; // NOLINT
};
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), type), closer); // NOLINT
if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
f(pipe.get());
}
return ec;
}
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
return exec(cmd, "r", [&](FILE* f) {
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), f) != nullptr)
std_out(buffer.data());
});
}
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
{
return exec(cmd, "w", [&](FILE* f) {
std_in([&](const char* buffer, std::size_t n) { std::fwrite(buffer, 1, n, f); });
});
}
#else
constexpr std::size_t MIGRAPHX_PROCESS_BUFSIZE = 4096;
enum class direction
{
input,
output
};
template <direction dir>
class pipe
{
public:
explicit pipe()
{
SECURITY_ATTRIBUTES attrs;
attrs.nLength = sizeof(SECURITY_ATTRIBUTES);
attrs.bInheritHandle = TRUE;
attrs.lpSecurityDescriptor = nullptr;
if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE)
throw GetLastError();
if(dir == direction::output)
{
// Do not inherit the read handle for the output pipe
if(SetHandleInformation(m_read, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
else
{
// Do not inherit the write handle for the input pipe
if(SetHandleInformation(m_write, HANDLE_FLAG_INHERIT, 0) == 0)
throw GetLastError();
}
}
pipe(const pipe&) = delete;
pipe& operator=(const pipe&) = delete;
pipe(pipe&&) = default;
~pipe()
{
if(m_write != nullptr)
{
CloseHandle(m_write);
}
if(m_read != nullptr)
{
CloseHandle(m_read);
}
}
bool close_write_handle()
{
auto result = true;
if(m_write != nullptr)
{
result = CloseHandle(m_write) == TRUE;
m_write = nullptr;
}
return result;
}
bool close_read_handle()
{
auto result = true;
if(m_read != nullptr)
{
result = CloseHandle(m_read) == TRUE;
m_read = nullptr;
}
return result;
}
std::pair<bool, DWORD> read(LPVOID buffer, DWORD length) const
{
DWORD bytes_read;
if(ReadFile(m_read, buffer, length, &bytes_read, nullptr) == FALSE and
GetLastError() == ERROR_MORE_DATA)
{
return {true, bytes_read};
}
return {false, bytes_read};
}
HANDLE get_read_handle() const { return m_read; }
bool write(LPCVOID buffer, DWORD length) const
{
DWORD bytes_written;
return WriteFile(m_write, buffer, length, &bytes_written, nullptr) == TRUE;
}
HANDLE get_write_handle() const { return m_write; }
private:
HANDLE m_write = nullptr, m_read = nullptr;
};
// clang-format off
template <typename F>
int exec(const std::string& cmd, const std::string& cwd, const std::string& args,
const std::string& envs, F f)
// clang-format on
{
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
{
std::cout << "[cwd=" << cwd << "]; cmd='" << cmd << "\'; args='" << args << "'; envs='"
<< envs << "'\n";
}
// See CreateProcess() WIN32 documentation for details.
constexpr std::size_t CMDLINE_LENGTH = 32767;
// Build lpCommandLine parameter.
std::string cmdline = quote_string(cmd);
if(not args.empty())
cmdline += " " + args;
// clang-format off
if(cmdline.size() > CMDLINE_LENGTH)
MIGRAPHX_THROW("Command line too long, required maximum " +
std::to_string(CMDLINE_LENGTH) + " characters.");
// clang-format on
if(cmdline.size() < CMDLINE_LENGTH)
cmdline.resize(CMDLINE_LENGTH, '\0');
// Build lpEnvironment parameter.
std::vector<TCHAR> environment{};
if(not envs.empty())
{
std::istringstream iss{envs};
std::string str;
while(iss >> str)
{
environment.insert(environment.end(), str.begin(), str.end());
environment.push_back('\0');
}
environment.push_back('\0');
}
try
{
STARTUPINFO info;
PROCESS_INFORMATION process_info;
pipe<direction::input> input{};
pipe<direction::output> output{};
ZeroMemory(&info, sizeof(STARTUPINFO));
info.cb = sizeof(STARTUPINFO);
info.hStdError = output.get_write_handle();
info.hStdOutput = output.get_write_handle();
info.hStdInput = input.get_read_handle();
info.dwFlags |= STARTF_USESTDHANDLES;
ZeroMemory(&process_info, sizeof(process_info));
if(CreateProcess(cmd.c_str(),
cmdline.data(),
nullptr,
nullptr,
TRUE,
0,
environment.empty() ? nullptr : environment.data(),
cwd.empty() ? nullptr : static_cast<LPCSTR>(cwd.c_str()),
&info,
&process_info) == FALSE)
{
MIGRAPHX_THROW("Error creating process (" + std::to_string(GetLastError()) + ")");
}
CloseHandle(process_info.hThread);
if(not output.close_write_handle())
MIGRAPHX_THROW("Error closing STDOUT handle for writing (" +
std::to_string(GetLastError()) + ")");
if(not input.close_read_handle())
MIGRAPHX_THROW("Error closing STDIN handle for reading (" +
std::to_string(GetLastError()) + ")");
f(input, output);
if(not input.close_write_handle())
MIGRAPHX_THROW("Error closing STDIN handle for writing (" +
std::to_string(GetLastError()) + ")");
WaitForSingleObject(process_info.hProcess, INFINITE);
DWORD status{};
GetExitCodeProcess(process_info.hProcess, &status);
CloseHandle(process_info.hProcess);
return static_cast<int>(status);
}
// cppcheck-suppress catchExceptionByValue
catch(DWORD error)
{
MIGRAPHX_THROW("Error spawning process (" + std::to_string(error) + ")");
}
}
// clang-format off
int exec(const std::string& cmd, const std::string& cwd, const std::string& args,
const std::string& envs, HANDLE std_out)
{
TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE];
return (std_out == nullptr or std_out == INVALID_HANDLE_VALUE)
? GetLastError() : exec(cmd, cwd, args, envs,
[&](const pipe<direction::input>&, const pipe<direction::output>& out) {
for(;;)
{
auto [more_data, bytes_read] = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE);
if(bytes_read == 0)
break;
if(WriteFile(std_out, buffer, bytes_read, nullptr, nullptr) == FALSE)
break;
if(not more_data)
break;
}
});
}
int exec(const std::string& cmd, const std::string& cwd, const std::string& args,
const std::string& envs, std::function<void(process::writer)> std_in)
{
return exec(cmd, cwd, args, envs,
[&](const pipe<direction::input>& input, const pipe<direction::output>&) {
std_in([&](const char* buffer, std::size_t n) { input.write(buffer, n); });
});
}
// clang-format on
#endif
struct process_impl
{
std::string args{};
std::string envs{};
std::string command{};
fs::path cwd{};
std::string get_command() const
{
std::string result;
if(not cwd.empty())
result += "cd " + cwd.string() + "; ";
if(not envs.empty())
result += envs + " ";
result += command;
if(not args.empty())
result += " " + args;
return result;
}
template <class... Ts>
void check_exec(Ts&&... xs) const
{
int ec = migraphx::exec(std::forward<Ts>(xs)...);
if(ec != 0)
MIGRAPHX_THROW("Command " + get_command() + " exited with status " +
std::to_string(ec));
}
};
process::process(const std::string& cmd, const std::vector<std::string>& args)
: impl(std::make_unique<process_impl>())
{
impl->command = cmd;
if(not args.empty())
impl->args = join_strings(args, " ");
}
process::process(process&&) noexcept = default;
process& process::operator=(process rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
process::~process() noexcept = default;
process& process::cwd(const fs::path& p)
{
impl->cwd = p;
return *this;
}
process& process::env(const std::vector<std::string>& envs)
{
if(not envs.empty())
{
impl->envs = join_strings(envs, " ");
}
return *this;
}
void process::read(const writer& output) const
{
#ifdef _WIN32
// clang-format off
constexpr std::string_view filename = "stdout";
auto tmp = tmp_dir{};
HANDLE handle = CreateFile((tmp.path / filename).string().c_str(),
GENERIC_READ | GENERIC_WRITE,
0,
nullptr,
CREATE_ALWAYS,
FILE_ATTRIBUTE_NORMAL,
nullptr);
impl->check_exec(impl->command, impl->cwd.string(), impl->args, impl->envs,
handle == nullptr or handle == INVALID_HANDLE_VALUE ?
GetStdHandle(STD_OUTPUT_HANDLE) : handle);
CloseHandle(handle);
handle = CreateFile((tmp.path / filename).string().c_str(),
GENERIC_READ | GENERIC_WRITE,
0,
nullptr,
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
nullptr);
if(handle == nullptr or handle == INVALID_HANDLE_VALUE)
MIGRAPHX_THROW("Unable to open file: " + (tmp.path / filename));
auto size = GetFileSize(handle, nullptr);
std::string result(size, '\0');
if(ReadFile(handle, result.data(), size, nullptr, nullptr) == FALSE)
MIGRAPHX_THROW("Failed reading file: " + (tmp.path / filename));
CloseHandle(handle);
// clang-format on
#else
std::stringstream ss;
impl->check_exec(impl->get_command(), redirect_to(ss));
auto result = ss.str();
#endif
output(result.data(), result.size());
}
void process::exec()
{
#ifndef _WIN32
impl->check_exec(impl->get_command(), redirect_to(std::cout));
#else
// clang-format off
impl->check_exec(impl->command, impl->cwd.string(), impl->args, impl->envs,
GetStdHandle(STD_OUTPUT_HANDLE));
// clang-format on
#endif
}
void process::write(std::function<void(writer)> pipe_in)
{
#ifndef _WIN32
impl->check_exec(impl->get_command(), std::move(pipe_in));
#else
// clang-format off
impl->check_exec(impl->command, impl->cwd.string(),
impl->args, impl->envs, std::move(pipe_in));
// clang-format on
#endif
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,55 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m == *root_module)
return;
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
auto new_lit = root_module->add_literal(ins->get_literal());
auto ins_outputs = ins->outputs();
for(auto out_ins : ins_outputs)
{
out_ins->replace_argument(out_ins, ins, new_lit);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,138 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/propagate_constant.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/simple_par_for.hpp>
#include <migraphx/env.hpp>
#include <thread>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propagate(instruction_ref ins)
{
if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name()))
return skip_propagate(ins->inputs().front());
if(ins->name() == "unpack_int4")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and s.element_space() < s.elements())
return true;
auto alias = instruction::get_output_alias(ins, true);
if(alias != ins)
return skip_propagate(alias);
if(ins->is_undefined())
return true;
return false;
}
bool is_const_ins(instruction_ref ins, const std::unordered_set<std::string>& skip_ops)
{
return ins->can_eval() and not skip_propagate(ins) and
skip_ops.find(ins->name()) == skip_ops.end();
}
argument as_packed(const argument& c)
{
if(c.get_shape().packed())
return c;
auto s = c.get_shape().with_lens(c.get_shape().lens());
argument result;
c.visit([&](auto x) { result = literal{s, x.begin(), x.end()}.get_argument(); });
return result;
}
void propagate_constant::apply(module& m) const
{
std::unordered_set<instruction_ref> const_instrs;
auto last = std::prev(m.end());
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{
const bool is_const = is_const_ins(i, skip_ops);
if(is_const and i != last)
continue;
if(i == last and is_const)
{
const_instrs.insert(i);
}
else
{
std::copy_if(i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) {
return is_const_ins(ins, skip_ops) and ins->name() != "@literal";
});
}
}
// Compute literals in parallel
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
std::vector<argument> literals(const_instrs_vec.size());
std::size_t grainsize = 1;
#if !MIGRAPHX_HAS_EXECUTORS
std::size_t n = std::max<std::size_t>(2048 / std::thread::hardware_concurrency(), 1);
grainsize = const_instrs_vec.size() / n;
#endif
simple_par_for(const_instrs_vec.size(), grainsize, [&](const auto i) {
literals[i] = as_packed(const_instrs_vec[i]->eval());
});
// Replace instructions in m
for(size_t i = 0; i < const_instrs_vec.size(); i++)
{
if(not literals[i].empty())
{
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape().lens() == const_instrs_vec[i]->get_shape().lens());
assert(literals[i].get_shape().bytes() <= const_instrs_vec[i]->get_shape().bytes());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,211 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/truncate_float.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/normalize_ops.hpp>
#include <set>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_QUANTIZATION)
tracer quant_tracer()
{
if(enabled(MIGRAPHX_TRACE_QUANTIZATION{}))
return tracer{std::cout};
return tracer{};
};
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing or underflowing, but it
// is uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
// avoid loss of precision.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
truncate_float_pass{ins_names, shape::half_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},
quant_tracer());
}
void quantize_bf16(program& prog, const std::vector<std::string>& ins_names)
{
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
truncate_float_pass{ins_names, shape::bf16_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},
quant_tracer());
}
void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
const std::vector<parameter_map>& calibration,
const std::unordered_set<std::string>& ins_names)
{
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes(prog, {normalize_ops{}, optimize_module{}}, quant_tracer());
std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
std::map<shape::type_t, float> type_ranges = {{shape::type_t::int8_type, 127.0},
{shape::type_t::fp8e4m3fnuz_type, 240.0},
{shape::type_t::fp8e4m3fn_type, 448.0}};
float quantized_range = type_ranges.at(precision);
auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
argument arg = t.copy_from(args.front());
arg.visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
// if all values are 0, no need to do scaling
if(float_equal(max_abs_vals->at(ins_index), 0.0f))
{
param_pair.first = 1.0f;
}
else
{
param_pair.first = quantized_range / max_abs_vals->at(ins_index);
}
quant_8bit_params->at(ins_index) = param_pair;
};
// pass to add capture argument op
std::size_t param_num = 0;
run_passes(
prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}}, quant_tracer());
quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_8BITS_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < quant_8bit_params->size(); ++i)
{
auto param = quant_8bit_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
run_passes(prog,
{quantize_8bits_pass{precision, *quant_8bit_params}, dead_code_elimination{}},
quant_tracer());
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::unordered_set<std::string>& ins_names)
{
std::unordered_set<std::string> op_names = {"convolution", "dot"};
if(op_names != ins_names)
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
}
void quantize_int4_weights(program& prog)
{
run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}}, quant_tracer());
}
void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
{
std::unordered_set<std::string> supported_ins_names;
auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm))
{
if(ins->name() == "convert")
{
continue;
}
if(not starts_with(ins->name(), "@"))
{
supported_ins_names.insert(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fn_type, calibration, supported_ins_names);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,117 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_8bits_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != precision)
{
auto zero_point =
m.add_literal(migraphx::literal{migraphx::shape{precision}, {param.second}});
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if((not contains(ins_names, ins->name())) or (ins->name() == "convert"))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
if(contains(quantizable_types, input->get_shape().type()))
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
else
{
new_args.push_back(input);
}
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,108 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void int4_quantize_module(module& m)
{
std::vector<std::string> int4_instrs{"dot", "convolution"};
for(auto ins : iterator_for(m))
{
if(not(contains(int4_instrs, ins->name())))
continue;
if(ins->inputs().empty())
continue;
auto s = ins->get_shape();
auto mod_inputs = ins->module_inputs();
// Convert each of the inputs that are fp32 or fp16 to int4
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto inp) {
auto sh = inp->get_shape();
if(sh.broadcasted())
return inp;
auto input_type = sh.type();
if(input_type != shape::float_type and input_type != shape::half_type)
return inp;
auto lens = sh.lens();
if(lens[lens.size() - 1] % 2)
return inp; // even sized dimensions to pack
if(not inp->can_eval())
return inp;
std::vector<float> val;
inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); });
auto [min, max] = std::minmax_element(val.begin(), val.end());
*min = *min > 0 ? 0 : *min;
*max = *max < 0 ? 0 : *max;
float fscale4 = (*max - *min) / 15; // INT4 range is [0-15]
int zp4 = float_equal(fscale4, 0) ? 0 : std::round(-*min / fscale4);
auto scale = m.add_literal(literal({s.type()}, {fscale4}));
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
auto zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
zp = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), zp);
auto q_in = m.insert_instruction(ins, make_op("quantizelinear"), inp, scale, zp);
auto pk = m.insert_instruction(ins, make_op("pack_int4", {{"axis", -1}}), q_in);
auto unpk = m.insert_instruction(ins, make_op("unpack_int4", {{"axis", -1}}), pk);
auto dq_scale = m.add_literal(literal({s.type()}, {fscale4}));
dq_scale = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
auto dq_zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
dq_zp =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_zp);
return m.insert_instruction(ins, make_op("dequantizelinear"), unpk, dq_scale, dq_zp);
});
auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
m.replace_instruction(ins, converted_ins);
}
}
void quantize_int4_pass::apply(module& m) const { int4_quantize_module(m); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,153 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
{
std::vector<std::size_t> new_lens;
for(const auto& s : shapes)
{
assert(n < s.lens().size());
if((n + 1) >= s.lens().size())
return false;
auto astride = s.strides()[n];
auto alen = s.lens()[n];
auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1];
if(astride == bstride * blen or alen == 1)
new_lens.push_back(alen * blen);
}
if(new_lens.size() != shapes.size())
return false;
std::size_t i = 0;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.erase(lens.begin() + n);
strides.erase(strides.begin() + n);
lens[n] = new_lens[i];
s = shape{s.type(), lens, strides};
i++;
}
return true;
}
void reduce_dim1(std::vector<shape>& shapes)
{
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size())
{
(void)n;
}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
{
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
{
return std::accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
std::vector<std::size_t> result;
const auto* x = &s.lens();
const auto* y = &lens;
if(x->size() > y->size())
std::swap(x, y);
std::transform(
x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
return std::max(a, b);
});
return result;
});
}
shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
{
assert(s.lens().size() == lens.size());
std::vector<std::size_t> rstrides(lens.size());
std::size_t stride = 1;
for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
{
if(lens[i] == s.lens()[i])
{
rstrides[i] = stride;
stride *= lens[i];
}
else if(lens[i] != 1 and s.lens()[i] != 1)
{
return shape{};
}
}
return shape{s.type(), lens, rstrides};
}
std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
auto result = shapes;
auto base = base_lens(shapes);
for(auto&& s : shapes)
{
if(s.lens().size() != base.size())
return shapes;
if(s.lens() == base)
continue;
auto mshape = mask_shape(s, base);
if(mshape.lens().size() != base.size())
return shapes;
result.push_back(mshape);
}
reduce_dim_all(result);
result.erase(result.begin() + shapes.size(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,65 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/register_op.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, operation>& op_map()
{
static std::unordered_map<std::string, operation> m; // NOLINT
return m;
}
void register_op_init() { (void)op_map(); }
void register_op(const operation& op) { op_map()[op.name()] = op; }
void unregister_op(const std::string& op_name)
{
assert(op_map().count(op_name));
op_map().erase(op_name);
}
operation load_op(const std::string& name)
{
return at(op_map(), name, "Operator not found: " + name);
}
bool has_op(const std::string& name) { return op_map().count(name) == 1; }
std::vector<std::string> get_operators()
{
std::vector<std::string> result;
std::transform(op_map().begin(), op_map().end(), std::back_inserter(result), [&](auto&& p) {
return p.first;
});
std::sort(result.begin(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,101 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <string>
#include <unordered_map>
#include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/fileutils.hpp>
#include <migraphx/version.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void store_target_lib(const dynamic_loader& lib)
{
static std::vector<dynamic_loader> target_loader;
target_loader.emplace_back(lib);
}
std::unordered_map<std::string, target>& target_map()
{
static std::unordered_map<std::string, target> m; // NOLINT
return m;
}
void register_target_init() { (void)target_map(); }
void unregister_target(const std::string& name)
{
assert(target_map().count(name));
target_map().erase(name);
}
void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name)
{
if(not contains(target_map(), name))
{
std::string so_major_version = "." + std::to_string(MIGRAPHX_SO_MAJOR_VERSION);
auto target_name = make_shared_object_filename("migraphx_" + name);
// Try to load library with so_major_version appended to the name.
// If library with so_major_version name is not found,
// try loading the library without the so_major_version name appended.
// For example, if "libmigraphx_ref.so.2010000" is not found,
// try loading "libmigraphx_ref.so".
try
{
// Default to loading shared libraries with
// so_major_version appended.
store_target_lib(dynamic_loader(target_name + so_major_version));
}
catch(...)
{
// Load the library without the so_major_version in the name.
store_target_lib(dynamic_loader(target_name));
}
}
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not loaded or not supported");
}
return it->second;
}
std::vector<std::string> get_targets()
{
std::vector<std::string> result;
std::transform(target_map().begin(),
target_map().end(),
std::back_inserter(result),
[&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,123 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/allocate.hpp>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<instruction_ref, std::string> create_output_names(const module& mod)
{
std::unordered_map<instruction_ref, std::string> mod_output_names{};
auto last = std::prev(mod.end());
if(last->name() == "@return")
{
const auto& prog_outputs = last->inputs();
std::vector<instruction_ref> outputs_alias(prog_outputs.size());
std::transform(prog_outputs.begin(),
prog_outputs.end(),
outputs_alias.begin(),
[](const auto& i) { return instruction::get_output_alias(i); });
std::size_t index = 0;
for(auto ins : outputs_alias)
{
mod_output_names[ins] = mod.name() + ":#output_" + std::to_string(index++);
}
}
else
{
auto ins = instruction::get_output_alias(last);
mod_output_names[ins] = "output";
}
return mod_output_names;
}
void insert_submod_allocations(instruction_ref ins, module& mod, const allocation_model& model)
{
std::vector<instruction_ref> inputs = ins->inputs();
std::vector<module_ref> mod_args = ins->module_inputs();
std::map<std::string, shape> name_shapes;
for(const auto& smod : mod_args)
{
auto ps = smod->get_parameter_shapes();
name_shapes.insert(ps.begin(), ps.end());
}
for(const auto& pn : name_shapes)
{
const auto& s = pn.second;
instruction_ref output{};
output = mod.insert_instruction(ins, model.allocate(s));
inputs.push_back(output);
}
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
}
void replace_allocate::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m);
bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m))
{
auto op = ins->get_operator();
auto op_name = op.name();
// check if allocations from submodules need to be inserted
// for now, only the "if" operator is affected
if(op_name == "if")
{
insert_submod_allocations(ins, m, model);
continue;
}
if(op_name != "allocate")
continue;
auto s = ins->get_shape();
if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{
auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param);
}
else
{
m.replace_instruction(ins,
make_op(model.name(), migraphx::value{{"shape", to_value(s)}}));
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,110 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/match/gelu_tanh.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* The replacement approximation is equivalent to:
* GELU(x) ~= 0.5 * x * ( 1 + tanh( sqrt(2/M_PI) * (x + 0.044715 * x^3)))
* You can rearrange to the form used in this by recognizing that
* 1 + tanh(x) = (2) / (1 + exp(-2 * x)).
* The fitting constant 0.044715 is from
* A. Choudhury, A simple approximation to the area under standard normal curve, Mathematics and
* Statistics, vol. 2, no. 3, pp. 147149, 2014.
*/
void replace_with_tanh_exp_gelu(module& m, const match::matcher_result& r)
{
auto ins = r.result;
auto x = r.instructions["x"];
double const0 = -2. * sqrt(M_2_PI);
double const1 = 0.044715 * const0;
auto lit0 = m.add_literal(literal{shape{x->get_shape().type()}, {const0}});
auto lit1 = m.add_literal(literal{shape{x->get_shape().type()}, {const1}});
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0}});
auto xb = insert_common_op(m, ins, make_op("mul"), {x, lit1});
auto a = m.insert_instruction(ins, make_op("mul"), x, xb);
auto b = insert_common_op(m, ins, make_op("add"), {a, lit0});
auto u = m.insert_instruction(ins, make_op("mul"), x, b);
auto emu = m.insert_instruction(ins, make_op("exp"), u);
auto c = insert_common_op(m, ins, make_op("add"), {one, emu});
auto y = m.insert_instruction(ins, make_op("div"), x, c);
m.replace_instruction(ins, y);
}
/**
* Finds erfGELU blocks using the Gaussian distribution and replaces them with the tanh_exp
* approximation if the data type is fp16. TODO consider also for fp8 datatype.
*/
struct find_gelu_erf
{
auto matcher() const { return match::any_of(match::gelu_erf(), match::gelu_tanh()); }
void apply(module& m, const match::matcher_result& r) const
{
auto x = r.instructions["x"];
auto input_type = x->get_shape().type();
std::set<migraphx::shape::type_t> convert_types = {migraphx::shape::half_type};
if(not contains(convert_types, input_type))
return;
replace_with_tanh_exp_gelu(m, r);
}
};
/**
* Find tanhGELU blocks and replace them with a rearranged version that is less likely to overflow
* and is more performant.
*/
struct find_tanh_fast_gelu
{
auto matcher() const { return match::gelu_tanh(); }
void apply(module& m, const match::matcher_result& r) const
{
replace_with_tanh_exp_gelu(m, r);
}
};
void rewrite_gelu::apply(module& m) const
{
if(fast_math)
{
match::find_matches(m, find_gelu_erf{});
}
else
{
match::find_matches(m, find_tanh_fast_gelu{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,68 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_low_precision.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_pow2_div
{
auto pow2() const
{
auto pow2 = match::name("pow")(match::arg(0)(match::any().bind("x")),
match::arg(1)(match::has_value(2.0f)));
auto x_square =
match::name("mul")(match::same_inputs(), match::arg(0)(match::any().bind("x")));
return match::any_of(pow2, x_square);
}
auto matcher() const
{
return match::name("div")(match::arg(0)(pow2()),
match::arg(1)(match::is_constant().bind("n")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto n = r.instructions["n"];
auto x = r.instructions["x"];
if(x->get_shape().type() != migraphx::shape::half_type)
return;
auto x_div_n = m.insert_instruction(ins, make_op("div"), {x, n});
auto mul = m.insert_instruction(ins, make_op("mul"), {x_div_n, x});
m.replace_instruction(ins, mul);
}
};
void rewrite_low_precision::apply(module& m) const { match::find_matches(m, find_pow2_div{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,185 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void replace_with_reduce(module& m, instruction_ref ins)
{
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
auto lens = s.lens();
std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// average pooling
if(op.mode == op::pooling_mode::average)
{
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
}
// max pooling
else
{
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
}
}
static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;
std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());
for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}
auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}
// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}
for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}
// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}
void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
bool default_strides =
std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
bool default_dilations =
std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
{
replace_with_reduce(m, ins);
}
else if(not default_dilations)
{
// Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}
replace_dilations_with_gather_pooling(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,124 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3)
{
auto zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", y_scale->get_shape().type()}}),
ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
double max_quant = 0;
double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
instruction_ref min_arg;
instruction_ref max_arg;
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<double> min_data(s.elements(), min_quant);
std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
else
{
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
}
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x_scale = ins->inputs()[1];
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[0]);
if(ins->inputs().size() == 3)
{
auto x_zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", x_scale->get_shape().type()}}),
ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,153 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#include <migraphx/rewrite_reduce.hpp>
#include <migraphx/module.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_softmax
{
auto matcher() const { return match::name("softmax"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto op = ins->get_operator().to_value();
auto axis = op["axis"].to<std::int64_t>();
auto input = ins->inputs().front();
auto max = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {axis}}}), input);
auto maxb = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", input->get_shape().lens()}}), max);
auto sub = m.insert_instruction(ins, make_op("sub"), input, maxb);
auto exp = m.insert_instruction(ins, make_op("exp"), sub);
auto sum = m.insert_instruction(ins, make_op("reduce_sum", {{"axes", {axis}}}), exp);
auto sumb = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", input->get_shape().lens()}}), sum);
m.replace_instruction(ins, make_op("div"), exp, sumb);
}
};
struct find_reduce_mean_variance
{
auto matcher() const
{
auto reduce_mean = match::name("reduce_mean");
auto skip_broadcasts_mean = match::skip_broadcasts(reduce_mean.bind("mean"));
auto x_minus_mean = match::name("sub")(match::arg(0)(match::any().bind("x")),
match::arg(1)(skip_broadcasts_mean));
auto pow_x_minus_mean =
match::name("pow")(match::arg(0)(x_minus_mean), match::arg(1)(match::has_value(2.0f)));
auto mul_x_minus_mean =
match::name("mul")(match::arg(0)(x_minus_mean), match::arg(1)(x_minus_mean));
auto sqdiff = match::name("sqdiff")(
match::either_arg(0, 1)(match::any().bind("x"), skip_broadcasts_mean));
return reduce_mean(
match::arg(0)(match::any_of(pow_x_minus_mean, mul_x_minus_mean, sqdiff)));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto mean = r.instructions["mean"];
if(ins->get_operator() != mean->get_operator())
return;
if(mean->inputs().front() != x_ins)
return;
auto x2 = m.insert_instruction(ins, make_op("mul"), x_ins, x_ins);
auto mean_x2 = m.insert_instruction(ins, mean->get_operator(), x2);
auto mean_x_2 = m.insert_instruction(ins, make_op("mul"), mean, mean);
m.replace_instruction(ins, make_op("sub"), mean_x2, mean_x_2);
}
};
struct find_reduce_mean
{
auto matcher() const { return match::name("reduce_mean"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto op = ins->get_operator().to_value();
auto axes = op["axes"].to_vector<std::int64_t>();
auto input = ins->inputs().front();
bool is_integral = false;
double max_n = 0;
std::size_t size = 0;
input->get_shape().visit_type([&](auto t) {
is_integral = t.is_integral();
max_n = t.max();
size = t.size();
});
auto n = input->get_shape().elements() / ins->get_shape().elements();
// Convert accumulator to float if <= 8bit type or if < 3 bytes and n >= max_n /4
if(size == 1 or (n >= max_n / 4 and size < 3))
{
shape::type_t t = is_integral ? shape::int32_type : shape::float_type;
input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input);
}
auto n_literal = m.add_literal(literal{{input->get_shape().type(), {1}}, {n}});
if(is_integral)
{
auto reduce_sum =
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), input);
auto div = insert_common_op(m, ins, make_op("div"), {reduce_sum, n_literal});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), div);
}
else
{
auto new_input = insert_common_op(m, ins, make_op("div"), {input, n_literal});
auto reduce_sum =
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), new_input);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), reduce_sum);
}
}
};
} // namespace
void rewrite_reduce::apply(module& m) const
{
match::find_matches(m, find_softmax{}, find_reduce_mean_variance{});
match::find_matches(m, find_reduce_mean{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,633 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/simple_par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
#include <unordered_map>
#include <unordered_set>
#include <queue>
#include <thread>
#include <mutex>
#include <migraphx/make_op.hpp>
#include <set>
#include <deque>
#include <chrono>
#include <iomanip>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SCHEDULE)
auto get_inputs()
{
return [](auto i) { return i->inputs(); };
}
auto get_outputs()
{
return [](auto i) { return i->outputs(); };
}
struct stream_info
{
std::unordered_map<instruction_ref, std::size_t> ins2stream;
std::unordered_map<instruction_ref, std::size_t> weights;
std::unordered_map<instruction_ref, std::size_t> iweights;
ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model)
{
fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
if(not contains(weights, ins))
{
std::size_t weight = 0;
auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op);
// This will ensure a stream will be assigned to return
if(op.name() == "@return")
weight = 1;
iweights[ins] = weight;
auto inputs = ins->inputs();
if(contains(mod_implicit_deps, ins))
{
const auto& impl_deps = mod_implicit_deps.at(ins);
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
}
weights[ins] = std::accumulate(
inputs.begin(), inputs.end(), weight, [&](std::size_t w, instruction_ref i) {
return w + self(i);
});
}
return weights[ins];
})(last);
}
template <class Compare>
void sort_args_by_weight(std::vector<instruction_ref>& args, Compare compare) const
{
if(args.size() < 2)
return;
std::sort(args.begin(), args.end(), by(compare, [this](auto x) {
return std::make_tuple(
this->weights.at(x), x->inputs().size(), std::addressof(*x));
}));
}
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{
if(args.size() < 2)
{
return args.end();
}
const std::size_t min_partition_threshold = 2;
sort_args_by_weight(args, std::greater<>{});
auto it = std::lower_bound(std::next(args.begin()),
args.end(),
min_partition_threshold,
[&](auto i, std::size_t w) { return this->weights[i] > w; });
assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
assert(it == args.end() or std::prev(it) == args.begin() or
this->weights[*std::prev(it)] > min_partition_threshold);
return it;
}
struct partition
{
std::size_t weight = 0;
std::vector<instruction_ref> instructions{};
void add(instruction_ref ins, std::size_t w)
{
weight += w;
instructions.push_back(ins);
}
};
std::size_t assign_streams(module& m, std::size_t n)
{
assert(n > 0);
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
assert(not is_end(ins, m.end()));
if(not m.has_instruction(ins))
return;
if(contains(partitions, ins))
return;
// Add an entry so we know the instruction was visited
partitions[ins];
part.add(ins, this->iweights[ins]);
auto args = ins->inputs();
auto threshold_it = this->sort_args(args);
if(not args.empty())
{
assert(threshold_it != args.begin());
self(args.front(), part);
for(auto i : range(std::next(args.begin()), threshold_it))
{
partitions[ins].emplace_back();
self(i, partitions[ins].back());
}
for(auto i : range(threshold_it, args.end()))
{
self(i, part);
}
}
// Sort instructions
m.move_instruction(ins, m.end());
})(std::prev(m.end()), critical);
// Set the critical partition to stream 0
set_stream(critical, 0);
if(n == 1)
{
// Assign streams for the other partitions
for(auto&& ins_part : partitions)
for(auto&& part : ins_part.second)
set_stream(part, 0);
return 1;
}
else
{
std::vector<std::size_t> streams(n - 1);
// Assign streams for the other partitions
for(auto&& ins_part : partitions)
{
std::sort(ins_part.second.begin(),
ins_part.second.end(),
by(std::greater<>{}, [](auto&& x) {
return std::make_tuple(x.weight, x.instructions.size());
}));
for(auto&& part : ins_part.second)
{
auto stream =
std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream + 1);
streams[stream] += part.weight;
}
}
return 1 + std::count_if(streams.begin(), streams.end(), [](auto x) { return x > 0; });
}
}
using weight_ins = std::pair<std::size_t, instruction_ref>;
struct compare_weight_ins
{
bool operator()(const weight_ins& x, const weight_ins& y) const
{
return std::make_pair(x.first, std::addressof(*x.second)) <
std::make_pair(y.first, std::addressof(*y.second));
}
};
void sort(module& m, std::size_t)
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(m.end());
auto mw = this->weights.at(last);
auto nw = mw / (m.size() + 1);
auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins);
auto& v = visited[ins];
auto it = children.find(std::make_pair(v * w, ins));
if(it == children.end())
{
v++;
children.insert(std::make_pair(v * w, ins));
}
};
add_child(last);
while(not children.empty())
{
// Pop the first element
auto top = children.begin()->second;
children.erase(children.begin());
m.move_instruction(top, m.begin());
for(auto ins : top->inputs())
{
if(not m.has_instruction(ins))
continue;
add_child(ins);
}
if(contains(mod_implicit_deps, top))
{
for(auto ins : mod_implicit_deps.at(top))
{
assert(m.has_instruction(ins));
add_child(ins);
}
}
}
// move dangling parameter to the front so as not be removed
auto ins = std::next(last);
while(ins != m.end())
{
auto next = std::next(ins);
if(ins->name() == "@param")
{
m.move_instruction(ins, m.begin());
}
ins = next;
}
}
void set_stream(const partition& p, std::size_t n)
{
for(auto ins : p.instructions)
if(iweights[ins] > 0)
set_stream(ins, n);
}
void set_stream(instruction_ref ins, std::size_t n)
{
assert(iweights[ins] > 0);
ins2stream[ins] = n;
}
std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); }
template <class F>
bool different(F f, std::size_t stream) const
{
bool result = false;
f([&](auto s) {
if(s != stream)
{
result = true;
return false;
}
// cppcheck-suppress uselessAssignmentArg
stream = s;
return true;
});
return result;
}
template <class F>
bool different(F f) const
{
bool result = false;
f([&](auto s) {
result = this->different(f, s);
return false;
});
return result;
}
template <class Selector>
auto get_streams_from(instruction_ref start, Selector select) const
{
return [=](auto f) {
return fix<bool>([&](auto self, auto ins) {
return all_of(select(ins), [&](auto i) {
if(has_stream(i))
return f(this->get_stream(i));
else
return self(i);
});
})(start);
};
}
std::unordered_set<std::size_t> get_streams(instruction_ref ins) const
{
if(has_stream(ins))
return {get_stream(ins)};
std::unordered_set<std::size_t> result;
get_streams_from(ins, get_inputs())([&](auto s) {
result.insert(s);
return true;
});
return result;
}
template <class... Ts>
bool is_merge_point(instruction_ref ins, Ts... xs) const
{
return different(get_streams_from(ins, get_inputs()), xs...);
}
template <class... Ts>
bool is_split_point(instruction_ref ins, Ts... xs) const
{
return different(get_streams_from(ins, get_outputs()), xs...);
}
std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
{
std::vector<instruction_ref> result;
std::unordered_map<std::size_t, instruction_ref> m;
fix([&](auto self, auto ins) {
for(auto i : ins->inputs())
{
if(iweights.at(i) == 0)
{
self(i);
continue;
}
auto stream = this->get_stream(i);
if(not contains(m, stream))
m[stream] = i;
else
m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) {
return std::distance(x, start);
}));
}
})(start);
std::transform(
m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; });
return result;
}
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(module& m) const
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(m);
result.reserve(m.size());
merge_from.reserve(m.size());
for(auto ins : reverse_iterator_for(m))
{
for(auto&& arg : ins->outputs())
{
if(not m.has_instruction(arg))
continue;
if(is_merge_point(arg))
merge_from[ins].insert(arg);
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
}
if(is_split_point(ins))
{
erase_if(merge_from[ins],
[&](auto merge) { return di.strictly_dominate(ins, merge); });
}
auto streams = this->get_streams(ins);
// Collect concur instructions for each merge point.
for(const auto& merge : merge_from[ins])
{
for(auto stream : streams)
{
if(result[merge].size() <= stream)
result[merge].resize(stream + 1);
auto&& r = result[merge][stream];
r.push_back(ins);
// Copy inputs if they dont have a stream(and are not a builtin and context
// free). Inputs without a stream can have a implicit dependency
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(r),
[&](auto x) {
return not this->has_stream(x) and
not is_context_free(x->get_operator()) and
x->name().front() != '@';
});
}
}
}
return result;
}
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& m)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(m);
// Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0;
for(auto ins : iterator_for(m))
ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables(
std::thread::hardware_concurrency());
std::vector<instruction_ref> index_to_ins;
index_to_ins.reserve(concur_ins.size());
std::transform(concur_ins.begin(),
concur_ins.end(),
std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; });
simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first);
// ensure there are enough elements for different threads
assert(tid < thread_conflict_tables.size());
auto& thrd_table = thread_conflict_tables.at(tid);
std::unordered_set<instruction_ref> checked_ins_set;
auto range_i = range(merge_second.begin(), std::prev(merge_second.end()));
for(auto it_i : iterator_for(range_i))
{
std::unordered_set<instruction_ref> ins1_set;
std::copy_if(it_i->begin(),
it_i->end(),
std::inserter(ins1_set, ins1_set.end()),
[&](auto i) { return not contains(checked_ins_set, i); });
checked_ins_set.insert(ins1_set.begin(), ins1_set.end());
auto range_j = range(std::next(it_i), merge_second.end());
std::unordered_set<instruction_ref> ins2_set;
for(auto it_j : iterator_for(range_j))
{
std::copy_if(it_j->begin(),
it_j->end(),
std::inserter(ins2_set, ins2_set.end()),
[&](auto i) { return not contains(checked_ins_set, i); });
}
for(auto ins1 : ins1_set)
{
auto p1 = ins2index.at(ins1);
for(auto ins2 : ins2_set)
{
if(ins1 == ins2)
continue;
auto p2 = ins2index.at(ins2);
if(p2 > p1)
thrd_table[ins2].insert(ins1);
else
thrd_table[ins1].insert(ins2);
}
}
}
});
// merge thread_conflict_tables together
for(auto& tbl : thread_conflict_tables)
{
for(auto& it : tbl)
{
conflict_table[it.first].insert(it.second.begin(), it.second.end());
}
}
// Remove instructions from the conflict table of an ealier instruction
for(auto&& ip : conflict_table)
{
auto ins1 = ip.first;
for(auto ins2 : ip.second)
if(contains(conflict_table[ins2], ins1))
conflict_table[ins2].erase(ins1);
}
return conflict_table;
}
};
void schedule::apply(module& m) const
{
if(not enable)
return;
stream_info si;
si.calc_implicit_deps(m);
auto last = std::prev(m.end());
si.accumulate_weights(last, model);
auto nstreams = si.assign_streams(m, model.concurrency());
si.sort(m, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{
m.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins))
return;
std::cout << ":";
std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={";
si.get_streams_from(ins, get_inputs())([&](auto s) {
std::cout << s << ",";
return true;
});
std::cout << "}";
if(si.has_stream(ins))
std::cout << " stream=" << si.get_stream(ins);
});
std::cout << std::endl;
}
// No concurrency
if(nstreams < 2)
return;
// Schedule instructions
std::size_t wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(m.size());
ins2waited.reserve(m.size());
for(auto ins : iterator_for(m))
{
// Only schedule instructions that have a stream
if(not si.has_stream(ins))
continue;
assert(si.weights[ins] > 0);
// Schedule instruction on the stream
auto stream = si.get_stream(ins);
assert(stream < model.concurrency());
model.sched(m, ins, stream);
// Insert wait instructions
if(si.is_merge_point(ins, stream))
{
for(auto i : si.get_recorded_instructions(ins))
{
if(not si.has_stream(i) or si.get_stream(i) == stream)
continue;
// Create a new event if it hasn't been recorded
if(not contains(ins2wait, i))
{
ins2wait[i] = wait_id;
model.record(m, i, wait_id);
wait_id++;
}
auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont
// insert another wait event
if(not contains(waited_for[stream], w))
model.wait(m, ins, w);
// Store the event as waited
waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction
waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end());
}
}
// Store wait events that have already been waited on
if(si.is_split_point(ins, stream))
{
ins2waited[ins] = waited_for[stream];
}
}
// Add memory conflicts
auto conflict_table = si.get_conflicts(m);
for(auto&& ip : conflict_table)
{
if(ip.second.empty())
continue;
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
m.insert_instruction(std::next(ip.first), make_op("identity"), args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,66 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class RawData>
void raw_data_to_value(value& v, const RawData& rd)
{
value result;
result["shape"] = migraphx::to_value(rd.get_shape());
if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects());
else if(not rd.empty())
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result;
}
void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); }
void migraphx_from_value(const value& v, literal& l)
{
auto s = migraphx::from_value<shape>(v.at("shape"));
l = literal(s, v.at("data").get_binary().data());
}
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
void migraphx_from_value(const value& v, argument& a)
{
if(v.contains("data"))
{
literal l = migraphx::from_value<literal>(v);
a = l.get_argument();
}
else if(v.contains("sub"))
{
a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,857 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
#include <unordered_map>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl
{
static std::shared_ptr<shape_impl> default_shape()
{
static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
return result;
}
shape_impl() : m_type(shape::float_type) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
{
assert(t != shape::tuple_type);
}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
assert(t != shape::tuple_type);
this->calculate_strides();
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size());
// Calculate standard shape flag for these lens/strides. Strides on size-1
// axes are ignored to support an MLIR rule.
std::vector<size_t> filtered_strides;
for(size_t ind = 0; ind < m_strides.size(); ind++)
if(m_lens[ind] != 1)
filtered_strides.push_back(m_strides[ind]);
m_standard = this->elements() == this->element_space() and not skips() and
std::is_sorted(filtered_strides.rbegin(), filtered_strides.rend());
}
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
: m_type(t), m_dyn_dims(std::move(dims))
{
}
shape_impl(shape::type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::set<std::size_t>> optimals_list)
: m_type(t)
{
if(optimals_list.empty())
{
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
}
}
else
{
assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
}
}
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type;
std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {};
bool m_standard = false;
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides()
{
m_strides.clear();
m_strides.resize(m_lens.size(), 0);
if(m_strides.empty())
return;
m_strides.back() = 1;
std::partial_sum(m_lens.rbegin(),
m_lens.rend() - 1,
m_strides.rbegin() + 1,
std::multiplies<std::size_t>());
}
std::size_t element_space() const
{
if(not m_dyn_dims.empty())
{
auto maxes = max_lens();
std::size_t max_val = std::numeric_limits<std::size_t>::max();
return std::accumulate(
maxes.begin(), maxes.end(), std::size_t{1}, [&](std::size_t x, std::size_t y) {
// overflow check and clip
if(x != 0 and y > max_val / x)
{
return max_val;
}
return x * y;
});
}
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::inner_product(m_lens.begin(),
m_lens.end(),
m_strides.begin(),
std::size_t{0},
std::plus<std::size_t>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
1;
}
std::size_t elements() const
{
if(not m_dyn_dims.empty())
{
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
}
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t get_index(size_t i) const
{
std::size_t result = 0;
std::size_t s = 1;
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
{
std::size_t stride = m_strides[k];
std::size_t len = m_lens[k];
std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
std::vector<std::size_t> min_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](const shape::dynamic_dimension& x) { return x.min; });
return ret;
}
std::vector<std::size_t> max_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](const shape::dynamic_dimension& x) { return x.max; });
return ret;
}
std::vector<std::set<std::size_t>> opt_lens() const
{
std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](const shape::dynamic_dimension& x) { return x.optimals; });
return ret;
}
// Does the shape skip over elements?
bool skips() const
{
assert(m_lens.size() == m_strides.size());
if(elements() == 1)
return false;
return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};
std::string shape::to_sizes_string(const std::vector<shape>& shapes)
{
std::vector<std::string> sizes;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(sizes), [&](const shape& s) {
std::string r = to_string_range(s.lens(), "x");
if(not s.standard())
r += ":" + to_string_range(s.strides(), "x");
return r;
});
return join_strings(sizes, ", ");
}
const std::vector<shape::type_t>& shape::types()
{
static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
return result;
}
std::string shape::name(shape::type_t t)
{
switch(t)
{
case tuple_type: return "tuple_type";
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
}
MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
switch(t)
{
case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
case x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
}
MIGRAPHX_THROW("Invalid type");
}
bool shape::is_integral(shape::type_t t)
{
bool result = false;
visit(t, [&](auto as) { result = as.is_integral(); });
return result;
}
bool shape::is_compatible(const shape& actual, const shape& expected)
{
// Check subshapes
if(expected.type() == shape::tuple_type)
return migraphx::equal(actual.sub_shapes(), expected.sub_shapes(), &is_compatible);
if(actual == expected)
return true;
if(actual.type() != expected.type())
return false;
// Only the expected can be dynamic
if(expected.dynamic())
return actual.ndim() == expected.ndim();
if(actual.dynamic())
return false;
if(actual.lens() != expected.lens())
return false;
// Check strides from dimensions that are not 1
return all_of(range(actual.lens().size()), [&](auto i) {
if(actual.lens()[i] == 1)
return true;
return actual.strides()[i] == expected.strides()[i];
});
}
bool shape::is_unsigned(shape::type_t t)
{
bool result = false;
visit(t, [&](auto as) { result = as.is_unsigned(); });
return result;
}
shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l)
: impl(std::make_shared<shape_impl>(t, std::move(l)))
{
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{
}
shape::shape(type_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
{
}
shape::shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::set<std::size_t>> optimals_list)
: impl(std::make_shared<shape_impl>(
t, std::move(mins), std::move(maxes), std::move(optimals_list)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
{
auto new_lens = reorder_dims(l, perm);
shape result = reorder_shape({t, new_lens}, invert_permutation(perm));
assert(result.lens() == l);
return result;
}
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
}
return impl->m_lens;
}
const std::vector<std::size_t>& shape::strides() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
}
return impl->m_strides;
}
std::size_t shape::ndim() const
{
if(this->dynamic())
{
return dyn_dims().size();
}
return lens().size();
}
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
{
if(this->sub_shapes().empty())
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
}
else
{
return std::accumulate(this->sub_shapes().begin(),
this->sub_shapes().end(),
std::size_t{0},
[&](auto x, auto y) { return x + y.bytes(); });
}
}
std::size_t shape::type_size() const
{
std::size_t n = 0;
if(this->sub_shapes().empty())
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(l.size() <= this->lens().size());
assert(this->lens().size() == this->strides().size());
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
}
std::size_t shape::index(std::size_t i) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(this->lens().size() == this->strides().size());
if(this->standard())
return i;
return impl->get_index(i);
}
std::vector<std::size_t> shape::multi(std::size_t idx) const
{
assert(idx < elements());
std::vector<std::size_t> indices(lens().size());
multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices;
}
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{
size_t tidx = idx;
(void)end;
assert(idx < elements());
assert(lens().size() <= (end - start));
for(size_t ii = lens().size() - 1; ii > 0; ii--)
{
*(start + ii) = tidx % lens()[ii];
tidx = tidx / lens()[ii];
}
*start = tidx;
}
bool shape::multi_within_bounds(std::vector<std::size_t> multi) const
{
assert(this->lens().size() == multi.size());
return std::equal(multi.begin(), multi.end(), this->lens().begin(), std::less<>{});
}
bool shape::packed() const
{
if(this->dynamic())
{
return false;
}
return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space();
}
bool shape::transposed() const
{
if(this->dynamic())
{
return false;
}
if(this->broadcasted())
{
// TODO: Use a filter_iterator instead
std::vector<std::size_t> s;
s.reserve(this->strides().size());
std::copy_if(this->strides().begin(),
this->strides().end(),
std::back_inserter(s),
[](std::size_t x) { return x != 0; });
return not std::is_sorted(s.rbegin(), s.rend());
}
else
{
return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
}
}
bool shape::broadcasted() const
{
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size());
return std::any_of(
this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
}
bool shape::scalar() const
{
if(this->dynamic())
{
return false;
}
assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false
return this->sub_shapes().empty() and
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
}
bool shape::standard() const { return impl->m_standard; }
shape shape::normalize_standard() const
{
if(this->standard())
return {this->type(), this->lens()};
else
return *this;
}
shape shape::as_standard() const
{
if(not this->dynamic())
return {this->type(), this->lens()};
else
return *this;
}
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
assert(l.size() == this->lens().size());
auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm);
}
shape shape::with_lens(const std::vector<std::size_t>& l) const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
}
return this->with_lens(this->type(), l);
}
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
shape shape::to_dynamic() const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_dynamic(); });
return shape(subs);
}
if(this->dynamic())
{
return *this;
}
return {type(), lens(), lens(), {}};
}
shape shape::to_static(std::size_t x) const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[&](auto s) { return s.to_static(x); });
return shape(subs);
}
if(not this->dynamic())
{
return *this;
}
auto static_lens = this->max_lens();
std::transform(static_lens.begin(),
static_lens.end(),
this->dyn_dims().cbegin(),
static_lens.begin(),
[&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
return {type(), static_lens};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
bool shape::any_of_dynamic() const
{
if(this->dynamic())
{
return true;
}
return std::any_of(this->sub_shapes().cbegin(), this->sub_shapes().cend(), [](auto s) {
return s.any_of_dynamic();
});
}
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
if(not this->dynamic())
{
MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
}
return impl->m_dyn_dims;
}
std::vector<std::size_t> shape::min_lens() const
{
return this->dynamic() ? impl->min_lens() : this->lens();
}
std::vector<std::size_t> shape::max_lens() const
{
return this->dynamic() ? impl->max_lens() : this->lens();
}
std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{
this->min += x;
this->max += x;
std::set<std::size_t> new_optimals;
std::transform(this->optimals.begin(),
this->optimals.end(),
std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) { return (opt + x); });
this->optimals = new_optimals;
return *this;
}
shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x)
{
assert(this->min >= x);
assert(this->max >= x);
this->min -= x;
this->max -= x;
std::set<std::size_t> new_optimals;
std::transform(this->optimals.begin(),
this->optimals.end(),
std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) {
assert(opt >= x);
return (opt - x);
});
this->optimals = new_optimals;
return *this;
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
// don't check optimals if both are fixed
return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return not(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
return os;
}
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
{
return x.min == y and x.max == y;
}
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y)
{
auto dd = x;
return dd += y;
}
shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y)
{
return y + x;
}
shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y)
{
auto dd = x;
return dd -= y;
}
bool operator==(const shape& x, const shape& y)
{
if(x.dynamic() and y.dynamic())
{
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
x.sub_shapes() == y.sub_shapes());
}
return x.impl == y.impl or
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return not(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
if(x.sub_shapes().empty())
{
if(x.dynamic())
{
os << "dynamic, ";
os << x.type_string() << ", ";
os << "{" << to_string_range(x.dyn_dims()) << "}";
}
else
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
}
}
else
{
os << "[" << to_string_range(x.sub_shapes()) << "]";
}
return os;
}
shape::type_t shape::parse_type(const std::string& s)
{
static const std::unordered_map<std::string, shape::type_t> m = {
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
tuple_type},
{"tuple", tuple_type}};
return m.at(s);
}
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
std::vector<shape> flatten(const std::vector<shape>& shapes)
{
std::vector<shape> result;
for(const auto& s : shapes)
{
if(s.type() == shape::tuple_type)
{
auto subs = flatten(s.sub_shapes());
result.insert(result.end(), subs.begin(), subs.end());
}
else
{
result.push_back(s);
}
}
return result;
}
void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
// avoid calling functions that will throw
if(s.dynamic())
{
result["lens"] = {};
result["strides"] = {};
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["dynamic_dimensions"] = {};
}
v = result;
}
void migraphx_from_value(const value& v, shape& s)
{
auto t = v.at("type").get_string();
if(t == "tuple_type")
{
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
}
else
{
if(v.at("dynamic_dimensions").empty())
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
else
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(
v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
return from_value<shape::dynamic_dimension>(x);
});
s = shape{shape::parse_type(t), dyn_dims};
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,722 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/onehot.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/resize.hpp>
#include <migraphx/common.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input
* into multibroadcast op with a static output shape attribute.
*
*/
struct find_broadcast_with_dims_static
{
auto matcher() const
{
return match::name("broadcast_with_dims")(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
// read the values of arg(1) to create input to multibroadcast
std::vector<size_t> sizes_vec;
inputs.at(1)->eval().visit(
[&](auto output) { sizes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0));
}
};
/**
* Convert a Resize op. with Nearest mode to an implementation using Gather op.
* From: resize[scales={...}/sizes={...},](static, constant)
* To:
* 0 = literal{ ... } computed_indices
* ...
* 2 = reshape[dims={45}](X) 1-dimensional
* 3 = gather[axis=0](2,0)
*
* At the time of writing, this conversion is required for GPU targets because there
* is not direct a GPU implementation of the Resize operation.
* This matcher depends on a split_single_dyn_dim pass being run before it, which
* will convert any dynamic-batch input to static inputs and make this conversion possible.
*
* At time of writing, Resize allows either 1 or 2 inputs
* but the 1-input case is never created by Onnx parsing.
*/
struct find_resize_static
{
auto matcher() const
{
return match::name("resize")(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto resize_op = any_cast<op::resize>(ins->get_operator());
auto in_lens = inputs.at(0)->get_shape().lens();
std::vector<size_t> sizes_vec(inputs.at(0)->get_shape().ndim());
std::vector<float> scales_vec(inputs.at(0)->get_shape().ndim());
// populate both scales and sizes for the benefit of the algorithm.
inputs.at(1)->eval().visit([&](auto input) {
using type = typename decltype(input)::value_type;
if constexpr(std::is_integral<type>{})
{
// read output sizes and use them to compute scales
sizes_vec.assign(input.begin(), input.end());
std::transform(
input.begin(),
input.end(),
in_lens.begin(),
scales_vec.begin(),
[](auto sz, size_t in_len) { return static_cast<float>(sz) / in_len; });
}
else
{
// read scales and use them to compute output sizes
scales_vec.assign(input.begin(), input.end());
std::transform(
input.begin(),
input.end(),
in_lens.begin(),
sizes_vec.begin(),
[](auto sz, size_t in_len) { return static_cast<size_t>(sz * in_len); });
}
});
auto in_s = inputs.at(0)->get_shape();
shape out_s{in_s.type(), sizes_vec};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
auto nearest_op = op::resize::get_nearest_op(resize_op.nearest_mode);
auto idx_op = op::resize::get_original_idx_op(resize_op.coordinate_transformation_mode);
shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < in_lens.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], sizes_vec[ii], out_idx_v[ii], scales_vec[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
auto reshape_op = make_op("reshape", {{"dims", rsp_lens}});
auto rsp = m.insert_instruction(ins, reshape_op, ins->inputs().at(0));
// Add our computed indices as a literal.
// ins_ind is a multi dimensional index that will restore original rank
shape ind_s{shape::int32_type, sizes_vec};
auto ins_ind = m.add_literal(literal(ind_s, ind));
m.replace_instruction(ins, make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::ends_axes)
{
// slice(data, starts)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::starts_axes)
{
// slice(data, ends)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
axes_vec = slice_op.axes;
}
else
{
// slice(data, axes)
inputs.at(1)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
ends_vec = slice_op.ends;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == op::slice::axes_only)
{
// slice(data, starts, ends)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_vec = slice_op.axes;
}
else if(set_attrs == op::slice::ends_only)
{
// slice(data, starts, axes)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
}
else
{
// slice(data, ends, axes)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
* Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(false);
argument ends_arg = inputs.at(2)->eval(false);
argument axes_arg = inputs.at(3)->eval(false);
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
{
auto matcher() const
{
auto const_alloc = match::arg(1)(match::name("allocate")(match::is_constant()));
return match::name("reshape")(match::nargs(2), const_alloc);
}
void apply(module& m, const match::matcher_result& mr) const
{
auto reshape_ins = mr.result;
auto reshape_inputs = reshape_ins->inputs();
auto alloc_ins = reshape_inputs.at(1);
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
std::vector<int64_t> output_dims_vec;
output_dims_arg.visit(
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
// have dead_code_elimination remove the previous allocate
}
};
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct find_const_alloc_fill
{
auto matcher() const
{
auto const_alloc = match::arg(1)(match::name("allocate")(match::is_constant()));
return match::name("fill")(match::arg(0)(match::is_constant()), const_alloc);
}
void apply(module& m, const match::matcher_result& mr) const
{
auto fill_ins = mr.result;
auto fill_arg = fill_ins->eval(false);
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
m.replace_instruction(fill_ins, l);
}
};
/**
* Simplify broadcast_for_dot instructions with two static shaped arguments
* From:
* broadcast_for_dot(static_shape_arg, static_shape_arg)
* To:
* multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape
*/
struct find_static_broadcast_for_dot
{
auto matcher() const
{
return match::name("broadcast_for_dot")(match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto broadcast_for_dot_ins = mr.result;
auto inputs = broadcast_for_dot_ins->inputs();
auto s0 = inputs.at(0)->get_shape();
auto s1 = inputs.at(1)->get_shape();
auto l0_it = s0.lens().end() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
m.replace_instruction(broadcast_for_dot_ins,
make_op("multibroadcast", {{"out_lens", output_lens}}),
inputs.at(0));
}
};
/**
* Simplify onehot instructions with static shape `indices` input and
* a compile-time constant `depth` attribute or input.
* From:
* onehot(static_shape_arg, constant_arg, values) or
* onehot(static_shape_arg, values)
* To:
* A = literal(shape = onehot_output_shape, value = 0)
* B = unsqueeze(literal(lens = indices_lens, strides = broadcasted scalar, value = 1),
* axis=onehot_axis) C = scatter(A, unsqueeze(indices, axis=onehot_axis), B) diff = on_value -
* off_value D = mul(diff, C); return = add(D, off_value);
*
* NOTE: It might be cleaner to use some form of `fill` instead of
* (on_value - off_value) * mask + off_value when we have `fill` working
* on the GPU.
*/
struct find_static_onehot
{
auto matcher() const
{
auto match_2_args = match::nargs(2)(match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
auto match_3_args = match::nargs(3)(match::arg(0)(match::static_shape()),
match::arg(1)(match::is_constant()),
match::arg(2)(match::static_shape()));
return match::name("onehot")(match::any_of(match_2_args, match_3_args));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto onehot_ins = mr.result;
auto onehot_inputs = onehot_ins->inputs();
auto onehot_op = any_cast<op::onehot>(onehot_ins->get_operator());
auto indices_ins = onehot_inputs[0];
shape indices_shape = indices_ins->get_shape();
std::size_t depth_val;
migraphx::instruction_ref values_ins;
if(onehot_op.depth.has_value())
{
assert(onehot_inputs.size() == 2);
depth_val = onehot_op.depth.value();
values_ins = onehot_inputs[1];
}
else
{
assert(onehot_inputs.size() == 3);
auto depth_ins = onehot_inputs[1];
depth_ins->eval().visit([&](auto d) { depth_val = d[0]; });
values_ins = onehot_inputs[2];
}
shape values_shape = values_ins->get_shape();
std::vector<std::size_t> static_output_lens = indices_shape.lens();
auto normalized_axis =
(onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis;
static_output_lens.insert(static_output_lens.begin() + normalized_axis, depth_val);
shape output_shape{values_shape.type(), static_output_lens};
std::vector<float> zeros(output_shape.elements(), 0);
auto zeros_lit = m.add_literal(literal(output_shape, zeros));
auto unsqueeze_inds = m.insert_instruction(
onehot_ins, migraphx::make_op("unsqueeze", {{"axes", {normalized_axis}}}), indices_ins);
// broadcast the one scalar to the correct shape
auto ones_lit = m.add_literal(literal(shape{values_shape.type(), {1}, {0}}, {1}));
auto mb_ones = m.insert_instruction(
onehot_ins,
migraphx::make_op("multibroadcast", {{"out_lens", unsqueeze_inds->get_shape().lens()}}),
ones_lit);
auto mask = m.insert_instruction(
onehot_ins,
make_op("scatter_none", {{"axis", normalized_axis}, {"skip_out_of_bounds", true}}),
zeros_lit,
unsqueeze_inds,
mb_ones);
auto off_val =
m.insert_instruction(onehot_ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}),
values_ins);
auto on_val =
m.insert_instruction(onehot_ins,
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}),
values_ins);
auto diff_val = m.insert_instruction(onehot_ins, make_op("sub"), on_val, off_val);
auto mul_diff_mask = insert_common_op(m, onehot_ins, make_op("mul"), {diff_val, mask});
auto mb_off_val = m.insert_instruction(
onehot_ins, make_op("multibroadcast", {{"out_lens", output_shape.lens()}}), off_val);
m.replace_instruction(onehot_ins, make_op("add"), mb_off_val, mul_diff_mask);
}
};
/**
* Go through `select_module` instructions and update the `output_dyn_shapes` attribute.
* Checks the submodule output shapes and determines an appropriate `output_dyn_shapes` attribute.
* This version ignores dynamic_dimension opt values.
* Intended to be run after the other simplify_dyn_ops passes.
*/
struct simplify_select_module_output_shape
{
auto matcher() const { return match::name("select_module"); }
void apply(module& m, const match::matcher_result& mr) const
{
auto sm_ins = mr.result;
auto sm_module_inputs = sm_ins->module_inputs();
std::vector<std::vector<shape>> all_output_shapes(sm_module_inputs.size());
std::transform(sm_module_inputs.begin(),
sm_module_inputs.end(),
all_output_shapes.begin(),
[](auto submod) { return submod->get_output_shapes(); });
// check that all of the submodules have the same number of outputs and all respective
// outputs have the same rank and type
auto shapes_ndim = get_shapes_ndim(all_output_shapes.front());
auto shapes_types = get_shapes_types(all_output_shapes.front());
if(std::any_of(
all_output_shapes.begin() + 1, all_output_shapes.end(), [&](auto out_shapes) {
bool same_types = get_shapes_types(out_shapes) == shapes_types;
bool same_ndim = get_shapes_ndim(out_shapes) == shapes_ndim;
return not same_types or not same_ndim;
}))
{
return;
}
auto num_out_shapes = shapes_ndim.size();
std::vector<shape> dyn_shapes(num_out_shapes);
auto num_submod = sm_module_inputs.size();
// compare respective output shapes from each submodule to get a range for the output shape
for(int i : range(num_out_shapes))
{
std::vector<shape> shapes_at_index(num_submod);
std::transform(all_output_shapes.begin(),
all_output_shapes.end(),
shapes_at_index.begin(),
[&](auto output_shapes) { return output_shapes.at(i); });
dyn_shapes.at(i) = dyn_shape_from_shapes(shapes_at_index);
}
auto tuple_shape = shape{dyn_shapes};
m.replace_instruction(
sm_ins,
make_op("select_module", {{"output_dyn_shapes", to_value(tuple_shape)}}),
sm_ins->inputs(),
sm_module_inputs);
}
std::vector<std::size_t> get_shapes_ndim(const std::vector<shape>& shapes) const
{
std::vector<std::size_t> ret(shapes.size());
std::transform(
shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.ndim(); });
return ret;
}
std::vector<shape::type_t> get_shapes_types(const std::vector<shape>& shapes) const
{
std::vector<shape::type_t> ret(shapes.size());
std::transform(
shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.type(); });
return ret;
}
/**
* Calculating an appropriate shape that encompasses all of the given vector of shapes.
* Equivalent to creating a 2D matrix of shape lengths and do a reduce over each axis.
* The shapes can be dynamic or static.
* Assuming all shapes have the same ndim.
*/
shape dyn_shape_from_shapes(std::vector<shape> shape_vec) const
{
// making 2D matrices of min_lens and max_lens
// specifically using uint64_t because we're going to put the values into a tensor_view
// later
std::vector<uint64_t> all_min_lens;
std::vector<uint64_t> all_max_lens;
for(const auto& s : shape_vec)
{
auto min_lens = s.min_lens();
auto max_lens = s.max_lens();
std::copy(min_lens.begin(), min_lens.end(), std::back_inserter(all_min_lens));
std::copy(max_lens.begin(), max_lens.end(), std::back_inserter(all_max_lens));
}
assert(all_min_lens.size() == shape_vec.size() * shape_vec.front().ndim());
assert(all_max_lens.size() == shape_vec.size() * shape_vec.front().ndim());
auto num_rows = shape_vec.size();
auto num_cols = shape_vec.front().ndim();
shape tensor_shape{shape::uint64_type, {num_rows, num_cols}};
auto min_lens_matrix = make_view(tensor_shape, all_min_lens.data());
auto max_lens_matrix = make_view(tensor_shape, all_max_lens.data());
std::vector<uint64_t> mins(num_cols);
std::vector<uint64_t> maxes(num_cols);
// rearranging data into column vectors to reduce over
// i = row, j = column
for(int j : range(num_cols))
{
std::vector<uint64_t> reduce_min_vals(num_rows);
std::vector<uint64_t> reduce_max_vals(num_rows);
for(int i : range(num_rows))
{
reduce_min_vals.at(i) = min_lens_matrix(i, j);
reduce_max_vals.at(i) = max_lens_matrix(i, j);
}
uint64_t max_int = std::numeric_limits<uint64_t>::max();
uint64_t min_val =
std::accumulate(reduce_min_vals.begin(),
reduce_min_vals.end(),
max_int,
[](uint64_t x, uint64_t y) { return x < y ? x : y; });
uint64_t max_val = std::accumulate(
reduce_max_vals.begin(), reduce_max_vals.end(), 0, [](uint64_t x, uint64_t y) {
return x > y ? x : y;
});
mins.at(j) = min_val;
maxes.at(j) = max_val;
}
// fixed output shape case
if(mins == maxes)
{
return shape{shape_vec.front().type(), mins};
}
// dynamic output shape case
return shape{shape_vec.front().type(), mins, maxes, {}};
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
find_broadcast_with_dims_static{},
find_resize_static{},
find_static_dimensions_of{},
find_const_alloc_reshapes{},
find_static_2in_broadcasts{},
find_const_2in_slice{},
find_const_3in_slice{},
find_const_4in_slice{},
find_const_alloc_fill{},
find_static_broadcast_for_dot{},
find_static_onehot{});
match::find_matches(m, simplify_select_module_output_shape{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,463 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/fp8_types.hpp>
#include <migraphx/match/dq_helpers.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}
struct match_find_quantizable_ops
{
static bool
is_valid_qparam(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
return qparam->get_shape().elements() == 1 or
qparam->get_shape().elements() == lens.at(axis);
}
static bool is_symmetric_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
return all_zeros;
}
static auto
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
if(qparam->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto propagate_quantized_ins(module& m,
const instruction_ref dqins,
const instruction_ref qop_arg,
bool is_fp16_model = false)
{
auto prev_ins = qop_arg;
std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != dqins)
{
ins_inbetween.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto qinp = dqins->inputs().front();
for(auto ins : reverse_iterator_for(ins_inbetween))
{
if((*ins)->name() == "convert" and is_fp16_model)
{
continue;
}
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
}
return qinp;
}
auto matcher() const
{
auto dq1 = match::arg(0)(
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
auto dq2 = match::arg(1)(
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
return match::name(get_quantizable_op_names())(dq1, dq2);
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = fp8_types{}.get();
supported_types.insert(migraphx::shape::int8_type);
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
bool is_fp16_model = false;
if(dq1->get_shape().type() != qop->get_shape().type() and
qop->get_shape().type() == migraphx::shape::half_type)
{
assert(dq1->get_shape().type() == migraphx::shape::float_type);
is_fp16_model = true;
}
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model);
auto arg1_lens = qop_args[0]->get_shape().lens();
auto arg2_lens = qop_args[1]->get_shape().lens();
instruction_ref dq;
instruction_ref out_scale;
instruction_ref out_zp;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Ensure input and weight quantization paramaters are of a proper form
// Input is of shape [n, c, x1, ..., xn]. Only scalar quantization allowed
// Weight is of shape [k, c, y1, ... , yn]. Valid quantization axis is k
if(not(scale1->get_shape().elements() == 1 and zp1->get_shape().elements() == 1 and
is_valid_qparam(scale2, arg2_lens, 0) and is_valid_qparam(zp2, arg2_lens, 0)))
return;
// This implementation supports affine quantization for both input and weight
// In practice, weight is quantized symmetrically
auto s1_bcast =
m.insert_instruction(qop, qparam_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, qparam_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
// Compute the zero-point terms; initialize as 0 and add relevant terms
auto zero_lit = m.add_literal(literal{shape{dq->get_shape().type()}, {0}});
out_zp = m.insert_instruction(
qop, make_op("multibroadcast", {{"out_lens", dq->get_shape().lens()}}), zero_lit);
auto inp_zp_bc = m.insert_instruction(qop, qparam_broadcast_op(zp1, arg1_lens, 1), zp1);
auto w_zp_bc = m.insert_instruction(qop, qparam_broadcast_op(zp2, arg2_lens, 0), zp2);
if(not is_symmetric_zero_point(zp1))
{
auto out_zp_1 = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), inp_zp_bc, qop_args[1]);
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_1);
}
if(not is_symmetric_zero_point(zp2))
{
auto out_zp_2 = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args[0], w_zp_bc);
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_2);
}
if(not is_symmetric_zero_point(zp1) and not is_symmetric_zero_point(zp2))
{
auto out_zp_3 = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), inp_zp_bc, w_zp_bc);
out_zp = m.insert_instruction(qop, migraphx::make_op("sub"), out_zp, out_zp_3);
}
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, valid quantization axes are M for input1 and K for
// input 2
if(not(is_valid_qparam(scale1, out_lens, out_lens.size() - 2) and
is_valid_qparam(zp1, out_lens, out_lens.size() - 2) and
is_valid_qparam(scale2, out_lens, out_lens.size() - 1) and
is_valid_qparam(zp2, out_lens, out_lens.size() - 1)))
{
return;
}
// This implementation supports both arguments being per-axis affine quantized
// In practice, inputs are per-tensor affine and weights are per-axis symmetric
auto s1_bcast = m.insert_instruction(
qop, qparam_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, qparam_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
// Compute the zero-point terms; initialize as 0 and add relevant terms
auto zero_lit = m.add_literal(literal{shape{dq->get_shape().type()}, {0}});
out_zp = m.insert_instruction(
qop, make_op("multibroadcast", {{"out_lens", dq->get_shape().lens()}}), zero_lit);
auto zp1_bc = m.insert_instruction(
qop, qparam_broadcast_op(zp1, arg1_lens, arg1_lens.size() - 2), zp1);
auto zp2_bc = m.insert_instruction(
qop, qparam_broadcast_op(zp2, arg2_lens, arg2_lens.size() - 1), zp2);
if(not is_symmetric_zero_point(zp1))
{
auto out_zp_1 =
m.insert_instruction(qop, migraphx::make_op("quant_dot"), zp1_bc, qop_args[1]);
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_1);
}
if(not is_symmetric_zero_point(zp2))
{
auto out_zp_2 =
m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args[0], zp2_bc);
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_2);
}
if(not is_symmetric_zero_point(zp1) and not is_symmetric_zero_point(zp2))
{
auto out_zp_3 =
m.insert_instruction(qop, migraphx::make_op("quant_dot"), zp1_bc, zp2_bc);
out_zp = m.insert_instruction(qop, migraphx::make_op("sub"), out_zp, out_zp_3);
}
}
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale, out_zp);
if(is_fp16_model)
{
dq = m.insert_instruction(
qop, make_op("convert", {{"target_type", migraphx::shape::half_type}}), dq);
}
m.replace_instruction(qop, dq);
}
};
bool compare_literals(instruction_ref ins1, instruction_ref ins2)
{
if(ins1->name() == "broadcast" or ins1->name() == "multibroadcast")
ins1 = ins1->inputs().front();
auto x = ins1->eval();
if(x.empty())
return false;
auto literal1 = ins1->get_literal();
if(ins2->name() == "broadcast" or ins2->name() == "multibroadcast")
ins2 = ins2->inputs().front();
auto y = ins2->eval();
if(y.empty())
return false;
auto literal2 = ins2->get_literal();
bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals =
std::all_of(l1.begin() + 1,
l1.end(),
[&](auto v) {
return ((float_equal(v, l1.front())) or
(std::isinf(static_cast<double>(l1.front())) and
std::isinf(static_cast<double>(v))));
}) and
std::all_of(l2.begin(), l2.end(), [&](auto v) {
return ((float_equal(v, l1.front())) or
(std::isinf(static_cast<double>(l1.front())) and
std::isinf(static_cast<double>(v))));
});
});
return (x == y) or diff_shapes_equal_vals;
}
template <class Iterator>
bool precedes(Iterator x, Iterator y, Iterator last)
{
auto r = range(std::next(x), last);
return any_of(iterator_for(r), [&](auto it) { return it == y; });
}
struct match_qlinear_reused
{
auto matcher() const
{
return match::name("quantizelinear")(
match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
assert(ins != x_ins);
auto dq_inputs = ins->inputs();
dq_inputs[0] = ins;
auto outputs = x_ins->outputs();
if(outputs.size() != 2)
return;
for(auto output : outputs)
{
if(output->name() == "quantizelinear")
continue;
if(not output->get_operator().attributes().contains("pointwise"))
continue;
if(not precedes(ins, output, m.end()))
continue;
auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"), dq_inputs);
instruction::replace_argument(output, x_ins, dq);
}
}
};
bool is_same_value(instruction_ref a, instruction_ref b)
{
if(a == b)
return true;
return compare_literals(a, b);
}
bool is_same_scale_zero(instruction_ref a, instruction_ref b)
{
if(a->inputs().size() != b->inputs().size())
return false;
if(not is_same_value(a->inputs().at(1), b->inputs().at(1)))
return false;
if(a->inputs().size() == 2)
return true;
return is_same_value(a->inputs().at(2), b->inputs().at(2));
}
// When an unpack instruction is inserted, its original input must be an int4/uint4.
// Therefore check for an unpack_int4 operator -- while ignoring out shape related ops.
bool is_any_input_int4(instruction_ref a)
{
static std::set<std::string> ign = {"unsqueeze",
"broadcast",
"multibroadcast",
"contiguous",
"transpose",
"reshape",
"convert"};
return std::any_of(a->inputs().begin(), a->inputs().end(), [](auto i) {
while(ign.find(i->name()) != ign.end())
i = i->inputs()[0];
return i->name() == "unpack_int4";
});
}
void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
{
auto args = ins->inputs();
for(auto&& arg : args)
{
if(arg->name() == "dequantizelinear")
{
auto q = arg->inputs().front();
if((q->name() == "quantizelinear") and is_same_scale_zero(arg, q))
{
instruction::replace_argument(ins, arg, q->inputs().front());
}
}
}
}
}
void remove_zero_point(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "dequantizelinear")
continue;
if(ins->inputs().size() != 3)
continue;
auto zp = ins->inputs().at(2);
if(not zp->can_eval())
continue;
auto a = zp->eval();
bool is_zero = false;
a.visit([&](auto t) {
is_zero = std::all_of(t.begin(), t.end(), [](auto x) { return float_equal(x, 0); });
});
if(not is_zero)
continue;
m.replace_instruction(ins, ins->get_operator(), ins->inputs().at(0), ins->inputs().at(1));
}
}
void add_int4_pack_unpack_pair(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "dequantizelinear")
continue;
for(auto&& inp : ins->inputs())
{
if((inp->name() == "quantizelinear") and is_any_input_int4(inp))
{
auto pk = m.insert_instruction(ins, make_op("pack_int4"), inp);
auto unpk = m.insert_instruction(ins, make_op("unpack_int4"), pk);
instruction::replace_argument(ins, inp, unpk);
}
}
}
}
} // namespace
void simplify_qdq::apply(module& m) const
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,243 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#include <migraphx/split_reduce.hpp>
#include <migraphx/dom_info.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/liveness.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/param_utils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct split_fused_reduce
{
std::vector<std::int64_t> axes{};
std::string assign = "assign_none";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"), f(self.assign, "assign"));
}
value attributes() const { return {{"prefill", 0}}; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
const auto* sm = mods.front();
auto names = sm->get_parameter_names();
check_shapes{inputs, *this}.has(names.size()).same_ndims();
auto result =
sm->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true});
if(result.size() == 1)
return result.front();
return shape{result};
}
std::string name() const { return "split_fused_reduce"; }
};
MIGRAPHX_REGISTER_OP(split_fused_reduce);
static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); }
namespace {
struct splitter
{
const_module_ref rm;
bool strictly_dominate(instruction_ref a, instruction_ref b)
{
if(not dom.has_value())
dom = compute_dominator(*rm);
return dom->strictly_dominate(a, b);
}
std::vector<instruction_ref> find_splits() const
{
std::vector<instruction_ref> result;
copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins) {
return is_reduce(*ins);
});
if(result.size() > 2)
return {};
// Only handle reduce_sum for now
// TODO: Support other reduction types
if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) {
return ins->name() == "reduce_sum";
}))
return {};
if(result.size() < 2)
return result;
if(reaches(result[0], result[1]))
return {};
return result;
}
std::vector<instruction_ref> find_alive(const std::vector<instruction_ref>& splits)
{
std::vector<instruction_ref> result;
bool stop = false;
liveness(*rm, [&](auto rins, const auto& live_set) {
if(stop)
return;
if(rins == rm->begin())
return;
// We want to know what instructions are live after the split instruction
auto ins = instruction::get_output_alias(std::prev(rins));
if(not contains(splits, ins))
return;
std::copy_if(live_set.begin(),
live_set.end(),
std::back_inserter(result),
[&](instruction_ref live) {
if(live->name() == "@param")
return false;
if(contains(splits, live))
return false;
if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) {
return this->strictly_dominate(live, split);
}))
return false;
return true;
});
stop = true;
});
return result;
}
std::optional<dominator_info> dom = std::nullopt;
};
} // namespace
static std::string assign_op(const std::vector<instruction_ref>& splits)
{
static std::unordered_map<std::string, std::string> m = {
{"reduce_sum", "assign_add"},
{"reduce_mean", "assign_add"},
{"reduce_prod", "assign_mul"},
{"reduce_max", "assign_max"},
{"reduce_min", "assign_min"},
};
return m.at(splits.front()->name());
}
static std::vector<instruction_ref>
insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& mwi)
{
auto param_map = mwi.mod.get_ins_param_map(mwi.inputs, true);
return m.insert_instructions(ins, &mwi.mod, &param_map);
}
static std::size_t get_reduce_size(const_module_ref rm)
{
auto ins = std::find_if(rm->begin(), rm->end(), &is_reduce);
assert(ins != rm->end());
return ins->inputs().front()->get_shape().elements() / ins->get_shape().elements();
}
void split_reduce::apply(module_pass_manager& mpm) const
{
for(auto ins : iterator_for(mpm.get_module()))
{
if(ins->name() != "fused_reduce")
continue;
auto* rm = ins->module_inputs().front();
if(get_reduce_size(rm) < split_size)
continue;
splitter s{rm};
auto splits = s.find_splits();
if(splits.empty())
continue;
// Only use split reduce with float for now
// TODO: Support other data types
if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) {
return contains({shape::float_type, shape::half_type}, split->get_shape().type());
}))
continue;
auto v = ins->get_operator().to_value();
auto axes = v["axes"].to_vector<std::int64_t>();
auto alive = s.find_alive(splits);
std::array<module::with_inputs, 2> mods;
if(not alive.empty())
{
auto mods3 = rm->split(ins->inputs(), alive, splits);
auto r = insert_module_inline(mpm.get_module(), ins, mods3[0]);
mods3[1].replace(alive, r);
mods3[2].replace(alive, r);
mods = {std::move(mods3[1]), std::move(mods3[2])};
}
else
{
mods = rm->split(ins->inputs(), splits);
}
auto* splitm = mpm.create_module(rm->name() + "_split", std::move(mods[0].mod));
splitm->set_bypass();
// Insert split reduce
auto split_reduce = mpm.get_module().insert_instruction(
ins,
make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}),
mods[0].inputs,
{splitm});
std::vector<instruction_ref> split_reduce_each;
if(splits.size() == 1)
{
split_reduce_each = {split_reduce};
}
else
{
transform(range(splits.size()), std::back_inserter(split_reduce_each), [&](auto i) {
return mpm.get_module().insert_instruction(
ins, make_op("get_tuple_elem", {{"index", i}}), split_reduce);
});
}
mods[1].replace(splits, split_reduce_each);
auto replaced = insert_module_inline(mpm.get_module(), ins, mods[1]);
assert(replaced.size() == 1);
mpm.get_module().replace_instruction(ins, replaced.front());
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,173 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
shape::dynamic_dimension dd;
};
/**
* Returns value if the parameters contain non-fixed dynamic_dimensions that are the same between
* all of the dynamic shape parameters.
* In other words, each parameter can have one non-fixed dynamic_dimension `x` where `x` is the same
* between all of the parameters with a non-fixed dynamic_dimension.
* Returns the parameters and the dynamic dimension in a vector of dynamic_dimensions_check objects.
*/
optional<std::vector<dynamic_dimensions_check>>
has_one_unique_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
std::vector<std::decay_t<decltype(param_shapes)>::value_type> dyn_params{};
std::copy_if(
param_shapes.begin(), param_shapes.end(), std::back_inserter(dyn_params), is_dynamic);
if(dyn_params.empty())
return std::nullopt;
std::vector<dynamic_dimensions_check> ret{};
// get non-fixed dynamic_dimension from all parameters
for(const auto& param : dyn_params)
{
const auto& dds = param.second.dyn_dims();
auto num_non_fixed = std::count_if(dds.cbegin(), dds.cend(), [&](auto dd) {
if(not dd.is_fixed())
{
ret.push_back(dynamic_dimensions_check{param.first, dd});
return true;
}
return false;
});
// catch more than one non-fixed dynamic_dimension
if(num_non_fixed > 1)
{
return std::nullopt;
}
}
if(ret.empty())
{
return std::nullopt;
}
// check all the same dynamic_dimension
bool same_dd =
std::all_of(ret.begin() + 1, ret.end(), [&](auto ddc) { return ddc.dd == ret.at(0).dd; });
if(same_dd)
{
return ret;
}
return std::nullopt;
}
/**
* Check the parameters in std::vector<dynamic_dimensions_check> object to see if any of the
* parameters outputs to a select_module operator.
*/
bool any_sm_next(const_module_ref mm, const std::vector<dynamic_dimensions_check>& ddcs)
{
for(const auto& ddc : ddcs)
{
auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs();
bool is_sm_next = std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
if(is_sm_next)
{
return true;
};
}
return false;
}
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<std::vector<dynamic_dimensions_check>> dd_check_vec =
has_one_unique_dyn_dim(param_shapes);
if(dd_check_vec.has_value() and not any_sm_next(mm, dd_check_vec.value()))
{
// all dynamic dimension objects should be the same for all parameters in dd_check_vec
auto dyn_dim = dd_check_vec->at(0).dd;
// create submodules for each dimension size
std::vector<module_ref> submodules;
for(size_t dim_size : migraphx::range(dyn_dim.min, dyn_dim.max + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
for(const auto& dd_check : dd_check_vec.value())
{
// create static shape using dim_size
const auto& dyn_param = mm->get_parameter(dd_check.dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check.dyn_param_str);
auto static_shape = dyn_param_shape.to_static(dim_size);
map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape);
}
auto outputs = submod->add_instructions(mm, &map_ins);
submod->add_return({outputs});
submodules.push_back(submod);
}
// sort parameters by name for consistency (vs. parameter order attr)
std::sort(param_names.begin(), param_names.end());
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
auto output_shapes = mm->get_output_shapes();
migraphx::shape out_attr = migraphx::shape{output_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
std::vector<instruction_ref> outputs(output_shapes.size());
for(size_t i = 0; i < output_shapes.size(); ++i)
{
outputs.at(i) =
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", i}}), sm_ins);
}
mm->replace_return(outputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,112 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/sqlite.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/errors.hpp>
#include <sqlite3.h>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using sqlite3_ptr = MIGRAPHX_MANAGE_PTR(sqlite3*, sqlite3_close);
struct sqlite_impl
{
sqlite3* get() const { return ptr.get(); }
void open(const fs::path& p, int flags)
{
sqlite3* ptr_tmp = nullptr;
int rc = sqlite3_open_v2(p.string().c_str(), &ptr_tmp, flags, nullptr);
ptr = sqlite3_ptr{ptr_tmp};
if(rc != 0)
MIGRAPHX_THROW("error opening " + p.string() + ": " + error_message());
}
template <class F>
void exec(const char* sql, F f)
{
// cppcheck-suppress constParameterPointer
auto callback = [](void* obj, auto... xs) -> int {
try
{
const auto* g = static_cast<const F*>(obj);
(*g)(xs...);
return 0;
}
catch(...)
{
return -1;
}
};
int rc = sqlite3_exec(get(), sql, callback, &f, nullptr);
if(rc != 0)
MIGRAPHX_THROW(error_message());
}
std::string error_message() const
{
std::string msg = "sqlite3: ";
return msg + sqlite3_errmsg(get());
}
sqlite3_ptr ptr;
};
sqlite sqlite::read(const fs::path& p)
{
sqlite r;
r.impl = std::make_shared<sqlite_impl>();
r.impl->open(p, SQLITE_OPEN_READONLY);
return r;
}
sqlite sqlite::write(const fs::path& p)
{
sqlite r;
r.impl = std::make_shared<sqlite_impl>();
// Using '+' instead of bitwise '|' to avoid compilation warning
r.impl->open(p, SQLITE_OPEN_READWRITE + SQLITE_OPEN_CREATE);
return r;
}
std::vector<std::unordered_map<std::string, std::string>> sqlite::execute(const std::string& s)
{
std::vector<std::unordered_map<std::string, std::string>> result;
impl->exec(s.c_str(), [&](int n, char** texts, char** names) {
std::unordered_map<std::string, std::string> row;
row.reserve(n);
std::transform(
names,
names + n,
texts,
std::inserter(row, row.begin()),
[&](const char* name, const char* text) { return std::make_pair(name, text); });
result.push_back(row);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,37 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,108 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/tmp_dir.hpp>
#include <migraphx/env.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <random>
#include <thread>
#include <sstream>
#include <iostream>
#include <string>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <sys/types.h>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_SAVE_TEMP_DIR)
std::string random_string(std::string::size_type length)
{
static const std::string& chars = "0123456789"
"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
std::mt19937 rg{std::random_device{}()};
std::uniform_int_distribution<std::string::size_type> pick(0, chars.length() - 1);
std::string str(length, 0);
std::generate(str.begin(), str.end(), [&] { return chars[pick(rg)]; });
return str;
}
std::string unique_string(const std::string& prefix)
{
auto pid = getpid();
auto tid = std::this_thread::get_id();
auto clk = std::chrono::steady_clock::now().time_since_epoch().count();
std::stringstream ss;
ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16);
return ss.str();
}
tmp_dir::tmp_dir(std::string_view prefix)
: path(fs::temp_directory_path() /
unique_string(prefix.empty() ? "migraphx" : "migraphx-" + std::string{prefix}))
{
fs::create_directories(this->path);
}
void tmp_dir::execute(std::string_view cmd, const std::vector<std::string>& args) const
{
process{cmd, args}.cwd(this->path).exec();
}
tmp_dir::~tmp_dir()
{
if(not enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{}))
{
constexpr int max_retries_count = 5;
for([[maybe_unused]] auto count : range(max_retries_count))
{
std::error_code ec;
fs::remove_all(path, ec);
if(not ec)
break;
std::cerr << "Failed to remove " << path << ": " << ec.message() << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(125));
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,103 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/truncate_float.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void
quantize_module(module& m, const std::vector<std::string>& ins_names, shape::type_t float_type)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert each of the inputs that are floating point to float type
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", float_type}}), input);
});
// Insert quantized ins
auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// tuple can't be directly converted, get_tuple_elem needs conversion
if(ins->get_shape().type() == shape::tuple_type)
{
auto outputs = ins->outputs();
std::transform(
outputs.begin(), outputs.end(), outputs.begin(), [&](const auto gte_ins) {
auto gte_ins_float_type =
m.insert_instruction(ins, gte_ins->get_operator(), converted_ins);
// Convert back to output type after quantizing
auto gte_converted = m.insert_instruction(
ins,
make_op("convert", {{"target_type", gte_ins->get_shape().type()}}),
gte_ins_float_type);
// Replace output instruction
return m.replace_instruction(gte_ins, gte_converted);
});
}
else
{
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
}
}
}
void truncate_float_pass::apply(module& m) const { quantize_module(m, ins_names, float_type); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,573 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cassert>
#include <iostream>
#include <migraphx/cloneable.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/hash.hpp>
#include <unordered_map>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct value_base_impl : cloneable<value_base_impl>
{
virtual value::type_t get_type() { return value::null_type; }
#define MIGRAPHX_VALUE_GENERATE_BASE_FUNCTIONS(vt, cpp_type) \
virtual const cpp_type* if_##vt() const { return nullptr; }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_BASE_FUNCTIONS)
virtual std::vector<value>* if_array() { return nullptr; }
virtual std::unordered_map<std::string, std::size_t>* if_object() { return nullptr; }
virtual value_base_impl* if_value() const { return nullptr; }
value_base_impl() = default;
value_base_impl(const value_base_impl&) = default;
value_base_impl& operator=(const value_base_impl&) = default;
virtual ~value_base_impl() override {}
};
#define MIGRAPHX_VALUE_GENERATE_BASE_TYPE(vt, cpp_type) \
struct vt##_value_holder : value_base_impl::share \
{ \
vt##_value_holder(cpp_type d) : data(std::move(d)) {} \
virtual value::type_t get_type() override { return value::vt##_type; } \
virtual const cpp_type* if_##vt() const override { return &data; } \
cpp_type data; \
};
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_BASE_TYPE)
struct array_value_holder : value_base_impl::derive<array_value_holder>
{
array_value_holder() {}
array_value_holder(std::vector<value> d) : data(std::move(d)) {}
virtual value::type_t get_type() override { return value::array_type; }
virtual std::vector<value>* if_array() override { return &data; }
std::vector<value> data;
};
struct object_value_holder : value_base_impl::derive<object_value_holder>
{
object_value_holder() {}
object_value_holder(std::vector<value> d, std::unordered_map<std::string, std::size_t> l)
: data(std::move(d)), lookup(std::move(l))
{
}
virtual value::type_t get_type() override { return value::object_type; }
virtual std::vector<value>* if_array() override { return &data; }
virtual std::unordered_map<std::string, std::size_t>* if_object() override { return &lookup; }
std::vector<value> data;
std::unordered_map<std::string, std::size_t> lookup;
};
value::value(const value& rhs) : x(rhs.x ? rhs.x->clone() : nullptr), key(rhs.key) {}
value& value::operator=(value rhs)
{
std::swap(rhs.x, x);
if(not rhs.key.empty())
std::swap(rhs.key, key);
return *this;
}
void set_vector(std::shared_ptr<value_base_impl>& x,
const std::vector<value>& v,
bool array_on_empty = true)
{
if(v.empty())
{
if(array_on_empty)
x = std::make_shared<array_value_holder>();
else
x = std::make_shared<object_value_holder>();
return;
}
if(v.front().get_key().empty())
{
x = std::make_shared<array_value_holder>(v);
}
else
{
std::unordered_map<std::string, std::size_t> lookup;
std::size_t i = 0;
for(auto&& e : v)
{
lookup[e.get_key()] = i;
i++;
}
x = std::make_shared<object_value_holder>(v, lookup);
}
}
value::value(const std::initializer_list<value>& i) : x(nullptr)
{
if(i.size() == 2 and i.begin()->is_string() and i.begin()->get_key().empty())
{
key = i.begin()->get_string();
auto r = (i.begin() + 1)->x;
x = r ? r->clone() : nullptr;
return;
}
set_vector(x, std::vector<value>(i.begin(), i.end()));
}
value::value(const std::vector<value>& v, bool array_on_empty) : x(nullptr)
{
set_vector(x, v, array_on_empty);
}
value::value(const std::unordered_map<std::string, value>& m)
: value(std::vector<value>(m.begin(), m.end()), false)
{
}
value::value(const std::string& pkey, const std::vector<value>& v, bool array_on_empty)
: x(nullptr), key(pkey)
{
set_vector(x, v, array_on_empty);
}
value::value(const std::string& pkey, const std::unordered_map<std::string, value>& m)
: value(pkey, std::vector<value>(m.begin(), m.end()), false)
{
}
value::value(const std::string& pkey, std::nullptr_t) : x(nullptr), key(pkey) {}
value::value(std::nullptr_t) : x(nullptr) {}
value::value(const std::string& pkey, const value& rhs)
: x(rhs.x ? rhs.x->clone() : nullptr), key(pkey)
{
}
value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {}
value::value(const char* i) : value(std::string(i)) {}
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
value::value(cpp_type i) : x(std::make_shared<vt##_value_holder>(std::move(i))) {} \
value::value(const std::string& pkey, cpp_type i) \
: x(std::make_shared<vt##_value_holder>(std::move(i))), key(pkey) \
{ \
} \
value& value::operator=(cpp_type rhs) \
{ \
x = std::make_shared<vt##_value_holder>(std::move(rhs)); \
return *this; \
} \
bool value::is_##vt() const { return x ? x->get_type() == vt##_type : false; } \
const cpp_type& value::get_##vt() const \
{ \
auto* r = this->if_##vt(); \
assert(r); \
return *r; \
} \
const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
value& value::operator=(const char* c)
{
*this = std::string{c};
return *this;
}
value& value::operator=(std::nullptr_t)
{
x = nullptr;
return *this;
}
value& value::operator=(const std::initializer_list<value>& i)
{
value rhs = i;
std::swap(rhs.x, x);
return *this;
}
bool value::is_array() const { return x ? x->get_type() == array_type : false; }
const std::vector<value>& value::value::get_array() const
{
const auto* r = this->if_array();
assert(r);
return *r;
}
const std::vector<value>* value::if_array() const { return x ? x->if_array() : nullptr; }
bool value::is_object() const { return x ? x->get_type() == object_type : false; }
const std::vector<value>& value::get_object() const
{
const auto* r = this->if_object();
assert(r);
return *r;
}
const std::vector<value>* value::if_object() const
{
auto* r = x ? x->if_array() : nullptr;
assert(r == nullptr or
std::none_of(r->begin(), r->end(), [](auto&& v) { return v.get_key().empty(); }));
return r;
}
bool value::is_null() const { return x == nullptr; }
const std::string& value::get_key() const { return key; }
std::vector<value>* if_array_impl(const std::shared_ptr<value_base_impl>& x)
{
if(x == nullptr)
return nullptr;
return x->if_array();
}
std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x)
{
auto* a = if_array_impl(x);
assert(a);
return *a;
}
std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
{
auto* a = if_array_impl(x);
if(a == nullptr)
MIGRAPHX_THROW("Expected an array or object");
return *a;
}
template <class T>
T* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key, T* end)
{
auto* a = if_array_impl(x);
if(a == nullptr)
return end;
auto* lookup = x->if_object();
if(lookup == nullptr)
return end;
auto it = lookup->find(key);
if(it == lookup->end())
return end;
return std::addressof((*a)[it->second]);
}
value* value::find(const std::string& pkey) { return find_impl(x, pkey, this->end()); }
const value* value::find(const std::string& pkey) const { return find_impl(x, pkey, this->end()); }
bool value::contains(const std::string& pkey) const
{
const auto* it = find(pkey);
if(it == nullptr)
return false;
if(it == end())
return false;
return true;
}
std::size_t value::size() const
{
const auto* a = if_array_impl(x);
if(a == nullptr)
return 0;
return a->size();
}
bool value::empty() const { return size() == 0; }
const value* value::data() const
{
auto* a = if_array_impl(x);
if(a == nullptr)
return nullptr;
return a->data();
}
value* value::data()
{
auto* a = if_array_impl(x);
if(a == nullptr)
return nullptr;
return a->data();
}
value* value::begin()
{
// cppcheck-suppress assertWithSideEffect
assert(data() or empty());
return data();
}
const value* value::begin() const
{
assert(data() or empty());
return data();
}
value* value::end() { return begin() + size(); }
const value* value::end() const { return begin() + size(); }
value& value::front()
{
assert(this->size() > 0);
return *begin();
}
const value& value::front() const
{
assert(this->size() > 0);
return *begin();
}
value& value::back()
{
assert(this->size() > 0);
return *std::prev(end());
}
const value& value::back() const
{
assert(this->size() > 0);
return *std::prev(end());
}
value& value::at(std::size_t i)
{
auto* a = if_array_impl(x);
if(a == nullptr)
MIGRAPHX_THROW("Not an array");
return a->at(i);
}
const value& value::at(std::size_t i) const
{
auto* a = if_array_impl(x);
if(a == nullptr)
MIGRAPHX_THROW("Not an array");
return a->at(i);
}
value& value::at(const std::string& pkey)
{
auto* r = find(pkey);
if(r == nullptr)
MIGRAPHX_THROW("Not an object");
if(r == end())
MIGRAPHX_THROW("Key not found: " + pkey);
return *r;
}
const value& value::at(const std::string& pkey) const
{
const auto* r = find(pkey);
if(r == nullptr)
MIGRAPHX_THROW("Not an object for field: " + pkey);
if(r == end())
MIGRAPHX_THROW("Key not found: " + pkey);
return *r;
}
value& value::operator[](std::size_t i)
{
assert(i < this->size());
return *(begin() + i);
}
const value& value::operator[](std::size_t i) const
{
assert(i < this->size());
return *(begin() + i);
}
value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; }
void value::clear() { get_array_throw(x).clear(); }
void value::resize(std::size_t n)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n);
}
void value::resize(std::size_t n, const value& v)
{
if(not is_array())
MIGRAPHX_THROW("Expected an array.");
get_array_impl(x).resize(n, v);
}
std::pair<value*, bool> value::insert(const value& v)
{
if(v.key.empty())
{
if(not x)
x = std::make_shared<array_value_holder>();
get_array_impl(x).push_back(v);
assert(this->if_array());
return std::make_pair(&back(), true);
}
else
{
if(not x)
x = std::make_shared<object_value_holder>();
auto p = x->if_object()->emplace(v.key, get_array_impl(x).size());
if(p.second)
get_array_impl(x).push_back(v);
assert(this->if_object());
return std::make_pair(&get_array_impl(x)[p.first->second], p.second);
}
}
value* value::insert(const value* pos, const value& v)
{
assert(v.key.empty());
if(not x)
x = std::make_shared<array_value_holder>();
auto&& a = get_array_impl(x);
auto it = a.insert(a.begin() + (pos - begin()), v);
return std::addressof(*it);
}
value value::without_key() const
{
value result = *this;
result.key = "";
return result;
}
value value::with_key(const std::string& pkey) const
{
value result = *this;
result.key = pkey;
return result;
}
template <class T>
const T& compare_decay(const T& x)
{
return x;
}
int compare_decay(std::nullptr_t) { return 0; }
template <class F>
bool compare(const value& x, const value& y, F f)
{
bool result = false;
x.visit_value([&](auto&& a) {
y.visit_value([&](auto&& b) {
if constexpr(std::is_same<decltype(a), decltype(b)>{})
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
std::forward_as_tuple(y.get_key(), compare_decay(b)));
else
assert(false); // NOLINT
});
});
return result;
}
value::type_t value::get_type() const
{
if(not x)
return null_type;
return x->get_type();
}
bool operator==(const value& x, const value& y)
{
if(x.get_type() != y.get_type())
return false;
return compare(x, y, std::equal_to<>{});
}
bool operator!=(const value& x, const value& y) { return not(x == y); }
bool operator<(const value& x, const value& y)
{
if(x.get_type() != y.get_type())
return x.get_type() < y.get_type();
return compare(x, y, std::less<>{});
}
bool operator<=(const value& x, const value& y) { return not(x > y); }
bool operator>(const value& x, const value& y) { return y < x; }
bool operator>=(const value& x, const value& y) { return not(x < y); }
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
template <class T>
void print_value(std::ostream& os, const T& x)
{
os << x;
}
template <class T, class U>
void print_value(std::ostream& os, const std::pair<T, U>& x)
{
os << x.first;
os << ": ";
print_value(os, x.second);
}
void print_value(std::ostream& os, const std::vector<value>& x)
{
os << "{";
os << to_string_range(x);
os << "}";
}
void print_value(std::ostream& os, const value::binary& x) { os << x; }
std::ostream& operator<<(std::ostream& os, const value& d)
{
d.visit([&](auto&& y) { print_value(os, y); });
return os;
}
template <class T>
std::size_t value_hash(const std::string& key, const T& x)
{
std::size_t h = hash_value(key);
hash_combine(h, x);
return h;
}
std::size_t value_hash(const std::string& key, std::nullptr_t) { return hash_value(key); }
std::size_t value_hash(const std::string& key, const std::vector<value>& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value_hash(const std::string& key, const value::binary& x)
{
std::size_t h = hash_value(key);
for(const auto& v : x)
hash_combine(h, v);
return h;
}
std::size_t value::hash() const
{
std::size_t h = 0;
this->visit_value([&](const auto& a) { h = value_hash(this->get_key(), a); });
return h;
}
void value::debug_print(bool show_type) const
{
if(show_type)
{
switch(get_type())
{
#define MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(vt, cpp_type) \
case vt##_type: std::cout << #vt << ": "; break;
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE)
MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(null, )
MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(array, )
MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(object, )
}
}
std::cout << *this << std::endl;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,110 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/verify_args.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_DUMP_DIFF);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool verify_args(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
verify::tolerance tols)
{
bool passed = true;
visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double rms_error;
passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
if(not passed)
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
std::cout << "RMS Error: " << rms_error << std::endl;
if(ref.size() < 32 or enabled(MIGRAPHX_VERIFY_DUMP_DIFF{}))
std::cout << "ref:" << ref << std::endl;
if(target.size() < 32 or enabled(MIGRAPHX_VERIFY_DUMP_DIFF{}))
std::cout << "target:" << target << std::endl;
if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl;
if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl;
auto mxdiff = verify::max_diff(ref, target);
std::cout << "Max diff: " << mxdiff << std::endl;
auto idx = verify::mismatch_idx(ref, target, float_equal);
if(idx < verify::range_distance(ref))
{
std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
<< std::endl;
}
auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl;
std::cout << std::endl;
}
else
{
if(verify::range_zero(ref))
std::cout << "Ref data is all zeros" << std::endl;
if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl;
auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
<< ref[ref_nan_idx] << std::endl;
auto target_nan_idx = find_idx(target, verify::not_finite);
if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl;
}
});
return passed;
}
bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance)
{
double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::tolerance tols{rms_tol};
return verify_args(name, target_arg, ref_arg, tols);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

View File

@ -0,0 +1,33 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
// clang-format off
#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@
#define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@
#define MIGRAPHX_VERSION_TWEAK "@PROJECT_VERSION_TWEAK@"
#define MIGRAPHX_SO_MAJOR_VERSION \
@PROJECT_VERSION_MAJOR@ * 1000 * 1000 + \
@PROJECT_VERSION_MINOR@ * 1000 + \
@PROJECT_VERSION_PATCH@
// clang-format on