mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-19 01:17:06 +03:00
Add files via upload
This commit is contained in:
parent
8e5fe2703a
commit
58e9831aef
70
docker/rocm/migraphx/adjust_allocation.cpp
Normal file
70
docker/rocm/migraphx/adjust_allocation.cpp
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/adjust_allocation.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void adjust_allocation::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
// skip instruction with no input
|
||||
if(ins->inputs().empty())
|
||||
continue;
|
||||
|
||||
// Skip target-independent operators
|
||||
if(ins->get_operator().is_context_free())
|
||||
continue;
|
||||
|
||||
auto alias_ins = instruction::get_output_alias(ins, true);
|
||||
if(alias_ins->name() != model.name() and alias_ins->name() != "@param")
|
||||
continue;
|
||||
// shape allocated is different from actual shape
|
||||
// of the instruction, reallocate and replace the previous one
|
||||
if(alias_ins->get_shape() == ins->get_shape())
|
||||
continue;
|
||||
auto alloc_ins = m.insert_instruction(ins, model.allocate(ins->get_shape()));
|
||||
m.replace_instruction(alias_ins, alloc_ins);
|
||||
// If the memory is an output parameter then copy the memory to the parameter
|
||||
if(alias_ins->name() == "@param")
|
||||
{
|
||||
auto copy = m.insert_instruction(std::next(ins), make_op(model.copy()), ins, alias_ins);
|
||||
auto tail = range(std::next(copy), m.end());
|
||||
for(auto i : iterator_for(tail))
|
||||
{
|
||||
if(contains(i->inputs(), ins))
|
||||
instruction::replace_argument(i, ins, copy);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
114
docker/rocm/migraphx/analyze_streams.cpp
Normal file
114
docker/rocm/migraphx/analyze_streams.cpp
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/analyze_streams.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
bool happens_before(const std::vector<std::size_t>& e1, const std::vector<std::size_t>& e2)
|
||||
{
|
||||
return std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::less_equal<>{}) and
|
||||
not std::equal(e1.begin(), e1.end(), e2.begin(), e2.end(), std::greater_equal<>{});
|
||||
}
|
||||
|
||||
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm)
|
||||
{
|
||||
using vector_clock = std::vector<std::size_t>;
|
||||
std::vector<stream_race> races;
|
||||
auto nstream = strmm.get_nstream();
|
||||
std::vector<vector_clock> vclock(nstream, vector_clock(nstream));
|
||||
std::unordered_map<instruction_ref, vector_clock> timestamp;
|
||||
std::unordered_map<std::size_t, vector_clock> events;
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(not strmm.has_stream(ins))
|
||||
continue;
|
||||
std::size_t s = strmm.get_stream(ins);
|
||||
assert(s < nstream);
|
||||
assert(vclock.size() == nstream);
|
||||
assert(vclock[s].size() == nstream);
|
||||
if(strmm.is_record(ins))
|
||||
{
|
||||
vclock[s][s]++;
|
||||
auto event = strmm.get_event_id(ins);
|
||||
events[event] = vclock[s];
|
||||
}
|
||||
else if(strmm.is_wait(ins))
|
||||
{
|
||||
auto event = strmm.get_event_id(ins);
|
||||
if(not contains(events, event))
|
||||
MIGRAPHX_THROW("Event is waited on before being recorded: " +
|
||||
std::to_string(event));
|
||||
auto payload = events.at(event);
|
||||
assert(vclock[s].size() == payload.size());
|
||||
std::transform(vclock[s].begin(),
|
||||
vclock[s].end(),
|
||||
payload.begin(),
|
||||
vclock[s].begin(),
|
||||
[&](auto x, auto y) { return std::max(x, y); });
|
||||
vclock[s][s]++;
|
||||
}
|
||||
else
|
||||
{
|
||||
vclock[s][s]++;
|
||||
}
|
||||
timestamp[ins] = vclock[s];
|
||||
}
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(not strmm.has_stream(ins))
|
||||
continue;
|
||||
if(ins->inputs().empty())
|
||||
continue;
|
||||
std::size_t s = strmm.get_stream(ins);
|
||||
// Find inputs from different streams
|
||||
std::vector<instruction_ref> inputs;
|
||||
fix([&](auto self, auto start) {
|
||||
for(auto input : start->inputs())
|
||||
{
|
||||
if(not strmm.has_stream(input))
|
||||
self(input);
|
||||
else if(strmm.get_stream(input) != s)
|
||||
inputs.push_back(input);
|
||||
}
|
||||
})(ins);
|
||||
auto it = std::find_if(inputs.begin(), inputs.end(), [&](auto input) {
|
||||
return not happens_before(timestamp.at(input), timestamp.at(ins));
|
||||
});
|
||||
if(it != inputs.end())
|
||||
{
|
||||
races.push_back({ins, *it});
|
||||
}
|
||||
}
|
||||
|
||||
return races;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
79
docker/rocm/migraphx/apply_alpha_beta.cpp
Normal file
79
docker/rocm/migraphx/apply_alpha_beta.cpp
Normal file
@ -0,0 +1,79 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/common.hpp>
|
||||
#include <migraphx/apply_alpha_beta.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
instruction_ref insert_apply_alpha_beta(module& m,
|
||||
instruction_ref pos,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const operation& op,
|
||||
const literal& alpha,
|
||||
const literal& beta)
|
||||
{
|
||||
auto a = args[0];
|
||||
auto b = args[1];
|
||||
auto input_type = a->get_shape().type();
|
||||
if(not float_equal(alpha.at<float>(0), 1.0))
|
||||
{
|
||||
auto alpha_literal = m.add_literal(alpha);
|
||||
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
|
||||
if(a->get_shape().type() != input_type)
|
||||
{
|
||||
a = m.insert_instruction(pos, make_op("convert", {{"target_type", input_type}}), a);
|
||||
}
|
||||
}
|
||||
auto op_res = m.insert_instruction(pos, op, a, b);
|
||||
if(args.size() == 3)
|
||||
{
|
||||
if(not float_equal(beta.at<float>(0), 0.0) and args[2]->get_shape().elements() > 0)
|
||||
{
|
||||
auto out_lens = op_res->get_shape().lens();
|
||||
auto c = args[2];
|
||||
auto c_lens = c->get_shape().lens();
|
||||
input_type = c->get_shape().type();
|
||||
if(out_lens != c_lens)
|
||||
{
|
||||
c = m.insert_instruction(
|
||||
pos, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
|
||||
}
|
||||
auto beta_literal = m.add_literal(beta);
|
||||
auto beta_c = insert_common_op(m, pos, migraphx::make_op("mul"), {c, beta_literal});
|
||||
if(beta_c->get_shape().type() != input_type)
|
||||
{
|
||||
beta_c = m.insert_instruction(
|
||||
pos, migraphx::make_op("convert", {{"target_type", input_type}}), beta_c);
|
||||
}
|
||||
return m.insert_instruction(pos, migraphx::make_op("add"), op_res, beta_c);
|
||||
}
|
||||
}
|
||||
return op_res;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
210
docker/rocm/migraphx/argument.cpp
Normal file
210
docker/rocm/migraphx/argument.cpp
Normal file
@ -0,0 +1,210 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/argument.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
argument::argument(const shape& s) : m_shape(s)
|
||||
{
|
||||
auto buffer = make_shared_array<char>(s.bytes());
|
||||
assign_buffer({[=]() mutable { return buffer.get(); }});
|
||||
}
|
||||
|
||||
argument::argument(shape s, std::nullptr_t)
|
||||
: m_shape(std::move(s)), m_data({[] { return nullptr; }})
|
||||
{
|
||||
}
|
||||
|
||||
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
|
||||
|
||||
void argument::assign_buffer(std::function<char*()> d)
|
||||
{
|
||||
const shape& s = m_shape;
|
||||
if(s.type() != shape::tuple_type)
|
||||
{
|
||||
m_data = {std::move(d)};
|
||||
return;
|
||||
}
|
||||
// Collect all shapes
|
||||
std::unordered_map<std::size_t, shape> shapes;
|
||||
{
|
||||
std::size_t i = 0;
|
||||
fix([&](auto self, auto ss) {
|
||||
if(ss.sub_shapes().empty())
|
||||
{
|
||||
shapes[i] = ss;
|
||||
i++;
|
||||
}
|
||||
else
|
||||
{
|
||||
for(auto&& child : ss.sub_shapes())
|
||||
self(child);
|
||||
}
|
||||
})(s);
|
||||
}
|
||||
// Sort by type size
|
||||
std::vector<std::size_t> order(shapes.size());
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(), by(std::greater<>{}, [&](auto i) {
|
||||
return shapes[i].type_size();
|
||||
}));
|
||||
// Compute offsets
|
||||
std::unordered_map<std::size_t, std::size_t> offsets;
|
||||
std::size_t offset = 0;
|
||||
for(auto i : order)
|
||||
{
|
||||
offsets[i] = offset;
|
||||
offset += shapes[i].bytes();
|
||||
}
|
||||
assert(offset == s.bytes());
|
||||
|
||||
std::size_t i = 0;
|
||||
m_data = fix<data_t>([&](auto self, auto ss) {
|
||||
data_t result;
|
||||
if(ss.sub_shapes().empty())
|
||||
{
|
||||
auto n = offsets[i];
|
||||
result = {[d, n]() mutable { return d() + n; }};
|
||||
i++;
|
||||
return result;
|
||||
}
|
||||
std::vector<data_t> subs;
|
||||
std::transform(ss.sub_shapes().begin(),
|
||||
ss.sub_shapes().end(),
|
||||
std::back_inserter(subs),
|
||||
[&](auto child) { return self(child); });
|
||||
result.sub = subs;
|
||||
return result;
|
||||
})(s);
|
||||
}
|
||||
|
||||
std::vector<argument> flatten(const std::vector<argument>& args)
|
||||
{
|
||||
std::vector<argument> result;
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
if(arg.get_shape().type() == shape::tuple_type)
|
||||
{
|
||||
auto subs = flatten(arg.get_sub_objects());
|
||||
result.insert(result.end(), subs.begin(), subs.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
result.push_back(arg);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<shape> to_shapes(const std::vector<argument>& args)
|
||||
{
|
||||
std::vector<shape> shapes;
|
||||
std::transform(args.begin(), args.end(), std::back_inserter(shapes), [](auto&& arg) {
|
||||
return arg.get_shape();
|
||||
});
|
||||
return shapes;
|
||||
}
|
||||
|
||||
argument::argument(const std::vector<argument>& args)
|
||||
: m_shape(to_shapes(args)), m_data(data_t::from_args(args))
|
||||
{
|
||||
}
|
||||
|
||||
char* argument::data() const
|
||||
{
|
||||
assert(m_shape.type() != shape::tuple_type);
|
||||
assert(not this->empty());
|
||||
return m_data.get();
|
||||
}
|
||||
|
||||
bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
|
||||
|
||||
const shape& argument::get_shape() const { return this->m_shape; }
|
||||
|
||||
argument argument::reshape(const shape& s) const
|
||||
{
|
||||
assert(s.element_space() <= this->get_shape().element_space());
|
||||
return {s, this->m_data};
|
||||
}
|
||||
|
||||
argument::data_t argument::data_t::share() const
|
||||
{
|
||||
data_t result;
|
||||
if(this->get)
|
||||
{
|
||||
auto self = std::make_shared<data_t>(*this);
|
||||
result.get = [self]() mutable { return self->get(); };
|
||||
}
|
||||
std::transform(sub.begin(), sub.end(), std::back_inserter(result.sub), [](const auto& d) {
|
||||
return d.share();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
|
||||
{
|
||||
data_t result;
|
||||
std::transform(args.begin(), args.end(), std::back_inserter(result.sub), [](auto&& arg) {
|
||||
return arg.m_data;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
argument argument::copy() const
|
||||
{
|
||||
argument result{this->get_shape()};
|
||||
auto* src = this->data();
|
||||
std::copy(src, src + this->get_shape().bytes(), result.data());
|
||||
return result;
|
||||
}
|
||||
|
||||
argument argument::share() const { return {m_shape, m_data.share()}; }
|
||||
|
||||
std::vector<argument> argument::get_sub_objects() const
|
||||
{
|
||||
std::vector<argument> result;
|
||||
assert(m_shape.sub_shapes().size() == m_data.sub.size());
|
||||
std::transform(m_shape.sub_shapes().begin(),
|
||||
m_shape.sub_shapes().end(),
|
||||
m_data.sub.begin(),
|
||||
std::back_inserter(result),
|
||||
[](auto&& s, auto&& d) {
|
||||
return argument{s, d};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
argument argument::element(std::size_t i) const
|
||||
{
|
||||
assert(this->get_shape().sub_shapes().empty());
|
||||
auto idx = this->get_shape().index(i);
|
||||
auto offset = this->get_shape().type_size() * idx;
|
||||
return argument{shape{this->get_shape().type()}, this->data() + offset};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
80
docker/rocm/migraphx/auto_contiguous.cpp
Normal file
80
docker/rocm/migraphx/auto_contiguous.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/auto_contiguous.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void auto_contiguous::apply(module& m) const
|
||||
{
|
||||
std::string key = "require_std_shape";
|
||||
for(auto ins : reverse_iterator_for(m))
|
||||
{
|
||||
auto&& attr = ins->get_operator().attributes();
|
||||
if((attr.get(key, false)))
|
||||
{
|
||||
auto args = ins->inputs();
|
||||
auto new_args = args;
|
||||
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
|
||||
if(in->name() == "contiguous")
|
||||
{
|
||||
return in;
|
||||
}
|
||||
return m.insert_instruction(ins, make_op("contiguous"), in);
|
||||
});
|
||||
|
||||
if(new_args != args)
|
||||
{
|
||||
m.replace_instruction(ins, ins->get_operator(), new_args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto last = std::prev(m.end());
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(contains({"layout", "@return"}, ins->name()))
|
||||
continue;
|
||||
// for last instruction that is NOT a return
|
||||
if(ins->outputs().empty() and ins != last)
|
||||
continue;
|
||||
shape s = ins->get_shape();
|
||||
// If s is not standard layout or has out of sequence strides, insert "contiguous" op
|
||||
// to make a standard shape
|
||||
if(not s.dynamic() and (not s.standard() or s.normalize_standard() != s) and
|
||||
s.elements() > 1)
|
||||
{
|
||||
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
|
||||
m.replace_instruction(ins, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
82
docker/rocm/migraphx/autocast_fp8.cpp
Normal file
82
docker/rocm/migraphx/autocast_fp8.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/autocast_fp8.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/fp8_types.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void autocast_fp8_pass::apply(module& m) const
|
||||
{
|
||||
std::vector<instruction_ref> remove_parameters;
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
const auto& ins_name = ins->name();
|
||||
if(ins_name == "@param" and contains(fp8_types{}.get(), ins->get_shape().type()))
|
||||
{
|
||||
shape::type_t fp8_type = ins->get_shape().type();
|
||||
migraphx::shape new_shape = ins->get_shape().with_type(target_type);
|
||||
std::string param_name = ins->get_operator().to_value()["parameter"].to<std::string>();
|
||||
m.rename_parameter(ins, param_name + "_old");
|
||||
auto new_param = m.add_parameter(param_name, new_shape);
|
||||
auto new_ins = m.insert_instruction(
|
||||
ins,
|
||||
migraphx::make_op("convert", {{"target_type", migraphx::to_value(fp8_type)}}),
|
||||
new_param);
|
||||
m.replace_instruction(ins, new_ins);
|
||||
remove_parameters.push_back(ins);
|
||||
}
|
||||
|
||||
if(ins_name == "@return")
|
||||
{
|
||||
std::vector<instruction_ref> inputs = ins->inputs();
|
||||
std::vector<instruction_ref> new_inputs;
|
||||
std::transform(
|
||||
inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [&](auto i) {
|
||||
if(contains(fp8_types{}.get(), i->get_shape().type()))
|
||||
{
|
||||
return m.insert_instruction(
|
||||
ins,
|
||||
migraphx::make_op("convert",
|
||||
{{"target_type", migraphx::to_value(target_type)}}),
|
||||
i);
|
||||
}
|
||||
else
|
||||
return i;
|
||||
});
|
||||
m.replace_return({new_inputs});
|
||||
}
|
||||
}
|
||||
// Remove unused parameters with fp8 type
|
||||
for(const auto& i : remove_parameters)
|
||||
m.remove_instruction(i);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
81
docker/rocm/migraphx/base64.cpp
Normal file
81
docker/rocm/migraphx/base64.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/base64.hpp>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
namespace {
|
||||
using byte = unsigned char;
|
||||
|
||||
std::array<char, 64> constexpr b64_chars{
|
||||
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
|
||||
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
|
||||
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
|
||||
'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
|
||||
|
||||
/// base64 encoder snippet altered from https://stackoverflow.com/a/37109258
|
||||
std::string encode(const std::vector<byte>& buf)
|
||||
{
|
||||
std::size_t len = buf.size();
|
||||
std::vector<byte> res_vec((len + 2) / 3 * 4, '=');
|
||||
std::size_t j = 0;
|
||||
std::size_t remaining = len % 3;
|
||||
const size_t last = len - remaining;
|
||||
|
||||
for(size_t i = 0; i < last; i += 3)
|
||||
{
|
||||
std::size_t n = static_cast<std::size_t>(buf.at(i)) << 16u |
|
||||
static_cast<std::size_t>(buf.at(i + 1)) << 8u |
|
||||
static_cast<std::size_t>(buf.at(i + 2));
|
||||
res_vec.at(j++) = b64_chars.at(n >> 18u);
|
||||
res_vec.at(j++) = b64_chars.at(n >> 12u & 0x3Fu);
|
||||
res_vec.at(j++) = b64_chars.at(n >> 6u & 0x3Fu);
|
||||
res_vec.at(j++) = b64_chars.at(n & 0x3Fu);
|
||||
}
|
||||
// Set padding
|
||||
if(remaining != 0)
|
||||
{
|
||||
std::size_t n = --remaining == 0 ? static_cast<std::size_t>(buf.at(last))
|
||||
: static_cast<std::size_t>(buf.at(last)) << 8u |
|
||||
static_cast<std::size_t>(buf.at(last + 1));
|
||||
res_vec.at(j++) = b64_chars.at(remaining == 0 ? n >> 2u : n >> 10u & 0x3Fu);
|
||||
res_vec.at(j++) = b64_chars.at(remaining == 0 ? n << 4u & 0x3Fu : n >> 4u & 0x03Fu);
|
||||
res_vec.at(j++) = remaining == 0 ? '=' : b64_chars.at(n << 2u & 0x3Fu);
|
||||
}
|
||||
return {res_vec.begin(), res_vec.end()};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string base64_encode(const std::string& str)
|
||||
{
|
||||
return encode(std::vector<byte>(str.begin(), str.end()));
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
254
docker/rocm/migraphx/common.cpp
Normal file
254
docker/rocm/migraphx/common.cpp
Normal file
@ -0,0 +1,254 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/common.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/algorithm.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
|
||||
std::vector<std::size_t> s1)
|
||||
{
|
||||
if(s0 == s1)
|
||||
return s0;
|
||||
if(s0.size() > s1.size())
|
||||
s0.swap(s1);
|
||||
std::vector<std::size_t> out_lens(s1);
|
||||
auto offset = s1.size() - s0.size();
|
||||
std::transform(
|
||||
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
|
||||
if(a != b and a != 1 and b != 1)
|
||||
{
|
||||
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
|
||||
"} and {" + migraphx::to_string_range(s1) + "} mismatch!");
|
||||
}
|
||||
return std::max(a, b);
|
||||
});
|
||||
return out_lens;
|
||||
}
|
||||
std::vector<shape::dynamic_dimension>
|
||||
compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
|
||||
std::vector<shape::dynamic_dimension> dds1)
|
||||
{
|
||||
if(dds0.size() > dds1.size())
|
||||
{
|
||||
std::swap(dds0, dds1);
|
||||
}
|
||||
auto offset = dds1.size() - dds0.size();
|
||||
std::vector<shape::dynamic_dimension> out_dims(dds1);
|
||||
std::transform(dds0.cbegin(),
|
||||
dds0.cend(),
|
||||
dds1.cbegin() + offset,
|
||||
out_dims.begin() + offset,
|
||||
[&](auto a, auto b) {
|
||||
if(a == b or b == 1)
|
||||
{
|
||||
return a;
|
||||
}
|
||||
else if(a == 1)
|
||||
{
|
||||
return b;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto intersect = a.intersection(b);
|
||||
if(intersect.has_value())
|
||||
{
|
||||
return intersect.value();
|
||||
}
|
||||
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
|
||||
migraphx::to_string_range(dds0) + "} and {" +
|
||||
migraphx::to_string_range(dds1) + "} mismatch!");
|
||||
}
|
||||
});
|
||||
return out_dims;
|
||||
}
|
||||
|
||||
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
|
||||
{
|
||||
// change both shapes to dynamic_dimension representation
|
||||
s0 = s0.to_dynamic();
|
||||
s1 = s1.to_dynamic();
|
||||
return compute_broadcasted_dyn_dims(s0.dyn_dims(), s1.dyn_dims());
|
||||
}
|
||||
|
||||
std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
|
||||
{
|
||||
auto ret_shape = shapes.at(0);
|
||||
std::for_each(shapes.cbegin() + 1, shapes.cend(), [&](auto s) {
|
||||
ret_shape = shape{ret_shape.type(), compute_broadcasted_dyn_dims(ret_shape, s)};
|
||||
});
|
||||
return ret_shape.dyn_dims();
|
||||
}
|
||||
|
||||
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
|
||||
{
|
||||
assert(not shapes.empty());
|
||||
assert(
|
||||
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
|
||||
return transform_accumulate(shapes.begin() + 1,
|
||||
shapes.end(),
|
||||
shapes.front().lens(),
|
||||
&compute_broadcasted_lens,
|
||||
[](auto s) { return s.lens(); });
|
||||
}
|
||||
|
||||
shape::type_t compute_common_type(shape::type_t t1, shape::type_t t2)
|
||||
{
|
||||
if(t1 == t2)
|
||||
return t1;
|
||||
shape::type_t result;
|
||||
shape::visit(t1, [&](auto x) {
|
||||
shape::visit(t2, [&](auto y) {
|
||||
// Workaround broken warning on gcc 5
|
||||
(void)x;
|
||||
(void)y;
|
||||
using type = std::common_type_t<decltype(x()), decltype(y())>;
|
||||
result = shape::get_type<type>{};
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
shape::type_t compute_common_types(const std::vector<shape>& shapes)
|
||||
{
|
||||
assert(not shapes.empty());
|
||||
return transform_accumulate(
|
||||
shapes.begin() + 1, shapes.end(), shapes.front().type(), &compute_common_type, [&](auto s) {
|
||||
return s.type();
|
||||
});
|
||||
}
|
||||
|
||||
shape common_shape(const std::vector<shape>& shapes)
|
||||
{
|
||||
if(shapes.empty())
|
||||
return {};
|
||||
return {compute_common_types(shapes), compute_common_lens(shapes)};
|
||||
}
|
||||
|
||||
std::vector<instruction_ref> insert_common_args(module& m,
|
||||
instruction_ref ins,
|
||||
std::vector<instruction_ref> inputs,
|
||||
common_options options)
|
||||
{
|
||||
if(std::any_of(
|
||||
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
|
||||
{
|
||||
auto input_shapes = to_shapes(inputs);
|
||||
if(options.common_lens)
|
||||
{
|
||||
auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
|
||||
|
||||
auto s0 = inputs[0]->get_shape();
|
||||
// always add both multibroadcast instructions for dynamic shapes
|
||||
inputs[0] = m.insert_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
|
||||
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
|
||||
// uses previous input to avoid recalculating the common shape from the
|
||||
// full set of input shapes at runtime
|
||||
auto s = input->get_shape();
|
||||
return m.insert_instruction(
|
||||
ins,
|
||||
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
|
||||
input,
|
||||
inputs[0]);
|
||||
});
|
||||
}
|
||||
if(options.common_type)
|
||||
{
|
||||
auto c_type = compute_common_types(input_shapes);
|
||||
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
|
||||
if(input->get_shape().type() != c_type)
|
||||
{
|
||||
input = m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", c_type}}), input);
|
||||
}
|
||||
return input;
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto common = common_shape(to_shapes(inputs));
|
||||
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
|
||||
if(options.common_lens and input->get_shape().lens() != common.lens())
|
||||
{
|
||||
input = m.insert_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
|
||||
}
|
||||
if(options.common_type and input->get_shape().type() != common.type())
|
||||
{
|
||||
input = m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", common.type()}}), input);
|
||||
}
|
||||
return input;
|
||||
});
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<instruction_ref>
|
||||
add_common_args(module& m, std::vector<instruction_ref> inputs, common_options options)
|
||||
{
|
||||
return insert_common_args(m, m.end(), std::move(inputs), options);
|
||||
}
|
||||
|
||||
instruction_ref insert_common_op(module& m,
|
||||
instruction_ref ins,
|
||||
const operation& op,
|
||||
std::vector<instruction_ref> inputs,
|
||||
common_options options)
|
||||
{
|
||||
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs), options));
|
||||
}
|
||||
|
||||
instruction_ref add_common_op(module& m,
|
||||
const operation& op,
|
||||
std::vector<instruction_ref> inputs,
|
||||
common_options options)
|
||||
{
|
||||
return insert_common_op(m, m.end(), op, std::move(inputs), options);
|
||||
}
|
||||
|
||||
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
|
||||
{
|
||||
assert(not input_shape.dynamic());
|
||||
auto offset = bcast_lens.size() - input_shape.ndim();
|
||||
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
|
||||
for(std::ptrdiff_t i : reverse(range(input_shape.ndim())))
|
||||
{
|
||||
if(bcast_lens.at(i + offset) == input_shape.lens()[i])
|
||||
{
|
||||
bcast_strides.at(i + offset) = input_shape.strides()[i];
|
||||
}
|
||||
}
|
||||
return shape{input_shape.type(), bcast_lens, bcast_strides};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
203
docker/rocm/migraphx/common_dims.cpp
Normal file
203
docker/rocm/migraphx/common_dims.cpp
Normal file
@ -0,0 +1,203 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/common_dims.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class Iterator>
|
||||
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
|
||||
{
|
||||
std::size_t x = 1;
|
||||
auto it = std::find_if(start, last, [&](auto i) {
|
||||
x *= i;
|
||||
return x > dim;
|
||||
});
|
||||
if(x < dim)
|
||||
return start;
|
||||
return it;
|
||||
}
|
||||
|
||||
struct common_dim_state
|
||||
{
|
||||
common_dim_state(const std::vector<std::size_t>& pdims,
|
||||
std::vector<std::vector<std::size_t>>& paxes_map)
|
||||
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
|
||||
{
|
||||
}
|
||||
const std::vector<std::size_t>* dims = nullptr;
|
||||
std::vector<std::vector<std::size_t>>* axes_map = nullptr;
|
||||
std::vector<std::size_t>::const_iterator it{};
|
||||
std::size_t rem = 1;
|
||||
std::size_t get() const { return *it / rem; }
|
||||
bool is_end() const { return it == dims->end(); }
|
||||
void next(std::size_t i = 1) { it += i; }
|
||||
auto dims_for(std::size_t d) const
|
||||
{
|
||||
auto dim_end = compute_end_dim(it, dims->end(), d);
|
||||
return range(it, dim_end);
|
||||
}
|
||||
void add_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
|
||||
{
|
||||
auto axes = compute_axes(naxes, start);
|
||||
axes_map->push_back(std::move(axes));
|
||||
}
|
||||
|
||||
void add_multi_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
|
||||
{
|
||||
auto axes = compute_axes(naxes, start);
|
||||
std::transform(axes.begin(),
|
||||
axes.end(),
|
||||
std::back_inserter(*axes_map),
|
||||
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
|
||||
}
|
||||
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
|
||||
{
|
||||
if(rem != 1)
|
||||
{
|
||||
assert(start > 0);
|
||||
naxes++;
|
||||
start--;
|
||||
}
|
||||
std::vector<std::size_t> axes(naxes);
|
||||
std::iota(axes.begin(), axes.end(), start);
|
||||
return axes;
|
||||
}
|
||||
};
|
||||
|
||||
static bool compute_common_dim(std::vector<std::size_t>& cd_dims,
|
||||
common_dim_state& state1,
|
||||
common_dim_state& state2)
|
||||
{
|
||||
assert(state1.get() < state2.get());
|
||||
auto d2 = state2.get();
|
||||
auto dims = state1.dims_for(d2);
|
||||
auto n = elements(dims);
|
||||
auto naxes = distance(dims);
|
||||
if(naxes == 0)
|
||||
return false;
|
||||
// If not divisible then we can't compute a common dim
|
||||
if((d2 % n) != 0)
|
||||
return false;
|
||||
auto rem = d2 / n;
|
||||
state1.add_multi_axes(naxes, cd_dims.size());
|
||||
state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size());
|
||||
|
||||
state1.rem = rem;
|
||||
state2.rem = 1;
|
||||
|
||||
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
|
||||
if(state1.rem != 1)
|
||||
cd_dims.push_back(state1.rem);
|
||||
state1.next(distance(dims));
|
||||
state2.next();
|
||||
return true;
|
||||
}
|
||||
|
||||
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
|
||||
const std::vector<std::size_t>& dims2)
|
||||
{
|
||||
assert(elements(dims1) > 0);
|
||||
assert(elements(dims1) == elements(dims2));
|
||||
common_dims cd;
|
||||
common_dim_state state1{dims1, cd.axes_map1};
|
||||
common_dim_state state2{dims2, cd.axes_map2};
|
||||
while(not state1.is_end() and not state2.is_end())
|
||||
{
|
||||
auto d1 = state1.get();
|
||||
auto d2 = state2.get();
|
||||
if(d1 == d2)
|
||||
{
|
||||
state1.add_axes(1, cd.dims.size());
|
||||
state2.add_axes(1, cd.dims.size());
|
||||
state1.rem = 1;
|
||||
state2.rem = 1;
|
||||
cd.dims.push_back(d1);
|
||||
state1.next();
|
||||
state2.next();
|
||||
}
|
||||
else if(d1 < d2)
|
||||
{
|
||||
if(not compute_common_dim(cd.dims, state1, state2))
|
||||
return {};
|
||||
}
|
||||
else // if(d1 > d2)
|
||||
{
|
||||
if(not compute_common_dim(cd.dims, state2, state1))
|
||||
return {};
|
||||
}
|
||||
}
|
||||
assert(elements(dims1) == elements(cd.dims));
|
||||
return cd;
|
||||
}
|
||||
|
||||
const std::vector<std::vector<std::size_t>>* common_dims::get_axes_map(std::size_t n) const
|
||||
{
|
||||
if(axes_map1.size() == n)
|
||||
return &axes_map1;
|
||||
if(axes_map2.size() == n)
|
||||
return &axes_map2;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<std::size_t>
|
||||
common_dims::get_dimensions_for(const std::vector<std::size_t>& idims) const
|
||||
{
|
||||
if(dims.size() == idims.size())
|
||||
return idims;
|
||||
if(elements(dims) == elements(idims))
|
||||
return dims;
|
||||
// Bail for now since its ambiguous which axes map can be used
|
||||
// TODO: Check for similiarity
|
||||
if(axes_map1.size() == axes_map2.size())
|
||||
return {};
|
||||
const auto* axes_map = get_axes_map(idims.size());
|
||||
if(axes_map == nullptr)
|
||||
return {};
|
||||
auto xdims = dims;
|
||||
for(auto i : range(axes_map->size()))
|
||||
{
|
||||
auto dim = idims[i];
|
||||
const auto& axes = (*axes_map)[i];
|
||||
if(axes.size() == 1)
|
||||
{
|
||||
xdims[axes.front()] = dim;
|
||||
}
|
||||
else if(dim == 1)
|
||||
{
|
||||
for(auto axis : axes)
|
||||
xdims[axis] = 1;
|
||||
}
|
||||
}
|
||||
if(elements(xdims) == elements(idims))
|
||||
return xdims;
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
76
docker/rocm/migraphx/compile_src.cpp
Normal file
76
docker/rocm/migraphx/compile_src.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/compile_src.hpp>
|
||||
#include <migraphx/file_buffer.hpp>
|
||||
#include <migraphx/tmp_dir.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/fileutils.hpp>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
|
||||
{
|
||||
assert(not srcs.empty());
|
||||
tmp_dir td{"compile"};
|
||||
std::vector<std::string> params{flags};
|
||||
|
||||
params.emplace_back("-I.");
|
||||
|
||||
auto out = output;
|
||||
|
||||
for(const auto& src : srcs)
|
||||
{
|
||||
fs::path full_path = td.path / src.path;
|
||||
fs::path parent_path = full_path.parent_path();
|
||||
fs::create_directories(parent_path);
|
||||
write_buffer(full_path, src.content.data(), src.content.size());
|
||||
if(src.path.extension().string() == ".cpp")
|
||||
{
|
||||
params.emplace_back(src.path.filename().string());
|
||||
if(out.empty())
|
||||
out = src.path.stem().string() + out_ext;
|
||||
}
|
||||
}
|
||||
|
||||
params.emplace_back("-o " + out);
|
||||
|
||||
std::vector<std::string> args;
|
||||
if(not launcher.empty())
|
||||
args.push_back(compiler.string());
|
||||
args.insert(args.end(), params.begin(), params.end());
|
||||
td.execute(launcher.empty() ? compiler : launcher, args);
|
||||
|
||||
auto out_path = td.path / out;
|
||||
if(not fs::exists(out_path))
|
||||
MIGRAPHX_THROW("Output file missing: " + out);
|
||||
|
||||
return read_buffer(out_path);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
105
docker/rocm/migraphx/convert_to_json.cpp
Normal file
105
docker/rocm/migraphx/convert_to_json.cpp
Normal file
@ -0,0 +1,105 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/convert_to_json.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/lexing.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::vector<std::string_view> json_tokenize(const std::string& s)
|
||||
{
|
||||
std::vector<lexer> lexers;
|
||||
|
||||
// Quote
|
||||
lexers.push_back([](const char* start, const char* end) {
|
||||
if(*start != '\"')
|
||||
return start;
|
||||
++start;
|
||||
while((start != end) and (*start != '\"'))
|
||||
{
|
||||
if(*start == '\\')
|
||||
start++;
|
||||
start++;
|
||||
}
|
||||
|
||||
return ++start;
|
||||
});
|
||||
|
||||
// Line comments
|
||||
lexers.push_back([](const char* start, const char* end) {
|
||||
if(*start == '#')
|
||||
start++;
|
||||
else if((start + 1) < end and start[0] == '/' and start[1] == '/')
|
||||
start += 2;
|
||||
else
|
||||
return start;
|
||||
return std::find_if(start, end, [&](char c) { return c == '\n'; });
|
||||
});
|
||||
|
||||
// Whitespace
|
||||
lexers.push_back(lex_while(&isspace));
|
||||
|
||||
// Punctation
|
||||
lexers.push_back(lex_if(&ispunct));
|
||||
|
||||
// Identifier/number
|
||||
lexers.push_back(lex_while([](char c) {
|
||||
return (isalnum(c) != 0 or contains({'_', '.', '+'}, c));
|
||||
}));
|
||||
|
||||
return tokenize(s.data(), s.data() + s.length(), lexers);
|
||||
}
|
||||
|
||||
std::string convert_to_json(const std::string& str)
|
||||
{
|
||||
auto tokens = json_tokenize(str);
|
||||
std::stringstream ss;
|
||||
|
||||
for(auto& token : tokens)
|
||||
{
|
||||
std::string s(token);
|
||||
if(starts_with(s, "#") or starts_with(s, "//"))
|
||||
continue;
|
||||
if(std::isalpha(s.front()) != 0 and
|
||||
not contains({"null", "nan", "true", "false", "inf"}, s))
|
||||
{
|
||||
ss << "\"" << s << "\"";
|
||||
}
|
||||
else
|
||||
{
|
||||
ss << s;
|
||||
}
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
284
docker/rocm/migraphx/cpp_generator.cpp
Normal file
284
docker/rocm/migraphx/cpp_generator.cpp
Normal file
@ -0,0 +1,284 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/cpp_generator.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/builtin.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
cpp_generator::function&
|
||||
cpp_generator::function::set_body(const module& m, const cpp_generator::generate_module_callback& g)
|
||||
{
|
||||
const std::string prefix = "zz";
|
||||
std::unordered_map<migraphx::instruction_ref, std::string> names;
|
||||
std::stringstream ss;
|
||||
|
||||
auto return_ins = std::prev(m.end());
|
||||
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
ss << "// " << ins->get_operator() << " -> " << ins->get_shape() << "\n";
|
||||
if(ins->name() == "@param")
|
||||
{
|
||||
names[ins] = to_c_id(
|
||||
migraphx::any_cast<migraphx::builtin::param>(ins->get_operator()).parameter);
|
||||
}
|
||||
else if(ins->name() == "@return")
|
||||
{
|
||||
names[ins] = prefix + "return";
|
||||
ss << "auto " << names[ins] << " = " << g(ins, names) << ";\n";
|
||||
return_ins = ins;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string n = prefix + std::to_string(names.size());
|
||||
names[ins] = n;
|
||||
ss << "auto " << n << " = " << g(ins, names) << ";\n";
|
||||
}
|
||||
}
|
||||
ss << "return " << names.at(return_ins) << ";\n";
|
||||
body = ss.str();
|
||||
return *this;
|
||||
}
|
||||
|
||||
cpp_generator::function& cpp_generator::function::set_types(const module& m)
|
||||
{
|
||||
return cpp_generator::function::set_types(m, [](auto s) { return shape::cpp_type(s.type()); });
|
||||
}
|
||||
cpp_generator::function&
|
||||
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
|
||||
{
|
||||
this->params.clear();
|
||||
auto pmap = m.get_parameter_shapes();
|
||||
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
|
||||
std::transform(
|
||||
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
|
||||
return param{p.first, parse(p.second)};
|
||||
});
|
||||
auto output_shapes = m.get_output_shapes();
|
||||
assert(not output_shapes.empty());
|
||||
this->return_type = parse(output_shapes.front());
|
||||
return *this;
|
||||
}
|
||||
|
||||
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
|
||||
{
|
||||
this->params.clear();
|
||||
auto pmap = m.get_parameter_shapes();
|
||||
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
|
||||
std::transform(
|
||||
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
|
||||
return param{p.first, "T" + to_c_id(p.first)};
|
||||
});
|
||||
|
||||
std::transform(input_map.begin(),
|
||||
input_map.end(),
|
||||
std::back_inserter(this->tparams),
|
||||
[&](auto&& p) { return "class T" + to_c_id(p.first); });
|
||||
this->return_type = "auto";
|
||||
return *this;
|
||||
}
|
||||
|
||||
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
|
||||
{
|
||||
body.insert(0, "(void)" + pname + ";\n");
|
||||
return *this;
|
||||
}
|
||||
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
|
||||
{
|
||||
params.push_back({pname, "T" + pname});
|
||||
tparams.push_back("class T" + pname);
|
||||
return *this;
|
||||
}
|
||||
|
||||
struct cpp_generator_impl
|
||||
{
|
||||
std::stringstream fs{};
|
||||
std::size_t function_count = 0;
|
||||
std::function<std::string(std::string)> fmap = nullptr;
|
||||
std::function<std::string(shape)> fresult = nullptr;
|
||||
std::unordered_map<std::string, std::string> point_op_map = {};
|
||||
bool always_return_tuple = false;
|
||||
};
|
||||
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
|
||||
|
||||
cpp_generator::cpp_generator(cpp_generator&&) noexcept = default;
|
||||
|
||||
cpp_generator& cpp_generator::operator=(cpp_generator rhs)
|
||||
{
|
||||
std::swap(impl, rhs.impl);
|
||||
return *this;
|
||||
}
|
||||
|
||||
cpp_generator::~cpp_generator() noexcept = default;
|
||||
|
||||
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
|
||||
|
||||
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
|
||||
|
||||
void cpp_generator::always_return_tuple(bool b) { impl->always_return_tuple = b; }
|
||||
|
||||
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
|
||||
{
|
||||
impl->point_op_map[op_name] = code;
|
||||
}
|
||||
|
||||
std::string cpp_generator::generate_point_op(const operation& op,
|
||||
const std::vector<std::string>& args)
|
||||
{
|
||||
auto v = op.to_value();
|
||||
std::string code;
|
||||
if(contains(impl->point_op_map, op.name()))
|
||||
{
|
||||
code = impl->point_op_map.at(op.name());
|
||||
}
|
||||
else
|
||||
{
|
||||
auto attributes = op.attributes();
|
||||
if(not attributes.contains("point_op"))
|
||||
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
|
||||
code = attributes["point_op"].to<std::string>();
|
||||
}
|
||||
return interpolate_string(code, [&](auto start, auto last) -> std::string {
|
||||
auto key = trim({start, last});
|
||||
if(key.empty())
|
||||
MIGRAPHX_THROW("Empty parameter");
|
||||
std::string fselector = "function:";
|
||||
if(starts_with(key, fselector))
|
||||
{
|
||||
auto fname = key.substr(fselector.size());
|
||||
if(impl->fmap == nullptr)
|
||||
return fname;
|
||||
else
|
||||
return impl->fmap(fname);
|
||||
}
|
||||
else if(with_char(::isdigit)(key[0]))
|
||||
{
|
||||
auto i = std::stoul(key);
|
||||
if(i >= args.size())
|
||||
MIGRAPHX_THROW("Invalid argument index: " + key);
|
||||
return args.at(i);
|
||||
}
|
||||
else if(v.contains(key))
|
||||
{
|
||||
return v[key].template to<std::string>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return key;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::string cpp_generator::str() const { return impl->fs.str(); }
|
||||
|
||||
cpp_generator::function cpp_generator::generate_module(const module& m,
|
||||
const generate_module_callback& g)
|
||||
{
|
||||
function f;
|
||||
f.set_name(to_c_id(m.name()))
|
||||
.set_types(m)
|
||||
.set_body(m, [&](instruction_ref ins, const auto& names) -> std::string {
|
||||
if(ins->name() == "@literal")
|
||||
{
|
||||
std::string string_literal;
|
||||
ins->get_literal().visit([&](auto v) {
|
||||
assert(v.size() == 1);
|
||||
auto x = v.front();
|
||||
if(std::isinf(static_cast<double>(x)))
|
||||
{
|
||||
string_literal = "__builtin_huge_val()";
|
||||
if(x < 0)
|
||||
string_literal = "-__builtin_huge_val()";
|
||||
}
|
||||
else if(std::isnan(static_cast<double>(x)))
|
||||
string_literal = "__builtin_nan(\"0\")";
|
||||
else
|
||||
string_literal = ins->get_literal().to_string();
|
||||
});
|
||||
return shape::cpp_type(ins->get_shape().type()) + "(" + string_literal + ")";
|
||||
}
|
||||
if(ins->name() == "@return")
|
||||
{
|
||||
// TODO: Customize the make_tuple call
|
||||
if(impl->always_return_tuple or ins->inputs().size() != 1)
|
||||
return "make_tuple(" + join_strings(to_args(ins->inputs(), names), ", ") + ")";
|
||||
return names.at(ins->inputs().front());
|
||||
}
|
||||
auto s = g(ins, names);
|
||||
if(impl->fresult)
|
||||
return impl->fresult(ins->get_shape()) + '(' + s + ')';
|
||||
else
|
||||
return s;
|
||||
});
|
||||
return f;
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
cpp_generator::to_args(const std::vector<instruction_ref>& inputs,
|
||||
const std::unordered_map<instruction_ref, std::string>& names)
|
||||
{
|
||||
std::vector<std::string> args;
|
||||
std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) {
|
||||
return names.at(i);
|
||||
});
|
||||
return args;
|
||||
}
|
||||
|
||||
cpp_generator::function cpp_generator::generate_module(const module& m)
|
||||
{
|
||||
return this->generate_module(m, [&](auto ins, const auto& names) {
|
||||
return this->generate_point_op(ins->get_operator(), to_args(ins->inputs(), names));
|
||||
});
|
||||
}
|
||||
|
||||
std::string cpp_generator::create_function(const cpp_generator::function& f)
|
||||
{
|
||||
impl->function_count++;
|
||||
if(not f.tparams.empty())
|
||||
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
|
||||
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
|
||||
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
|
||||
char delim = '(';
|
||||
if(f.params.empty())
|
||||
impl->fs << delim;
|
||||
for(auto&& p : f.params)
|
||||
{
|
||||
impl->fs << delim << p.type << " " << to_c_id(p.name);
|
||||
delim = ',';
|
||||
}
|
||||
impl->fs << ") {\n" << f.body << "\n}\n";
|
||||
return name;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
86
docker/rocm/migraphx/dead_code_elimination.cpp
Normal file
86
docker/rocm/migraphx/dead_code_elimination.cpp
Normal file
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
|
||||
|
||||
void dead_code_elimination::apply(module& m) const
|
||||
{
|
||||
auto last = std::prev(m.end());
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
// Skip the first instruction, since we always process the previous
|
||||
// instruction
|
||||
if(ins == m.begin())
|
||||
continue;
|
||||
const auto i = std::prev(ins);
|
||||
// Skip the last instruction
|
||||
if(i == last)
|
||||
break;
|
||||
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
|
||||
// identity, allocate or tuple_type]
|
||||
if((not i->get_shape().dynamic() and
|
||||
(i->get_shape().elements() == 0 and
|
||||
i->get_shape().type() != migraphx::shape::tuple_type)) and
|
||||
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
|
||||
not i->is_undefined())
|
||||
continue;
|
||||
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
|
||||
std::unordered_set<instruction_ref> visited;
|
||||
fix([&](auto self, auto leaf) {
|
||||
if(not m.has_instruction(leaf))
|
||||
return;
|
||||
|
||||
if(leaf->outputs().empty())
|
||||
{
|
||||
// Dont visit inputs twice
|
||||
if(not visited.insert(leaf).second)
|
||||
return;
|
||||
std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
|
||||
leaf->inputs().end());
|
||||
leaf->clear_arguments();
|
||||
assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
|
||||
assert(leaf != ins);
|
||||
if(leaf->name() != "@param")
|
||||
m.move_instruction(leaf, m.end());
|
||||
for(auto arg : args)
|
||||
self(arg);
|
||||
}
|
||||
})(i);
|
||||
}
|
||||
m.remove_instructions(std::next(last), m.end());
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
100
docker/rocm/migraphx/dom_info.cpp
Normal file
100
docker/rocm/migraphx/dom_info.cpp
Normal file
@ -0,0 +1,100 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/dom_info.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/erase.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2) const
|
||||
{
|
||||
if(ins1 == ins2)
|
||||
return false;
|
||||
auto iter = ins2idom.find(ins2);
|
||||
while(iter != ins2idom.end())
|
||||
{
|
||||
if(ins1 == iter->second)
|
||||
return true;
|
||||
assert(iter != ins2idom.find(iter->second));
|
||||
iter = ins2idom.find(iter->second);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct module_visitor
|
||||
{
|
||||
const module* mm;
|
||||
const module& get_nodes() const { return *mm; }
|
||||
|
||||
const std::vector<instruction_ref>& get_children(instruction_ref ins) { return ins->inputs(); }
|
||||
};
|
||||
|
||||
template <class Visitor>
|
||||
dominator_info compute_dominator_generic(Visitor v)
|
||||
{
|
||||
dominator_info info;
|
||||
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> instr2_doms;
|
||||
for(instruction_ref ins : iterator_for(v.get_nodes()))
|
||||
{
|
||||
const std::vector<instruction_ref>& children = v.get_children(ins);
|
||||
if(children.size() == 1)
|
||||
{
|
||||
info.ins2idom[ins] = children.front();
|
||||
instr2_doms[ins] = instr2_doms[children.front()];
|
||||
}
|
||||
else if(children.size() > 1)
|
||||
{
|
||||
auto&& doms = instr2_doms[ins];
|
||||
|
||||
doms = instr2_doms[children.front()];
|
||||
std::for_each(children.begin() + 1, children.end(), [&](instruction_ref child) {
|
||||
auto&& child_doms = instr2_doms[child];
|
||||
erase_if(doms, [&](auto x) { return not contains(child_doms, x); });
|
||||
});
|
||||
auto iter = std::find_if(doms.begin(), doms.end(), [&](auto dom1) {
|
||||
return std::none_of(doms.begin(), doms.end(), [&](auto dom2) {
|
||||
if(dom1 == dom2)
|
||||
return false;
|
||||
return info.strictly_dominate(dom1, dom2);
|
||||
});
|
||||
});
|
||||
if(iter != doms.end())
|
||||
info.ins2idom[ins] = *iter;
|
||||
}
|
||||
instr2_doms[ins].insert(ins);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
dominator_info compute_dominator(const module& m)
|
||||
{
|
||||
return compute_dominator_generic(module_visitor{&m});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
203
docker/rocm/migraphx/dynamic_loader.cpp
Normal file
203
docker/rocm/migraphx/dynamic_loader.cpp
Normal file
@ -0,0 +1,203 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/manage_ptr.hpp>
|
||||
#include <migraphx/dynamic_loader.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/file_buffer.hpp>
|
||||
#include <migraphx/tmp_dir.hpp>
|
||||
#include <utility>
|
||||
|
||||
#ifdef _WIN32
|
||||
// cppcheck-suppress definePrefix
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <Windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifndef _WIN32
|
||||
|
||||
void check_load_error(bool flush = false)
|
||||
{
|
||||
char* error_msg = dlerror();
|
||||
if(not flush and error_msg != nullptr)
|
||||
MIGRAPHX_THROW("Dynamic loading or symbol lookup failed with " + std::string(error_msg));
|
||||
}
|
||||
|
||||
struct dynamic_loader_impl
|
||||
{
|
||||
dynamic_loader_impl() = default;
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
#endif
|
||||
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
|
||||
: handle(dlopen(p.string().c_str(), RTLD_GLOBAL | RTLD_NOW),
|
||||
manage_deleter<decltype(&dlclose), &dlclose>{}),
|
||||
temp(std::move(t))
|
||||
{
|
||||
check_load_error();
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
|
||||
{
|
||||
auto t = std::make_shared<tmp_dir>("dloader");
|
||||
auto f = t->path / "libtmp.so";
|
||||
write_buffer(f, image, size);
|
||||
return std::make_shared<dynamic_loader_impl>(f, t);
|
||||
}
|
||||
|
||||
std::shared_ptr<void> handle = nullptr;
|
||||
std::shared_ptr<tmp_dir> temp = nullptr;
|
||||
};
|
||||
|
||||
fs::path dynamic_loader::path(void* address)
|
||||
{
|
||||
fs::path p;
|
||||
Dl_info info;
|
||||
// Find the location of .so
|
||||
if(dladdr(address, &info) != 0)
|
||||
p = info.dli_fname;
|
||||
return p;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
struct dynamic_loader_impl
|
||||
{
|
||||
dynamic_loader_impl() = default;
|
||||
dynamic_loader_impl(const fs::path& p, tmp_dir t = {})
|
||||
: handle{LoadLibrary(p.string().c_str())}, temp{std::move(t)}
|
||||
{
|
||||
if(handle == nullptr)
|
||||
{
|
||||
MIGRAPHX_THROW("Error loading DLL: " + p.string() + " (" +
|
||||
std::to_string(GetLastError()) + ")");
|
||||
}
|
||||
}
|
||||
|
||||
dynamic_loader_impl(const dynamic_loader_impl&) = delete;
|
||||
dynamic_loader_impl& operator=(const dynamic_loader_impl&) = delete;
|
||||
|
||||
dynamic_loader_impl(dynamic_loader_impl&&) = default;
|
||||
|
||||
~dynamic_loader_impl()
|
||||
{
|
||||
if(handle != nullptr)
|
||||
{
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
}
|
||||
|
||||
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
|
||||
{
|
||||
auto t = tmp_dir{"migx-dynload"};
|
||||
auto f = t.path / "tmp.dll";
|
||||
write_buffer(f, image, size);
|
||||
return std::make_shared<dynamic_loader_impl>(f, std::move(t));
|
||||
}
|
||||
|
||||
HMODULE handle = nullptr;
|
||||
tmp_dir temp;
|
||||
};
|
||||
|
||||
fs::path dynamic_loader::path(void* address)
|
||||
{
|
||||
HMODULE module = nullptr;
|
||||
if(GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
|
||||
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
|
||||
static_cast<LPCSTR>(address),
|
||||
&module) == 0)
|
||||
{
|
||||
auto err = GetLastError();
|
||||
MIGRAPHX_THROW("Unable to obtain module handle, error = " + std::to_string(err));
|
||||
}
|
||||
TCHAR buffer[MAX_PATH];
|
||||
if(GetModuleFileName(module, buffer, sizeof(buffer)) == 0)
|
||||
{
|
||||
auto err = GetLastError();
|
||||
MIGRAPHX_THROW("Unable to read module file path, error = " + std::to_string(err));
|
||||
}
|
||||
if(GetLastError() == ERROR_INSUFFICIENT_BUFFER)
|
||||
{
|
||||
MIGRAPHX_THROW("Buffer too small (" + std::to_string(MAX_PATH) + ") to hold the path");
|
||||
}
|
||||
return {buffer};
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
|
||||
{
|
||||
try
|
||||
{
|
||||
return dynamic_loader{p};
|
||||
}
|
||||
catch(const std::exception&)
|
||||
{
|
||||
return nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
dynamic_loader::dynamic_loader(const fs::path& p) : impl(std::make_shared<dynamic_loader_impl>(p))
|
||||
{
|
||||
}
|
||||
|
||||
dynamic_loader::dynamic_loader(const char* image, std::size_t size)
|
||||
: impl(dynamic_loader_impl::from_buffer(image, size))
|
||||
{
|
||||
}
|
||||
|
||||
dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
|
||||
: impl(dynamic_loader_impl::from_buffer(buffer.data(), buffer.size()))
|
||||
{
|
||||
}
|
||||
|
||||
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
|
||||
{
|
||||
#ifndef _WIN32
|
||||
// flush any previous error messages
|
||||
check_load_error(true);
|
||||
void* symbol = dlsym(impl->handle.get(), name.c_str());
|
||||
if(symbol == nullptr)
|
||||
check_load_error();
|
||||
return {impl, symbol};
|
||||
#else
|
||||
FARPROC addr = GetProcAddress(impl->handle, name.c_str());
|
||||
if(addr == nullptr)
|
||||
MIGRAPHX_THROW("Symbol not found: " + name + " (" + std::to_string(GetLastError()) + ")");
|
||||
return {impl, reinterpret_cast<void*>(addr)};
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
66
docker/rocm/migraphx/eliminate_allocation.cpp
Normal file
66
docker/rocm/migraphx/eliminate_allocation.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_allocation.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/serialize.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void eliminate_allocation::apply(module& m) const
|
||||
{
|
||||
assert(alignment > 0);
|
||||
|
||||
std::size_t n = 0;
|
||||
std::vector<std::pair<instruction_ref, std::size_t>> allocs;
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != allocation_op)
|
||||
continue;
|
||||
allocs.emplace_back(ins, n);
|
||||
std::size_t size = ins->get_shape().bytes();
|
||||
std::size_t padding = (alignment - (size % alignment)) % alignment;
|
||||
n += size + padding;
|
||||
}
|
||||
if(n > 0)
|
||||
{
|
||||
auto mem = m.add_parameter("memory", shape{shape::int8_type, {n}});
|
||||
for(auto&& pp : allocs)
|
||||
{
|
||||
auto ins = pp.first;
|
||||
auto s = ins->get_shape();
|
||||
auto offset = pp.second;
|
||||
m.replace_instruction(
|
||||
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
76
docker/rocm/migraphx/eliminate_common_subexpression.cpp
Normal file
76
docker/rocm/migraphx/eliminate_common_subexpression.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_common_subexpression.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class Range>
|
||||
void cse_range(module& m, Range&& r)
|
||||
{
|
||||
std::unordered_multimap<std::string, instruction_ref> instructions;
|
||||
std::unordered_set<instruction_ref> processed_ins;
|
||||
for(auto ins : r)
|
||||
{
|
||||
// Skip dead instructions
|
||||
if(ins->outputs().empty())
|
||||
continue;
|
||||
|
||||
// Find instruction with the same name
|
||||
auto found_instructions = range(instructions.equal_range(ins->name()));
|
||||
for(const auto& pp : found_instructions)
|
||||
{
|
||||
auto eq = pp.second;
|
||||
if(contains(processed_ins, eq))
|
||||
continue;
|
||||
if(*eq != *ins)
|
||||
continue;
|
||||
m.replace_instruction(ins, eq);
|
||||
processed_ins.emplace(ins);
|
||||
std::vector<instruction_ref> outputs;
|
||||
std::copy_if(eq->outputs().begin(),
|
||||
eq->outputs().end(),
|
||||
std::back_inserter(outputs),
|
||||
[&](auto x) { return m.has_instruction(x); });
|
||||
|
||||
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
|
||||
return std::distance(eq, x) < std::distance(eq, y);
|
||||
});
|
||||
cse_range(m, outputs);
|
||||
}
|
||||
instructions.emplace(ins->name(), ins);
|
||||
}
|
||||
}
|
||||
|
||||
void eliminate_common_subexpression::apply(module& m) const { cse_range(m, iterator_for(m)); }
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
110
docker/rocm/migraphx/eliminate_concat.cpp
Normal file
110
docker/rocm/migraphx/eliminate_concat.cpp
Normal file
@ -0,0 +1,110 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <iterator>
|
||||
#include <migraphx/eliminate_concat.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/op/load.hpp>
|
||||
#include <migraphx/op/identity.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
#include <migraphx/dfor.hpp>
|
||||
#include <migraphx/tune_axis.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
void eliminate_concat::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
auto concat_op = concat_opt.get_concat(ins->get_operator());
|
||||
// Look for the concat operator
|
||||
if(not concat_op.has_value())
|
||||
continue;
|
||||
// If any inputs are builtin or context free then abort
|
||||
// If any inputs are used more than once, then abort since there could
|
||||
// be errors due to aliasing
|
||||
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto arg) {
|
||||
return arg->name().front() == '@' or
|
||||
(arg->get_operator().is_context_free() and
|
||||
not contains({"concat", "identity"}, arg->name())) or
|
||||
arg->outputs().size() > 1;
|
||||
}))
|
||||
continue;
|
||||
// We can only do this optimization when concat axis is either the leftmost
|
||||
// axis OR the sizes to the left of this axis are all equal to 1
|
||||
// Since we've already checked that the non-axis dimensions are identical
|
||||
// we only need to check the first input
|
||||
auto lens = ins->inputs().front()->get_shape().lens();
|
||||
std::size_t axis_index = tune_axis(lens.size(), concat_op->axis, concat_op->name());
|
||||
if(axis_index == 0 or
|
||||
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
|
||||
{
|
||||
// Last input should be an allocation
|
||||
auto last = ins->inputs().back();
|
||||
if(last->name() != concat_opt.allocate())
|
||||
continue;
|
||||
// Where are the allocations for the tensors to be concatenated?
|
||||
std::vector<instruction_ref> allocations;
|
||||
|
||||
std::transform(
|
||||
ins->inputs().begin(),
|
||||
std::prev(ins->inputs().end()),
|
||||
std::back_inserter(allocations),
|
||||
[&](instruction_ref x) { return instruction::get_output_alias(x, true); });
|
||||
|
||||
if(std::any_of(allocations.begin(), allocations.end(), [&](auto x) {
|
||||
return x->name() != concat_opt.allocate();
|
||||
}))
|
||||
continue;
|
||||
|
||||
// Need to sort the allocations, so that we know where to
|
||||
// insert the "super"-allocation
|
||||
auto sorted_allocations = allocations;
|
||||
std::sort(sorted_allocations.begin(),
|
||||
sorted_allocations.end(),
|
||||
[&](instruction_ref x, instruction_ref y) {
|
||||
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
|
||||
});
|
||||
// Move "super" allocation to the front
|
||||
auto first = sorted_allocations.front();
|
||||
auto super = m.move_instruction(last, first);
|
||||
// Replace each allocation with a load
|
||||
std::size_t offset = 0;
|
||||
for(auto alloc : allocations)
|
||||
{
|
||||
op::load op{alloc->get_shape(), offset};
|
||||
m.replace_instruction(alloc, op, {super});
|
||||
offset += alloc->get_shape().bytes();
|
||||
}
|
||||
std::vector<instruction_ref> args = {super};
|
||||
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
|
||||
m.replace_instruction(ins, migraphx::make_op("identity"), args);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
195
docker/rocm/migraphx/eliminate_contiguous.cpp
Normal file
195
docker/rocm/migraphx/eliminate_contiguous.cpp
Normal file
@ -0,0 +1,195 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_contiguous.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/op/contiguous.hpp>
|
||||
#include <migraphx/op/identity.hpp>
|
||||
#include <migraphx/par_for.hpp>
|
||||
#include <utility>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS)
|
||||
|
||||
static bool try_compute_shape(instruction_ref ins,
|
||||
const std::vector<shape>& inputs,
|
||||
const std::vector<module_ref>& mods)
|
||||
{
|
||||
try
|
||||
{
|
||||
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
|
||||
|
||||
// Cannot tell if a dynamic shape will need to be made contiguous
|
||||
if(new_shape.dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the output shape is a standard shape, no need to try its output
|
||||
if(new_shape.standard())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// if no changes for the shape, the contiguous can also be removed
|
||||
if(new_shape == ins->get_shape())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
auto outputs = ins->outputs();
|
||||
// If the current instruction has no output, it means it is the last
|
||||
// instruction and generates a non-standard output shape, and the last
|
||||
// output shape is different from the case with the contiguous operator
|
||||
if(outputs.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for(auto output : outputs)
|
||||
{
|
||||
auto args = output->inputs();
|
||||
std::vector<shape> input_shapes(args.size());
|
||||
std::transform(args.begin(), args.end(), input_shapes.begin(), [&](auto& arg) {
|
||||
return (arg == ins) ? new_shape : arg->get_shape();
|
||||
});
|
||||
|
||||
if(not try_compute_shape(output, input_shapes, output->module_inputs()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
|
||||
{
|
||||
std::cout << "Exception: " << e.what() << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
|
||||
{
|
||||
std::cout << "Unknown exception" << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool try_compute_shape(instruction_ref ins,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const std::vector<module_ref>& mods)
|
||||
{
|
||||
auto inputs = to_shapes(args);
|
||||
return try_compute_shape(ins, inputs, mods);
|
||||
}
|
||||
|
||||
template <class F>
|
||||
static void remove_contiguous(const std::string& op_name, module& m, F f)
|
||||
{
|
||||
auto last = std::prev(m.end());
|
||||
std::vector<instruction_ref> const_instructions;
|
||||
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
// return instruction should have inputs with standard shape
|
||||
if(ins->name() == "@return")
|
||||
continue;
|
||||
|
||||
if(ins != last and ins->outputs().empty())
|
||||
continue;
|
||||
|
||||
if(not f(ins))
|
||||
continue;
|
||||
|
||||
auto args = ins->inputs();
|
||||
auto mod_args = ins->module_inputs();
|
||||
|
||||
for(auto arg : ins->inputs())
|
||||
{
|
||||
if(arg->name() != op_name)
|
||||
continue;
|
||||
if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{}))
|
||||
{
|
||||
std::cout << "eliminate_contiguous: ";
|
||||
m.debug_print(ins);
|
||||
}
|
||||
auto prev = arg->inputs().front();
|
||||
// create copy of args each time as they are modified inside the loop
|
||||
auto new_args = ins->inputs();
|
||||
replace(new_args, arg, prev);
|
||||
if(try_compute_shape(ins, new_args, mod_args))
|
||||
{
|
||||
instruction::replace_argument(ins, arg, prev);
|
||||
}
|
||||
else if(prev->can_eval())
|
||||
{
|
||||
const_instructions.push_back(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform static contiguous evaluations in parallel
|
||||
std::vector<argument> literals(const_instructions.size());
|
||||
par_for(const_instructions.size(), 1, [&](const auto i) {
|
||||
auto c = op::contiguous{};
|
||||
auto prev = const_instructions[i]->inputs().front();
|
||||
// compute the output contiguous shape from the previous instruction shape
|
||||
shape computed_shape = c.compute_shape({prev->get_shape()});
|
||||
const std::vector<argument>& prev_eval = {prev->eval()};
|
||||
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
|
||||
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
|
||||
literals[i] = c.compute(co_shape, prev_eval);
|
||||
});
|
||||
|
||||
// Replace static contiguous operations with a literal
|
||||
for(size_t i = 0; i < const_instructions.size(); i++)
|
||||
{
|
||||
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
|
||||
m.replace_instruction(const_instructions[i], l);
|
||||
}
|
||||
}
|
||||
|
||||
void eliminate_contiguous::apply(module& m) const
|
||||
{
|
||||
// Skip contiguous from splits first
|
||||
remove_contiguous(op_name, m, [](auto ins) {
|
||||
if(ins->name() != "slice")
|
||||
return true;
|
||||
return (ins->inputs().front()->outputs().size() == 1);
|
||||
});
|
||||
remove_contiguous(op_name, m, [](auto) { return true; });
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
81
docker/rocm/migraphx/eliminate_convert.cpp
Normal file
81
docker/rocm/migraphx/eliminate_convert.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
,*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_convert.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* Matches with some sequence of sequential convert instructions.
|
||||
* If input to the sequence of converts has the same shape as the last convert,
|
||||
* replace last convert with the input.
|
||||
* If input to the sequence is not the same shape as the last convert,
|
||||
* replace last convert with convert from the input to the last shape.
|
||||
*/
|
||||
struct find_nested_convert
|
||||
{
|
||||
auto matcher() const { return match::name("convert")(match::arg(0)(match::name("convert"))); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto matched_ins = mr.result;
|
||||
auto prev_convert = matched_ins->inputs().front();
|
||||
auto input = prev_convert->inputs().front();
|
||||
while(input->name() == "convert")
|
||||
{
|
||||
input = input->inputs().front();
|
||||
}
|
||||
if(matched_ins->get_shape() == input->get_shape())
|
||||
{
|
||||
m.replace_instruction(matched_ins, input);
|
||||
}
|
||||
else
|
||||
{
|
||||
m.replace_instruction(matched_ins, matched_ins->get_operator(), input);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct find_nop_converts
|
||||
{
|
||||
auto matcher() const { return match::name("convert")(match::same_shape(match::arg(0))); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
m.replace_instruction(ins, ins->inputs().front());
|
||||
}
|
||||
};
|
||||
|
||||
void eliminate_convert::apply(module& m) const
|
||||
{
|
||||
match::find_matches(m, find_nested_convert{}, find_nop_converts{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
127
docker/rocm/migraphx/eliminate_data_type.cpp
Normal file
127
docker/rocm/migraphx/eliminate_data_type.cpp
Normal file
@ -0,0 +1,127 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_data_type.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void insert_convert_to_supported_type(module& m,
|
||||
instruction_ref ins,
|
||||
migraphx::shape::type_t target_type,
|
||||
std::set<migraphx::shape::type_t> unsupported_types)
|
||||
{
|
||||
migraphx::shape::type_t orig_type = ins->get_shape().type();
|
||||
std::vector<instruction_ref> inputs = ins->inputs();
|
||||
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) {
|
||||
if(contains(unsupported_types, i->get_shape().type()))
|
||||
{
|
||||
return m.insert_instruction(
|
||||
ins,
|
||||
migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}),
|
||||
i);
|
||||
}
|
||||
else
|
||||
{
|
||||
return i;
|
||||
}
|
||||
});
|
||||
// if no change
|
||||
if(inputs == ins->inputs())
|
||||
return;
|
||||
auto op = ins->get_operator();
|
||||
auto attributes = op.attributes();
|
||||
if(attributes.contains("general_data_type"))
|
||||
{
|
||||
op = make_op(attributes["general_data_type"].to<std::string>(), op.to_value());
|
||||
}
|
||||
auto new_ins = m.insert_instruction(ins, op, inputs);
|
||||
if(orig_type == shape::tuple_type)
|
||||
{
|
||||
auto orig_outs = ins->outputs();
|
||||
if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) {
|
||||
return out_ins->name() == "get_tuple_elem";
|
||||
}))
|
||||
MIGRAPHX_THROW(
|
||||
"eliminate_data_type: Instruction with tuple output doesn't have all its "
|
||||
"usages as get_tuple_elem instruction");
|
||||
|
||||
std::transform(
|
||||
orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) {
|
||||
auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins);
|
||||
auto orig_out_type = out_ins->get_shape().type();
|
||||
if(contains(unsupported_types, orig_out_type))
|
||||
{
|
||||
auto gte_convert = m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins);
|
||||
return m.replace_instruction(out_ins, gte_convert);
|
||||
}
|
||||
else
|
||||
{
|
||||
return m.replace_instruction(out_ins, gte_ins);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
auto convert_back_ins = m.insert_instruction(
|
||||
ins,
|
||||
migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}),
|
||||
new_ins);
|
||||
m.replace_instruction(ins, convert_back_ins);
|
||||
}
|
||||
}
|
||||
|
||||
void eliminate_data_type::apply(module& m) const
|
||||
{
|
||||
static const std::vector<std::string> skip_op_names = {"convert",
|
||||
"get_tuple_elem",
|
||||
"if",
|
||||
"loop",
|
||||
"roialign",
|
||||
"nonmaxsuppression",
|
||||
"scatternd_add",
|
||||
"scatternd_mul",
|
||||
"scatternd_none",
|
||||
"select_module"};
|
||||
if(unsupported_types.empty())
|
||||
return;
|
||||
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name()[0] == '@')
|
||||
continue;
|
||||
if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name()))
|
||||
continue;
|
||||
if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name()))
|
||||
insert_convert_to_supported_type(m, ins, target_type, unsupported_types);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
70
docker/rocm/migraphx/eliminate_identity.cpp
Normal file
70
docker/rocm/migraphx/eliminate_identity.cpp
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_identity.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <utility>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void eliminate_identity::apply(module& m) const
|
||||
{
|
||||
auto last = std::prev(m.end());
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
// Skip the first instruction, since we always process the previous
|
||||
// instruction
|
||||
if(ins == m.begin())
|
||||
continue;
|
||||
const auto i = std::prev(ins);
|
||||
|
||||
if(i->name() == "identity")
|
||||
{
|
||||
m.replace_instruction(i, i->inputs().front());
|
||||
m.move_instruction(i, m.end());
|
||||
}
|
||||
if(ins == last)
|
||||
{
|
||||
if(ins->name() == "identity")
|
||||
{
|
||||
const instruction_ref& identity_input = ins->inputs().front();
|
||||
if(identity_input->outputs().size() == 1)
|
||||
{
|
||||
m.move_instruction(identity_input, last);
|
||||
// since this is the last instruction, removing it only
|
||||
// requires changing "last" and calling remove below
|
||||
last = std::prev(last);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
m.remove_instructions(std::next(last), m.end());
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
114
docker/rocm/migraphx/eliminate_pad.cpp
Normal file
114
docker/rocm/migraphx/eliminate_pad.cpp
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/eliminate_pad.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/op/convolution.hpp>
|
||||
#include <migraphx/op/im2col.hpp>
|
||||
#include <migraphx/op/pooling.hpp>
|
||||
#include <migraphx/op/pad.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
|
||||
{
|
||||
auto pad_op = any_cast<op::pad>(input->get_operator());
|
||||
|
||||
auto kdims = input->get_shape().lens().size() - 2;
|
||||
auto kdims_it = pad_op.pads.begin() + 2;
|
||||
|
||||
std::vector<size_t> pads_l(kdims_it, kdims_it + kdims);
|
||||
std::vector<size_t> pads_r(kdims_it + kdims + 2, pad_op.pads.end());
|
||||
|
||||
auto op = ins->get_operator();
|
||||
std::vector<size_t> padding(kdims * 2, 0);
|
||||
|
||||
std::transform(
|
||||
pads_l.begin(), pads_l.end(), padding.begin(), padding.begin(), std::plus<size_t>());
|
||||
std::transform(pads_r.begin(),
|
||||
pads_r.end(),
|
||||
padding.begin() + kdims,
|
||||
padding.begin() + kdims,
|
||||
std::plus<size_t>());
|
||||
|
||||
op.from_value({{"padding", padding}});
|
||||
|
||||
std::vector<instruction_ref> new_inputs{ins->inputs()};
|
||||
new_inputs.front() = input->inputs().front();
|
||||
|
||||
m.replace_instruction(ins, op, new_inputs);
|
||||
}
|
||||
|
||||
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
|
||||
{
|
||||
auto op = any_cast<op::pooling>(ins->get_operator());
|
||||
if(op.mode == op::pooling_mode::average)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto pad_op = any_cast<op::pad>(input->get_operator());
|
||||
|
||||
auto kdims = input->get_shape().lens().size() - 2;
|
||||
auto kdims_it = pad_op.pads.begin() + 2;
|
||||
|
||||
std::vector<size_t> pads_l(kdims_it, kdims_it + kdims);
|
||||
std::vector<size_t> pads_r(kdims_it + kdims + 2, pad_op.pads.end());
|
||||
|
||||
std::transform(
|
||||
pads_l.begin(), pads_l.end(), op.padding.begin(), op.padding.begin(), std::plus<size_t>());
|
||||
std::transform(pads_r.begin(),
|
||||
pads_r.end(),
|
||||
op.padding.begin() + kdims,
|
||||
op.padding.begin() + kdims,
|
||||
std::plus<size_t>());
|
||||
|
||||
std::vector<instruction_ref> new_inputs{ins->inputs()};
|
||||
new_inputs.front() = input->inputs().front();
|
||||
|
||||
m.replace_instruction(ins, op, new_inputs);
|
||||
}
|
||||
|
||||
void eliminate_pad::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
const std::string& op_name = ins->name();
|
||||
if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
|
||||
continue;
|
||||
auto input = ins->inputs().front();
|
||||
if(input->name() != "pad")
|
||||
continue;
|
||||
if(op_name == "convolution" or op_name == "im2col")
|
||||
update_op(input, ins, m);
|
||||
else if(op_name == "pooling")
|
||||
update_pooling(input, ins, m);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
73
docker/rocm/migraphx/env.cpp
Normal file
73
docker/rocm/migraphx/env.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/env.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
bool enabled(const char* name)
|
||||
{
|
||||
auto e = env(name);
|
||||
if(e.empty())
|
||||
return false;
|
||||
return contains({"1", "enable", "enabled", "yes", "true"}, e.front());
|
||||
}
|
||||
|
||||
bool disabled(const char* name)
|
||||
{
|
||||
auto e = env(name);
|
||||
if(e.empty())
|
||||
return false;
|
||||
return contains({"0", "disable", "disabled", "no", "false"}, e.front());
|
||||
}
|
||||
|
||||
std::size_t value_of(const char* name, std::size_t fallback)
|
||||
{
|
||||
auto e = env(name);
|
||||
if(e.empty())
|
||||
return fallback;
|
||||
return std::stoul(e.front());
|
||||
}
|
||||
|
||||
std::string string_value_of(const char* name, std::string fallback)
|
||||
{
|
||||
auto e = env(name);
|
||||
if(e.empty())
|
||||
return fallback;
|
||||
return e.front();
|
||||
}
|
||||
|
||||
std::vector<std::string> env(const char* name)
|
||||
{
|
||||
auto* p = std::getenv(name);
|
||||
if(p == nullptr)
|
||||
return {};
|
||||
else
|
||||
return {{p}};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
85
docker/rocm/migraphx/file_buffer.cpp
Normal file
85
docker/rocm/migraphx/file_buffer.cpp
Normal file
@ -0,0 +1,85 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/file_buffer.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/fileutils.hpp>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class T>
|
||||
T generic_read_file(const fs::path& filename, size_t offset = 0, size_t nbytes = 0)
|
||||
{
|
||||
std::ifstream is(filename, std::ios::binary | std::ios::ate);
|
||||
if(not is.is_open())
|
||||
MIGRAPHX_THROW("Failure opening file: " + filename);
|
||||
if(nbytes == 0)
|
||||
{
|
||||
// if there is a non-zero offset and nbytes is not set,
|
||||
// calculate size of remaining bytes to read
|
||||
nbytes = is.tellg();
|
||||
if(offset > nbytes)
|
||||
MIGRAPHX_THROW("offset is larger than file size");
|
||||
nbytes -= offset;
|
||||
}
|
||||
if(nbytes < 1)
|
||||
MIGRAPHX_THROW("Invalid size for: " + filename);
|
||||
is.seekg(offset, std::ios::beg);
|
||||
|
||||
T buffer(nbytes, 0);
|
||||
if(not is.read(&buffer[0], nbytes))
|
||||
MIGRAPHX_THROW("Error reading file: " + filename);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::vector<char> read_buffer(const fs::path& filename, size_t offset, size_t nbytes)
|
||||
{
|
||||
return generic_read_file<std::vector<char>>(filename, offset, nbytes);
|
||||
}
|
||||
|
||||
std::string read_string(const fs::path& filename)
|
||||
{
|
||||
return generic_read_file<std::string>(filename);
|
||||
}
|
||||
|
||||
void write_string(const fs::path& filename, const std::string& buffer)
|
||||
{
|
||||
write_buffer(filename, buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
void write_buffer(const fs::path& filename, const char* buffer, std::size_t size)
|
||||
{
|
||||
std::ofstream os(filename, std::ios::out | std::ios::binary);
|
||||
os.write(buffer, size);
|
||||
}
|
||||
|
||||
void write_buffer(const fs::path& filename, const std::vector<char>& buffer)
|
||||
{
|
||||
write_buffer(filename, buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
70
docker/rocm/migraphx/fileutils.cpp
Normal file
70
docker/rocm/migraphx/fileutils.cpp
Normal file
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/fileutils.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
#ifdef _WIN32
|
||||
constexpr std::string_view executable_postfix{".exe"};
|
||||
constexpr std::string_view library_prefix{""};
|
||||
constexpr std::string_view library_postfix{".dll"};
|
||||
constexpr std::string_view static_library_postfix{".lib"};
|
||||
constexpr std::string_view object_file_postfix{".obj"};
|
||||
#else
|
||||
constexpr std::string_view executable_postfix{""};
|
||||
constexpr std::string_view library_prefix{"lib"};
|
||||
constexpr std::string_view library_postfix{".so"};
|
||||
constexpr std::string_view static_library_postfix{".a"};
|
||||
constexpr std::string_view object_file_postfix{".o"};
|
||||
#endif
|
||||
|
||||
fs::path make_executable_filename(std::string_view name)
|
||||
{
|
||||
return std::string{name}.append(executable_postfix);
|
||||
}
|
||||
|
||||
fs::path make_shared_object_filename(std::string_view name)
|
||||
{
|
||||
return std::string{library_prefix}.append(name).append(library_postfix);
|
||||
}
|
||||
|
||||
fs::path make_object_file_filename(std::string_view name)
|
||||
{
|
||||
return std::string{name}.append(object_file_postfix);
|
||||
}
|
||||
|
||||
fs::path make_static_library_filename(std::string_view name)
|
||||
{
|
||||
return std::string{library_prefix}.append(name).append(static_library_postfix);
|
||||
}
|
||||
|
||||
fs::path append_extension(const fs::path& path, std::string_view ext)
|
||||
{
|
||||
return fs::path{path}.replace_extension(path.extension().string().append(ext));
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
178
docker/rocm/migraphx/fp8_ocp_to_fnuz.cpp
Normal file
178
docker/rocm/migraphx/fp8_ocp_to_fnuz.cpp
Normal file
@ -0,0 +1,178 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/fp8_ocp_to_fnuz.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/match/dq_helpers.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
namespace {
|
||||
|
||||
using fp8::fp8e4m3fnuz;
|
||||
|
||||
std::unordered_set<std::string> get_quantizable_op_names()
|
||||
{
|
||||
static std::unordered_set<std::string> s = {"convolution", "dot"};
|
||||
return s;
|
||||
}
|
||||
|
||||
struct match_fp8ocp_convert_to_fp8fnuz
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto dq1 = match::arg(0)(
|
||||
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
|
||||
auto dq2 = match::arg(1)(
|
||||
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
|
||||
return match::name(get_quantizable_op_names())(dq1, dq2);
|
||||
}
|
||||
|
||||
static auto bit_cast_and_handle_specials(module& m,
|
||||
const instruction_ref dq,
|
||||
const instruction_ref x,
|
||||
const instruction_ref bits_0x80_lit,
|
||||
const instruction_ref bits_0x7f_lit,
|
||||
const instruction_ref bits_0xff_lit,
|
||||
const instruction_ref bits_0x00_lit)
|
||||
{
|
||||
auto x_lens = x->get_shape().lens();
|
||||
auto cast_input = m.insert_instruction(
|
||||
dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x);
|
||||
auto mb_bits_0x80_lit = m.insert_instruction(
|
||||
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit);
|
||||
auto mb_bits_0x7f_lit = m.insert_instruction(
|
||||
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit);
|
||||
auto mb_bits_0xff_lit = m.insert_instruction(
|
||||
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit);
|
||||
auto mb_zero_lit = m.insert_instruction(
|
||||
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit);
|
||||
// negative zero in fp8e4m3fn to zero in fp8e4m3fnuz
|
||||
// a == 0x80 ? 0x0 : a
|
||||
auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit);
|
||||
auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input);
|
||||
|
||||
// positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz
|
||||
// (a == 0x7f or a == 0xff) ? 0x80 : a
|
||||
auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit);
|
||||
|
||||
auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit);
|
||||
|
||||
auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff);
|
||||
ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Add the same broadcast instructions after adjusted scales or
|
||||
// adjusted zero points from after the originals. Similar to
|
||||
// propagate_quantized_ins in simplify_qdq.
|
||||
static auto propagate_broadcasts(module& m,
|
||||
const instruction_ref adj,
|
||||
const instruction_ref ori,
|
||||
const instruction_ref start,
|
||||
const instruction_ref insert_pt)
|
||||
{
|
||||
auto prev_ins = start;
|
||||
std::vector<instruction_ref> ins_between;
|
||||
// matcher skips continguous, multi/broadcasts and transposes, collect all those
|
||||
// instructions
|
||||
while(prev_ins != ori)
|
||||
{
|
||||
ins_between.push_back(prev_ins);
|
||||
prev_ins = prev_ins->inputs().front();
|
||||
}
|
||||
auto ret = adj;
|
||||
for(auto ins : reverse_iterator_for(ins_between))
|
||||
{
|
||||
ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret});
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static auto cast_to_fnuz(module& m,
|
||||
const instruction_ref dq,
|
||||
const instruction_ref input,
|
||||
const instruction_ref dq_scale,
|
||||
const instruction_ref dq_zp)
|
||||
{
|
||||
auto x = input;
|
||||
std::vector<fp8e4m3fnuz> bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())};
|
||||
auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80);
|
||||
|
||||
std::vector<fp8e4m3fnuz> bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())};
|
||||
auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f);
|
||||
|
||||
std::vector<fp8e4m3fnuz> bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())};
|
||||
auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff);
|
||||
|
||||
std::vector<fp8e4m3fnuz> bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())};
|
||||
auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00);
|
||||
|
||||
x = bit_cast_and_handle_specials(
|
||||
m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);
|
||||
auto adj_dq_zp = bit_cast_and_handle_specials(
|
||||
m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);
|
||||
|
||||
// adj_scale = 2 * scale
|
||||
auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}});
|
||||
two_lit = m.insert_instruction(
|
||||
dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit);
|
||||
auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit);
|
||||
|
||||
adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq);
|
||||
adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq);
|
||||
m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp);
|
||||
}
|
||||
|
||||
auto apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto dq1 = r.instructions["dq1"];
|
||||
auto dq2 = r.instructions["dq2"];
|
||||
auto scale1 = r.instructions["scale1"];
|
||||
auto scale2 = r.instructions["scale2"];
|
||||
auto zp1 = r.instructions["zp1"];
|
||||
auto zp2 = r.instructions["zp2"];
|
||||
|
||||
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fn_type};
|
||||
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
|
||||
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
|
||||
return;
|
||||
|
||||
cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1);
|
||||
cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
module_ref mm = &mpm.get_module();
|
||||
match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
40
docker/rocm/migraphx/fp_to_double.cpp
Normal file
40
docker/rocm/migraphx/fp_to_double.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/fp_to_double.hpp>
|
||||
#include <migraphx/eliminate_data_type.hpp>
|
||||
#include <migraphx/eliminate_convert.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void fp_to_double::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
mpm.run_pass(eliminate_data_type{convert_fp_types, shape::type_t::double_type});
|
||||
mpm.run_pass(eliminate_convert{});
|
||||
mpm.run_pass(migraphx::dead_code_elimination{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
242
docker/rocm/migraphx/fuse_concat.cpp
Normal file
242
docker/rocm/migraphx/fuse_concat.cpp
Normal file
@ -0,0 +1,242 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/fuse_concat.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/algorithm.hpp>
|
||||
#include <migraphx/check_shapes.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
unsigned int get_noop_counter()
|
||||
{
|
||||
static unsigned int counter = 0;
|
||||
return counter++;
|
||||
}
|
||||
|
||||
struct fused_concat
|
||||
{
|
||||
int64_t axis = 0;
|
||||
|
||||
std::string name() const { return "fused_concat"; }
|
||||
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.axis, "axis"));
|
||||
}
|
||||
|
||||
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
|
||||
{
|
||||
check_shapes{inputs, *this}.same_ndims();
|
||||
// original concat can have multiple inputs. Let's say it has `n` input args.
|
||||
// Each of those `n` input args are converted into pointwise modules that take atleast 1
|
||||
// input parameter. Fused concat will have `n+1` module arguments. `n+1`th module is the
|
||||
// post pointwise module which can take 0 or more input arguments.
|
||||
if((inputs.size() + 1) < mods.size())
|
||||
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules inputs parameters");
|
||||
auto input_iter = inputs.begin();
|
||||
std::vector<shape> concat_inputs;
|
||||
for(module_ref mod : range(mods.begin(), mods.end() - 1))
|
||||
{
|
||||
concat_inputs.push_back(*input_iter);
|
||||
input_iter += mod->get_parameter_names().size();
|
||||
}
|
||||
module_ref post_mod = mods.back();
|
||||
// post_mod has one input argument that is result of concat and will get generated from
|
||||
// pre-mods internally. Therefore deduct 1 from post_mod params while asserting.
|
||||
assert(input_iter + post_mod->get_parameter_names().size() - 1 == inputs.end());
|
||||
auto type = std::prev(post_mod->end())->get_shape().type();
|
||||
const auto& first_shape_lens = concat_inputs.front().lens();
|
||||
auto mismatch_it =
|
||||
std::find_if_not(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
|
||||
const auto& lens = s.lens();
|
||||
return std::equal(lens.begin(),
|
||||
lens.begin() + axis,
|
||||
first_shape_lens.begin(),
|
||||
first_shape_lens.begin() + axis) and
|
||||
std::equal(lens.begin() + axis + 1,
|
||||
lens.end(),
|
||||
first_shape_lens.begin() + axis + 1,
|
||||
first_shape_lens.end());
|
||||
});
|
||||
if(mismatch_it != concat_inputs.end())
|
||||
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis of " +
|
||||
std::to_string(axis) + ": {" + to_string_range(first_shape_lens) +
|
||||
"} != {" + to_string_range(mismatch_it->lens()) + "}");
|
||||
|
||||
std::size_t new_dim_axis = transform_accumulate(
|
||||
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
|
||||
return input.lens()[axis];
|
||||
});
|
||||
auto new_lens = concat_inputs.front().lens();
|
||||
new_lens[axis] = new_dim_axis;
|
||||
return shape::from_permutation(type, new_lens, find_permutation(inputs));
|
||||
}
|
||||
};
|
||||
MIGRAPHX_REGISTER_OP(fused_concat);
|
||||
|
||||
namespace {
|
||||
struct find_concat_pointwise
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto pointwise_used_once = match::name("pointwise")(match::used_once());
|
||||
return match::name("concat")(match::used_once(),
|
||||
match::any_of[match::inputs()](pointwise_used_once));
|
||||
}
|
||||
|
||||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
||||
{
|
||||
auto concat_ins = r.result;
|
||||
|
||||
std::vector<instruction_ref> inputs;
|
||||
size_t num_noops = 0;
|
||||
for(auto input : concat_ins->inputs())
|
||||
{
|
||||
if(input->name() == "pointwise" and input->outputs().size() == 1)
|
||||
{
|
||||
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
|
||||
}
|
||||
else
|
||||
{
|
||||
num_noops++;
|
||||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
if(num_noops > std::max(size_t{1}, concat_ins->inputs().size() / 4))
|
||||
{
|
||||
return;
|
||||
}
|
||||
std::vector<module_ref> module_inputs;
|
||||
std::transform(concat_ins->inputs().begin(),
|
||||
concat_ins->inputs().end(),
|
||||
std::back_inserter(module_inputs),
|
||||
[&](instruction_ref input) {
|
||||
if(input->name() == "pointwise" and input->outputs().size() == 1)
|
||||
{
|
||||
auto* pm = input->module_inputs().front();
|
||||
return mpm.create_module("concat:" + pm->name(), *pm);
|
||||
}
|
||||
auto* pm = mpm.create_module("concat:noop" +
|
||||
std::to_string(get_noop_counter()));
|
||||
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
|
||||
pm->add_return({x});
|
||||
return pm;
|
||||
});
|
||||
auto* post_pm = mpm.create_module("noop:concat" + std::to_string(get_noop_counter()));
|
||||
auto x = post_pm->add_parameter("!x0", shape{concat_ins->get_shape().type()});
|
||||
post_pm->add_return({x});
|
||||
module_inputs.push_back(post_pm);
|
||||
mpm.get_module().replace_instruction(
|
||||
concat_ins,
|
||||
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
|
||||
inputs,
|
||||
module_inputs);
|
||||
}
|
||||
};
|
||||
|
||||
struct find_pointwise_concat_pointwise
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto pointwise = match::name("pointwise")(match::used_once());
|
||||
auto concat =
|
||||
match::name("concat")(match::used_once(), match::any_of[match::inputs()](pointwise));
|
||||
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
|
||||
}
|
||||
|
||||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto concat_ins = r.instructions["concat"];
|
||||
|
||||
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
|
||||
ins->inputs().begin();
|
||||
std::vector<instruction_ref> inputs;
|
||||
for(auto input : concat_ins->inputs())
|
||||
{
|
||||
if(input->name() == "pointwise" and input->outputs().size() == 1)
|
||||
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
|
||||
else
|
||||
inputs.push_back(input);
|
||||
}
|
||||
std::copy_if(ins->inputs().begin(),
|
||||
ins->inputs().end(),
|
||||
std::back_inserter(inputs),
|
||||
[&](auto input) { return input != concat_ins; });
|
||||
|
||||
std::vector<module_ref> module_inputs;
|
||||
std::transform(concat_ins->inputs().begin(),
|
||||
concat_ins->inputs().end(),
|
||||
std::back_inserter(module_inputs),
|
||||
[&](instruction_ref input) {
|
||||
if(input->name() == "pointwise" and input->outputs().size() == 1)
|
||||
{
|
||||
auto* pm = input->module_inputs().front();
|
||||
return mpm.create_module("concat:" + pm->name(), *pm);
|
||||
}
|
||||
auto* pm = mpm.create_module("concat:noop" +
|
||||
std::to_string(get_noop_counter()));
|
||||
auto x = pm->add_parameter("x0", shape{input->get_shape().type()});
|
||||
pm->add_return({x});
|
||||
return pm;
|
||||
});
|
||||
|
||||
auto* post_pm = ins->module_inputs().front();
|
||||
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
|
||||
std::vector<std::string> names = rm->get_parameter_names();
|
||||
std::sort(names.begin(), names.end());
|
||||
auto concat_param_name = names[concat_arg];
|
||||
auto concat_param = rm->get_parameter(concat_param_name);
|
||||
auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
|
||||
rm->replace_instruction(concat_param, param);
|
||||
rm->remove_instruction(concat_param);
|
||||
|
||||
module_inputs.push_back(rm);
|
||||
mpm.get_module().replace_instruction(
|
||||
ins,
|
||||
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
|
||||
inputs,
|
||||
module_inputs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void fuse_concat::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
match::find_matches(mpm, find_pointwise_concat_pointwise{});
|
||||
mpm.run_pass(migraphx::dead_code_elimination{});
|
||||
match::find_matches(mpm, find_concat_pointwise{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
262
docker/rocm/migraphx/fuse_pointwise.cpp
Normal file
262
docker/rocm/migraphx/fuse_pointwise.cpp
Normal file
@ -0,0 +1,262 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/fuse_pointwise.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/eliminate_identity.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/param_utils.hpp>
|
||||
#include <migraphx/rewrite_reshapes.hpp>
|
||||
#include <iterator>
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static literal get_scalar(instruction_ref ins)
|
||||
{
|
||||
if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name()))
|
||||
return get_scalar(ins->inputs().front());
|
||||
const auto& s = ins->get_shape();
|
||||
if(s.elements() != 1 and not(s.scalar()))
|
||||
return {};
|
||||
if(not ins->can_eval())
|
||||
return {};
|
||||
auto e = ins->eval();
|
||||
literal r{};
|
||||
// needed for bool as visit_at invokes as() which promotes bool to int8
|
||||
// Without this we'll break type checks for logical ops that are fused.
|
||||
if(e.get_shape().type() == shape::bool_type)
|
||||
{
|
||||
r = literal{e.at<bool>()};
|
||||
}
|
||||
else
|
||||
{
|
||||
e.visit_at([&](auto x) { r = literal{x}; });
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
static void create_pointwise_modules(module_pass_manager& mpm)
|
||||
{
|
||||
std::size_t n = 0;
|
||||
for(auto ins : iterator_for(mpm.get_module()))
|
||||
{
|
||||
if(not ins->get_operator().attributes().get("pointwise", false))
|
||||
continue;
|
||||
if(ins->get_operator().name() == "layout")
|
||||
continue;
|
||||
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
|
||||
pm->set_bypass();
|
||||
|
||||
std::unordered_map<instruction_ref, instruction_ref> param_map;
|
||||
std::vector<instruction_ref> pointwise_inputs;
|
||||
std::size_t i = 0;
|
||||
|
||||
for(auto input : ins->inputs())
|
||||
{
|
||||
if(contains(param_map, input))
|
||||
continue;
|
||||
auto scalar = get_scalar(input);
|
||||
if(scalar.empty())
|
||||
{
|
||||
pointwise_inputs.push_back(input);
|
||||
param_map[input] =
|
||||
pm->add_parameter(param_name(i), shape{input->get_shape().type()});
|
||||
i++;
|
||||
}
|
||||
else
|
||||
{
|
||||
param_map[input] = pm->add_literal(scalar);
|
||||
}
|
||||
}
|
||||
|
||||
// Don't create pointwise module if no inputs are detected
|
||||
if(pointwise_inputs.empty())
|
||||
continue;
|
||||
|
||||
std::vector<instruction_ref> inputs;
|
||||
std::transform(ins->inputs().begin(),
|
||||
ins->inputs().end(),
|
||||
std::back_inserter(inputs),
|
||||
[&](auto input) { return param_map[input]; });
|
||||
auto r = pm->add_instruction(ins->get_operator(), inputs);
|
||||
pm->add_return({r});
|
||||
|
||||
mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm});
|
||||
}
|
||||
}
|
||||
|
||||
static module::with_inputs append_pointwise_module(instruction_ref ins, instruction_ref output)
|
||||
{
|
||||
assert(contains(output->inputs(), ins));
|
||||
module pm = *ins->module_inputs().at(0);
|
||||
module_ref xm = output->module_inputs().at(0);
|
||||
|
||||
auto last = std::prev(pm.end());
|
||||
assert(last->name() == "@return");
|
||||
assert(last->inputs().size() == 1);
|
||||
|
||||
assert(pm.get_parameter_names().size() == ins->inputs().size());
|
||||
assert(xm->get_parameter_names().size() == output->inputs().size());
|
||||
|
||||
std::vector<instruction_ref> inputs = ins->inputs();
|
||||
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
||||
std::unordered_map<instruction_ref, instruction_ref> input_map;
|
||||
// Copy inputs to input_map
|
||||
for(auto i : range(inputs.size()))
|
||||
{
|
||||
auto input = inputs[i];
|
||||
auto param = pm.get_parameter(param_name(i));
|
||||
assert(param != pm.end());
|
||||
input_map[input] = param;
|
||||
}
|
||||
// Add the new parameter and additional inputs
|
||||
for(auto i : range(output->inputs().size()))
|
||||
{
|
||||
auto input = output->inputs()[i];
|
||||
auto param = xm->get_parameter(param_name(i));
|
||||
assert(param != xm->end());
|
||||
if(input == ins)
|
||||
{
|
||||
map_ins[param] = last->inputs().front();
|
||||
input_map[input] = map_ins[param];
|
||||
}
|
||||
// Avoid duplicate paramter inputs
|
||||
else if(contains(input_map, input))
|
||||
{
|
||||
map_ins[param] = input_map[input];
|
||||
}
|
||||
else
|
||||
{
|
||||
map_ins[param] =
|
||||
pm.add_parameter(param_name(inputs.size()), {input->get_shape().type()});
|
||||
inputs.push_back(input);
|
||||
input_map[input] = map_ins[param];
|
||||
}
|
||||
}
|
||||
pm.replace_return(pm.insert_instructions(last, xm, &map_ins));
|
||||
return {std::move(pm), inputs};
|
||||
}
|
||||
|
||||
static bool find_pointwise_modules(module_pass_manager& mpm)
|
||||
{
|
||||
bool changed = false;
|
||||
auto last = std::prev(mpm.get_module().end());
|
||||
for(auto ins : iterator_for(mpm.get_module()))
|
||||
{
|
||||
if(ins->name() != "pointwise")
|
||||
continue;
|
||||
if(ins->outputs().empty() and ins != last)
|
||||
continue;
|
||||
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
|
||||
return i->name() == "pointwise" and i->outputs().size() == 1;
|
||||
});
|
||||
if(it == ins->inputs().end())
|
||||
continue;
|
||||
auto input = *it;
|
||||
|
||||
auto fused = append_pointwise_module(input, ins);
|
||||
auto name = fused.mod.name();
|
||||
mpm.rename_module(name, name + ":" + ins->module_inputs().front()->name() + "-deleted");
|
||||
auto* new_pm = mpm.create_module(name, std::move(fused.mod));
|
||||
mpm.get_module().replace_instruction(ins, input->get_operator(), fused.inputs, {new_pm});
|
||||
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct pointwise_reshape : rewrite_reshapes_base
|
||||
{
|
||||
static std::string name() { return "pointwise"; }
|
||||
};
|
||||
|
||||
struct pointwise_broadcast_pointwise
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto broadcast_pointwise =
|
||||
match::name("multibroadcast")(
|
||||
match::used_once(),
|
||||
match::args(match::name("pointwise")(match::used_once()).bind("x")))
|
||||
.bind("broadcast");
|
||||
return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto broadcast_ins = r.instructions["broadcast"];
|
||||
auto x_ins = r.instructions["x"];
|
||||
|
||||
auto broadcast = broadcast_ins->get_operator();
|
||||
|
||||
auto x_inputs = x_ins->inputs();
|
||||
std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), [&](auto input) {
|
||||
return m.insert_instruction(broadcast_ins, broadcast, input);
|
||||
});
|
||||
|
||||
m.replace_instruction(
|
||||
broadcast_ins, x_ins->get_operator(), x_inputs, x_ins->module_inputs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static void rewrite_broadcasts(module_pass_manager& mpm)
|
||||
{
|
||||
match::find_matches(mpm.get_module(), pointwise_broadcast_pointwise{});
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
}
|
||||
|
||||
void fuse_pointwise::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
mpm.run_pass(eliminate_identity{});
|
||||
create_pointwise_modules(mpm);
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
|
||||
{
|
||||
return;
|
||||
}
|
||||
for(int i = 0; i < 8; i++)
|
||||
{
|
||||
if(enable_rewrite_reshapes)
|
||||
mpm.run_pass(rewrite_reshapes<pointwise_reshape>{});
|
||||
if(enable_rewrite_broadcasts)
|
||||
rewrite_broadcasts(mpm);
|
||||
if(not find_pointwise_modules(mpm))
|
||||
break;
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
55
docker/rocm/migraphx/fuse_pointwise_reduce.cpp
Normal file
55
docker/rocm/migraphx/fuse_pointwise_reduce.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#include <migraphx/fuse_pointwise_reduce.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/fuse_pointwise.hpp>
|
||||
#include <migraphx/fuse_reduce.hpp>
|
||||
#include <migraphx/split_reduce.hpp>
|
||||
#include <migraphx/env.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SPLIT_REDUCE_SIZE);
|
||||
|
||||
static std::size_t get_split_size(std::size_t default_split)
|
||||
{
|
||||
std::string value = string_value_of(MIGRAPHX_SPLIT_REDUCE_SIZE{});
|
||||
if(value.empty())
|
||||
return default_split;
|
||||
return std::stoul(value);
|
||||
}
|
||||
|
||||
void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = false});
|
||||
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = false});
|
||||
mpm.run_pass(fuse_pointwise{.enable_rewrite_reshapes = true});
|
||||
mpm.run_pass(fuse_reduce{.enable_rewrite_reshapes = true});
|
||||
mpm.run_pass(split_reduce{.split_size = get_split_size(split_size)});
|
||||
mpm.run_pass(fuse_pointwise{.enable_rewrite_broadcasts = true});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
440
docker/rocm/migraphx/fuse_reduce.cpp
Normal file
440
docker/rocm/migraphx/fuse_reduce.cpp
Normal file
@ -0,0 +1,440 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/fuse_reduce.hpp>
|
||||
#include <migraphx/check_shapes.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/eliminate_common_subexpression.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
#include <migraphx/rewrite_reshapes.hpp>
|
||||
#include <migraphx/param_utils.hpp>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
|
||||
|
||||
struct fused_reduce
|
||||
{
|
||||
std::vector<std::int64_t> axes{};
|
||||
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.axes, "axes"));
|
||||
}
|
||||
|
||||
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
|
||||
{
|
||||
if(mods.size() != 1)
|
||||
MIGRAPHX_THROW("should have one submodule.");
|
||||
const auto* sm = mods.front();
|
||||
if(sm->get_output_shapes().size() != 1)
|
||||
MIGRAPHX_THROW("Only one output supported");
|
||||
if(not sm->bypass())
|
||||
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
|
||||
auto names = sm->get_parameter_names();
|
||||
check_shapes{inputs, *this}.has(names.size()).same_ndims();
|
||||
std::sort(names.begin(), names.end());
|
||||
auto shapes = sm->get_parameter_shapes();
|
||||
// Check dimension matches for each input
|
||||
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
|
||||
return shapes.at(name).lens() == input.lens();
|
||||
}))
|
||||
MIGRAPHX_THROW("Input dimension does not match the submodule.");
|
||||
|
||||
return shape::from_permutation(sm->get_output_shapes().front().type(),
|
||||
sm->get_output_shapes().front().lens(),
|
||||
find_permutation(inputs));
|
||||
}
|
||||
|
||||
std::string name() const { return "fused_reduce"; }
|
||||
};
|
||||
MIGRAPHX_REGISTER_OP(fused_reduce);
|
||||
|
||||
/*
|
||||
* Predicate matcher checks that input and output shapes have the same rank. This is assumed
|
||||
* for broadcast instructions for these fusions.
|
||||
*/
|
||||
MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins)
|
||||
{
|
||||
auto input_shape = ins->inputs().front()->get_shape();
|
||||
auto output_shape = ins->get_shape();
|
||||
return input_shape.ndim() == output_shape.ndim();
|
||||
}
|
||||
|
||||
static auto
|
||||
insert_module_in_submodule(module_ref sm,
|
||||
instruction_ref ins,
|
||||
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
|
||||
module::inserter insert = nullptr)
|
||||
{
|
||||
assert(ins->module_inputs().size() == 1);
|
||||
return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert));
|
||||
}
|
||||
|
||||
static void create_reduce_modules(module_pass_manager& mpm)
|
||||
{
|
||||
std::size_t n = 0;
|
||||
for(auto ins : iterator_for(mpm.get_module()))
|
||||
{
|
||||
if(not ins->get_operator().attributes().get("reduce", false))
|
||||
continue;
|
||||
if(ins->inputs().size() != 1)
|
||||
continue;
|
||||
|
||||
auto* rm =
|
||||
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
|
||||
rm->set_bypass();
|
||||
|
||||
rm->add_return(rm->fuse({ins}));
|
||||
auto v = ins->get_operator().to_value();
|
||||
mpm.get_module().replace_instruction(
|
||||
ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
instruction_ref get_broadcast_output(instruction_ref broadcast)
|
||||
{
|
||||
if(broadcast->outputs().size() != 1)
|
||||
return broadcast;
|
||||
auto output = broadcast->outputs().front();
|
||||
if(output->name() == "contiguous")
|
||||
return get_broadcast_output(output);
|
||||
return output;
|
||||
}
|
||||
|
||||
MIGRAPHX_PRED_MATCHER(used_once_except_broadcast, instruction_ref ins)
|
||||
{
|
||||
if(ins->outputs().size() == 1)
|
||||
return true;
|
||||
if(ins->outputs().size() == 2)
|
||||
{
|
||||
auto is_broadcast = [](instruction_ref output) {
|
||||
return contains(output->name(), "broadcast");
|
||||
};
|
||||
auto broadcast = std::find_if(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
|
||||
if(broadcast == ins->outputs().end())
|
||||
return false;
|
||||
auto non_broadcast =
|
||||
std::find_if_not(ins->outputs().begin(), ins->outputs().end(), is_broadcast);
|
||||
if(non_broadcast == ins->outputs().end())
|
||||
return false;
|
||||
auto output = get_broadcast_output(*broadcast);
|
||||
return output == *non_broadcast;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
template <class... Ms>
|
||||
static auto match_broadcast(Ms... ms)
|
||||
{
|
||||
return match::skip(match::name("contiguous"))(
|
||||
match::name("multibroadcast")(
|
||||
match::arg(0)(ms...), match::used_once(), input_output_ndim_match())
|
||||
.bind("broadcast"))
|
||||
.bind("final_broadcast");
|
||||
}
|
||||
|
||||
template <class... Ms>
|
||||
static auto any_input(Ms... ms)
|
||||
{
|
||||
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
|
||||
}
|
||||
|
||||
bool is_valid_broadcast(const instruction_ref b, const std::vector<size_t>& reduce_axes)
|
||||
{
|
||||
std::vector<size_t> broadcast_axes;
|
||||
auto bstrides = b->get_shape().strides();
|
||||
|
||||
for(size_t i = 0; i < bstrides.size(); ++i)
|
||||
{
|
||||
if(bstrides.at(i) == 0)
|
||||
broadcast_axes.push_back(i);
|
||||
}
|
||||
|
||||
return broadcast_axes == reduce_axes;
|
||||
}
|
||||
|
||||
template <class M>
|
||||
static auto match_broadcast_axes(M m)
|
||||
{
|
||||
return match::make_basic_fun_matcher(
|
||||
[=](match::matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
|
||||
optional<instruction_ref> result = m.match(ctx, ins);
|
||||
if(contains(ctx.instructions, "broadcast"))
|
||||
{
|
||||
instruction_ref reduce;
|
||||
if(ins->get_operator().name() == "fused_reduce")
|
||||
{
|
||||
reduce = ins;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(contains(ctx.instructions, "reduce"));
|
||||
reduce = ctx.instructions["reduce"];
|
||||
}
|
||||
auto axes = reduce->get_operator().to_value().at("axes").to_vector<size_t>();
|
||||
auto broadcast = ctx.instructions["broadcast"];
|
||||
if(not is_valid_broadcast(broadcast, axes))
|
||||
return nullopt;
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
static auto match_broadcastable_input(const std::string& op, const std::string& name)
|
||||
{
|
||||
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
|
||||
auto match_op_input = any_input(match_op, match::used_once());
|
||||
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
|
||||
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
|
||||
}
|
||||
|
||||
static void finalize_reduce_module(module_ref m)
|
||||
{
|
||||
eliminate_common_subexpression{}.apply(*m);
|
||||
dead_code_elimination{}.apply(*m);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct find_pointwise_reduce
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
// fused_reduce instruction with pointwise inputs.
|
||||
return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise"));
|
||||
}
|
||||
|
||||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
||||
{
|
||||
auto reduce = r.result;
|
||||
auto input = r.instructions["pointwise"];
|
||||
const auto* pm = input->module_inputs().front();
|
||||
const auto* old_rm = reduce->module_inputs().front();
|
||||
|
||||
auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name());
|
||||
rm->set_bypass();
|
||||
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
||||
// Insert pointwise
|
||||
auto rins = rm->fuse({input}, &map_ins).front();
|
||||
map_ins[input] = rins;
|
||||
|
||||
if(contains(r.instructions, "broadcast"))
|
||||
{
|
||||
auto broadcast = r.instructions["broadcast"];
|
||||
auto fbroadcast = r.instructions["final_broadcast"];
|
||||
map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front();
|
||||
if(fbroadcast != broadcast)
|
||||
map_ins[fbroadcast] = map_ins[broadcast];
|
||||
}
|
||||
|
||||
// Insert fused_reduce
|
||||
rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins));
|
||||
finalize_reduce_module(rm);
|
||||
|
||||
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
|
||||
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
|
||||
}
|
||||
};
|
||||
|
||||
struct find_reduce_pointwise
|
||||
{
|
||||
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("pointwise")(match_broadcastable_input("fused_reduce", "reduce"));
|
||||
}
|
||||
|
||||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
||||
{
|
||||
auto pw = r.result;
|
||||
auto reduce = r.instructions["reduce"];
|
||||
auto input = r.instructions["input"];
|
||||
|
||||
const auto* pm = pw->module_inputs().front();
|
||||
const auto* old_rm = reduce->module_inputs().front();
|
||||
auto* rm = mpm.create_module(old_rm->name() + ":" + pm->name());
|
||||
rm->set_bypass();
|
||||
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
||||
// Copy module instructions
|
||||
insert_module_in_submodule(rm, reduce, &map_ins);
|
||||
if(contains(r.instructions, "broadcast"))
|
||||
{
|
||||
auto broadcast = r.instructions["broadcast"];
|
||||
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
|
||||
auto bout = rm->fuse({broadcast}, &map_ins);
|
||||
map_ins[input] = bout.front();
|
||||
}
|
||||
else
|
||||
{
|
||||
map_ins[input] = rm->get_returns().front();
|
||||
}
|
||||
|
||||
auto out = rm->fuse({pw}, &map_ins);
|
||||
rm->replace_return(out);
|
||||
finalize_reduce_module(rm);
|
||||
|
||||
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
|
||||
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
|
||||
}
|
||||
};
|
||||
|
||||
struct find_reduce_reduce
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce"));
|
||||
}
|
||||
|
||||
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
|
||||
{
|
||||
auto reduce1 = r.result;
|
||||
auto reduce2 = r.instructions["reduce"];
|
||||
auto input = r.instructions["input"];
|
||||
|
||||
if(reduce1->get_operator() != reduce2->get_operator())
|
||||
return;
|
||||
|
||||
const auto* rm1 = reduce1->module_inputs().front();
|
||||
const auto* rm2 = reduce2->module_inputs().front();
|
||||
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
|
||||
rm->set_bypass();
|
||||
|
||||
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
||||
// Copy reduce1 instructions
|
||||
insert_module_in_submodule(rm, reduce2, &map_ins);
|
||||
if(contains(r.instructions, "broadcast"))
|
||||
{
|
||||
auto broadcast = r.instructions["broadcast"];
|
||||
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
|
||||
auto bout = rm->fuse({broadcast}, &map_ins);
|
||||
map_ins[input] = bout.front();
|
||||
}
|
||||
else
|
||||
{
|
||||
map_ins[input] = rm->get_returns().front();
|
||||
}
|
||||
|
||||
auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
|
||||
rm->replace_return(out);
|
||||
finalize_reduce_module(rm);
|
||||
|
||||
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
|
||||
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
|
||||
}
|
||||
};
|
||||
|
||||
struct reduce_reshape : rewrite_reshapes_base
|
||||
{
|
||||
static std::string name() { return "fused_reduce"; }
|
||||
|
||||
template <class Transform>
|
||||
static auto transform_op(Transform t)
|
||||
{
|
||||
return [=](module& m,
|
||||
instruction_ref ins,
|
||||
const operation& op,
|
||||
const std::vector<instruction_ref>& inputs,
|
||||
const std::vector<module_ref>& mod_args) {
|
||||
auto new_op = t(op);
|
||||
return m.insert_instruction(ins, new_op, inputs, mod_args);
|
||||
};
|
||||
}
|
||||
|
||||
template <class AxesMap>
|
||||
static instruction_ref insert(module_pass_manager& mpm,
|
||||
instruction_ref ins,
|
||||
const std::vector<instruction_ref>& inputs,
|
||||
const AxesMap& am)
|
||||
{
|
||||
auto op = any_cast<fused_reduce>(ins->get_operator());
|
||||
std::vector<int64_t> axes;
|
||||
for(auto axis : op.axes)
|
||||
{
|
||||
auto new_axes = am.at(axis);
|
||||
axes.insert(axes.end(), new_axes.begin(), new_axes.end());
|
||||
}
|
||||
std::sort(axes.begin(), axes.end());
|
||||
auto dims = base_dims(inputs);
|
||||
auto* oldm = ins->module_inputs().front();
|
||||
auto* sm = mpm.create_module(oldm->name() + "_reshape");
|
||||
sm->set_bypass();
|
||||
auto outs = sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {
|
||||
if(contains(sop.name(), "reduce"))
|
||||
return make_op(sop.name(), {{"axes", axes}});
|
||||
if(sop.name() == "multibroadcast")
|
||||
return make_op("multibroadcast", {{"out_lens", dims}});
|
||||
assert(sop.name() == "pointwise");
|
||||
return sop;
|
||||
}));
|
||||
sm->add_return(outs);
|
||||
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
|
||||
}
|
||||
|
||||
static std::vector<std::size_t> base_dims(const std::vector<instruction_ref>& inputs)
|
||||
{
|
||||
auto input = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](auto i) {
|
||||
return i->get_shape().elements();
|
||||
}));
|
||||
return (*input)->get_shape().lens();
|
||||
}
|
||||
|
||||
static std::vector<std::size_t> base_dims(instruction_ref ins)
|
||||
{
|
||||
return base_dims(ins->inputs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void fuse_reduce::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
if(enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}))
|
||||
return;
|
||||
create_reduce_modules(mpm);
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
for(int i = 0; i < 4; i++)
|
||||
{
|
||||
if(enable_rewrite_reshapes)
|
||||
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
|
||||
match::find_matches(
|
||||
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
106
docker/rocm/migraphx/generate.cpp
Normal file
106
docker/rocm/migraphx/generate.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/generate.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
argument fill_argument(shape s, double value)
|
||||
{
|
||||
argument result;
|
||||
if(s.type() == shape::tuple_type)
|
||||
{
|
||||
std::vector<argument> sub_args;
|
||||
const auto& sub_ss = s.sub_shapes();
|
||||
std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
|
||||
return fill_argument(ss, value);
|
||||
});
|
||||
|
||||
result = argument(sub_args);
|
||||
}
|
||||
else
|
||||
{
|
||||
s.visit_type([&](auto as) {
|
||||
using type = typename decltype(as)::type;
|
||||
auto v = fill_tensor_data<type>(s, value);
|
||||
result = {s, v};
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
argument generate_argument(shape s, unsigned long seed, random_mode m)
|
||||
{
|
||||
argument result;
|
||||
if(s.type() == shape::tuple_type)
|
||||
{
|
||||
const auto& sub_ss = s.sub_shapes();
|
||||
std::vector<argument> sub_args;
|
||||
std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
|
||||
return generate_argument(ss, seed, m);
|
||||
});
|
||||
|
||||
result = argument(sub_args);
|
||||
}
|
||||
else
|
||||
{
|
||||
s.visit_type([&](auto as) {
|
||||
// we use char type to store bool type internally, so bool_type
|
||||
// needs special processing to generate data
|
||||
if(s.type() == shape::bool_type)
|
||||
{
|
||||
auto v = generate_tensor_data<bool>(s, seed, m);
|
||||
result = {s, v};
|
||||
}
|
||||
else
|
||||
{
|
||||
using type = typename decltype(as)::type;
|
||||
auto v = generate_tensor_data<type>(s, seed, m);
|
||||
result = {s, v};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
literal generate_literal(shape s, unsigned long seed)
|
||||
{
|
||||
literal result;
|
||||
s.visit_type([&](auto as) {
|
||||
using type = typename decltype(as)::type;
|
||||
auto v = generate_tensor_data<type>(s, seed);
|
||||
result = {s, reinterpret_cast<char*>(v.get())};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
// TODO: Move to literal.cpp
|
||||
literal abs(literal l)
|
||||
{
|
||||
return transform(std::move(l), [](auto x) { return std::fabs(x); });
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
68
docker/rocm/migraphx/inline_module.cpp
Normal file
68
docker/rocm/migraphx/inline_module.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/inline_module.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static void inline_submodule(module& m, instruction_ref ins, bool cond)
|
||||
{
|
||||
const auto& mod_inputs = ins->module_inputs();
|
||||
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
|
||||
auto mod_outputs = m.insert_instructions(ins, smod);
|
||||
|
||||
auto ins_outputs = ins->outputs();
|
||||
assert(mod_outputs.size() >= ins_outputs.size());
|
||||
for(const auto& out : ins_outputs)
|
||||
{
|
||||
auto val = out->get_operator().to_value();
|
||||
assert(val.contains("index"));
|
||||
auto index = val.at("index").to<std::size_t>();
|
||||
m.replace_instruction(out, mod_outputs.at(index));
|
||||
}
|
||||
}
|
||||
|
||||
void inline_module::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "if")
|
||||
continue;
|
||||
|
||||
auto arg_cond = ins->inputs().front()->eval();
|
||||
if(not arg_cond.empty())
|
||||
{
|
||||
bool cond = arg_cond.at<bool>();
|
||||
inline_submodule(m, ins, cond);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
124
docker/rocm/migraphx/insert_pad.cpp
Normal file
124
docker/rocm/migraphx/insert_pad.cpp
Normal file
@ -0,0 +1,124 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/insert_pad.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/op/convolution.hpp>
|
||||
#include <migraphx/op/im2col.hpp>
|
||||
#include <migraphx/op/pooling.hpp>
|
||||
#include <migraphx/op/pad.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static void update_op(const instruction_ref& input, const instruction_ref& ins, module& m)
|
||||
{
|
||||
auto op = ins->get_operator();
|
||||
auto val = op.to_value();
|
||||
auto op_padding = val.at("padding").to_vector<size_t>();
|
||||
|
||||
// skip if shape is dynamic
|
||||
if(input->get_shape().dynamic())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto kdims = input->get_shape().lens().size() - 2;
|
||||
if(std::equal(op_padding.begin(),
|
||||
op_padding.begin() + kdims,
|
||||
op_padding.begin() + kdims,
|
||||
op_padding.end()))
|
||||
return;
|
||||
|
||||
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0);
|
||||
std::vector<size_t> pads_l(op_padding.begin(), op_padding.begin() + kdims);
|
||||
std::vector<size_t> pads_r(op_padding.begin() + kdims, op_padding.end());
|
||||
op_padding = std::vector<size_t>(kdims * 2, 0);
|
||||
op.from_value({{"padding", op_padding}});
|
||||
|
||||
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
|
||||
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
|
||||
|
||||
auto pad_op = m.insert_instruction(ins, op::pad{padding}, input);
|
||||
|
||||
auto new_inputs = ins->inputs();
|
||||
new_inputs.front() = pad_op;
|
||||
|
||||
m.replace_instruction(ins, op, new_inputs);
|
||||
}
|
||||
|
||||
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
|
||||
{
|
||||
auto op = any_cast<op::pooling>(ins->get_operator());
|
||||
if(op.mode == op::pooling_mode::average)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto kdims = input->get_shape().ndim() - 2;
|
||||
if(std::equal(op.padding.begin(),
|
||||
op.padding.begin() + kdims,
|
||||
op.padding.begin() + kdims,
|
||||
op.padding.end()))
|
||||
return;
|
||||
|
||||
std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
|
||||
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
|
||||
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
|
||||
op.padding = std::vector<size_t>(kdims * 2, 0);
|
||||
std::copy(pads_l.begin(), pads_l.end(), padding.begin() + 2);
|
||||
std::copy(pads_r.begin(), pads_r.end(), padding.begin() + kdims + 2 + 2);
|
||||
|
||||
float pad_val = 0.0f; // for the lpnorm
|
||||
if(op.mode == op::pooling_mode::max)
|
||||
{
|
||||
// maxpool uses lowest value for padding
|
||||
pad_val = std::numeric_limits<float>::lowest();
|
||||
}
|
||||
auto pad_op = m.insert_instruction(ins, op::pad{padding, pad_val}, input);
|
||||
|
||||
auto new_inputs = ins->inputs();
|
||||
new_inputs.front() = pad_op;
|
||||
|
||||
m.replace_instruction(ins, op, new_inputs);
|
||||
}
|
||||
|
||||
void insert_pad::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
const std::string& op_name = ins->name();
|
||||
if(not contains(ops, op_name))
|
||||
continue;
|
||||
auto input = ins->inputs().front();
|
||||
if(op_name == "convolution" or op_name == "im2col")
|
||||
update_op(input, ins, m);
|
||||
else if(op_name == "pooling")
|
||||
update_pooling(input, ins, m);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
564
docker/rocm/migraphx/instruction.cpp
Normal file
564
docker/rocm/migraphx/instruction.cpp
Normal file
@ -0,0 +1,564 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/builtin.hpp>
|
||||
#include <migraphx/erase.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/output_iterator.hpp>
|
||||
#include <queue>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class T>
|
||||
auto equal_to(const T& x)
|
||||
{
|
||||
return [&](const T& y) { return std::equal_to<T>{}(x, y); };
|
||||
}
|
||||
|
||||
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
|
||||
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
|
||||
{
|
||||
}
|
||||
|
||||
instruction::instruction(operation o,
|
||||
shape r,
|
||||
std::vector<instruction_ref> args,
|
||||
std::vector<module_ref> modules)
|
||||
: op(std::move(o)),
|
||||
result(std::move(r)),
|
||||
arguments(std::move(args)),
|
||||
module_args(std::move(modules))
|
||||
{
|
||||
}
|
||||
|
||||
instruction::instruction(literal l)
|
||||
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
|
||||
{
|
||||
}
|
||||
|
||||
struct replace_shape_order
|
||||
{
|
||||
instruction_ref start;
|
||||
|
||||
std::size_t location(instruction_ref x) const { return std::distance(start, x); }
|
||||
|
||||
bool operator()(instruction_ref x, instruction_ref y) const
|
||||
{
|
||||
return location(x) > location(y);
|
||||
}
|
||||
};
|
||||
|
||||
void instruction::replace(const shape& r)
|
||||
{
|
||||
if(r != result)
|
||||
{
|
||||
result = r;
|
||||
if(output.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto start = std::find_if(output.front()->inputs().begin(),
|
||||
output.front()->inputs().end(),
|
||||
[&](instruction_ref x) { return this == as_address(x); });
|
||||
assert(as_address(*start) == this);
|
||||
std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(
|
||||
output.begin(), output.end(), replace_shape_order{*start});
|
||||
while(not q.empty())
|
||||
{
|
||||
instruction_ref ins = q.top();
|
||||
q.pop();
|
||||
assert(ins->name() == "@return" or ins->name().front() != '@');
|
||||
shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
|
||||
if(new_r != ins->result)
|
||||
{
|
||||
ins->result = new_r;
|
||||
std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void instruction::replace(operation o)
|
||||
{
|
||||
normalized = false;
|
||||
op = std::move(o);
|
||||
recompute_shape();
|
||||
}
|
||||
|
||||
void instruction::recompute_shape() { replace(compute_shape(op, arguments, module_args)); }
|
||||
|
||||
void instruction::clear_arguments()
|
||||
{
|
||||
for(auto&& arg : arguments)
|
||||
{
|
||||
arg->remove_output(*this);
|
||||
}
|
||||
arguments.clear();
|
||||
module_args.clear();
|
||||
}
|
||||
|
||||
bool operator==(const instruction& i, instruction_ref ref)
|
||||
{
|
||||
return std::addressof(i) == std::addressof(*ref);
|
||||
}
|
||||
|
||||
bool instruction::valid(instruction_ref start, bool check_order) const
|
||||
{
|
||||
return valid() and std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
|
||||
auto self = std::find(i->outputs().begin(), i->outputs().end(), *this);
|
||||
bool ret = self != i->outputs().end();
|
||||
if(check_order)
|
||||
{
|
||||
// check arguments for this instruction before this instruction
|
||||
ret = ret and (std::distance(start, i) < std::distance(start, *self));
|
||||
}
|
||||
return ret;
|
||||
});
|
||||
}
|
||||
|
||||
bool instruction::valid() const
|
||||
{
|
||||
shape computed;
|
||||
if(op.name() == "@literal")
|
||||
{
|
||||
computed = lit.get_shape();
|
||||
}
|
||||
else if(op.name() == "@param")
|
||||
{
|
||||
computed = result;
|
||||
}
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
computed = compute_shape(op, arguments, module_args);
|
||||
}
|
||||
catch(migraphx::exception&)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return (result == computed) and
|
||||
std::all_of(output.begin(), output.end(), [&](instruction_ref i) {
|
||||
return std::find(i->inputs().begin(), i->inputs().end(), *this) != i->inputs().end();
|
||||
});
|
||||
}
|
||||
|
||||
shape instruction::get_shape() const { return result; }
|
||||
|
||||
const literal& instruction::get_literal() const
|
||||
{
|
||||
assert(op.name() == "@literal");
|
||||
return lit;
|
||||
}
|
||||
|
||||
const operation& instruction::get_operator() const { return op; }
|
||||
|
||||
std::string instruction::name() const { return op.name(); }
|
||||
|
||||
const std::vector<instruction_ref>& instruction::inputs() const { return arguments; }
|
||||
|
||||
const std::vector<module_ref>& instruction::module_inputs() const { return module_args; }
|
||||
|
||||
const std::vector<instruction_ref>& instruction::outputs() const { return output; }
|
||||
|
||||
bool operator==(const instruction& x, const instruction& y)
|
||||
{
|
||||
if(not std::equal(x.arguments.begin(),
|
||||
x.arguments.end(),
|
||||
y.arguments.begin(),
|
||||
y.arguments.end(),
|
||||
std::equal_to<instruction_ref>{}))
|
||||
return false;
|
||||
if(std::tie(x.result, x.op, x.module_args) != std::tie(y.result, y.op, y.module_args))
|
||||
return false;
|
||||
if(x.name() == "@literal")
|
||||
return x.lit == y.lit;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator!=(const instruction& x, const instruction& y) { return not(x == y); }
|
||||
|
||||
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
|
||||
|
||||
bool operator!=(const instruction& i, instruction_ref ref) { return not(i == ref); }
|
||||
|
||||
bool operator!=(instruction_ref ref, const instruction& i) { return not(i == ref); }
|
||||
|
||||
void instruction::add_output(instruction_ref ins)
|
||||
{
|
||||
if(std::find_if(output.begin(), output.end(), equal_to(ins)) == output.end())
|
||||
output.push_back(ins);
|
||||
}
|
||||
|
||||
void instruction::backreference(instruction_ref ref)
|
||||
{
|
||||
for(auto&& arg : ref->inputs())
|
||||
arg->add_output(ref);
|
||||
}
|
||||
|
||||
void instruction::replace_argument(instruction_ref ins,
|
||||
instruction_ref old,
|
||||
instruction_ref new_ins)
|
||||
{
|
||||
ins->replace_argument(old, new_ins);
|
||||
backreference(ins);
|
||||
ins->recompute_shape();
|
||||
}
|
||||
|
||||
void instruction::replace_mod_argument(instruction_ref ins, module_ref old, module_ref new_mod)
|
||||
{
|
||||
ins->replace_mod_argument(old, new_mod);
|
||||
backreference(ins);
|
||||
ins->recompute_shape();
|
||||
}
|
||||
|
||||
void instruction::replace(instruction_ref ins,
|
||||
operation o,
|
||||
const shape& r,
|
||||
std::vector<instruction_ref> args)
|
||||
{
|
||||
ins->replace(std::move(o), r, std::move(args));
|
||||
backreference(ins);
|
||||
}
|
||||
|
||||
void instruction::replace(instruction_ref ins,
|
||||
operation o,
|
||||
const shape& r,
|
||||
std::vector<instruction_ref> args,
|
||||
std::vector<module_ref> module_args)
|
||||
{
|
||||
ins->replace(std::move(o), r, std::move(args), std::move(module_args));
|
||||
backreference(ins);
|
||||
}
|
||||
|
||||
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
|
||||
{
|
||||
normalized = false;
|
||||
op = std::move(o);
|
||||
replace(r);
|
||||
replace(std::move(args));
|
||||
}
|
||||
|
||||
void instruction::replace(operation o,
|
||||
const shape& r,
|
||||
std::vector<instruction_ref> args,
|
||||
std::vector<module_ref> mdl_args)
|
||||
{
|
||||
op = std::move(o);
|
||||
replace(r);
|
||||
replace(std::move(args), std::move(mdl_args));
|
||||
}
|
||||
|
||||
void instruction::replace_refs(
|
||||
instruction_ref ins,
|
||||
const std::unordered_map<instruction_ref, instruction_ref>& map_insts,
|
||||
const std::unordered_map<module_ref, module_ref>& map_mods)
|
||||
{
|
||||
const auto& args = ins->inputs();
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
if(contains(map_insts, arg))
|
||||
{
|
||||
instruction::replace_argument(ins, arg, map_insts.at(arg));
|
||||
}
|
||||
}
|
||||
|
||||
const auto& module_args = ins->module_inputs();
|
||||
if(module_args.empty())
|
||||
return;
|
||||
|
||||
for(const auto& mod : module_args)
|
||||
{
|
||||
if(contains(map_mods, mod))
|
||||
{
|
||||
instruction::replace_mod_argument(ins, mod, map_mods.at(mod));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void instruction::replace(std::vector<instruction_ref> args)
|
||||
{
|
||||
clear_arguments();
|
||||
arguments = std::move(args);
|
||||
}
|
||||
|
||||
void instruction::replace(std::vector<instruction_ref> args, std::vector<module_ref> mdl_args)
|
||||
{
|
||||
clear_arguments();
|
||||
arguments = std::move(args);
|
||||
module_args = std::move(mdl_args);
|
||||
}
|
||||
|
||||
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
|
||||
{
|
||||
assert(std::any_of(arguments.begin(), arguments.end(), equal_to(old)));
|
||||
std::replace_if(arguments.begin(), arguments.end(), equal_to(old), new_ins);
|
||||
old->remove_output(*this);
|
||||
}
|
||||
|
||||
void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
|
||||
{
|
||||
assert(std::any_of(module_args.begin(), module_args.end(), [&](auto i) { return i == old; }));
|
||||
std::replace(module_args.begin(), module_args.end(), old, new_mod);
|
||||
}
|
||||
|
||||
bool instruction::is_undefined() const
|
||||
{
|
||||
if(op.name() == "undefined")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(this->inputs().empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
|
||||
return arg->is_undefined();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
bool instruction::can_eval() const
|
||||
{
|
||||
if(op.name() == "@literal")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if(is_context_free(op))
|
||||
{
|
||||
return std::all_of(
|
||||
this->inputs().begin(), this->inputs().end(), [](auto arg) { return arg->can_eval(); });
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
argument instruction::eval(bool check_eval) const
|
||||
{
|
||||
if(op.name() == "@literal")
|
||||
{
|
||||
return this->get_literal().get_argument();
|
||||
}
|
||||
if(is_context_free(op))
|
||||
{
|
||||
if(check_eval and not this->can_eval())
|
||||
return {};
|
||||
std::vector<argument> args;
|
||||
std::transform(this->inputs().begin(),
|
||||
this->inputs().end(),
|
||||
std::back_inserter(args),
|
||||
[](auto arg) { return arg->eval(false); });
|
||||
return normalized_operator().compute(result, args);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void instruction::finalize(context& ctx)
|
||||
{
|
||||
if(has_finalize(this->op))
|
||||
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
|
||||
}
|
||||
|
||||
void instruction::print(std::ostream& os,
|
||||
instruction_ref ins,
|
||||
const std::unordered_map<instruction_ref, std::string>& names)
|
||||
{
|
||||
os << names.at(ins) << " = ";
|
||||
|
||||
os << ins->get_operator();
|
||||
|
||||
if(ins->name() == "@literal")
|
||||
{
|
||||
if(ins->get_literal().get_shape().elements() > 10)
|
||||
os << "{ ... }";
|
||||
else
|
||||
os << "{" << ins->get_literal() << "}";
|
||||
}
|
||||
|
||||
if(not ins->inputs().empty())
|
||||
{
|
||||
char delim = '(';
|
||||
for(auto&& arg : ins->inputs())
|
||||
{
|
||||
std::string arg_name = contains(names, arg) ? names.at(arg) : "?";
|
||||
os << delim << arg_name;
|
||||
delim = ',';
|
||||
}
|
||||
os << ")";
|
||||
}
|
||||
|
||||
// print module inputs
|
||||
if(not ins->module_inputs().empty())
|
||||
{
|
||||
std::string delim = ", [";
|
||||
for(const const_module_ref& mod_arg : ins->module_inputs())
|
||||
{
|
||||
os << delim << mod_arg->name();
|
||||
delim = ", ";
|
||||
}
|
||||
os << "]";
|
||||
}
|
||||
|
||||
// skip return instruction shape
|
||||
if(ins->name() != "@return")
|
||||
os << " -> " << ins->get_shape();
|
||||
|
||||
// print tid
|
||||
if(ins->target_id != 0)
|
||||
os << ", target_id=" << ins->target_id;
|
||||
}
|
||||
|
||||
static void debug_name(std::ostream& os, const instruction& ins)
|
||||
{
|
||||
if(ins.name() == "@literal")
|
||||
{
|
||||
os << "@literal";
|
||||
if(ins.get_literal().get_shape().elements() > 10)
|
||||
os << "{ ... }";
|
||||
else
|
||||
os << "{" << ins.get_literal() << "}";
|
||||
}
|
||||
else
|
||||
{
|
||||
os << ins.get_operator();
|
||||
}
|
||||
}
|
||||
|
||||
void instruction::debug_print() const
|
||||
{
|
||||
debug_name(std::cout, *this);
|
||||
std::string delim = "(";
|
||||
for(auto arg : this->inputs())
|
||||
{
|
||||
std::cout << delim;
|
||||
debug_name(std::cout, *arg);
|
||||
delim = ", ";
|
||||
}
|
||||
if(not this->inputs().empty())
|
||||
std::cout << ")";
|
||||
std::cout << " -> " << this->get_shape() << std::endl;
|
||||
}
|
||||
|
||||
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
|
||||
{
|
||||
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
|
||||
if(i < 0)
|
||||
return ins;
|
||||
if(shallow)
|
||||
return ins->inputs().at(i);
|
||||
return get_output_alias(ins->inputs().at(i));
|
||||
}
|
||||
|
||||
void instruction::set_normalized(bool value) { normalized = value; }
|
||||
|
||||
bool instruction::is_normalized() const { return normalized; }
|
||||
|
||||
bool instruction::need_normalization() const
|
||||
{
|
||||
return this->get_operator().need_normalization() and not normalized;
|
||||
}
|
||||
|
||||
operation instruction::normalized_operator() const
|
||||
{
|
||||
operation o = this->get_operator();
|
||||
if(this->need_normalization())
|
||||
{
|
||||
auto s = this->inputs().front()->get_shape();
|
||||
if(not normalize_attributes(o, s))
|
||||
return this->get_operator();
|
||||
}
|
||||
return o;
|
||||
}
|
||||
std::size_t instruction::get_target_id() const { return target_id; }
|
||||
|
||||
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
|
||||
|
||||
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
|
||||
{
|
||||
std::vector<shape> shapes(args.size());
|
||||
std::transform(
|
||||
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
|
||||
return shapes;
|
||||
}
|
||||
|
||||
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
|
||||
{
|
||||
return op.compute_shape(to_shapes(args));
|
||||
}
|
||||
|
||||
shape compute_shape(const operation& op,
|
||||
const std::vector<instruction_ref>& args,
|
||||
const std::vector<module_ref>& mods)
|
||||
{
|
||||
if(mods.empty())
|
||||
{
|
||||
return op.compute_shape(to_shapes(args));
|
||||
}
|
||||
else
|
||||
{
|
||||
return op.compute_shape(to_shapes(args), mods);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs)
|
||||
{
|
||||
shape new_shape;
|
||||
try
|
||||
{
|
||||
new_shape = op.compute_shape(inputs);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
return {new_shape};
|
||||
}
|
||||
|
||||
migraphx::instruction* as_address(const instruction_ref& ins) noexcept
|
||||
{
|
||||
return std::addressof(*ins);
|
||||
}
|
||||
|
||||
bool reaches(instruction_ref start, instruction_ref end)
|
||||
{
|
||||
std::unordered_set<instruction_ref> visited;
|
||||
return fix<bool>([&](auto self, auto ins) -> bool {
|
||||
if(ins == start)
|
||||
return true;
|
||||
if(not visited.insert(ins).second)
|
||||
return false;
|
||||
return std::any_of(ins->inputs().begin(), ins->inputs().end(), self);
|
||||
})(end);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
177
docker/rocm/migraphx/json.cpp
Normal file
177
docker/rocm/migraphx/json.cpp
Normal file
@ -0,0 +1,177 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/serialize.hpp>
|
||||
#include <migraphx/argument.hpp>
|
||||
#include <migraphx/literal.hpp>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <migraphx/json.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
void value_to_json(const value& val, json& j);
|
||||
migraphx::value value_from_json(const json& j);
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
namespace nlohmann {
|
||||
template <>
|
||||
struct adl_serializer<migraphx::value>
|
||||
{
|
||||
static void to_json(json& j, const migraphx::value& val) { migraphx::value_to_json(val, j); }
|
||||
|
||||
static void from_json(const json& j, migraphx::value& val)
|
||||
{
|
||||
val = migraphx::value_from_json(j);
|
||||
}
|
||||
};
|
||||
} // namespace nlohmann
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
template <class T>
|
||||
void value_to_json(const T& x, json& j)
|
||||
{
|
||||
j = x;
|
||||
}
|
||||
|
||||
void value_to_json(const value::binary& x, json& j)
|
||||
{
|
||||
j = json::object();
|
||||
j["bytes"] = std::vector<int>(x.begin(), x.end());
|
||||
}
|
||||
|
||||
void value_to_json(const std::vector<value>& x, json& j)
|
||||
{
|
||||
for(const auto& v : x)
|
||||
{
|
||||
if(v.get_key().empty())
|
||||
{
|
||||
j.push_back(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
j[v.get_key()] = v.without_key();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void value_to_json(std::nullptr_t&, json& j) { j = {}; }
|
||||
|
||||
void value_to_json(const value& val, json& j)
|
||||
{
|
||||
if(val.is_array())
|
||||
{
|
||||
j = json::array();
|
||||
}
|
||||
|
||||
if(val.is_object())
|
||||
{
|
||||
j = json::object();
|
||||
}
|
||||
|
||||
val.visit([&](auto v) { value_to_json(v, j); });
|
||||
}
|
||||
|
||||
migraphx::value value_from_json(const json& j)
|
||||
{
|
||||
migraphx::value val;
|
||||
json::value_t type = j.type();
|
||||
switch(type)
|
||||
{
|
||||
case json::value_t::null: val = nullptr; break;
|
||||
|
||||
case json::value_t::boolean: val = j.get<bool>(); break;
|
||||
|
||||
case json::value_t::number_float: val = j.get<double>(); break;
|
||||
|
||||
case json::value_t::number_integer: val = j.get<int64_t>(); break;
|
||||
|
||||
case json::value_t::number_unsigned: val = j.get<uint64_t>(); break;
|
||||
|
||||
case json::value_t::string: val = j.get<std::string>(); break;
|
||||
|
||||
case json::value_t::array:
|
||||
val = migraphx::value::array{};
|
||||
std::transform(j.begin(), j.end(), std::back_inserter(val), [&](const json& jj) {
|
||||
return jj.get<value>();
|
||||
});
|
||||
break;
|
||||
|
||||
case json::value_t::object:
|
||||
if(j.contains("bytes") and j.size() == 1)
|
||||
{
|
||||
val = migraphx::value::binary{j["bytes"].get<std::vector<std::uint8_t>>()};
|
||||
}
|
||||
else
|
||||
{
|
||||
val = migraphx::value::object{};
|
||||
for(const auto& item : j.items())
|
||||
{
|
||||
const auto& key = item.key();
|
||||
const json& jv = item.value();
|
||||
val[key] = jv.get<value>();
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case json::value_t::binary: MIGRAPHX_THROW("Convert JSON to Value: binary type not supported!");
|
||||
case json::value_t::discarded:
|
||||
MIGRAPHX_THROW("Convert JSON to Value: discarded type not supported!");
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
std::string to_json_string(const value& val)
|
||||
{
|
||||
json j = val;
|
||||
return j.dump();
|
||||
}
|
||||
|
||||
std::string to_pretty_json_string(const value& val, std::size_t indent)
|
||||
{
|
||||
json j = val;
|
||||
return j.dump(indent);
|
||||
}
|
||||
|
||||
migraphx::value from_json_string(const char* str, std::size_t size)
|
||||
{
|
||||
json j = json::parse(str, str + size);
|
||||
return j.get<value>();
|
||||
}
|
||||
migraphx::value from_json_string(const std::string& str)
|
||||
{
|
||||
json j = json::parse(str);
|
||||
return j.get<value>();
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
132
docker/rocm/migraphx/layout_convolution.cpp
Normal file
132
docker/rocm/migraphx/layout_convolution.cpp
Normal file
@ -0,0 +1,132 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/layout_convolution.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/permutation.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/eliminate_contiguous.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
namespace {
|
||||
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
|
||||
{
|
||||
if(lc.channels_last)
|
||||
{
|
||||
std::vector<int64_t> perm(ins->get_shape().ndim());
|
||||
std::iota(perm.begin() + 1, perm.end() - 1, 2);
|
||||
perm.back() = 1;
|
||||
return perm;
|
||||
}
|
||||
return find_permutation(ins->inputs().front()->get_shape());
|
||||
}
|
||||
|
||||
bool skip_layout(const shape& s)
|
||||
{
|
||||
return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type;
|
||||
}
|
||||
|
||||
void preserve_output_layout(module& m)
|
||||
{
|
||||
auto last = std::prev(m.end());
|
||||
if(last->name() == "@return")
|
||||
{
|
||||
std::vector<instruction_ref> outputs;
|
||||
std::transform(last->inputs().begin(),
|
||||
last->inputs().end(),
|
||||
std::back_inserter(outputs),
|
||||
[&](instruction_ref ins) {
|
||||
if(skip_layout(ins->get_shape()))
|
||||
return ins;
|
||||
auto permutation = find_permutation(ins->get_shape());
|
||||
return m.insert_instruction(
|
||||
last, make_op("layout", {{"permutation", permutation}}), ins);
|
||||
});
|
||||
m.replace_return(outputs);
|
||||
}
|
||||
else if(not skip_layout(last->get_shape()))
|
||||
{
|
||||
auto permutation = find_permutation(last->get_shape());
|
||||
m.add_instruction(make_op("layout", {{"permutation", permutation}}), last);
|
||||
}
|
||||
}
|
||||
|
||||
void transform_convolutions(module& m, const layout_convolution& lc)
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(not contains({"convolution", "quant_convolution"}, ins->name()))
|
||||
continue;
|
||||
if(ins->get_shape().dynamic())
|
||||
continue;
|
||||
if(ins->get_shape().lens().size() != 4)
|
||||
continue;
|
||||
auto v = ins->get_operator().to_value();
|
||||
if(v.at("group").to<int>() > 1)
|
||||
continue;
|
||||
auto args = ins->inputs();
|
||||
auto perm = get_permutation(ins, lc);
|
||||
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
|
||||
return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i);
|
||||
});
|
||||
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
|
||||
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
|
||||
m.replace_instruction(ins, c);
|
||||
}
|
||||
}
|
||||
|
||||
void remove_layout(module& m)
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "layout")
|
||||
continue;
|
||||
if(ins->get_shape() != ins->inputs().front()->get_shape())
|
||||
continue;
|
||||
m.replace_instruction(ins, ins->inputs().front());
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void layout_convolution::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
preserve_output_layout(mpm.get_module());
|
||||
transform_convolutions(mpm.get_module(), *this);
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
mpm.run_pass(eliminate_contiguous{"contiguous"});
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
remove_layout(mpm.get_module());
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
71
docker/rocm/migraphx/lexing.cpp
Normal file
71
docker/rocm/migraphx/lexing.cpp
Normal file
@ -0,0 +1,71 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/lexing.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::function<const char*(const char*, const char*)> lex_equal(const std::string& s)
|
||||
{
|
||||
return [=](const char* start, const char* end) {
|
||||
auto n = end - start;
|
||||
if(n < s.size())
|
||||
return start;
|
||||
if(std::equal(start, start + s.size(), s.data()))
|
||||
return start + s.size();
|
||||
return start;
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<std::string_view>
|
||||
tokenize(const char* start, const char* end, const std::vector<lexer>& lexers)
|
||||
{
|
||||
std::vector<std::string_view> result;
|
||||
while(start != end)
|
||||
{
|
||||
bool error = true;
|
||||
for(const auto& l : lexers)
|
||||
{
|
||||
const auto* next = l(start, end);
|
||||
if(next != start)
|
||||
{
|
||||
result.emplace_back(start, next - start);
|
||||
start = next;
|
||||
error = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if(error)
|
||||
{
|
||||
MIGRAPHX_THROW("TOKENIZE: no token found!");
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
105
docker/rocm/migraphx/load_save.cpp
Normal file
105
docker/rocm/migraphx/load_save.cpp
Normal file
@ -0,0 +1,105 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/load_save.hpp>
|
||||
#include <migraphx/file_buffer.hpp>
|
||||
#include <migraphx/json.hpp>
|
||||
#include <migraphx/msgpack.hpp>
|
||||
#include <fstream>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
program load(const std::string& filename, const file_options& options)
|
||||
{
|
||||
return load_buffer(read_buffer(filename), options);
|
||||
}
|
||||
program load_buffer(const std::vector<char>& buffer, const file_options& options)
|
||||
{
|
||||
return load_buffer(buffer.data(), buffer.size(), options);
|
||||
}
|
||||
program load_buffer(const char* buffer, std::size_t size, const file_options& options)
|
||||
{
|
||||
program p;
|
||||
if(options.format == "msgpack")
|
||||
{
|
||||
p.from_value(from_msgpack(buffer, size));
|
||||
}
|
||||
else if(options.format == "json")
|
||||
{
|
||||
p.from_value(from_json_string(buffer, size));
|
||||
}
|
||||
else
|
||||
{
|
||||
MIGRAPHX_THROW("Unknown format: " + options.format);
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
void save(const program& p, const std::string& filename, const file_options& options)
|
||||
{
|
||||
write_buffer(filename, save_buffer(p, options));
|
||||
}
|
||||
|
||||
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
|
||||
void print_miopen_warning(const program& p)
|
||||
{
|
||||
auto mods = p.get_modules();
|
||||
if(std::any_of(mods.begin(), mods.end(), [](const auto* m) {
|
||||
return std::any_of(m->begin(), m->end(), [](const instruction& i) {
|
||||
return i.name() == "gpu::miopen_fusion";
|
||||
});
|
||||
}))
|
||||
{
|
||||
std::cout << "[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
|
||||
"are not stored inside serialized MIGraphX program. Consider serializing with "
|
||||
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
|
||||
<< std::endl;
|
||||
;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> save_buffer(const program& p, const file_options& options)
|
||||
{
|
||||
value v = p.to_value();
|
||||
print_miopen_warning(p);
|
||||
std::vector<char> buffer;
|
||||
if(options.format == "msgpack")
|
||||
{
|
||||
buffer = to_msgpack(v);
|
||||
}
|
||||
else if(options.format == "json")
|
||||
{
|
||||
std::string s = to_json_string(v);
|
||||
buffer = std::vector<char>(s.begin(), s.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
MIGRAPHX_THROW("Unknown format: " + options.format);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
73
docker/rocm/migraphx/make_op.cpp
Normal file
73
docker/rocm/migraphx/make_op.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
operation make_op(const std::string& name) { return load_op(name); }
|
||||
|
||||
template <class F>
|
||||
operation make_op_generic(const std::string& name, F for_each)
|
||||
{
|
||||
auto op = load_op(name);
|
||||
// Merge values
|
||||
value w = op.to_value();
|
||||
for_each([&](const auto& key, const auto& x) {
|
||||
if(not w.contains(key))
|
||||
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
||||
MIGRAPHX_THROW("No key '" + key + "' in " + name);
|
||||
w.at(key) = x;
|
||||
});
|
||||
op.from_value(w);
|
||||
return op;
|
||||
}
|
||||
|
||||
operation make_op(const std::string& name,
|
||||
const std::initializer_list<std::pair<std::string, value>>& v)
|
||||
{
|
||||
return make_op_generic(name, [&](auto f) {
|
||||
for(auto&& [key, x] : v)
|
||||
f(key, x);
|
||||
});
|
||||
}
|
||||
|
||||
operation make_op_from_value(const std::string& name, const value& v)
|
||||
{
|
||||
if(not(v.is_object() or (v.empty() and v.is_array())))
|
||||
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
|
||||
return make_op_generic(name, [&](auto f) {
|
||||
for(auto&& x : v)
|
||||
f(x.get_key(), x.without_key());
|
||||
});
|
||||
}
|
||||
|
||||
operation make_json_op(const std::string& name, const std::string& s)
|
||||
{
|
||||
return make_op(name, from_json_string(convert_to_json(s)));
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
372
docker/rocm/migraphx/memory_coloring.cpp
Normal file
372
docker/rocm/migraphx/memory_coloring.cpp
Normal file
@ -0,0 +1,372 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/memory_coloring.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/liveness.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/algorithm.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING);
|
||||
|
||||
using instruction_set = std::unordered_set<instruction_ref>;
|
||||
using instruction_set_map = std::unordered_map<instruction_ref, instruction_set>;
|
||||
|
||||
// This will build the conflict table or interference graph. This is
|
||||
// essentially a map from one instruction to a set of instruction that are
|
||||
// used together. Each instruction will be the allocation instruction.
|
||||
instruction_set_map build_conflict_table(const module& m, std::string allocation_op)
|
||||
{
|
||||
instruction_set_map conflict_table;
|
||||
liveness(m, [&](auto ins, auto live_set) {
|
||||
// Skip variables that aren't allocations
|
||||
if(ins->name() != allocation_op)
|
||||
return;
|
||||
// Skip zero allocations
|
||||
if(ins->get_shape().bytes() == 0)
|
||||
return;
|
||||
conflict_table[ins];
|
||||
for(auto i : live_set)
|
||||
{
|
||||
if(i == ins)
|
||||
continue;
|
||||
// Skip variables that aren't allocations
|
||||
if(i->name() != allocation_op)
|
||||
continue;
|
||||
// Skip zero allocations
|
||||
if(i->get_shape().bytes() == 0)
|
||||
continue;
|
||||
conflict_table[i].insert(ins);
|
||||
conflict_table[ins].insert(i);
|
||||
}
|
||||
});
|
||||
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [](auto&& pp) {
|
||||
return pp.second.count(pp.first) == 0;
|
||||
}));
|
||||
return conflict_table;
|
||||
}
|
||||
|
||||
// Check if intervals overlap
|
||||
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
|
||||
{
|
||||
return std::max(x.first, y.first) < std::min(x.second, y.second);
|
||||
}
|
||||
|
||||
struct allocation_segment
|
||||
{
|
||||
using segment = std::pair<std::size_t, std::size_t>;
|
||||
std::unordered_map<instruction_ref, segment> ins2segment;
|
||||
|
||||
const segment* add_segment(instruction_ref ins, segment s) { return &(ins2segment[ins] = s); }
|
||||
|
||||
const segment* get_segment(instruction_ref ins) const
|
||||
{
|
||||
auto it = ins2segment.find(ins);
|
||||
if(it == ins2segment.end())
|
||||
return nullptr;
|
||||
return &it->second;
|
||||
}
|
||||
|
||||
// Remove segment for an instruction
|
||||
void remove(instruction_ref ins)
|
||||
{
|
||||
auto it = ins2segment.find(ins);
|
||||
if(it != ins2segment.end())
|
||||
{
|
||||
ins2segment.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t max()
|
||||
{
|
||||
std::size_t n = 0;
|
||||
for(auto&& pp : ins2segment)
|
||||
{
|
||||
auto seg = pp.second;
|
||||
n = std::max(n, seg.second);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
template <class Iterator>
|
||||
static bool overlaps(Iterator first, Iterator last, const segment& s)
|
||||
{
|
||||
return std::any_of(first, last, [&](auto&& t) { return is_overlap(s, t); });
|
||||
}
|
||||
|
||||
static bool overlaps(const std::set<segment>& segments, const segment& s)
|
||||
{
|
||||
return overlaps(segments.begin(), segments.end(), s);
|
||||
}
|
||||
|
||||
static auto find_gap(const std::set<segment>& segments, std::size_t n)
|
||||
{
|
||||
std::size_t max_end = 0;
|
||||
return std::adjacent_find(segments.begin(), segments.end(), [&](segment x, segment y) {
|
||||
if(x.second < max_end)
|
||||
return false;
|
||||
max_end = x.second;
|
||||
if(is_overlap(x, y))
|
||||
return false;
|
||||
assert(y.first >= x.second);
|
||||
auto k = y.first - x.second;
|
||||
return (k >= n);
|
||||
});
|
||||
}
|
||||
|
||||
static std::size_t max_type_size(const shape& s)
|
||||
{
|
||||
return std::accumulate(
|
||||
s.sub_shapes().begin(),
|
||||
s.sub_shapes().end(),
|
||||
s.type_size(),
|
||||
[](auto size, const auto& sub) { return std::max(size, max_type_size(sub)); });
|
||||
}
|
||||
|
||||
static std::size_t compute_alignment(instruction_ref ins)
|
||||
{
|
||||
auto alignment = max_type_size(ins->get_shape());
|
||||
// A rough estimate for the total number of elements
|
||||
auto n = ins->get_shape().bytes() / alignment;
|
||||
// Check for vectorized alignment
|
||||
if(n > 4)
|
||||
{
|
||||
auto d = n % 4;
|
||||
if(d == 0)
|
||||
alignment *= 4;
|
||||
if(d == 2)
|
||||
alignment *= 2;
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
static segment
|
||||
next_segment(std::set<segment>& segments, instruction_ref ins, std::size_t alignment)
|
||||
{
|
||||
assert(ins->get_shape().bytes() > 0);
|
||||
// Compute alignment
|
||||
std::size_t n = 1 + (ins->get_shape().bytes() - 1) / alignment;
|
||||
assert(n > 0);
|
||||
std::size_t start = 0;
|
||||
// Insert at end if it cant fit at the begining
|
||||
if(segments.empty() or segments.begin()->first <= n)
|
||||
{
|
||||
auto it = find_gap(segments, n);
|
||||
if(it == segments.end())
|
||||
it = std::max_element(segments.begin(), segments.end(), [&](segment x, segment y) {
|
||||
return x.second < y.second;
|
||||
});
|
||||
if(it != segments.end())
|
||||
start = it->second;
|
||||
}
|
||||
auto s = segment{start, start + n};
|
||||
assert(not overlaps(segments, s));
|
||||
segments.insert(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
static std::unordered_map<instruction_ref, int>
|
||||
create_allocation_index(const module& m, const instruction_set_map& conflict_table)
|
||||
{
|
||||
std::unordered_map<instruction_ref, int> result;
|
||||
int i = 0;
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(not contains(conflict_table, ins))
|
||||
continue;
|
||||
result[ins] = i++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Build the allocation_color class from the conflict_table
|
||||
static allocation_segment
|
||||
build(const module& m, const instruction_set_map& conflict_table, std::size_t alignment)
|
||||
{
|
||||
allocation_segment as{};
|
||||
std::vector<instruction_ref> conflict_queue;
|
||||
// Add all allocations to the conflict_queue
|
||||
std::transform(conflict_table.begin(),
|
||||
conflict_table.end(),
|
||||
std::back_inserter(conflict_queue),
|
||||
[](auto&& pp) { return pp.first; });
|
||||
|
||||
auto alloc_index = create_allocation_index(m, conflict_table);
|
||||
|
||||
// Sort the conflict queue so we process the allocation with the most
|
||||
// number of adjacent allocations first
|
||||
std::sort(conflict_queue.begin(), conflict_queue.end(), by(std::greater<>{}, [&](auto x) {
|
||||
return std::make_tuple(
|
||||
conflict_table.at(x).size(), x->get_shape().bytes(), alloc_index.at(x));
|
||||
}));
|
||||
// Process the conflict_queue, we refer to the current allocation as
|
||||
// the parent and the adjacent allocations as children
|
||||
for(auto parent : conflict_queue)
|
||||
{
|
||||
// Sort children by size
|
||||
std::vector<instruction_ref> children(conflict_table.at(parent).begin(),
|
||||
conflict_table.at(parent).end());
|
||||
std::sort(children.begin(), children.end(), by(std::less<>{}, [&](auto x) {
|
||||
return std::make_tuple(x->get_shape().bytes(), alloc_index.at(x));
|
||||
}));
|
||||
assert(not contains(children, parent));
|
||||
// This set is to track the segments already processed
|
||||
std::set<segment> segments;
|
||||
// Add all segments for the children to the segments already processed
|
||||
transform_if(
|
||||
children.begin(),
|
||||
children.end(),
|
||||
std::inserter(segments, segments.begin()),
|
||||
[&](auto child) { return as.get_segment(child); },
|
||||
[&](auto child) { return *as.get_segment(child); });
|
||||
|
||||
assert(as.get_segment(parent) == nullptr);
|
||||
as.add_segment(parent, next_segment(segments, parent, alignment));
|
||||
}
|
||||
// Reduce the number of segments
|
||||
for(std::size_t n = 0; n < 3; n++)
|
||||
{
|
||||
for(auto parent : conflict_queue)
|
||||
{
|
||||
auto children = conflict_table.at(parent);
|
||||
// This set is to track the segments already processed
|
||||
std::set<segment> segments;
|
||||
// Add all segments for the children to the segments already processed
|
||||
transform_if(
|
||||
children.begin(),
|
||||
children.end(),
|
||||
std::inserter(segments, segments.begin()),
|
||||
[&](auto child) { return as.get_segment(child); },
|
||||
[&](auto child) { return *as.get_segment(child); });
|
||||
// Get the segment for the parent
|
||||
const auto* parent_segment = as.get_segment(parent);
|
||||
assert(parent_segment != nullptr);
|
||||
|
||||
auto s = next_segment(segments, parent, alignment);
|
||||
if(s != *parent_segment and s.second <= as.max())
|
||||
{
|
||||
as.add_segment(parent, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
return as;
|
||||
}
|
||||
};
|
||||
|
||||
static std::size_t find_max_alignment(const module& m, const std::string& allocation_op)
|
||||
{
|
||||
std::size_t alignment = 1;
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != allocation_op)
|
||||
continue;
|
||||
alignment = std::max(allocation_segment::compute_alignment(ins), alignment);
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
void memory_coloring::apply(module& m) const
|
||||
{
|
||||
const std::size_t alignment = find_max_alignment(m, allocation_op);
|
||||
auto conflict_table = build_conflict_table(m, allocation_op);
|
||||
auto as = allocation_segment::build(m, conflict_table, alignment);
|
||||
|
||||
// All allocations should have a segment
|
||||
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
|
||||
return as.get_segment(pp.first);
|
||||
}));
|
||||
|
||||
// Adjacent allocations should not have overlapping segments
|
||||
assert(std::none_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
|
||||
auto* x = as.get_segment(pp.first);
|
||||
return std::any_of(pp.second.begin(), pp.second.end(), [&](auto ins) {
|
||||
auto* y = as.get_segment(ins);
|
||||
assert(x and y);
|
||||
return is_overlap(*x, *y);
|
||||
});
|
||||
}));
|
||||
|
||||
// Print out segments
|
||||
if(enabled(MIGRAPHX_DEBUG_MEMORY_COLORING{}))
|
||||
{
|
||||
for(auto&& pp : conflict_table)
|
||||
{
|
||||
std::cout << "------- conflict -------" << std::endl;
|
||||
auto s1 = as.ins2segment.at(pp.first);
|
||||
std::cout << s1.first << ", " << s1.second << ": ";
|
||||
m.debug_print(pp.first);
|
||||
for(auto ins : pp.second)
|
||||
{
|
||||
auto s2 = as.ins2segment.at(ins);
|
||||
std::cout << s2.first << ", " << s2.second << ": ";
|
||||
m.debug_print(ins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Total memory
|
||||
std::size_t n = as.max() * alignment;
|
||||
|
||||
// Replace allocations
|
||||
auto mem = m.add_parameter("scratch", shape{shape::int8_type, {n}});
|
||||
for(auto&& [ins, seg] : as.ins2segment)
|
||||
{
|
||||
assert(ins->name() == allocation_op);
|
||||
auto s = ins->get_shape();
|
||||
std::size_t offset = seg.first * alignment;
|
||||
assert(offset < n);
|
||||
m.replace_instruction(
|
||||
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
|
||||
}
|
||||
|
||||
// Replace zero allocation
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != allocation_op)
|
||||
continue;
|
||||
assert(ins->get_shape().bytes() == 0);
|
||||
m.replace_instruction(
|
||||
ins, make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", 0}}), mem);
|
||||
}
|
||||
|
||||
// Remove scratch parameter if its not used
|
||||
if(mem->outputs().empty())
|
||||
{
|
||||
m.remove_instruction(mem);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
1544
docker/rocm/migraphx/module.cpp
Normal file
1544
docker/rocm/migraphx/module.cpp
Normal file
File diff suppressed because it is too large
Load Diff
256
docker/rocm/migraphx/msgpack.cpp
Normal file
256
docker/rocm/migraphx/msgpack.cpp
Normal file
@ -0,0 +1,256 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/msgpack.hpp>
|
||||
#include <migraphx/serialize.hpp>
|
||||
#include <msgpack.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
// Leave an extra byte for error checking
|
||||
constexpr std::size_t msgpack_size_limit = std::numeric_limits<uint32_t>::max() - 1;
|
||||
|
||||
template <class Range>
|
||||
std::size_t msgpack_chunk_size(const Range& r)
|
||||
{
|
||||
return 1 + (r.size() - 1) / msgpack_size_limit;
|
||||
}
|
||||
|
||||
template <class Iterator, class F>
|
||||
void msgpack_chunk_for_each(Iterator start, Iterator last, F f)
|
||||
{
|
||||
while(std::distance(start, last) > msgpack_size_limit)
|
||||
{
|
||||
auto next = std::next(start, msgpack_size_limit);
|
||||
f(start, next);
|
||||
start = next;
|
||||
}
|
||||
f(start, last);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
|
||||
namespace msgpack {
|
||||
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
|
||||
{
|
||||
namespace adaptor {
|
||||
|
||||
template <>
|
||||
struct convert<migraphx::value>
|
||||
{
|
||||
const msgpack::object& operator()(const msgpack::object& o, migraphx::value& v) const
|
||||
{
|
||||
switch(o.type)
|
||||
{
|
||||
case msgpack::type::NIL: {
|
||||
v = nullptr;
|
||||
break;
|
||||
}
|
||||
case msgpack::type::BOOLEAN: {
|
||||
v = o.as<bool>();
|
||||
break;
|
||||
}
|
||||
case msgpack::type::POSITIVE_INTEGER: {
|
||||
v = o.as<std::uint64_t>();
|
||||
break;
|
||||
}
|
||||
case msgpack::type::NEGATIVE_INTEGER: {
|
||||
v = o.as<std::int64_t>();
|
||||
break;
|
||||
}
|
||||
case msgpack::type::FLOAT32:
|
||||
case msgpack::type::FLOAT64: {
|
||||
v = o.as<double>();
|
||||
break;
|
||||
}
|
||||
case msgpack::type::STR: {
|
||||
v = o.as<std::string>();
|
||||
break;
|
||||
}
|
||||
case msgpack::type::BIN: {
|
||||
// For backwards compatibility
|
||||
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
|
||||
break;
|
||||
}
|
||||
case msgpack::type::ARRAY: {
|
||||
if(o.via.array.size != 0 and o.via.array.ptr->type == msgpack::type::BIN)
|
||||
{
|
||||
auto bin = migraphx::value::binary{};
|
||||
std::for_each(
|
||||
o.via.array.ptr,
|
||||
o.via.array.ptr + o.via.array.size,
|
||||
[&](const msgpack::object& so) {
|
||||
bin.insert(bin.end(), so.via.bin.ptr, so.via.bin.ptr + so.via.bin.size);
|
||||
});
|
||||
v = bin;
|
||||
}
|
||||
else
|
||||
{
|
||||
migraphx::value r = migraphx::value::array{};
|
||||
std::for_each(
|
||||
o.via.array.ptr,
|
||||
o.via.array.ptr + o.via.array.size,
|
||||
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
|
||||
v = r;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case msgpack::type::MAP: {
|
||||
migraphx::value r = migraphx::value::object{};
|
||||
std::for_each(o.via.map.ptr,
|
||||
o.via.map.ptr + o.via.map.size,
|
||||
[&](const msgpack::object_kv& p) {
|
||||
r[p.key.as<std::string>()] = p.val.as<migraphx::value>();
|
||||
});
|
||||
v = r;
|
||||
break;
|
||||
}
|
||||
case msgpack::type::EXT: {
|
||||
MIGRAPHX_THROW("msgpack EXT type not supported.");
|
||||
}
|
||||
}
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct pack<migraphx::value::binary>
|
||||
{
|
||||
template <class Stream>
|
||||
packer<Stream>& operator()(msgpack::packer<Stream>& o,
|
||||
const migraphx::value::binary& x) const
|
||||
{
|
||||
const auto* data = reinterpret_cast<const char*>(x.data());
|
||||
auto size = x.size();
|
||||
o.pack_array(migraphx::msgpack_chunk_size(x));
|
||||
migraphx::msgpack_chunk_for_each(
|
||||
data, data + size, [&](const char* start, const char* last) {
|
||||
o.pack_bin(last - start);
|
||||
o.pack_bin_body(start, last - start);
|
||||
});
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct pack<migraphx::value>
|
||||
{
|
||||
template <class Stream>
|
||||
void write(msgpack::packer<Stream>& o, const std::nullptr_t&) const
|
||||
{
|
||||
o.pack_nil();
|
||||
}
|
||||
template <class Stream, class T>
|
||||
void write(msgpack::packer<Stream>& o, const T& x) const
|
||||
{
|
||||
o.pack(x);
|
||||
}
|
||||
template <class Stream>
|
||||
void write(msgpack::packer<Stream>& o, const std::vector<migraphx::value>& v) const
|
||||
{
|
||||
if(v.empty())
|
||||
{
|
||||
o.pack_array(0);
|
||||
return;
|
||||
}
|
||||
if(v.size() > migraphx::msgpack_size_limit)
|
||||
MIGRAPHX_THROW("Size is too large for msgpack");
|
||||
if(not v.front().get_key().empty())
|
||||
{
|
||||
o.pack_map(v.size());
|
||||
for(auto&& x : v)
|
||||
{
|
||||
o.pack(x.get_key());
|
||||
o.pack(x.without_key());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
o.pack_array(v.size());
|
||||
for(auto&& x : v)
|
||||
{
|
||||
o.pack(x);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <class Stream>
|
||||
packer<Stream>& operator()(msgpack::packer<Stream>& o, const migraphx::value& v) const
|
||||
{
|
||||
v.visit_value([&](auto&& x) { this->write(o, x); });
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace adaptor
|
||||
} // MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
|
||||
} // namespace msgpack
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct vector_stream
|
||||
{
|
||||
std::vector<char> buffer{};
|
||||
vector_stream& write(const char* b, std::size_t n)
|
||||
{
|
||||
buffer.insert(buffer.end(), b, b + n);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
struct writer_stream
|
||||
{
|
||||
std::function<void(const char*, std::size_t)> writer;
|
||||
writer_stream& write(const char* b, std::size_t n)
|
||||
{
|
||||
writer(b, n);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer)
|
||||
{
|
||||
writer_stream ws{std::move(writer)};
|
||||
msgpack::pack(ws, v);
|
||||
}
|
||||
|
||||
std::vector<char> to_msgpack(const value& v)
|
||||
{
|
||||
vector_stream vs;
|
||||
msgpack::pack(vs, v);
|
||||
return vs.buffer;
|
||||
}
|
||||
value from_msgpack(const char* buffer, std::size_t size)
|
||||
{
|
||||
msgpack::object_handle oh = msgpack::unpack(buffer, size);
|
||||
return oh.get().as<value>();
|
||||
}
|
||||
value from_msgpack(const std::vector<char>& buffer)
|
||||
{
|
||||
return from_msgpack(buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
278
docker/rocm/migraphx/netron_output.cpp
Normal file
278
docker/rocm/migraphx/netron_output.cpp
Normal file
@ -0,0 +1,278 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/netron_output.hpp>
|
||||
#include <migraphx/json.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/serialize.hpp>
|
||||
#include <migraphx/base64.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
namespace {
|
||||
|
||||
// from https://onnx.ai/onnx/intro/concepts.html
|
||||
int get_onnx_type(shape::type_t s_type)
|
||||
{
|
||||
switch(s_type)
|
||||
{
|
||||
case shape::float_type: return 1;
|
||||
case shape::uint8_type: return 2;
|
||||
case shape::int8_type: return 3;
|
||||
case shape::uint16_type: return 4;
|
||||
case shape::int16_type: return 5;
|
||||
case shape::int32_type: return 6;
|
||||
case shape::int64_type: return 7;
|
||||
case shape::bool_type: return 9;
|
||||
case shape::half_type: return 10;
|
||||
case shape::double_type: return 11;
|
||||
case shape::uint32_type: return 12;
|
||||
case shape::uint64_type: return 13;
|
||||
case shape::bf16_type: return 16;
|
||||
case shape::fp8e4m3fn_type: return 17;
|
||||
case shape::fp8e4m3fnuz_type: return 18;
|
||||
case shape::fp8e5m2_type: return 19;
|
||||
case shape::fp8e5m2fnuz_type: return 20;
|
||||
case shape::tuple_type: return 0;
|
||||
}
|
||||
MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported");
|
||||
}
|
||||
|
||||
auto make_attribute(const migraphx::value& val)
|
||||
{
|
||||
value attribute = value(std::unordered_map<std::string, value>());
|
||||
attribute["name"] = val.get_key();
|
||||
auto val_string = val.to<std::string>();
|
||||
std::string sub_str = val.get_key() + ":";
|
||||
auto find_key = val_string.find(sub_str);
|
||||
if(find_key != std::string::npos)
|
||||
{
|
||||
val_string = val_string.substr(find_key + sub_str.length() + 1);
|
||||
}
|
||||
// TODO: doesn't work for some reason with Netron now
|
||||
// attribute["s"] = base64_encode(val_string);
|
||||
// attribute["type"] = "STRING";
|
||||
attribute["docString"] = val_string;
|
||||
return attribute;
|
||||
}
|
||||
|
||||
/// Returns a value with the JSON structure needed for a node
|
||||
auto make_onnx_json_node(instruction_ref ins,
|
||||
std::unordered_map<instruction_ref, std::string> ins_uids)
|
||||
{
|
||||
value node;
|
||||
// TODO add support for module inputs
|
||||
value input_arr = value({});
|
||||
for(instruction_ref input_ins : ins->inputs())
|
||||
{
|
||||
auto name = input_ins->name();
|
||||
if(name == "@literal" or name == "@param")
|
||||
{
|
||||
input_arr.push_back(ins_uids.at(input_ins));
|
||||
}
|
||||
// TODO make a better process for handling nodes to ignore
|
||||
else if(name.find("hip::hip_allocate_memory") != std::string::npos)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
input_arr.push_back(ins_uids.at(input_ins) + "->" + ins_uids.at(ins));
|
||||
}
|
||||
}
|
||||
value output_arr = value({});
|
||||
for(instruction_ref output_ins : ins->outputs())
|
||||
{
|
||||
if(output_ins->name() == "@return")
|
||||
{
|
||||
output_arr.push_back(ins_uids.at(output_ins));
|
||||
}
|
||||
else
|
||||
{
|
||||
output_arr.push_back(ins_uids.at(ins) + "->" + ins_uids.at(output_ins));
|
||||
}
|
||||
}
|
||||
node["input"] = input_arr;
|
||||
node["output"] = output_arr;
|
||||
node["name"] = ins_uids.at(ins);
|
||||
node["opType"] = ins->name();
|
||||
value op_attribute_arr = value({});
|
||||
auto op_value = ins->get_operator().to_value();
|
||||
std::for_each(op_value.begin(), op_value.end(), [&](auto v) {
|
||||
const std::string& attr_key = v.get_key();
|
||||
if(v.is_binary() or attr_key == "code_object")
|
||||
{
|
||||
return;
|
||||
}
|
||||
else if(attr_key == "symbol_name" or attr_key == "name")
|
||||
{
|
||||
node["opType"] = migraphx::from_value<std::string>(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
op_attribute_arr.push_back(make_attribute(v));
|
||||
}
|
||||
});
|
||||
node["attribute"] = op_attribute_arr;
|
||||
return node;
|
||||
}
|
||||
|
||||
// ONNX graph constant data called "initializer"
|
||||
auto make_onnx_json_literal(instruction_ref ins,
|
||||
std::unordered_map<instruction_ref, std::string> ins_uids)
|
||||
{
|
||||
value lit;
|
||||
lit["dims"] = ins->get_shape().lens();
|
||||
lit["dataType"] = get_onnx_type(ins->get_shape().type());
|
||||
lit["name"] = ins_uids.at(ins);
|
||||
// ignoring literal data, setting to "NULL" in base64
|
||||
lit["rawData"] = "TlVMTA==";
|
||||
return lit;
|
||||
}
|
||||
|
||||
// TODO handle dynamic shapes
|
||||
// TODO handle subshapes
|
||||
auto make_onnx_json_shape(const shape& s)
|
||||
{
|
||||
value ret;
|
||||
value dim = value({});
|
||||
for(std::size_t len : s.lens())
|
||||
{
|
||||
// cppcheck-suppress useStlAlgorithm
|
||||
dim.push_back({{"dimValue", len}});
|
||||
}
|
||||
ret["dim"] = dim;
|
||||
return ret;
|
||||
}
|
||||
|
||||
// ONNX graph edges called "valueType"
|
||||
auto make_onnx_json_edge(instruction_ref ins,
|
||||
instruction_ref out_ins,
|
||||
std::unordered_map<instruction_ref, std::string> ins_uids)
|
||||
{
|
||||
value ret;
|
||||
shape ins_shape = ins->get_shape();
|
||||
ret["name"] = ins_uids.at(ins) + "->" + ins_uids.at(out_ins);
|
||||
value type = {{"tensorType",
|
||||
{{"elemType", get_onnx_type(ins_shape.type())},
|
||||
{"shape", make_onnx_json_shape(ins_shape)}}}};
|
||||
ret["type"] = type;
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto make_onnx_json_in_out(instruction_ref ins,
|
||||
std::unordered_map<instruction_ref, std::string> ins_uids)
|
||||
{
|
||||
value ret;
|
||||
shape ins_shape = ins->get_shape();
|
||||
ret["name"] = ins_uids.at(ins);
|
||||
value type = {{"tensorType",
|
||||
{{"elemType", get_onnx_type(ins_shape.type())},
|
||||
{"shape", make_onnx_json_shape(ins_shape)}}}};
|
||||
ret["type"] = type;
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unordered_map<instruction_ref, std::string> make_ins_uids(const module& mod)
|
||||
{
|
||||
std::unordered_map<instruction_ref, std::string> ret;
|
||||
int count = 0;
|
||||
for(auto ins : iterator_for(mod))
|
||||
{
|
||||
std::string var_name;
|
||||
var_name = mod.name() + ":";
|
||||
var_name.append(ins->name() + ":");
|
||||
var_name.append("@" + std::to_string(count));
|
||||
count++;
|
||||
ret.emplace(ins, var_name);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
value make_graph(const module* mod)
|
||||
{
|
||||
value graph = {{"node", value({})},
|
||||
{"initializer", value({})},
|
||||
{"input", value({})},
|
||||
{"output", value({})},
|
||||
{"valueInfo", value({})}};
|
||||
auto ins_uids = make_ins_uids(*mod);
|
||||
for(auto ins = mod->begin(); ins != mod->end(); ++ins)
|
||||
{
|
||||
const auto& name = ins->name();
|
||||
if(name == "@literal")
|
||||
{
|
||||
graph["initializer"].push_back(make_onnx_json_literal(ins, ins_uids));
|
||||
}
|
||||
else if(name == "@param")
|
||||
{
|
||||
graph["input"].push_back(make_onnx_json_in_out(ins, ins_uids));
|
||||
}
|
||||
else if(name == "@return")
|
||||
{
|
||||
graph["output"].push_back(make_onnx_json_in_out(ins, ins_uids));
|
||||
}
|
||||
else if(name.find("hip::hip_allocate_memory") != std::string::npos)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
graph["node"].push_back(make_onnx_json_node(ins, ins_uids));
|
||||
const auto& outputs = ins->outputs();
|
||||
for(auto out_ins : outputs)
|
||||
{
|
||||
if(out_ins->name() != "@return")
|
||||
{
|
||||
graph["valueInfo"].push_back(make_onnx_json_edge(ins, out_ins, ins_uids));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string make_netron_output(const program& prog)
|
||||
{
|
||||
value output;
|
||||
auto prog_value = prog.to_value();
|
||||
// ONNX IR version 6
|
||||
// TODO: investigate sure how this affects things
|
||||
output["irVersion"] = 6;
|
||||
output["producerName"] = "AMDMIGraphX";
|
||||
output["producerVersion"] = prog_value.at("migraphx_version").to<std::string>();
|
||||
for(auto& mod : prog.get_modules())
|
||||
{
|
||||
auto graph = make_graph(mod);
|
||||
output["graph"] = graph;
|
||||
}
|
||||
return to_pretty_json_string(output, 4);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
283
docker/rocm/migraphx/normalize_attributes.cpp
Normal file
283
docker/rocm/migraphx/normalize_attributes.cpp
Normal file
@ -0,0 +1,283 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/normalize_attributes.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/op/normalize_attribute.hpp>
|
||||
#include <migraphx/op/common.hpp>
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* Parameters:
|
||||
* vec: the vector attribute to normalize
|
||||
* axes: the operator's axes attribute if it exists, empty otherwise
|
||||
* val: the normalize_axes key and options. Ex: normalize["axes"] =
|
||||
* value::array{normalize_attribute::include_min};
|
||||
* input_shape: input shape passed when calling
|
||||
* normalize_attributes(op&, input_shape)
|
||||
*
|
||||
* See normalize_attribute.hpp for explaining the options.
|
||||
*/
|
||||
template <class Message>
|
||||
auto tune_attribute(const std::vector<int64_t>& vec,
|
||||
const std::vector<int64_t>& axes,
|
||||
const value& val,
|
||||
const shape& input_shape,
|
||||
Message m)
|
||||
{
|
||||
std::vector<int64_t> result(vec);
|
||||
if(result.empty())
|
||||
{
|
||||
return result;
|
||||
};
|
||||
int64_t n_rank = input_shape.ndim();
|
||||
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
|
||||
if(contains(vec_attrs, op::normalize_attribute::use_output))
|
||||
{
|
||||
n_rank = n_rank + vec.size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> max_vals(vec.size(), n_rank);
|
||||
|
||||
if(contains(vec_attrs, op::normalize_attribute::use_len))
|
||||
{
|
||||
if(input_shape.dynamic())
|
||||
{
|
||||
// return the unchanged `vec` if the dynamic_dimensions at `axes` are not fixed
|
||||
if(std::any_of(axes.begin(), axes.end(), [&](auto ax) {
|
||||
return not input_shape.dyn_dims().at(ax).is_fixed();
|
||||
}))
|
||||
{
|
||||
return vec;
|
||||
}
|
||||
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
|
||||
return input_shape.dyn_dims().at(i).max;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
|
||||
return input_shape.lens().at(i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if(contains(vec_attrs, op::normalize_attribute::clip_max))
|
||||
{
|
||||
if(contains(vec_attrs, op::normalize_attribute::include_max))
|
||||
{
|
||||
std::transform(result.begin(),
|
||||
result.end(),
|
||||
max_vals.begin(),
|
||||
result.begin(),
|
||||
[](auto v, auto mv) { return v > mv ? mv : v; });
|
||||
}
|
||||
else
|
||||
{
|
||||
std::transform(result.begin(),
|
||||
result.end(),
|
||||
max_vals.begin(),
|
||||
result.begin(),
|
||||
[](auto v, auto mv) { return v >= mv ? mv - 1 : v; });
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(contains(vec_attrs, op::normalize_attribute::include_max))
|
||||
{
|
||||
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
|
||||
{
|
||||
MIGRAPHX_THROW(m() + "value out of range!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
|
||||
{
|
||||
MIGRAPHX_THROW(m() + "value out of range!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> min_vals = max_vals;
|
||||
std::transform(min_vals.begin(), min_vals.end(), min_vals.begin(), [](auto v) { return -v; });
|
||||
if(contains(vec_attrs, op::normalize_attribute::clip_min))
|
||||
{
|
||||
if(contains(vec_attrs, op::normalize_attribute::include_min))
|
||||
{
|
||||
std::transform(result.begin(),
|
||||
result.end(),
|
||||
min_vals.begin(),
|
||||
result.begin(),
|
||||
[](auto v, auto mv) { return v < mv ? mv : v; });
|
||||
}
|
||||
else
|
||||
{
|
||||
std::transform(result.begin(),
|
||||
result.end(),
|
||||
min_vals.begin(),
|
||||
result.begin(),
|
||||
[](auto v, auto mv) { return v < mv + 1 ? mv + 1 : v; });
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(contains(vec_attrs, op::normalize_attribute::include_min))
|
||||
{
|
||||
if(not std::equal(
|
||||
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
|
||||
{
|
||||
MIGRAPHX_THROW(m() + "attribute out of range!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
|
||||
{
|
||||
MIGRAPHX_THROW(m() + "attribute out of range!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::transform(
|
||||
result.begin(), result.end(), max_vals.begin(), result.begin(), [](auto v, auto mv) {
|
||||
return v < 0 ? v + mv : v;
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
auto tune_pad_attribute(const value& val)
|
||||
{
|
||||
|
||||
std::vector<size_t> vec_attrs = val.to_vector<size_t>();
|
||||
std::vector<size_t> result(vec_attrs.begin(), vec_attrs.end());
|
||||
std::copy(vec_attrs.begin(), vec_attrs.end(), std::back_inserter(result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Assumptions:
|
||||
* Dimensions to pad start from the third dimension (index 2).
|
||||
* Called by compute_shape_op() with the shape of the first input.
|
||||
*/
|
||||
bool normalize_attributes(operation& op, const shape& input_shape)
|
||||
{
|
||||
bool tuned = false;
|
||||
auto attrs = op.attributes();
|
||||
auto val = op.to_value();
|
||||
if(attrs.contains("normalize_padding"))
|
||||
{
|
||||
bool use_auto_padding =
|
||||
(val.contains("padding_mode") and
|
||||
(val.at("padding_mode").to<int>() != migraphx::op::padding_mode_t::default_));
|
||||
if(not use_auto_padding)
|
||||
{
|
||||
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
|
||||
auto padding_size = padding.size();
|
||||
auto padding_start = 2;
|
||||
if(padding_size == 2 * (input_shape.ndim() - padding_start))
|
||||
tuned = true;
|
||||
else if(padding_size != (input_shape.ndim() - padding_start))
|
||||
{
|
||||
MIGRAPHX_THROW("normalize_attributes: inconsistent padding vector size ");
|
||||
}
|
||||
else
|
||||
{
|
||||
auto result = tune_pad_attribute(padding);
|
||||
val["padding"] = result;
|
||||
op.from_value(val);
|
||||
tuned = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(not attrs.contains("normalize_axes"))
|
||||
{
|
||||
return tuned;
|
||||
}
|
||||
|
||||
auto attr_v = attrs.at("normalize_axes").without_key();
|
||||
for(const auto& rv : attr_v)
|
||||
{
|
||||
const auto& key = rv.get_key();
|
||||
if(val.contains(key))
|
||||
{
|
||||
auto message = [&] { return op.name() + ": " + key + ": "; };
|
||||
auto vv = val.at(key).without_key();
|
||||
if(vv.is_array())
|
||||
{
|
||||
std::vector<int64_t> axes;
|
||||
if(val.contains("axes"))
|
||||
{
|
||||
axes = val.at("axes").without_key().to_vector<int64_t>();
|
||||
}
|
||||
auto vec = vv.to_vector<int64_t>();
|
||||
auto result = tune_attribute(vec, axes, rv.without_key(), input_shape, message);
|
||||
val[key] = result;
|
||||
op.from_value(val);
|
||||
val = op.to_value();
|
||||
tuned = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto num = vv.to<int64_t>();
|
||||
auto result = tune_attribute({num}, {num}, rv.without_key(), input_shape, message);
|
||||
val[key] = result.front();
|
||||
op.from_value(val);
|
||||
val = op.to_value();
|
||||
tuned = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
MIGRAPHX_THROW("NORMALIZE_ATTR : op " + op.name() + " attribute \"" + key +
|
||||
"\" not exist!");
|
||||
}
|
||||
}
|
||||
|
||||
return tuned;
|
||||
}
|
||||
|
||||
std::vector<int64_t> normalize_axes(const std::vector<int64_t>& axes,
|
||||
const shape& input_shape,
|
||||
const value& attr_val,
|
||||
const std::string& prefix)
|
||||
{
|
||||
return tune_attribute(axes, {}, attr_val, input_shape, [&] { return prefix; });
|
||||
}
|
||||
|
||||
std::vector<int64_t> normalize_indices(const std::vector<int64_t>& indices,
|
||||
const std::vector<int64_t>& axes,
|
||||
const shape& input_shape,
|
||||
const value& attr_val,
|
||||
const std::string& prefix)
|
||||
{
|
||||
return tune_attribute(indices, axes, attr_val, input_shape, [&] { return prefix; });
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
57
docker/rocm/migraphx/normalize_ops.cpp
Normal file
57
docker/rocm/migraphx/normalize_ops.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <unordered_set>
|
||||
#include <migraphx/normalize_attributes.hpp>
|
||||
#include <migraphx/normalize_ops.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/auto_any_cast.hpp>
|
||||
#include <migraphx/value.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void normalize_ops::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
auto inputs = ins->inputs();
|
||||
if(inputs.empty())
|
||||
continue;
|
||||
|
||||
auto s = inputs[0]->get_shape();
|
||||
migraphx::operation tuned_op = ins->get_operator();
|
||||
if(normalize_attributes(tuned_op, s))
|
||||
{
|
||||
m.replace_instruction(ins, tuned_op, inputs);
|
||||
ins->set_normalized();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
55
docker/rocm/migraphx/op_enums.cpp
Normal file
55
docker/rocm/migraphx/op_enums.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
//
|
||||
// Supporting functions for enum values used in operator parameters.
|
||||
// These values are declared as "enum class" and should include << streaming operators
|
||||
// to be able to write their values in human-readable format so users can
|
||||
// save and edit model files.
|
||||
//
|
||||
#include <sstream>
|
||||
#include <migraphx/op/common.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
namespace op {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, pooling_mode v)
|
||||
{
|
||||
// the strings for the enum are the same as the values used for onnx parsing
|
||||
// but this enum is not onnx-specific: strings must be converted when parsing tf
|
||||
static const std::vector<std::string> pooling_mode_str = {"average", "max", "lpnorm"};
|
||||
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
|
||||
return os;
|
||||
}
|
||||
std::ostream& operator<<(std::ostream& os, rnn_direction v)
|
||||
{
|
||||
static const std::vector<std::string> rnn_direction_str = {
|
||||
"forward", "reverse", "bidirectional"};
|
||||
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
41
docker/rocm/migraphx/operation.cpp
Normal file
41
docker/rocm/migraphx/operation.cpp
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void migraphx_to_value(value& v, const operation& op)
|
||||
{
|
||||
v["name"] = op.name();
|
||||
v["operator"] = op.to_value();
|
||||
}
|
||||
void migraphx_from_value(const value& v, operation& op)
|
||||
{
|
||||
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
55
docker/rocm/migraphx/optimize_module.cpp
Normal file
55
docker/rocm/migraphx/optimize_module.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/optimize_module.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/simplify_reshapes.hpp>
|
||||
#include <migraphx/simplify_algebra.hpp>
|
||||
#include <migraphx/eliminate_common_subexpression.hpp>
|
||||
#include <migraphx/eliminate_convert.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/propagate_constant.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void optimize_module::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
mpm.get_module().repeat_while_changes(2, [&] {
|
||||
// loop to further optimize after initial transformations
|
||||
mpm.get_module().repeat_while_changes(4, [&] {
|
||||
mpm.run_pass(simplify_reshapes{});
|
||||
mpm.run_pass(eliminate_convert{});
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
mpm.run_pass(simplify_algebra{});
|
||||
});
|
||||
mpm.run_pass(eliminate_common_subexpression{});
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
mpm.run_pass(propagate_constant{propagate_constant_skip_ops});
|
||||
mpm.run_pass(dead_code_elimination{});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
154
docker/rocm/migraphx/pad_calc.cpp
Normal file
154
docker/rocm/migraphx/pad_calc.cpp
Normal file
@ -0,0 +1,154 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/pad_calc.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void calculate_padding(int64_t idx,
|
||||
std::vector<int64_t>& pads,
|
||||
int64_t input_dim,
|
||||
int64_t stride,
|
||||
int64_t dilation,
|
||||
int64_t weight_dim,
|
||||
bool is_same_upper)
|
||||
{
|
||||
int64_t output_dim = (input_dim + stride - 1) / stride; // round up result
|
||||
int64_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
|
||||
int64_t pad =
|
||||
std::max(static_cast<int64_t>(0), (output_dim - 1) * stride + new_weight_dim - input_dim);
|
||||
auto pad_ndims = pads.size() / 2;
|
||||
|
||||
if(is_same_upper)
|
||||
{
|
||||
pads[idx] = pad / 2;
|
||||
pads[idx + pad_ndims] = pad - pad / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
pads[idx + pad_ndims] = pad / 2;
|
||||
pads[idx] = pad - pad / 2;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given the input array dimensions; kernel (wei_lens); strides; and dilations,
|
||||
* calculate the padding value in each dimension.
|
||||
*
|
||||
*/
|
||||
std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
|
||||
const std::vector<std::size_t>& wei_lens,
|
||||
const std::vector<std::size_t>& strides,
|
||||
const std::vector<std::size_t>& dilations,
|
||||
bool use_upper)
|
||||
{
|
||||
std::vector<std::size_t> padding;
|
||||
assert(input_lens.size() >= 3);
|
||||
assert(input_lens.size() == wei_lens.size());
|
||||
std::size_t num_spatial_dims = input_lens.size() - 2;
|
||||
padding.resize(2 * num_spatial_dims);
|
||||
for(std::size_t i = 0; i < num_spatial_dims; i++)
|
||||
{
|
||||
std::ptrdiff_t input_dim = input_lens[i + 2];
|
||||
std::ptrdiff_t stride = strides[i];
|
||||
std::ptrdiff_t weight_dim = wei_lens[i + 2];
|
||||
std::ptrdiff_t dilation = dilations[i];
|
||||
std::ptrdiff_t output_dim = (input_dim + stride - 1) / stride; // round up result
|
||||
std::ptrdiff_t new_weight_dim = weight_dim + (weight_dim - 1) * (dilation - 1);
|
||||
std::size_t pad = std::max(static_cast<std::ptrdiff_t>(0),
|
||||
(output_dim - 1) * stride + new_weight_dim - input_dim);
|
||||
auto pad_ndims = padding.size() / 2;
|
||||
|
||||
if(use_upper)
|
||||
{
|
||||
padding[i] = pad / 2;
|
||||
padding[i + pad_ndims] = pad - pad / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
padding[i + pad_ndims] = pad / 2;
|
||||
padding[i] = pad - pad / 2;
|
||||
}
|
||||
}
|
||||
return padding;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the correct output shape for a convolution with
|
||||
* a given input size and other parameters.
|
||||
*
|
||||
*/
|
||||
shape compute_padded_shape(const shape& input,
|
||||
const shape& weights,
|
||||
const std::vector<std::size_t>& padding,
|
||||
const std::vector<std::size_t>& stride,
|
||||
const std::vector<std::size_t>& dilation)
|
||||
{
|
||||
const size_t num_spatial_dims = input.lens().size() - 2;
|
||||
|
||||
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
|
||||
// calculate the output shape of the convolution: ((W - K + 2P) / S) + 1
|
||||
for(size_t i = 0; i < num_spatial_dims; ++i)
|
||||
{
|
||||
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
|
||||
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
|
||||
1,
|
||||
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
|
||||
padding_factor) /
|
||||
stride[i] +
|
||||
1)));
|
||||
}
|
||||
return input.with_lens(output_lens);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the correct output shape for a pooling with
|
||||
* a given input size and other parameters. This uses
|
||||
* the same formula for pooling that compute_padded_shape() uses
|
||||
* for convolutions, but takes slightly different inputs.
|
||||
*
|
||||
*/
|
||||
shape compute_padded_pool_shape(const shape& input,
|
||||
const shape& kernel,
|
||||
const std::vector<std::size_t>& padding,
|
||||
const std::vector<std::size_t>& stride,
|
||||
const std::vector<std::size_t>& dilation)
|
||||
{
|
||||
const size_t num_spatial_dims = input.lens().size() - 2;
|
||||
|
||||
std::vector<size_t> output_lens{input.lens()[0], input.lens()[1]};
|
||||
// calculate the output shape of the pooling: ((W - K + 2P) / S) + 1
|
||||
for(size_t i = 0; i < num_spatial_dims; ++i)
|
||||
{
|
||||
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
|
||||
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
|
||||
1,
|
||||
(input.lens()[i + 2] - (1 + dilation[i] * (kernel.lens()[i] - 1)) + padding_factor) /
|
||||
stride[i] +
|
||||
1)));
|
||||
}
|
||||
return input.with_lens(output_lens);
|
||||
}
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
82
docker/rocm/migraphx/param_utils.cpp
Normal file
82
docker/rocm/migraphx/param_utils.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#include <migraphx/param_utils.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/builtin.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <map>
|
||||
#include <cmath>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::string param_name(std::size_t i, const std::string& prefix)
|
||||
{
|
||||
if(i < 10)
|
||||
return prefix + std::to_string(i);
|
||||
const std::size_t max_digits = 5;
|
||||
if(i >= std::pow(10, max_digits))
|
||||
MIGRAPHX_THROW("Too many parameters.");
|
||||
std::size_t n = log10(i) + 1;
|
||||
return prefix + ":" + std::string(max_digits - n, '0') + std::to_string(i);
|
||||
}
|
||||
|
||||
void sort_params(std::vector<instruction_ref>& params)
|
||||
{
|
||||
std::sort(params.begin(), params.end(), by(std::less<>{}, [](instruction_ref ins) {
|
||||
const auto& param = any_cast<const builtin::param&>(ins->get_operator());
|
||||
return param.parameter;
|
||||
}));
|
||||
}
|
||||
|
||||
std::vector<instruction_ref>
|
||||
find_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins,
|
||||
const_module_ref parent,
|
||||
const_module_ref sub)
|
||||
{
|
||||
std::vector<instruction_ref> result;
|
||||
std::map<std::string, instruction_ref> names;
|
||||
for(auto&& [input, param] : map_ins)
|
||||
{
|
||||
if(sub != nullptr and not sub->has_instruction(param))
|
||||
continue;
|
||||
if(param->name() != "@param")
|
||||
continue;
|
||||
if(parent != nullptr and not parent->has_instruction(input))
|
||||
continue;
|
||||
auto v = param->get_operator().to_value();
|
||||
auto name = v.at("parameter").to<std::string>();
|
||||
names[name] = input;
|
||||
}
|
||||
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
|
||||
return p.second;
|
||||
});
|
||||
assert(not sub or result.size() == sub->get_parameter_shapes().size());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
45
docker/rocm/migraphx/pass.cpp
Normal file
45
docker/rocm/migraphx/pass.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/pass.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/// Dummy pass for default return
|
||||
struct id_pass
|
||||
{
|
||||
std::string name() const { return "id"; }
|
||||
void apply(const module&) const {}
|
||||
};
|
||||
|
||||
pass enable_pass(bool enabled, pass p)
|
||||
{
|
||||
if(enabled)
|
||||
return p;
|
||||
return id_pass{};
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
209
docker/rocm/migraphx/pass_manager.cpp
Normal file
209
docker/rocm/migraphx/pass_manager.cpp
Normal file
@ -0,0 +1,209 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/target.hpp>
|
||||
#include <migraphx/env.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/time.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_PASSES);
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_PASSES);
|
||||
|
||||
static bool is_pass_disabled(const std::string& name)
|
||||
{
|
||||
static const auto passes = split_string(string_value_of(MIGRAPHX_DISABLE_PASSES{}, ""), ',');
|
||||
return contains(passes, name);
|
||||
}
|
||||
|
||||
void validate_pass(module& mod, const pass& p, tracer trace)
|
||||
{
|
||||
(void)mod;
|
||||
(void)p;
|
||||
(void)trace;
|
||||
#ifndef NDEBUG
|
||||
trace("Validate ...");
|
||||
auto invalid = mod.validate();
|
||||
if(invalid != mod.end())
|
||||
{
|
||||
auto index = std::distance(mod.begin(), invalid);
|
||||
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
|
||||
std::to_string(index) + ": " + invalid->name());
|
||||
}
|
||||
trace();
|
||||
#endif
|
||||
}
|
||||
void run_pass(program& prog, const pass& p, tracer trace)
|
||||
{
|
||||
trace("Pass: ", p.name());
|
||||
p.apply(prog);
|
||||
trace(prog);
|
||||
}
|
||||
|
||||
struct module_pm : module_pass_manager
|
||||
{
|
||||
module* mod = nullptr;
|
||||
module* root_mod = nullptr;
|
||||
tracer* t = nullptr;
|
||||
module* common_parent = nullptr;
|
||||
program* prog = nullptr;
|
||||
|
||||
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
|
||||
|
||||
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
|
||||
: mod(pmod), root_mod(rmod), t(pt)
|
||||
{
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
void trace(Ts&&... xs) const
|
||||
{
|
||||
assert(t);
|
||||
(*t)(xs...);
|
||||
}
|
||||
|
||||
virtual module& get_module() override
|
||||
{
|
||||
assert(mod);
|
||||
return *mod;
|
||||
}
|
||||
|
||||
virtual module* create_module(const std::string& name) override
|
||||
{
|
||||
assert(prog);
|
||||
return prog->create_module(name);
|
||||
}
|
||||
|
||||
virtual module* create_module(const std::string& name, module m) override
|
||||
{
|
||||
assert(prog);
|
||||
return prog->create_module(name, std::move(m));
|
||||
}
|
||||
|
||||
virtual void rename_module(const std::string& old_name, const std::string& new_name) override
|
||||
{
|
||||
assert(prog);
|
||||
assert(mod);
|
||||
assert(
|
||||
any_of(mod->get_sub_modules(), [&](module_ref sm) { return sm->name() == old_name; }));
|
||||
prog->rename_module(old_name, new_name);
|
||||
}
|
||||
|
||||
virtual module* get_common_parent() override { return common_parent; }
|
||||
|
||||
virtual module* get_root_module() override
|
||||
{
|
||||
if(root_mod != nullptr)
|
||||
return root_mod;
|
||||
assert(prog);
|
||||
return prog->get_main_module();
|
||||
}
|
||||
|
||||
virtual void run_pass(const pass& p) override
|
||||
{
|
||||
if(is_pass_disabled(p.name()))
|
||||
return;
|
||||
trace("Pass: ", p.name());
|
||||
assert(mod);
|
||||
assert(mod->validate() == mod->end());
|
||||
if(enabled(MIGRAPHX_TIME_PASSES{}))
|
||||
{
|
||||
using milliseconds = std::chrono::duration<double, std::milli>;
|
||||
auto ms = time<milliseconds>([&] { p.apply(*this); });
|
||||
std::cout << p.name() << ": " << ms << "ms\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
p.apply(*this);
|
||||
}
|
||||
trace(*mod);
|
||||
validate_pass(*mod, p, *t);
|
||||
}
|
||||
};
|
||||
|
||||
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
|
||||
|
||||
void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& passes, tracer trace)
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_PASSES{}))
|
||||
trace = tracer{std::cout};
|
||||
std::unordered_set<module_ref> visited;
|
||||
for(const auto& p : passes)
|
||||
{
|
||||
auto tree = prog.get_module_tree();
|
||||
std::vector<module_ref> sub_mods = root_mod->get_sub_modules();
|
||||
sub_mods.insert(sub_mods.begin(), root_mod);
|
||||
visited.clear();
|
||||
for(const auto& mod : reverse(sub_mods))
|
||||
{
|
||||
if(mod->bypass())
|
||||
continue;
|
||||
if(not visited.insert(mod).second)
|
||||
continue;
|
||||
module_pm mpm{mod, root_mod, &trace};
|
||||
mpm.prog = &prog;
|
||||
auto parents = range(tree.equal_range(mod));
|
||||
auto nparents = distance(parents);
|
||||
if(nparents == 0)
|
||||
mpm.common_parent = nullptr;
|
||||
else if(nparents == 1)
|
||||
mpm.common_parent = parents.begin()->second;
|
||||
else
|
||||
// Just set common parent to main module when there is muliple parents for now
|
||||
// TODO: Compute the common parent
|
||||
mpm.common_parent = prog.get_main_module();
|
||||
mpm.run_pass(p);
|
||||
}
|
||||
run_pass(prog, p, trace);
|
||||
}
|
||||
}
|
||||
|
||||
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_PASSES{}))
|
||||
trace = tracer{std::cout};
|
||||
for(const auto& p : passes)
|
||||
{
|
||||
module_pm{&mod, &mod, &trace}.run_pass(p);
|
||||
}
|
||||
}
|
||||
|
||||
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
|
||||
{
|
||||
run_passes(prog, prog.get_main_module(), passes, trace);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
88
docker/rocm/migraphx/permutation.cpp
Normal file
88
docker/rocm/migraphx/permutation.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/permutation.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/algorithm.hpp>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
|
||||
{
|
||||
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
|
||||
}
|
||||
|
||||
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
|
||||
{
|
||||
return sort_permutation(permutation, std::less<>{});
|
||||
}
|
||||
|
||||
std::vector<int64_t> find_permutation(const shape& s)
|
||||
{
|
||||
std::vector<std::int64_t> result(s.lens().size());
|
||||
std::iota(result.begin(), result.end(), 0);
|
||||
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
|
||||
return std::make_tuple(s.strides()[x], s.lens()[x]);
|
||||
}));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
|
||||
{
|
||||
if(shapes.empty())
|
||||
return {};
|
||||
std::map<std::vector<int64_t>, std::size_t> count;
|
||||
for(auto&& s : shapes)
|
||||
{
|
||||
if(s.broadcasted())
|
||||
continue;
|
||||
count[find_permutation(s)]++;
|
||||
}
|
||||
if(count.empty())
|
||||
{
|
||||
std::vector<int64_t> r(shapes.front().lens().size());
|
||||
std::iota(r.begin(), r.end(), 0);
|
||||
return r;
|
||||
}
|
||||
auto it = std::max_element(
|
||||
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }));
|
||||
assert(it != count.end());
|
||||
return it->first;
|
||||
}
|
||||
|
||||
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
|
||||
{
|
||||
auto result = shapes;
|
||||
auto perm = find_permutation(shapes);
|
||||
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
|
||||
return reorder_shape(s, perm);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
52
docker/rocm/migraphx/preallocate_param.cpp
Normal file
52
docker/rocm/migraphx/preallocate_param.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/preallocate_param.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void preallocate_param::apply(module& m) const
|
||||
{
|
||||
auto last = std::prev(m.end());
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "@param")
|
||||
continue;
|
||||
if(param != any_cast<builtin::param>(ins->get_operator()).parameter)
|
||||
continue;
|
||||
std::string id = m.name() + ":" + param;
|
||||
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
|
||||
m.replace_instruction(ins, r);
|
||||
m.move_instruction(ins, m.end());
|
||||
}
|
||||
m.remove_instructions(std::next(last), m.end());
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
461
docker/rocm/migraphx/process.cpp
Normal file
461
docker/rocm/migraphx/process.cpp
Normal file
@ -0,0 +1,461 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/env.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/process.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/tmp_dir.hpp>
|
||||
#include <migraphx/fileutils.hpp>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#ifdef _WIN32
|
||||
// cppcheck-suppress definePrefix
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <Windows.h>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <optional>
|
||||
#endif
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE)
|
||||
|
||||
#ifndef _WIN32
|
||||
|
||||
std::function<void(const char*)> redirect_to(std::ostream& os)
|
||||
{
|
||||
return [&](const char* x) { os << x; };
|
||||
}
|
||||
|
||||
template <class F>
|
||||
int exec(const std::string& cmd, const char* type, F f)
|
||||
{
|
||||
int ec = 0;
|
||||
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
|
||||
std::cout << cmd << std::endl;
|
||||
auto closer = [&](FILE* stream) {
|
||||
auto status = pclose(stream);
|
||||
ec = WIFEXITED(status) ? WEXITSTATUS(status) : 0; // NOLINT
|
||||
};
|
||||
{
|
||||
// TODO: Use execve instead of popen
|
||||
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), type), closer); // NOLINT
|
||||
if(not pipe)
|
||||
MIGRAPHX_THROW("popen() failed: " + cmd);
|
||||
f(pipe.get());
|
||||
}
|
||||
return ec;
|
||||
}
|
||||
|
||||
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
|
||||
{
|
||||
return exec(cmd, "r", [&](FILE* f) {
|
||||
std::array<char, 128> buffer;
|
||||
while(fgets(buffer.data(), buffer.size(), f) != nullptr)
|
||||
std_out(buffer.data());
|
||||
});
|
||||
}
|
||||
|
||||
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
|
||||
{
|
||||
return exec(cmd, "w", [&](FILE* f) {
|
||||
std_in([&](const char* buffer, std::size_t n) { std::fwrite(buffer, 1, n, f); });
|
||||
});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
constexpr std::size_t MIGRAPHX_PROCESS_BUFSIZE = 4096;
|
||||
|
||||
enum class direction
|
||||
{
|
||||
input,
|
||||
output
|
||||
};
|
||||
|
||||
template <direction dir>
|
||||
class pipe
|
||||
{
|
||||
public:
|
||||
explicit pipe()
|
||||
{
|
||||
SECURITY_ATTRIBUTES attrs;
|
||||
attrs.nLength = sizeof(SECURITY_ATTRIBUTES);
|
||||
attrs.bInheritHandle = TRUE;
|
||||
attrs.lpSecurityDescriptor = nullptr;
|
||||
|
||||
if(CreatePipe(&m_read, &m_write, &attrs, 0) == FALSE)
|
||||
throw GetLastError();
|
||||
|
||||
if(dir == direction::output)
|
||||
{
|
||||
// Do not inherit the read handle for the output pipe
|
||||
if(SetHandleInformation(m_read, HANDLE_FLAG_INHERIT, 0) == 0)
|
||||
throw GetLastError();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Do not inherit the write handle for the input pipe
|
||||
if(SetHandleInformation(m_write, HANDLE_FLAG_INHERIT, 0) == 0)
|
||||
throw GetLastError();
|
||||
}
|
||||
}
|
||||
|
||||
pipe(const pipe&) = delete;
|
||||
pipe& operator=(const pipe&) = delete;
|
||||
|
||||
pipe(pipe&&) = default;
|
||||
|
||||
~pipe()
|
||||
{
|
||||
if(m_write != nullptr)
|
||||
{
|
||||
CloseHandle(m_write);
|
||||
}
|
||||
if(m_read != nullptr)
|
||||
{
|
||||
CloseHandle(m_read);
|
||||
}
|
||||
}
|
||||
|
||||
bool close_write_handle()
|
||||
{
|
||||
auto result = true;
|
||||
if(m_write != nullptr)
|
||||
{
|
||||
result = CloseHandle(m_write) == TRUE;
|
||||
m_write = nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool close_read_handle()
|
||||
{
|
||||
auto result = true;
|
||||
if(m_read != nullptr)
|
||||
{
|
||||
result = CloseHandle(m_read) == TRUE;
|
||||
m_read = nullptr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<bool, DWORD> read(LPVOID buffer, DWORD length) const
|
||||
{
|
||||
DWORD bytes_read;
|
||||
if(ReadFile(m_read, buffer, length, &bytes_read, nullptr) == FALSE and
|
||||
GetLastError() == ERROR_MORE_DATA)
|
||||
{
|
||||
return {true, bytes_read};
|
||||
}
|
||||
return {false, bytes_read};
|
||||
}
|
||||
|
||||
HANDLE get_read_handle() const { return m_read; }
|
||||
|
||||
bool write(LPCVOID buffer, DWORD length) const
|
||||
{
|
||||
DWORD bytes_written;
|
||||
return WriteFile(m_write, buffer, length, &bytes_written, nullptr) == TRUE;
|
||||
}
|
||||
|
||||
HANDLE get_write_handle() const { return m_write; }
|
||||
|
||||
private:
|
||||
HANDLE m_write = nullptr, m_read = nullptr;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <typename F>
|
||||
int exec(const std::string& cmd, const std::string& cwd, const std::string& args,
|
||||
const std::string& envs, F f)
|
||||
// clang-format on
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
|
||||
{
|
||||
std::cout << "[cwd=" << cwd << "]; cmd='" << cmd << "\'; args='" << args << "'; envs='"
|
||||
<< envs << "'\n";
|
||||
}
|
||||
|
||||
// See CreateProcess() WIN32 documentation for details.
|
||||
constexpr std::size_t CMDLINE_LENGTH = 32767;
|
||||
|
||||
// Build lpCommandLine parameter.
|
||||
std::string cmdline = quote_string(cmd);
|
||||
if(not args.empty())
|
||||
cmdline += " " + args;
|
||||
|
||||
// clang-format off
|
||||
if(cmdline.size() > CMDLINE_LENGTH)
|
||||
MIGRAPHX_THROW("Command line too long, required maximum " +
|
||||
std::to_string(CMDLINE_LENGTH) + " characters.");
|
||||
// clang-format on
|
||||
|
||||
if(cmdline.size() < CMDLINE_LENGTH)
|
||||
cmdline.resize(CMDLINE_LENGTH, '\0');
|
||||
|
||||
// Build lpEnvironment parameter.
|
||||
std::vector<TCHAR> environment{};
|
||||
if(not envs.empty())
|
||||
{
|
||||
std::istringstream iss{envs};
|
||||
std::string str;
|
||||
while(iss >> str)
|
||||
{
|
||||
environment.insert(environment.end(), str.begin(), str.end());
|
||||
environment.push_back('\0');
|
||||
}
|
||||
environment.push_back('\0');
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
STARTUPINFO info;
|
||||
PROCESS_INFORMATION process_info;
|
||||
|
||||
pipe<direction::input> input{};
|
||||
pipe<direction::output> output{};
|
||||
|
||||
ZeroMemory(&info, sizeof(STARTUPINFO));
|
||||
info.cb = sizeof(STARTUPINFO);
|
||||
info.hStdError = output.get_write_handle();
|
||||
info.hStdOutput = output.get_write_handle();
|
||||
info.hStdInput = input.get_read_handle();
|
||||
info.dwFlags |= STARTF_USESTDHANDLES;
|
||||
|
||||
ZeroMemory(&process_info, sizeof(process_info));
|
||||
|
||||
if(CreateProcess(cmd.c_str(),
|
||||
cmdline.data(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
TRUE,
|
||||
0,
|
||||
environment.empty() ? nullptr : environment.data(),
|
||||
cwd.empty() ? nullptr : static_cast<LPCSTR>(cwd.c_str()),
|
||||
&info,
|
||||
&process_info) == FALSE)
|
||||
{
|
||||
MIGRAPHX_THROW("Error creating process (" + std::to_string(GetLastError()) + ")");
|
||||
}
|
||||
|
||||
CloseHandle(process_info.hThread);
|
||||
|
||||
if(not output.close_write_handle())
|
||||
MIGRAPHX_THROW("Error closing STDOUT handle for writing (" +
|
||||
std::to_string(GetLastError()) + ")");
|
||||
|
||||
if(not input.close_read_handle())
|
||||
MIGRAPHX_THROW("Error closing STDIN handle for reading (" +
|
||||
std::to_string(GetLastError()) + ")");
|
||||
|
||||
f(input, output);
|
||||
|
||||
if(not input.close_write_handle())
|
||||
MIGRAPHX_THROW("Error closing STDIN handle for writing (" +
|
||||
std::to_string(GetLastError()) + ")");
|
||||
|
||||
WaitForSingleObject(process_info.hProcess, INFINITE);
|
||||
|
||||
DWORD status{};
|
||||
GetExitCodeProcess(process_info.hProcess, &status);
|
||||
|
||||
CloseHandle(process_info.hProcess);
|
||||
|
||||
return static_cast<int>(status);
|
||||
}
|
||||
// cppcheck-suppress catchExceptionByValue
|
||||
catch(DWORD error)
|
||||
{
|
||||
MIGRAPHX_THROW("Error spawning process (" + std::to_string(error) + ")");
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
int exec(const std::string& cmd, const std::string& cwd, const std::string& args,
|
||||
const std::string& envs, HANDLE std_out)
|
||||
{
|
||||
TCHAR buffer[MIGRAPHX_PROCESS_BUFSIZE];
|
||||
return (std_out == nullptr or std_out == INVALID_HANDLE_VALUE)
|
||||
? GetLastError() : exec(cmd, cwd, args, envs,
|
||||
[&](const pipe<direction::input>&, const pipe<direction::output>& out) {
|
||||
for(;;)
|
||||
{
|
||||
auto [more_data, bytes_read] = out.read(buffer, MIGRAPHX_PROCESS_BUFSIZE);
|
||||
if(bytes_read == 0)
|
||||
break;
|
||||
if(WriteFile(std_out, buffer, bytes_read, nullptr, nullptr) == FALSE)
|
||||
break;
|
||||
if(not more_data)
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int exec(const std::string& cmd, const std::string& cwd, const std::string& args,
|
||||
const std::string& envs, std::function<void(process::writer)> std_in)
|
||||
{
|
||||
return exec(cmd, cwd, args, envs,
|
||||
[&](const pipe<direction::input>& input, const pipe<direction::output>&) {
|
||||
std_in([&](const char* buffer, std::size_t n) { input.write(buffer, n); });
|
||||
});
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
#endif
|
||||
|
||||
struct process_impl
|
||||
{
|
||||
std::string args{};
|
||||
std::string envs{};
|
||||
std::string command{};
|
||||
fs::path cwd{};
|
||||
|
||||
std::string get_command() const
|
||||
{
|
||||
std::string result;
|
||||
if(not cwd.empty())
|
||||
result += "cd " + cwd.string() + "; ";
|
||||
if(not envs.empty())
|
||||
result += envs + " ";
|
||||
result += command;
|
||||
if(not args.empty())
|
||||
result += " " + args;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
void check_exec(Ts&&... xs) const
|
||||
{
|
||||
int ec = migraphx::exec(std::forward<Ts>(xs)...);
|
||||
if(ec != 0)
|
||||
MIGRAPHX_THROW("Command " + get_command() + " exited with status " +
|
||||
std::to_string(ec));
|
||||
}
|
||||
};
|
||||
|
||||
process::process(const std::string& cmd, const std::vector<std::string>& args)
|
||||
: impl(std::make_unique<process_impl>())
|
||||
{
|
||||
impl->command = cmd;
|
||||
if(not args.empty())
|
||||
impl->args = join_strings(args, " ");
|
||||
}
|
||||
|
||||
process::process(process&&) noexcept = default;
|
||||
|
||||
process& process::operator=(process rhs)
|
||||
{
|
||||
std::swap(impl, rhs.impl);
|
||||
return *this;
|
||||
}
|
||||
|
||||
process::~process() noexcept = default;
|
||||
|
||||
process& process::cwd(const fs::path& p)
|
||||
{
|
||||
impl->cwd = p;
|
||||
return *this;
|
||||
}
|
||||
|
||||
process& process::env(const std::vector<std::string>& envs)
|
||||
{
|
||||
if(not envs.empty())
|
||||
{
|
||||
impl->envs = join_strings(envs, " ");
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void process::read(const writer& output) const
|
||||
{
|
||||
#ifdef _WIN32
|
||||
// clang-format off
|
||||
constexpr std::string_view filename = "stdout";
|
||||
auto tmp = tmp_dir{};
|
||||
HANDLE handle = CreateFile((tmp.path / filename).string().c_str(),
|
||||
GENERIC_READ | GENERIC_WRITE,
|
||||
0,
|
||||
nullptr,
|
||||
CREATE_ALWAYS,
|
||||
FILE_ATTRIBUTE_NORMAL,
|
||||
nullptr);
|
||||
impl->check_exec(impl->command, impl->cwd.string(), impl->args, impl->envs,
|
||||
handle == nullptr or handle == INVALID_HANDLE_VALUE ?
|
||||
GetStdHandle(STD_OUTPUT_HANDLE) : handle);
|
||||
CloseHandle(handle);
|
||||
handle = CreateFile((tmp.path / filename).string().c_str(),
|
||||
GENERIC_READ | GENERIC_WRITE,
|
||||
0,
|
||||
nullptr,
|
||||
OPEN_EXISTING,
|
||||
FILE_ATTRIBUTE_NORMAL,
|
||||
nullptr);
|
||||
if(handle == nullptr or handle == INVALID_HANDLE_VALUE)
|
||||
MIGRAPHX_THROW("Unable to open file: " + (tmp.path / filename));
|
||||
auto size = GetFileSize(handle, nullptr);
|
||||
std::string result(size, '\0');
|
||||
if(ReadFile(handle, result.data(), size, nullptr, nullptr) == FALSE)
|
||||
MIGRAPHX_THROW("Failed reading file: " + (tmp.path / filename));
|
||||
CloseHandle(handle);
|
||||
// clang-format on
|
||||
#else
|
||||
std::stringstream ss;
|
||||
impl->check_exec(impl->get_command(), redirect_to(ss));
|
||||
auto result = ss.str();
|
||||
#endif
|
||||
output(result.data(), result.size());
|
||||
}
|
||||
|
||||
void process::exec()
|
||||
{
|
||||
#ifndef _WIN32
|
||||
impl->check_exec(impl->get_command(), redirect_to(std::cout));
|
||||
#else
|
||||
// clang-format off
|
||||
impl->check_exec(impl->command, impl->cwd.string(), impl->args, impl->envs,
|
||||
GetStdHandle(STD_OUTPUT_HANDLE));
|
||||
// clang-format on
|
||||
#endif
|
||||
}
|
||||
|
||||
void process::write(std::function<void(writer)> pipe_in)
|
||||
{
|
||||
#ifndef _WIN32
|
||||
impl->check_exec(impl->get_command(), std::move(pipe_in));
|
||||
#else
|
||||
// clang-format off
|
||||
impl->check_exec(impl->command, impl->cwd.string(),
|
||||
impl->args, impl->envs, std::move(pipe_in));
|
||||
// clang-format on
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
1331
docker/rocm/migraphx/program.cpp
Normal file
1331
docker/rocm/migraphx/program.cpp
Normal file
File diff suppressed because it is too large
Load Diff
55
docker/rocm/migraphx/promote_literals.cpp
Normal file
55
docker/rocm/migraphx/promote_literals.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/promote_literals.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void promote_literals::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
module& m = mpm.get_module();
|
||||
module_ref root_module = mpm.get_root_module();
|
||||
if(m == *root_module)
|
||||
return;
|
||||
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() == "@literal")
|
||||
{
|
||||
auto new_lit = root_module->add_literal(ins->get_literal());
|
||||
auto ins_outputs = ins->outputs();
|
||||
for(auto out_ins : ins_outputs)
|
||||
{
|
||||
out_ins->replace_argument(out_ins, ins, new_lit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
138
docker/rocm/migraphx/propagate_constant.cpp
Normal file
138
docker/rocm/migraphx/propagate_constant.cpp
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/propagate_constant.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/literal.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/simple_par_for.hpp>
|
||||
#include <migraphx/env.hpp>
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
|
||||
|
||||
bool skip_propagate(instruction_ref ins)
|
||||
{
|
||||
if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name()))
|
||||
return skip_propagate(ins->inputs().front());
|
||||
if(ins->name() == "unpack_int4")
|
||||
return true;
|
||||
auto&& s = ins->get_shape();
|
||||
if(s.broadcasted() and s.element_space() < s.elements())
|
||||
return true;
|
||||
auto alias = instruction::get_output_alias(ins, true);
|
||||
if(alias != ins)
|
||||
return skip_propagate(alias);
|
||||
if(ins->is_undefined())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_const_ins(instruction_ref ins, const std::unordered_set<std::string>& skip_ops)
|
||||
{
|
||||
return ins->can_eval() and not skip_propagate(ins) and
|
||||
skip_ops.find(ins->name()) == skip_ops.end();
|
||||
}
|
||||
|
||||
argument as_packed(const argument& c)
|
||||
{
|
||||
if(c.get_shape().packed())
|
||||
return c;
|
||||
auto s = c.get_shape().with_lens(c.get_shape().lens());
|
||||
argument result;
|
||||
c.visit([&](auto x) { result = literal{s, x.begin(), x.end()}.get_argument(); });
|
||||
return result;
|
||||
}
|
||||
|
||||
void propagate_constant::apply(module& m) const
|
||||
{
|
||||
std::unordered_set<instruction_ref> const_instrs;
|
||||
auto last = std::prev(m.end());
|
||||
|
||||
// Find instructions that can be evaluated to a literal
|
||||
for(auto i : iterator_for(m))
|
||||
{
|
||||
const bool is_const = is_const_ins(i, skip_ops);
|
||||
if(is_const and i != last)
|
||||
continue;
|
||||
|
||||
if(i == last and is_const)
|
||||
{
|
||||
const_instrs.insert(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::copy_if(i->inputs().begin(),
|
||||
i->inputs().end(),
|
||||
std::inserter(const_instrs, const_instrs.begin()),
|
||||
[&](const instruction_ref ins) {
|
||||
return is_const_ins(ins, skip_ops) and ins->name() != "@literal";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Compute literals in parallel
|
||||
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
|
||||
std::vector<argument> literals(const_instrs_vec.size());
|
||||
std::size_t grainsize = 1;
|
||||
#if !MIGRAPHX_HAS_EXECUTORS
|
||||
std::size_t n = std::max<std::size_t>(2048 / std::thread::hardware_concurrency(), 1);
|
||||
grainsize = const_instrs_vec.size() / n;
|
||||
#endif
|
||||
simple_par_for(const_instrs_vec.size(), grainsize, [&](const auto i) {
|
||||
literals[i] = as_packed(const_instrs_vec[i]->eval());
|
||||
});
|
||||
|
||||
// Replace instructions in m
|
||||
for(size_t i = 0; i < const_instrs_vec.size(); i++)
|
||||
{
|
||||
if(not literals[i].empty())
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
|
||||
{
|
||||
std::cout << "Constant replace: " << std::endl;
|
||||
std::vector<instruction_ref> inss;
|
||||
fix([&](auto self, auto ins) {
|
||||
if(contains(inss, ins))
|
||||
return;
|
||||
for(auto input : ins->inputs())
|
||||
self(input);
|
||||
inss.push_back(ins);
|
||||
})(const_instrs_vec[i]);
|
||||
m.debug_print(inss);
|
||||
}
|
||||
assert(literals[i].get_shape().lens() == const_instrs_vec[i]->get_shape().lens());
|
||||
assert(literals[i].get_shape().bytes() <= const_instrs_vec[i]->get_shape().bytes());
|
||||
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
|
||||
m.replace_instruction(const_instrs_vec[i], l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
211
docker/rocm/migraphx/quantization.cpp
Normal file
211
docker/rocm/migraphx/quantization.cpp
Normal file
@ -0,0 +1,211 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/float_equal.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/quantization.hpp>
|
||||
#include <migraphx/truncate_float.hpp>
|
||||
#include <migraphx/quantize_8bits.hpp>
|
||||
#include <migraphx/quantize_int4.hpp>
|
||||
#include <migraphx/simplify_reshapes.hpp>
|
||||
#include <migraphx/simplify_qdq.hpp>
|
||||
#include <migraphx/eliminate_common_subexpression.hpp>
|
||||
#include <migraphx/optimize_module.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/op/capture.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/target.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/normalize_ops.hpp>
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_QUANTIZATION)
|
||||
|
||||
tracer quant_tracer()
|
||||
{
|
||||
if(enabled(MIGRAPHX_TRACE_QUANTIZATION{}))
|
||||
return tracer{std::cout};
|
||||
|
||||
return tracer{};
|
||||
};
|
||||
|
||||
// This function is to convert any instructions specified in the input
|
||||
// from double or float to float16 by inserting a convert operator.
|
||||
// For the conversion, there could be cases of overflowing or underflowing, but it
|
||||
// is uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
|
||||
// avoid loss of precision.
|
||||
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
|
||||
{
|
||||
run_passes(prog,
|
||||
{normalize_ops{},
|
||||
optimize_module{{"quantizelinear", "dequantizelinear"}},
|
||||
truncate_float_pass{ins_names, shape::half_type},
|
||||
optimize_module{{"quantizelinear", "dequantizelinear"}}},
|
||||
quant_tracer());
|
||||
}
|
||||
|
||||
void quantize_bf16(program& prog, const std::vector<std::string>& ins_names)
|
||||
{
|
||||
run_passes(prog,
|
||||
{normalize_ops{},
|
||||
optimize_module{{"quantizelinear", "dequantizelinear"}},
|
||||
truncate_float_pass{ins_names, shape::bf16_type},
|
||||
optimize_module{{"quantizelinear", "dequantizelinear"}}},
|
||||
quant_tracer());
|
||||
}
|
||||
|
||||
void quantize_8bits(program& prog,
|
||||
const target& t,
|
||||
shape::type_t precision,
|
||||
const std::vector<parameter_map>& calibration,
|
||||
const std::unordered_set<std::string>& ins_names)
|
||||
{
|
||||
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
|
||||
// avoid loss of precision.
|
||||
run_passes(prog, {normalize_ops{}, optimize_module{}}, quant_tracer());
|
||||
|
||||
std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
|
||||
std::make_shared<std::vector<std::pair<float, float>>>();
|
||||
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
|
||||
std::map<shape::type_t, float> type_ranges = {{shape::type_t::int8_type, 127.0},
|
||||
{shape::type_t::fp8e4m3fnuz_type, 240.0},
|
||||
{shape::type_t::fp8e4m3fn_type, 448.0}};
|
||||
float quantized_range = type_ranges.at(precision);
|
||||
auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
|
||||
std::pair<float, float> param_pair{64.0f, 0.0f};
|
||||
// scale and shift is need for only int8 type, and we do not
|
||||
// consider shift, so set shift to 0
|
||||
std::vector<float> vec_val;
|
||||
argument arg = t.copy_from(args.front());
|
||||
arg.visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
|
||||
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
|
||||
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
|
||||
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
|
||||
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
|
||||
// if all values are 0, no need to do scaling
|
||||
if(float_equal(max_abs_vals->at(ins_index), 0.0f))
|
||||
{
|
||||
param_pair.first = 1.0f;
|
||||
}
|
||||
else
|
||||
{
|
||||
param_pair.first = quantized_range / max_abs_vals->at(ins_index);
|
||||
}
|
||||
quant_8bit_params->at(ins_index) = param_pair;
|
||||
};
|
||||
|
||||
// pass to add capture argument op
|
||||
std::size_t param_num = 0;
|
||||
run_passes(
|
||||
prog, {capture_arguments_pass{ins_names, calc_quant_params, ¶m_num}}, quant_tracer());
|
||||
quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
|
||||
max_abs_vals->resize(param_num, 0.0f);
|
||||
|
||||
// use the calibration data to compute the quantization scale
|
||||
auto capture_prog = prog;
|
||||
capture_prog.compile(t);
|
||||
|
||||
// use all calibration data to run the program to calculate the
|
||||
// quantization scale and shift
|
||||
for(auto&& arg : calibration)
|
||||
{
|
||||
parameter_map m;
|
||||
for(auto&& x : capture_prog.get_parameter_shapes())
|
||||
{
|
||||
if(arg.count(x.first) > 0)
|
||||
{
|
||||
assert(x.second == arg.at(x.first).get_shape());
|
||||
m[x.first] = t.copy_to(arg.at(x.first));
|
||||
}
|
||||
else
|
||||
{
|
||||
m[x.first] = t.allocate(x.second);
|
||||
}
|
||||
}
|
||||
capture_prog.eval(m);
|
||||
}
|
||||
|
||||
// print the quantization parameters in only the main module
|
||||
if(enabled(MIGRAPHX_8BITS_QUANTIZATION_PARAMS{}))
|
||||
{
|
||||
for(std::size_t i = 0; i < quant_8bit_params->size(); ++i)
|
||||
{
|
||||
auto param = quant_8bit_params->at(i);
|
||||
std::cout << "ins_index = " << i << ", scale = " << param.first
|
||||
<< ", shift = " << param.second << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
run_passes(prog,
|
||||
{quantize_8bits_pass{precision, *quant_8bit_params}, dead_code_elimination{}},
|
||||
quant_tracer());
|
||||
}
|
||||
|
||||
void quantize_int8(program& prog,
|
||||
const target& t,
|
||||
const std::vector<parameter_map>& calibration,
|
||||
const std::unordered_set<std::string>& ins_names)
|
||||
{
|
||||
std::unordered_set<std::string> op_names = {"convolution", "dot"};
|
||||
if(op_names != ins_names)
|
||||
{
|
||||
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
|
||||
}
|
||||
quantize_8bits(prog, t, shape::int8_type, calibration, ins_names);
|
||||
}
|
||||
|
||||
void quantize_int4_weights(program& prog)
|
||||
{
|
||||
run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}}, quant_tracer());
|
||||
}
|
||||
|
||||
void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
|
||||
{
|
||||
std::unordered_set<std::string> supported_ins_names;
|
||||
auto* mm = prog.get_main_module();
|
||||
for(auto ins : iterator_for(*mm))
|
||||
{
|
||||
if(ins->name() == "convert")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(not starts_with(ins->name(), "@"))
|
||||
{
|
||||
supported_ins_names.insert(ins->name());
|
||||
}
|
||||
}
|
||||
quantize_8bits(prog, t, shape::fp8e4m3fn_type, calibration, supported_ins_names);
|
||||
}
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
117
docker/rocm/migraphx/quantize_8bits.cpp
Normal file
117
docker/rocm/migraphx/quantize_8bits.cpp
Normal file
@ -0,0 +1,117 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/operation.hpp>
|
||||
#include <migraphx/float_equal.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/quantization.hpp>
|
||||
#include <migraphx/quantize_8bits.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/op/capture.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/target.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static std::vector<shape::type_t>& get_quantizable_type()
|
||||
{
|
||||
static std::vector<shape::type_t> quantable_types = {
|
||||
shape::float_type, shape::double_type, shape::half_type};
|
||||
return quantable_types;
|
||||
}
|
||||
|
||||
void quantize_8bits_pass::apply(module& m) const // NOLINT
|
||||
{
|
||||
const auto& quantizable_types = get_quantizable_type();
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "capture")
|
||||
continue;
|
||||
|
||||
auto op_val = ins->get_operator().to_value();
|
||||
assert(op_val.contains("ins_index"));
|
||||
|
||||
auto param_index = op_val.at("ins_index").to<std::size_t>();
|
||||
auto param = quant_params[param_index];
|
||||
|
||||
auto input = ins->inputs().front();
|
||||
auto s = input->get_shape();
|
||||
if(contains(quantizable_types, s.type()) and s.type() != precision)
|
||||
{
|
||||
auto zero_point =
|
||||
m.add_literal(migraphx::literal{migraphx::shape{precision}, {param.second}});
|
||||
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
|
||||
const auto& lens = s.lens();
|
||||
scale =
|
||||
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
|
||||
zero_point = m.insert_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
|
||||
auto q_in =
|
||||
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
|
||||
auto dq_in =
|
||||
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
|
||||
m.replace_instruction(ins, dq_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void capture_arguments_pass::apply(module& m) const // NOLINT
|
||||
{
|
||||
assert(param_index != nullptr);
|
||||
const auto& quantizable_types = get_quantizable_type();
|
||||
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if((not contains(ins_names, ins->name())) or (ins->name() == "convert"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto inputs = ins->inputs();
|
||||
std::vector<instruction_ref> new_args;
|
||||
for(auto input : inputs)
|
||||
{
|
||||
if(contains(quantizable_types, input->get_shape().type()))
|
||||
{
|
||||
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
|
||||
new_args.push_back(new_in);
|
||||
}
|
||||
else
|
||||
{
|
||||
new_args.push_back(input);
|
||||
}
|
||||
}
|
||||
m.replace_instruction(ins, ins->get_operator(), new_args);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
108
docker/rocm/migraphx/quantize_int4.cpp
Normal file
108
docker/rocm/migraphx/quantize_int4.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/float_equal.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/quantize_int4.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/target.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static void int4_quantize_module(module& m)
|
||||
{
|
||||
std::vector<std::string> int4_instrs{"dot", "convolution"};
|
||||
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(not(contains(int4_instrs, ins->name())))
|
||||
continue;
|
||||
|
||||
if(ins->inputs().empty())
|
||||
continue;
|
||||
|
||||
auto s = ins->get_shape();
|
||||
|
||||
auto mod_inputs = ins->module_inputs();
|
||||
|
||||
// Convert each of the inputs that are fp32 or fp16 to int4
|
||||
auto inputs = ins->inputs();
|
||||
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto inp) {
|
||||
auto sh = inp->get_shape();
|
||||
if(sh.broadcasted())
|
||||
return inp;
|
||||
auto input_type = sh.type();
|
||||
if(input_type != shape::float_type and input_type != shape::half_type)
|
||||
return inp;
|
||||
auto lens = sh.lens();
|
||||
if(lens[lens.size() - 1] % 2)
|
||||
return inp; // even sized dimensions to pack
|
||||
|
||||
if(not inp->can_eval())
|
||||
return inp;
|
||||
|
||||
std::vector<float> val;
|
||||
inp->eval().visit([&](auto in_data) { val.assign(in_data.begin(), in_data.end()); });
|
||||
|
||||
auto [min, max] = std::minmax_element(val.begin(), val.end());
|
||||
*min = *min > 0 ? 0 : *min;
|
||||
*max = *max < 0 ? 0 : *max;
|
||||
float fscale4 = (*max - *min) / 15; // INT4 range is [0-15]
|
||||
int zp4 = float_equal(fscale4, 0) ? 0 : std::round(-*min / fscale4);
|
||||
|
||||
auto scale = m.add_literal(literal({s.type()}, {fscale4}));
|
||||
scale =
|
||||
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
|
||||
auto zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
|
||||
zp = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), zp);
|
||||
auto q_in = m.insert_instruction(ins, make_op("quantizelinear"), inp, scale, zp);
|
||||
|
||||
auto pk = m.insert_instruction(ins, make_op("pack_int4", {{"axis", -1}}), q_in);
|
||||
auto unpk = m.insert_instruction(ins, make_op("unpack_int4", {{"axis", -1}}), pk);
|
||||
|
||||
auto dq_scale = m.add_literal(literal({s.type()}, {fscale4}));
|
||||
dq_scale = m.insert_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
|
||||
|
||||
auto dq_zp = m.add_literal(literal{{shape::uint8_type}, {zp4}});
|
||||
dq_zp =
|
||||
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), dq_zp);
|
||||
|
||||
return m.insert_instruction(ins, make_op("dequantizelinear"), unpk, dq_scale, dq_zp);
|
||||
});
|
||||
|
||||
auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
|
||||
m.replace_instruction(ins, converted_ins);
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_int4_pass::apply(module& m) const { int4_quantize_module(m); }
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
153
docker/rocm/migraphx/reduce_dims.cpp
Normal file
153
docker/rocm/migraphx/reduce_dims.cpp
Normal file
@ -0,0 +1,153 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/reduce_dims.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
|
||||
{
|
||||
std::vector<std::size_t> new_lens;
|
||||
for(const auto& s : shapes)
|
||||
{
|
||||
assert(n < s.lens().size());
|
||||
if((n + 1) >= s.lens().size())
|
||||
return false;
|
||||
auto astride = s.strides()[n];
|
||||
auto alen = s.lens()[n];
|
||||
auto bstride = s.strides()[n + 1];
|
||||
auto blen = s.lens()[n + 1];
|
||||
|
||||
if(astride == bstride * blen or alen == 1)
|
||||
new_lens.push_back(alen * blen);
|
||||
}
|
||||
if(new_lens.size() != shapes.size())
|
||||
return false;
|
||||
std::size_t i = 0;
|
||||
for(auto& s : shapes)
|
||||
{
|
||||
auto lens = s.lens();
|
||||
auto strides = s.strides();
|
||||
lens.erase(lens.begin() + n);
|
||||
strides.erase(strides.begin() + n);
|
||||
lens[n] = new_lens[i];
|
||||
s = shape{s.type(), lens, strides};
|
||||
i++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void reduce_dim1(std::vector<shape>& shapes)
|
||||
{
|
||||
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
|
||||
return s.lens().size() < 2 or s.lens().back() != 1;
|
||||
}))
|
||||
return;
|
||||
for(auto& s : shapes)
|
||||
{
|
||||
auto lens = s.lens();
|
||||
auto strides = s.strides();
|
||||
lens.pop_back();
|
||||
strides.pop_back();
|
||||
s = shape{s.type(), lens, strides};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
|
||||
{
|
||||
while(reduce_dim(shapes, n) and n < shapes.size())
|
||||
{
|
||||
(void)n;
|
||||
}
|
||||
return n + 1;
|
||||
}
|
||||
void reduce_dim_all(std::vector<shape>& shapes)
|
||||
{
|
||||
std::size_t n = 0;
|
||||
while(n < shapes.front().lens().size() - 1)
|
||||
n = reduce_dim_all(shapes, n);
|
||||
reduce_dim1(shapes);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
|
||||
{
|
||||
return std::accumulate(
|
||||
shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
|
||||
std::vector<std::size_t> result;
|
||||
const auto* x = &s.lens();
|
||||
const auto* y = &lens;
|
||||
if(x->size() > y->size())
|
||||
std::swap(x, y);
|
||||
std::transform(
|
||||
x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
|
||||
return std::max(a, b);
|
||||
});
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
|
||||
{
|
||||
assert(s.lens().size() == lens.size());
|
||||
std::vector<std::size_t> rstrides(lens.size());
|
||||
std::size_t stride = 1;
|
||||
for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
|
||||
{
|
||||
if(lens[i] == s.lens()[i])
|
||||
{
|
||||
rstrides[i] = stride;
|
||||
stride *= lens[i];
|
||||
}
|
||||
else if(lens[i] != 1 and s.lens()[i] != 1)
|
||||
{
|
||||
return shape{};
|
||||
}
|
||||
}
|
||||
return shape{s.type(), lens, rstrides};
|
||||
}
|
||||
|
||||
std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
|
||||
{
|
||||
if(shapes.empty())
|
||||
return {};
|
||||
auto result = shapes;
|
||||
auto base = base_lens(shapes);
|
||||
for(auto&& s : shapes)
|
||||
{
|
||||
if(s.lens().size() != base.size())
|
||||
return shapes;
|
||||
if(s.lens() == base)
|
||||
continue;
|
||||
auto mshape = mask_shape(s, base);
|
||||
if(mshape.lens().size() != base.size())
|
||||
return shapes;
|
||||
result.push_back(mshape);
|
||||
}
|
||||
reduce_dim_all(result);
|
||||
result.erase(result.begin() + shapes.size(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
65
docker/rocm/migraphx/register_op.cpp
Normal file
65
docker/rocm/migraphx/register_op.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/register_op.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::unordered_map<std::string, operation>& op_map()
|
||||
{
|
||||
static std::unordered_map<std::string, operation> m; // NOLINT
|
||||
return m;
|
||||
}
|
||||
|
||||
void register_op_init() { (void)op_map(); }
|
||||
|
||||
void register_op(const operation& op) { op_map()[op.name()] = op; }
|
||||
|
||||
void unregister_op(const std::string& op_name)
|
||||
{
|
||||
assert(op_map().count(op_name));
|
||||
op_map().erase(op_name);
|
||||
}
|
||||
|
||||
operation load_op(const std::string& name)
|
||||
{
|
||||
return at(op_map(), name, "Operator not found: " + name);
|
||||
}
|
||||
|
||||
bool has_op(const std::string& name) { return op_map().count(name) == 1; }
|
||||
|
||||
std::vector<std::string> get_operators()
|
||||
{
|
||||
std::vector<std::string> result;
|
||||
std::transform(op_map().begin(), op_map().end(), std::back_inserter(result), [&](auto&& p) {
|
||||
return p.first;
|
||||
});
|
||||
std::sort(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
101
docker/rocm/migraphx/register_target.cpp
Normal file
101
docker/rocm/migraphx/register_target.cpp
Normal file
@ -0,0 +1,101 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <migraphx/register_target.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/dynamic_loader.hpp>
|
||||
#include <migraphx/fileutils.hpp>
|
||||
#include <migraphx/version.h>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void store_target_lib(const dynamic_loader& lib)
|
||||
{
|
||||
static std::vector<dynamic_loader> target_loader;
|
||||
target_loader.emplace_back(lib);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, target>& target_map()
|
||||
{
|
||||
static std::unordered_map<std::string, target> m; // NOLINT
|
||||
return m;
|
||||
}
|
||||
|
||||
void register_target_init() { (void)target_map(); }
|
||||
|
||||
void unregister_target(const std::string& name)
|
||||
{
|
||||
assert(target_map().count(name));
|
||||
target_map().erase(name);
|
||||
}
|
||||
|
||||
void register_target(const target& t) { target_map()[t.name()] = t; }
|
||||
|
||||
target make_target(const std::string& name)
|
||||
{
|
||||
if(not contains(target_map(), name))
|
||||
{
|
||||
std::string so_major_version = "." + std::to_string(MIGRAPHX_SO_MAJOR_VERSION);
|
||||
auto target_name = make_shared_object_filename("migraphx_" + name);
|
||||
|
||||
// Try to load library with so_major_version appended to the name.
|
||||
// If library with so_major_version name is not found,
|
||||
// try loading the library without the so_major_version name appended.
|
||||
// For example, if "libmigraphx_ref.so.2010000" is not found,
|
||||
// try loading "libmigraphx_ref.so".
|
||||
try
|
||||
{
|
||||
// Default to loading shared libraries with
|
||||
// so_major_version appended.
|
||||
store_target_lib(dynamic_loader(target_name + so_major_version));
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
// Load the library without the so_major_version in the name.
|
||||
store_target_lib(dynamic_loader(target_name));
|
||||
}
|
||||
}
|
||||
const auto it = target_map().find(name);
|
||||
if(it == target_map().end())
|
||||
{
|
||||
MIGRAPHX_THROW("Requested target '" + name + "' is not loaded or not supported");
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> get_targets()
|
||||
{
|
||||
std::vector<std::string> result;
|
||||
std::transform(target_map().begin(),
|
||||
target_map().end(),
|
||||
std::back_inserter(result),
|
||||
[&](auto&& p) { return p.first; });
|
||||
std::sort(result.begin(), result.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
123
docker/rocm/migraphx/replace_allocate.cpp
Normal file
123
docker/rocm/migraphx/replace_allocate.cpp
Normal file
@ -0,0 +1,123 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/replace_allocate.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/op/allocate.hpp>
|
||||
#include <map>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
std::unordered_map<instruction_ref, std::string> create_output_names(const module& mod)
|
||||
{
|
||||
std::unordered_map<instruction_ref, std::string> mod_output_names{};
|
||||
auto last = std::prev(mod.end());
|
||||
if(last->name() == "@return")
|
||||
{
|
||||
const auto& prog_outputs = last->inputs();
|
||||
std::vector<instruction_ref> outputs_alias(prog_outputs.size());
|
||||
|
||||
std::transform(prog_outputs.begin(),
|
||||
prog_outputs.end(),
|
||||
outputs_alias.begin(),
|
||||
[](const auto& i) { return instruction::get_output_alias(i); });
|
||||
|
||||
std::size_t index = 0;
|
||||
for(auto ins : outputs_alias)
|
||||
{
|
||||
mod_output_names[ins] = mod.name() + ":#output_" + std::to_string(index++);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ins = instruction::get_output_alias(last);
|
||||
mod_output_names[ins] = "output";
|
||||
}
|
||||
return mod_output_names;
|
||||
}
|
||||
|
||||
void insert_submod_allocations(instruction_ref ins, module& mod, const allocation_model& model)
|
||||
{
|
||||
std::vector<instruction_ref> inputs = ins->inputs();
|
||||
std::vector<module_ref> mod_args = ins->module_inputs();
|
||||
|
||||
std::map<std::string, shape> name_shapes;
|
||||
for(const auto& smod : mod_args)
|
||||
{
|
||||
auto ps = smod->get_parameter_shapes();
|
||||
name_shapes.insert(ps.begin(), ps.end());
|
||||
}
|
||||
|
||||
for(const auto& pn : name_shapes)
|
||||
{
|
||||
const auto& s = pn.second;
|
||||
instruction_ref output{};
|
||||
output = mod.insert_instruction(ins, model.allocate(s));
|
||||
inputs.push_back(output);
|
||||
}
|
||||
|
||||
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
|
||||
}
|
||||
|
||||
void replace_allocate::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
module& m = mpm.get_module();
|
||||
auto mod_output_names = create_output_names(m);
|
||||
bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
auto op = ins->get_operator();
|
||||
auto op_name = op.name();
|
||||
|
||||
// check if allocations from submodules need to be inserted
|
||||
// for now, only the "if" operator is affected
|
||||
if(op_name == "if")
|
||||
{
|
||||
insert_submod_allocations(ins, m, model);
|
||||
continue;
|
||||
}
|
||||
if(op_name != "allocate")
|
||||
continue;
|
||||
|
||||
auto s = ins->get_shape();
|
||||
if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
|
||||
{
|
||||
auto out_param = m.add_parameter(mod_output_names[ins], s);
|
||||
m.replace_instruction(ins, out_param);
|
||||
}
|
||||
else
|
||||
{
|
||||
m.replace_instruction(ins,
|
||||
make_op(model.name(), migraphx::value{{"shape", to_value(s)}}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
110
docker/rocm/migraphx/rewrite_gelu.cpp
Normal file
110
docker/rocm/migraphx/rewrite_gelu.cpp
Normal file
@ -0,0 +1,110 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/rewrite_gelu.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/match/gelu_erf.hpp>
|
||||
#include <migraphx/match/gelu_tanh.hpp>
|
||||
#include <migraphx/common.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* The replacement approximation is equivalent to:
|
||||
* GELU(x) ~= 0.5 * x * ( 1 + tanh( sqrt(2/M_PI) * (x + 0.044715 * x^3)))
|
||||
* You can rearrange to the form used in this by recognizing that
|
||||
* 1 + tanh(x) = (2) / (1 + exp(-2 * x)).
|
||||
* The fitting constant 0.044715 is from
|
||||
* A. Choudhury, ‘A simple approximation to the area under standard normal curve’, Mathematics and
|
||||
* Statistics, vol. 2, no. 3, pp. 147–149, 2014.
|
||||
*/
|
||||
void replace_with_tanh_exp_gelu(module& m, const match::matcher_result& r)
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto x = r.instructions["x"];
|
||||
double const0 = -2. * sqrt(M_2_PI);
|
||||
double const1 = 0.044715 * const0;
|
||||
auto lit0 = m.add_literal(literal{shape{x->get_shape().type()}, {const0}});
|
||||
auto lit1 = m.add_literal(literal{shape{x->get_shape().type()}, {const1}});
|
||||
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0}});
|
||||
auto xb = insert_common_op(m, ins, make_op("mul"), {x, lit1});
|
||||
auto a = m.insert_instruction(ins, make_op("mul"), x, xb);
|
||||
auto b = insert_common_op(m, ins, make_op("add"), {a, lit0});
|
||||
auto u = m.insert_instruction(ins, make_op("mul"), x, b);
|
||||
auto emu = m.insert_instruction(ins, make_op("exp"), u);
|
||||
auto c = insert_common_op(m, ins, make_op("add"), {one, emu});
|
||||
auto y = m.insert_instruction(ins, make_op("div"), x, c);
|
||||
m.replace_instruction(ins, y);
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds erfGELU blocks using the Gaussian distribution and replaces them with the tanh_exp
|
||||
* approximation if the data type is fp16. TODO consider also for fp8 datatype.
|
||||
*/
|
||||
struct find_gelu_erf
|
||||
{
|
||||
auto matcher() const { return match::any_of(match::gelu_erf(), match::gelu_tanh()); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto x = r.instructions["x"];
|
||||
auto input_type = x->get_shape().type();
|
||||
std::set<migraphx::shape::type_t> convert_types = {migraphx::shape::half_type};
|
||||
if(not contains(convert_types, input_type))
|
||||
return;
|
||||
|
||||
replace_with_tanh_exp_gelu(m, r);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Find tanhGELU blocks and replace them with a rearranged version that is less likely to overflow
|
||||
* and is more performant.
|
||||
*/
|
||||
struct find_tanh_fast_gelu
|
||||
{
|
||||
auto matcher() const { return match::gelu_tanh(); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
replace_with_tanh_exp_gelu(m, r);
|
||||
}
|
||||
};
|
||||
|
||||
void rewrite_gelu::apply(module& m) const
|
||||
{
|
||||
if(fast_math)
|
||||
{
|
||||
match::find_matches(m, find_gelu_erf{});
|
||||
}
|
||||
else
|
||||
{
|
||||
match::find_matches(m, find_tanh_fast_gelu{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
68
docker/rocm/migraphx/rewrite_low_precision.cpp
Normal file
68
docker/rocm/migraphx/rewrite_low_precision.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/rewrite_low_precision.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/common.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct find_pow2_div
|
||||
{
|
||||
|
||||
auto pow2() const
|
||||
{
|
||||
auto pow2 = match::name("pow")(match::arg(0)(match::any().bind("x")),
|
||||
match::arg(1)(match::has_value(2.0f)));
|
||||
auto x_square =
|
||||
match::name("mul")(match::same_inputs(), match::arg(0)(match::any().bind("x")));
|
||||
return match::any_of(pow2, x_square);
|
||||
}
|
||||
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("div")(match::arg(0)(pow2()),
|
||||
match::arg(1)(match::is_constant().bind("n")));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto n = r.instructions["n"];
|
||||
auto x = r.instructions["x"];
|
||||
|
||||
if(x->get_shape().type() != migraphx::shape::half_type)
|
||||
return;
|
||||
auto x_div_n = m.insert_instruction(ins, make_op("div"), {x, n});
|
||||
auto mul = m.insert_instruction(ins, make_op("mul"), {x_div_n, x});
|
||||
m.replace_instruction(ins, mul);
|
||||
}
|
||||
};
|
||||
|
||||
void rewrite_low_precision::apply(module& m) const { match::find_matches(m, find_pow2_div{}); }
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
185
docker/rocm/migraphx/rewrite_pooling.cpp
Normal file
185
docker/rocm/migraphx/rewrite_pooling.cpp
Normal file
@ -0,0 +1,185 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/rewrite_pooling.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/op/pooling.hpp>
|
||||
#include <migraphx/op/reshape.hpp>
|
||||
#include <migraphx/op/reduce_mean.hpp>
|
||||
#include <migraphx/op/reduce_max.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
#include <migraphx/program.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static void replace_with_reduce(module& m, instruction_ref ins)
|
||||
{
|
||||
auto&& s = ins->inputs().front()->get_shape();
|
||||
auto&& op = any_cast<op::pooling>(ins->get_operator());
|
||||
auto lens = s.lens();
|
||||
std::vector<std::int64_t> axes(lens.size() - 2);
|
||||
std::iota(axes.begin(), axes.end(), 2);
|
||||
|
||||
// average pooling
|
||||
if(op.mode == op::pooling_mode::average)
|
||||
{
|
||||
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
|
||||
}
|
||||
// max pooling
|
||||
else
|
||||
{
|
||||
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
|
||||
}
|
||||
}
|
||||
|
||||
static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
|
||||
{
|
||||
// TODO remove this when MIOpen supports dilated pooling
|
||||
auto&& s = ins->inputs().front()->get_shape();
|
||||
auto&& op = any_cast<op::pooling>(ins->get_operator());
|
||||
// Ignore N, C axes
|
||||
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};
|
||||
|
||||
bool default_padding =
|
||||
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
|
||||
|
||||
if(not default_padding)
|
||||
{
|
||||
for(size_t idx{0}; idx < op.padding.size(); ++idx)
|
||||
{
|
||||
// We need to pad both ends
|
||||
dims[idx] += op.padding.at(idx) * 2;
|
||||
}
|
||||
}
|
||||
std::vector<size_t> kernels = op.lengths;
|
||||
std::vector<size_t> strides = op.stride;
|
||||
std::vector<size_t> dilations = op.dilations;
|
||||
|
||||
std::vector<std::vector<int>> axis_indices;
|
||||
axis_indices.resize(dims.size());
|
||||
|
||||
for(auto idx{0}; idx < dims.size(); ++idx)
|
||||
{
|
||||
// Only consider if iw fits into the window
|
||||
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
|
||||
stride += strides.at(idx))
|
||||
{
|
||||
for(size_t step{0}; step < kernels.at(idx); ++step)
|
||||
{
|
||||
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto elements = ins->inputs().front();
|
||||
if(not default_padding)
|
||||
{
|
||||
// Pad supports asym, we need to provide both ends
|
||||
std::vector<size_t> padding(2 * s.lens().size(), 0);
|
||||
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
|
||||
for(size_t idx{0}; idx < op.padding.size(); ++idx)
|
||||
{
|
||||
// Ignore N, C axes
|
||||
padding.at(2 + idx) = op.padding.at(idx);
|
||||
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
|
||||
}
|
||||
|
||||
// Default value needed for Max pooling
|
||||
elements = m.insert_instruction(
|
||||
ins,
|
||||
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
|
||||
elements);
|
||||
}
|
||||
|
||||
for(auto idx{0}; idx < axis_indices.size(); ++idx)
|
||||
{
|
||||
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
|
||||
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
|
||||
elements = m.insert_instruction(
|
||||
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
|
||||
}
|
||||
|
||||
// Ignore padding
|
||||
std::vector<size_t> new_padding(kernels.size(), 0);
|
||||
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
|
||||
// We need to skip them to not overlap
|
||||
std::vector<size_t> new_strides(kernels);
|
||||
// Ignore dilations
|
||||
std::vector<size_t> new_dilations(kernels.size(), 1);
|
||||
m.replace_instruction(ins,
|
||||
make_op("pooling",
|
||||
{{"mode", op.mode},
|
||||
{"padding", new_padding},
|
||||
{"stride", new_strides},
|
||||
{"lengths", kernels},
|
||||
{"dilations", new_dilations}}),
|
||||
elements);
|
||||
}
|
||||
|
||||
void rewrite_pooling::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "pooling")
|
||||
continue;
|
||||
if(ins->inputs().empty())
|
||||
continue;
|
||||
auto&& s = ins->inputs().front()->get_shape();
|
||||
auto&& op = any_cast<op::pooling>(ins->get_operator());
|
||||
bool same_kernel_as_shape = std::equal(
|
||||
s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
|
||||
bool default_strides =
|
||||
std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
|
||||
bool default_padding =
|
||||
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
|
||||
bool default_dilations =
|
||||
std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
|
||||
if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
|
||||
{
|
||||
replace_with_reduce(m, ins);
|
||||
}
|
||||
else if(not default_dilations)
|
||||
{
|
||||
// Dilated AvgPool with padding is not supported
|
||||
if(not default_padding and op.mode == op::pooling_mode::average)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto size =
|
||||
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
|
||||
// Can't handle too much size because of literal size
|
||||
if(size > 100000)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
replace_dilations_with_gather_pooling(m, ins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
124
docker/rocm/migraphx/rewrite_quantization.cpp
Normal file
124
docker/rocm/migraphx/rewrite_quantization.cpp
Normal file
@ -0,0 +1,124 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/rewrite_quantization.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/tune_axis.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/common.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
|
||||
|
||||
void apply_quantizelinear(module& m, instruction_ref ins)
|
||||
{
|
||||
assert(ins->name() == "quantizelinear");
|
||||
auto x = ins->inputs()[0];
|
||||
auto y_scale = ins->inputs()[1];
|
||||
|
||||
if(x->get_shape().type() != y_scale->get_shape().type())
|
||||
{
|
||||
x = m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
|
||||
}
|
||||
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
|
||||
auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
|
||||
|
||||
if(ins->inputs().size() == 3)
|
||||
{
|
||||
auto zero_point =
|
||||
m.insert_instruction(ins,
|
||||
make_op("convert", {{"target_type", y_scale->get_shape().type()}}),
|
||||
ins->inputs()[2]);
|
||||
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
|
||||
}
|
||||
|
||||
double max_quant = 0;
|
||||
double min_quant = 0;
|
||||
ins->get_shape().visit_type([&](auto qt) {
|
||||
max_quant = qt.max();
|
||||
min_quant = qt.min();
|
||||
});
|
||||
auto s = add_zero_point->get_shape();
|
||||
instruction_ref min_arg;
|
||||
instruction_ref max_arg;
|
||||
|
||||
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
|
||||
{
|
||||
std::vector<double> min_data(s.elements(), min_quant);
|
||||
std::vector<double> max_data(s.elements(), max_quant);
|
||||
min_arg = m.add_literal(literal(s, min_data));
|
||||
max_arg = m.add_literal(literal(s, max_data));
|
||||
}
|
||||
else
|
||||
{
|
||||
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
|
||||
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
|
||||
}
|
||||
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
|
||||
m.replace_instruction(
|
||||
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
|
||||
}
|
||||
|
||||
void apply_dequantizelinear(module& m, instruction_ref ins)
|
||||
{
|
||||
assert(ins->name() == "dequantizelinear");
|
||||
auto x_scale = ins->inputs()[1];
|
||||
auto x = m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[0]);
|
||||
|
||||
if(ins->inputs().size() == 3)
|
||||
{
|
||||
auto x_zero_point =
|
||||
m.insert_instruction(ins,
|
||||
make_op("convert", {{"target_type", x_scale->get_shape().type()}}),
|
||||
ins->inputs()[2]);
|
||||
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
|
||||
}
|
||||
|
||||
m.replace_instruction(ins, make_op("mul"), x, x_scale);
|
||||
}
|
||||
|
||||
void rewrite_quantization::apply(module& m) const
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() == "quantizelinear")
|
||||
{
|
||||
apply_quantizelinear(m, ins);
|
||||
}
|
||||
|
||||
else if(ins->name() == "dequantizelinear")
|
||||
{
|
||||
apply_dequantizelinear(m, ins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
153
docker/rocm/migraphx/rewrite_reduce.cpp
Normal file
153
docker/rocm/migraphx/rewrite_reduce.cpp
Normal file
@ -0,0 +1,153 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#include <migraphx/rewrite_reduce.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/common.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
namespace {
|
||||
struct find_softmax
|
||||
{
|
||||
auto matcher() const { return match::name("softmax"); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto op = ins->get_operator().to_value();
|
||||
auto axis = op["axis"].to<std::int64_t>();
|
||||
|
||||
auto input = ins->inputs().front();
|
||||
auto max = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {axis}}}), input);
|
||||
auto maxb = m.insert_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_lens", input->get_shape().lens()}}), max);
|
||||
auto sub = m.insert_instruction(ins, make_op("sub"), input, maxb);
|
||||
auto exp = m.insert_instruction(ins, make_op("exp"), sub);
|
||||
auto sum = m.insert_instruction(ins, make_op("reduce_sum", {{"axes", {axis}}}), exp);
|
||||
auto sumb = m.insert_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_lens", input->get_shape().lens()}}), sum);
|
||||
m.replace_instruction(ins, make_op("div"), exp, sumb);
|
||||
}
|
||||
};
|
||||
|
||||
struct find_reduce_mean_variance
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto reduce_mean = match::name("reduce_mean");
|
||||
auto skip_broadcasts_mean = match::skip_broadcasts(reduce_mean.bind("mean"));
|
||||
auto x_minus_mean = match::name("sub")(match::arg(0)(match::any().bind("x")),
|
||||
match::arg(1)(skip_broadcasts_mean));
|
||||
auto pow_x_minus_mean =
|
||||
match::name("pow")(match::arg(0)(x_minus_mean), match::arg(1)(match::has_value(2.0f)));
|
||||
auto mul_x_minus_mean =
|
||||
match::name("mul")(match::arg(0)(x_minus_mean), match::arg(1)(x_minus_mean));
|
||||
auto sqdiff = match::name("sqdiff")(
|
||||
match::either_arg(0, 1)(match::any().bind("x"), skip_broadcasts_mean));
|
||||
return reduce_mean(
|
||||
match::arg(0)(match::any_of(pow_x_minus_mean, mul_x_minus_mean, sqdiff)));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto x_ins = r.instructions["x"];
|
||||
auto mean = r.instructions["mean"];
|
||||
|
||||
if(ins->get_operator() != mean->get_operator())
|
||||
return;
|
||||
|
||||
if(mean->inputs().front() != x_ins)
|
||||
return;
|
||||
|
||||
auto x2 = m.insert_instruction(ins, make_op("mul"), x_ins, x_ins);
|
||||
auto mean_x2 = m.insert_instruction(ins, mean->get_operator(), x2);
|
||||
auto mean_x_2 = m.insert_instruction(ins, make_op("mul"), mean, mean);
|
||||
m.replace_instruction(ins, make_op("sub"), mean_x2, mean_x_2);
|
||||
}
|
||||
};
|
||||
|
||||
struct find_reduce_mean
|
||||
{
|
||||
auto matcher() const { return match::name("reduce_mean"); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto op = ins->get_operator().to_value();
|
||||
auto axes = op["axes"].to_vector<std::int64_t>();
|
||||
auto input = ins->inputs().front();
|
||||
|
||||
bool is_integral = false;
|
||||
double max_n = 0;
|
||||
std::size_t size = 0;
|
||||
input->get_shape().visit_type([&](auto t) {
|
||||
is_integral = t.is_integral();
|
||||
max_n = t.max();
|
||||
size = t.size();
|
||||
});
|
||||
|
||||
auto n = input->get_shape().elements() / ins->get_shape().elements();
|
||||
|
||||
// Convert accumulator to float if <= 8bit type or if < 3 bytes and n >= max_n /4
|
||||
if(size == 1 or (n >= max_n / 4 and size < 3))
|
||||
{
|
||||
shape::type_t t = is_integral ? shape::int32_type : shape::float_type;
|
||||
input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input);
|
||||
}
|
||||
|
||||
auto n_literal = m.add_literal(literal{{input->get_shape().type(), {1}}, {n}});
|
||||
if(is_integral)
|
||||
{
|
||||
auto reduce_sum =
|
||||
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), input);
|
||||
auto div = insert_common_op(m, ins, make_op("div"), {reduce_sum, n_literal});
|
||||
m.replace_instruction(
|
||||
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), div);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto new_input = insert_common_op(m, ins, make_op("div"), {input, n_literal});
|
||||
auto reduce_sum =
|
||||
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", axes}}), new_input);
|
||||
m.replace_instruction(
|
||||
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), reduce_sum);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void rewrite_reduce::apply(module& m) const
|
||||
{
|
||||
match::find_matches(m, find_softmax{}, find_reduce_mean_variance{});
|
||||
match::find_matches(m, find_reduce_mean{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
1443
docker/rocm/migraphx/rewrite_rnn.cpp
Normal file
1443
docker/rocm/migraphx/rewrite_rnn.cpp
Normal file
File diff suppressed because it is too large
Load Diff
633
docker/rocm/migraphx/schedule.cpp
Normal file
633
docker/rocm/migraphx/schedule.cpp
Normal file
@ -0,0 +1,633 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/schedule.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/iterator.hpp>
|
||||
#include <migraphx/dfor.hpp>
|
||||
#include <migraphx/simple_par_for.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/dom_info.hpp>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
#include <set>
|
||||
#include <deque>
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SCHEDULE)
|
||||
|
||||
auto get_inputs()
|
||||
{
|
||||
return [](auto i) { return i->inputs(); };
|
||||
}
|
||||
|
||||
auto get_outputs()
|
||||
{
|
||||
return [](auto i) { return i->outputs(); };
|
||||
}
|
||||
|
||||
struct stream_info
|
||||
{
|
||||
std::unordered_map<instruction_ref, std::size_t> ins2stream;
|
||||
std::unordered_map<instruction_ref, std::size_t> weights;
|
||||
std::unordered_map<instruction_ref, std::size_t> iweights;
|
||||
ins_dep_map mod_implicit_deps;
|
||||
|
||||
void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
|
||||
|
||||
void accumulate_weights(instruction_ref last, const schedule_model& model)
|
||||
{
|
||||
fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
|
||||
if(not contains(weights, ins))
|
||||
{
|
||||
std::size_t weight = 0;
|
||||
auto&& op = ins->get_operator();
|
||||
if(not is_context_free(op) and op.name()[0] != '@')
|
||||
weight = model.weight(op);
|
||||
// This will ensure a stream will be assigned to return
|
||||
if(op.name() == "@return")
|
||||
weight = 1;
|
||||
iweights[ins] = weight;
|
||||
auto inputs = ins->inputs();
|
||||
if(contains(mod_implicit_deps, ins))
|
||||
{
|
||||
const auto& impl_deps = mod_implicit_deps.at(ins);
|
||||
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
|
||||
}
|
||||
|
||||
weights[ins] = std::accumulate(
|
||||
inputs.begin(), inputs.end(), weight, [&](std::size_t w, instruction_ref i) {
|
||||
return w + self(i);
|
||||
});
|
||||
}
|
||||
return weights[ins];
|
||||
})(last);
|
||||
}
|
||||
|
||||
template <class Compare>
|
||||
void sort_args_by_weight(std::vector<instruction_ref>& args, Compare compare) const
|
||||
{
|
||||
if(args.size() < 2)
|
||||
return;
|
||||
std::sort(args.begin(), args.end(), by(compare, [this](auto x) {
|
||||
return std::make_tuple(
|
||||
this->weights.at(x), x->inputs().size(), std::addressof(*x));
|
||||
}));
|
||||
}
|
||||
|
||||
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
|
||||
{
|
||||
if(args.size() < 2)
|
||||
{
|
||||
return args.end();
|
||||
}
|
||||
|
||||
const std::size_t min_partition_threshold = 2;
|
||||
sort_args_by_weight(args, std::greater<>{});
|
||||
|
||||
auto it = std::lower_bound(std::next(args.begin()),
|
||||
args.end(),
|
||||
min_partition_threshold,
|
||||
[&](auto i, std::size_t w) { return this->weights[i] > w; });
|
||||
assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
|
||||
assert(it == args.end() or std::prev(it) == args.begin() or
|
||||
this->weights[*std::prev(it)] > min_partition_threshold);
|
||||
return it;
|
||||
}
|
||||
|
||||
struct partition
|
||||
{
|
||||
std::size_t weight = 0;
|
||||
std::vector<instruction_ref> instructions{};
|
||||
|
||||
void add(instruction_ref ins, std::size_t w)
|
||||
{
|
||||
weight += w;
|
||||
instructions.push_back(ins);
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t assign_streams(module& m, std::size_t n)
|
||||
{
|
||||
assert(n > 0);
|
||||
partition critical;
|
||||
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
|
||||
partitions.reserve(weights.size());
|
||||
fix([&](auto self, auto ins, auto& part) {
|
||||
assert(not is_end(ins, m.end()));
|
||||
if(not m.has_instruction(ins))
|
||||
return;
|
||||
if(contains(partitions, ins))
|
||||
return;
|
||||
|
||||
// Add an entry so we know the instruction was visited
|
||||
partitions[ins];
|
||||
part.add(ins, this->iweights[ins]);
|
||||
|
||||
auto args = ins->inputs();
|
||||
auto threshold_it = this->sort_args(args);
|
||||
|
||||
if(not args.empty())
|
||||
{
|
||||
assert(threshold_it != args.begin());
|
||||
self(args.front(), part);
|
||||
for(auto i : range(std::next(args.begin()), threshold_it))
|
||||
{
|
||||
partitions[ins].emplace_back();
|
||||
self(i, partitions[ins].back());
|
||||
}
|
||||
for(auto i : range(threshold_it, args.end()))
|
||||
{
|
||||
self(i, part);
|
||||
}
|
||||
}
|
||||
// Sort instructions
|
||||
m.move_instruction(ins, m.end());
|
||||
})(std::prev(m.end()), critical);
|
||||
|
||||
// Set the critical partition to stream 0
|
||||
set_stream(critical, 0);
|
||||
if(n == 1)
|
||||
{
|
||||
// Assign streams for the other partitions
|
||||
for(auto&& ins_part : partitions)
|
||||
for(auto&& part : ins_part.second)
|
||||
set_stream(part, 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<std::size_t> streams(n - 1);
|
||||
// Assign streams for the other partitions
|
||||
for(auto&& ins_part : partitions)
|
||||
{
|
||||
std::sort(ins_part.second.begin(),
|
||||
ins_part.second.end(),
|
||||
by(std::greater<>{}, [](auto&& x) {
|
||||
return std::make_tuple(x.weight, x.instructions.size());
|
||||
}));
|
||||
for(auto&& part : ins_part.second)
|
||||
{
|
||||
auto stream =
|
||||
std::min_element(streams.begin(), streams.end()) - streams.begin();
|
||||
set_stream(part, stream + 1);
|
||||
streams[stream] += part.weight;
|
||||
}
|
||||
}
|
||||
return 1 + std::count_if(streams.begin(), streams.end(), [](auto x) { return x > 0; });
|
||||
}
|
||||
}
|
||||
|
||||
using weight_ins = std::pair<std::size_t, instruction_ref>;
|
||||
struct compare_weight_ins
|
||||
{
|
||||
bool operator()(const weight_ins& x, const weight_ins& y) const
|
||||
{
|
||||
return std::make_pair(x.first, std::addressof(*x.second)) <
|
||||
std::make_pair(y.first, std::addressof(*y.second));
|
||||
}
|
||||
};
|
||||
|
||||
void sort(module& m, std::size_t)
|
||||
{
|
||||
std::set<weight_ins, compare_weight_ins> children;
|
||||
std::unordered_map<instruction_ref, std::size_t> visited;
|
||||
auto last = std::prev(m.end());
|
||||
auto mw = this->weights.at(last);
|
||||
auto nw = mw / (m.size() + 1);
|
||||
auto add_child = [&](auto ins) {
|
||||
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
|
||||
auto w = x * this->iweights.at(ins);
|
||||
auto& v = visited[ins];
|
||||
auto it = children.find(std::make_pair(v * w, ins));
|
||||
if(it == children.end())
|
||||
{
|
||||
v++;
|
||||
children.insert(std::make_pair(v * w, ins));
|
||||
}
|
||||
};
|
||||
add_child(last);
|
||||
|
||||
while(not children.empty())
|
||||
{
|
||||
// Pop the first element
|
||||
auto top = children.begin()->second;
|
||||
children.erase(children.begin());
|
||||
m.move_instruction(top, m.begin());
|
||||
for(auto ins : top->inputs())
|
||||
{
|
||||
if(not m.has_instruction(ins))
|
||||
continue;
|
||||
add_child(ins);
|
||||
}
|
||||
|
||||
if(contains(mod_implicit_deps, top))
|
||||
{
|
||||
for(auto ins : mod_implicit_deps.at(top))
|
||||
{
|
||||
assert(m.has_instruction(ins));
|
||||
add_child(ins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// move dangling parameter to the front so as not be removed
|
||||
auto ins = std::next(last);
|
||||
while(ins != m.end())
|
||||
{
|
||||
auto next = std::next(ins);
|
||||
if(ins->name() == "@param")
|
||||
{
|
||||
m.move_instruction(ins, m.begin());
|
||||
}
|
||||
ins = next;
|
||||
}
|
||||
}
|
||||
|
||||
void set_stream(const partition& p, std::size_t n)
|
||||
{
|
||||
for(auto ins : p.instructions)
|
||||
if(iweights[ins] > 0)
|
||||
set_stream(ins, n);
|
||||
}
|
||||
|
||||
void set_stream(instruction_ref ins, std::size_t n)
|
||||
{
|
||||
assert(iweights[ins] > 0);
|
||||
ins2stream[ins] = n;
|
||||
}
|
||||
|
||||
std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
|
||||
|
||||
bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); }
|
||||
|
||||
template <class F>
|
||||
bool different(F f, std::size_t stream) const
|
||||
{
|
||||
bool result = false;
|
||||
f([&](auto s) {
|
||||
if(s != stream)
|
||||
{
|
||||
result = true;
|
||||
return false;
|
||||
}
|
||||
// cppcheck-suppress uselessAssignmentArg
|
||||
stream = s;
|
||||
return true;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class F>
|
||||
bool different(F f) const
|
||||
{
|
||||
bool result = false;
|
||||
f([&](auto s) {
|
||||
result = this->different(f, s);
|
||||
return false;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class Selector>
|
||||
auto get_streams_from(instruction_ref start, Selector select) const
|
||||
{
|
||||
return [=](auto f) {
|
||||
return fix<bool>([&](auto self, auto ins) {
|
||||
return all_of(select(ins), [&](auto i) {
|
||||
if(has_stream(i))
|
||||
return f(this->get_stream(i));
|
||||
else
|
||||
return self(i);
|
||||
});
|
||||
})(start);
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_set<std::size_t> get_streams(instruction_ref ins) const
|
||||
{
|
||||
if(has_stream(ins))
|
||||
return {get_stream(ins)};
|
||||
std::unordered_set<std::size_t> result;
|
||||
get_streams_from(ins, get_inputs())([&](auto s) {
|
||||
result.insert(s);
|
||||
return true;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
bool is_merge_point(instruction_ref ins, Ts... xs) const
|
||||
{
|
||||
return different(get_streams_from(ins, get_inputs()), xs...);
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
bool is_split_point(instruction_ref ins, Ts... xs) const
|
||||
{
|
||||
return different(get_streams_from(ins, get_outputs()), xs...);
|
||||
}
|
||||
|
||||
std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
|
||||
{
|
||||
std::vector<instruction_ref> result;
|
||||
std::unordered_map<std::size_t, instruction_ref> m;
|
||||
fix([&](auto self, auto ins) {
|
||||
for(auto i : ins->inputs())
|
||||
{
|
||||
if(iweights.at(i) == 0)
|
||||
{
|
||||
self(i);
|
||||
continue;
|
||||
}
|
||||
auto stream = this->get_stream(i);
|
||||
if(not contains(m, stream))
|
||||
m[stream] = i;
|
||||
else
|
||||
m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) {
|
||||
return std::distance(x, start);
|
||||
}));
|
||||
}
|
||||
})(start);
|
||||
std::transform(
|
||||
m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; });
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
|
||||
find_concurrent_instructions(module& m) const
|
||||
{
|
||||
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
|
||||
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
|
||||
dominator_info di = compute_dominator(m);
|
||||
result.reserve(m.size());
|
||||
merge_from.reserve(m.size());
|
||||
for(auto ins : reverse_iterator_for(m))
|
||||
{
|
||||
for(auto&& arg : ins->outputs())
|
||||
{
|
||||
if(not m.has_instruction(arg))
|
||||
continue;
|
||||
if(is_merge_point(arg))
|
||||
merge_from[ins].insert(arg);
|
||||
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
|
||||
}
|
||||
|
||||
if(is_split_point(ins))
|
||||
{
|
||||
erase_if(merge_from[ins],
|
||||
[&](auto merge) { return di.strictly_dominate(ins, merge); });
|
||||
}
|
||||
|
||||
auto streams = this->get_streams(ins);
|
||||
// Collect concur instructions for each merge point.
|
||||
for(const auto& merge : merge_from[ins])
|
||||
{
|
||||
for(auto stream : streams)
|
||||
{
|
||||
if(result[merge].size() <= stream)
|
||||
result[merge].resize(stream + 1);
|
||||
auto&& r = result[merge][stream];
|
||||
r.push_back(ins);
|
||||
// Copy inputs if they dont have a stream(and are not a builtin and context
|
||||
// free). Inputs without a stream can have a implicit dependency
|
||||
std::copy_if(ins->inputs().begin(),
|
||||
ins->inputs().end(),
|
||||
std::back_inserter(r),
|
||||
[&](auto x) {
|
||||
return not this->has_stream(x) and
|
||||
not is_context_free(x->get_operator()) and
|
||||
x->name().front() != '@';
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
|
||||
get_conflicts(module& m)
|
||||
{
|
||||
|
||||
using conflict_table_type =
|
||||
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
|
||||
conflict_table_type conflict_table;
|
||||
auto concur_ins = this->find_concurrent_instructions(m);
|
||||
|
||||
// Compute an index for each instruction
|
||||
std::unordered_map<instruction_ref, std::size_t> ins2index;
|
||||
std::size_t index_total = 0;
|
||||
for(auto ins : iterator_for(m))
|
||||
ins2index[ins] = index_total++;
|
||||
|
||||
std::vector<conflict_table_type> thread_conflict_tables(
|
||||
std::thread::hardware_concurrency());
|
||||
std::vector<instruction_ref> index_to_ins;
|
||||
index_to_ins.reserve(concur_ins.size());
|
||||
std::transform(concur_ins.begin(),
|
||||
concur_ins.end(),
|
||||
std::back_inserter(index_to_ins),
|
||||
[](auto&& it) { return it.first; });
|
||||
|
||||
simple_par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
|
||||
auto merge_first = index_to_ins[ins_index];
|
||||
assert(concur_ins.count(merge_first) > 0);
|
||||
auto& merge_second = concur_ins.at(merge_first);
|
||||
|
||||
// ensure there are enough elements for different threads
|
||||
assert(tid < thread_conflict_tables.size());
|
||||
auto& thrd_table = thread_conflict_tables.at(tid);
|
||||
|
||||
std::unordered_set<instruction_ref> checked_ins_set;
|
||||
auto range_i = range(merge_second.begin(), std::prev(merge_second.end()));
|
||||
for(auto it_i : iterator_for(range_i))
|
||||
{
|
||||
std::unordered_set<instruction_ref> ins1_set;
|
||||
std::copy_if(it_i->begin(),
|
||||
it_i->end(),
|
||||
std::inserter(ins1_set, ins1_set.end()),
|
||||
[&](auto i) { return not contains(checked_ins_set, i); });
|
||||
checked_ins_set.insert(ins1_set.begin(), ins1_set.end());
|
||||
|
||||
auto range_j = range(std::next(it_i), merge_second.end());
|
||||
std::unordered_set<instruction_ref> ins2_set;
|
||||
for(auto it_j : iterator_for(range_j))
|
||||
{
|
||||
std::copy_if(it_j->begin(),
|
||||
it_j->end(),
|
||||
std::inserter(ins2_set, ins2_set.end()),
|
||||
[&](auto i) { return not contains(checked_ins_set, i); });
|
||||
}
|
||||
|
||||
for(auto ins1 : ins1_set)
|
||||
{
|
||||
auto p1 = ins2index.at(ins1);
|
||||
for(auto ins2 : ins2_set)
|
||||
{
|
||||
if(ins1 == ins2)
|
||||
continue;
|
||||
auto p2 = ins2index.at(ins2);
|
||||
if(p2 > p1)
|
||||
thrd_table[ins2].insert(ins1);
|
||||
else
|
||||
thrd_table[ins1].insert(ins2);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// merge thread_conflict_tables together
|
||||
for(auto& tbl : thread_conflict_tables)
|
||||
{
|
||||
for(auto& it : tbl)
|
||||
{
|
||||
conflict_table[it.first].insert(it.second.begin(), it.second.end());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove instructions from the conflict table of an ealier instruction
|
||||
for(auto&& ip : conflict_table)
|
||||
{
|
||||
auto ins1 = ip.first;
|
||||
for(auto ins2 : ip.second)
|
||||
if(contains(conflict_table[ins2], ins1))
|
||||
conflict_table[ins2].erase(ins1);
|
||||
}
|
||||
|
||||
return conflict_table;
|
||||
}
|
||||
};
|
||||
|
||||
void schedule::apply(module& m) const
|
||||
{
|
||||
if(not enable)
|
||||
return;
|
||||
|
||||
stream_info si;
|
||||
si.calc_implicit_deps(m);
|
||||
auto last = std::prev(m.end());
|
||||
si.accumulate_weights(last, model);
|
||||
auto nstreams = si.assign_streams(m, model.concurrency());
|
||||
si.sort(m, model.concurrency());
|
||||
|
||||
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
|
||||
{
|
||||
m.annotate(std::cout, [&](auto ins) {
|
||||
if(ins->name() == "@param" and not contains(si.weights, ins))
|
||||
return;
|
||||
|
||||
std::cout << ":";
|
||||
std::cout << " weight=" << si.weights.at(ins);
|
||||
std::cout << " input={";
|
||||
si.get_streams_from(ins, get_inputs())([&](auto s) {
|
||||
std::cout << s << ",";
|
||||
return true;
|
||||
});
|
||||
std::cout << "}";
|
||||
if(si.has_stream(ins))
|
||||
std::cout << " stream=" << si.get_stream(ins);
|
||||
});
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
// No concurrency
|
||||
if(nstreams < 2)
|
||||
return;
|
||||
|
||||
// Schedule instructions
|
||||
std::size_t wait_id = 0;
|
||||
std::unordered_map<instruction_ref, std::size_t> ins2wait;
|
||||
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
|
||||
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
|
||||
ins2wait.reserve(m.size());
|
||||
ins2waited.reserve(m.size());
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
// Only schedule instructions that have a stream
|
||||
if(not si.has_stream(ins))
|
||||
continue;
|
||||
assert(si.weights[ins] > 0);
|
||||
// Schedule instruction on the stream
|
||||
auto stream = si.get_stream(ins);
|
||||
assert(stream < model.concurrency());
|
||||
model.sched(m, ins, stream);
|
||||
// Insert wait instructions
|
||||
if(si.is_merge_point(ins, stream))
|
||||
{
|
||||
for(auto i : si.get_recorded_instructions(ins))
|
||||
{
|
||||
if(not si.has_stream(i) or si.get_stream(i) == stream)
|
||||
continue;
|
||||
|
||||
// Create a new event if it hasn't been recorded
|
||||
if(not contains(ins2wait, i))
|
||||
{
|
||||
ins2wait[i] = wait_id;
|
||||
model.record(m, i, wait_id);
|
||||
wait_id++;
|
||||
}
|
||||
auto w = ins2wait.at(i);
|
||||
// If we already waited for the event on this stream then dont
|
||||
// insert another wait event
|
||||
if(not contains(waited_for[stream], w))
|
||||
model.wait(m, ins, w);
|
||||
// Store the event as waited
|
||||
waited_for[stream].insert(w);
|
||||
// Store all wait events that have been waited on prior to the recorded instruction
|
||||
waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end());
|
||||
}
|
||||
}
|
||||
// Store wait events that have already been waited on
|
||||
if(si.is_split_point(ins, stream))
|
||||
{
|
||||
ins2waited[ins] = waited_for[stream];
|
||||
}
|
||||
}
|
||||
|
||||
// Add memory conflicts
|
||||
auto conflict_table = si.get_conflicts(m);
|
||||
for(auto&& ip : conflict_table)
|
||||
{
|
||||
if(ip.second.empty())
|
||||
continue;
|
||||
std::vector<instruction_ref> args;
|
||||
args.push_back(ip.first);
|
||||
args.insert(args.end(), ip.second.begin(), ip.second.end());
|
||||
m.insert_instruction(std::next(ip.first), make_op("identity"), args);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
66
docker/rocm/migraphx/serialize.cpp
Normal file
66
docker/rocm/migraphx/serialize.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/serialize.hpp>
|
||||
#include <migraphx/argument.hpp>
|
||||
#include <migraphx/literal.hpp>
|
||||
#include <migraphx/context.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
template <class RawData>
|
||||
void raw_data_to_value(value& v, const RawData& rd)
|
||||
{
|
||||
value result;
|
||||
result["shape"] = migraphx::to_value(rd.get_shape());
|
||||
if(rd.get_shape().type() == shape::tuple_type)
|
||||
result["sub"] = migraphx::to_value(rd.get_sub_objects());
|
||||
else if(not rd.empty())
|
||||
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
|
||||
v = result;
|
||||
}
|
||||
|
||||
void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); }
|
||||
void migraphx_from_value(const value& v, literal& l)
|
||||
{
|
||||
auto s = migraphx::from_value<shape>(v.at("shape"));
|
||||
l = literal(s, v.at("data").get_binary().data());
|
||||
}
|
||||
|
||||
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
|
||||
void migraphx_from_value(const value& v, argument& a)
|
||||
{
|
||||
if(v.contains("data"))
|
||||
{
|
||||
literal l = migraphx::from_value<literal>(v);
|
||||
a = l.get_argument();
|
||||
}
|
||||
else if(v.contains("sub"))
|
||||
{
|
||||
a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
857
docker/rocm/migraphx/shape.cpp
Normal file
857
docker/rocm/migraphx/shape.cpp
Normal file
@ -0,0 +1,857 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/serialize.hpp>
|
||||
#include <migraphx/permutation.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct shape_impl
|
||||
{
|
||||
static std::shared_ptr<shape_impl> default_shape()
|
||||
{
|
||||
static const std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
|
||||
return result;
|
||||
}
|
||||
|
||||
shape_impl() : m_type(shape::float_type) {}
|
||||
|
||||
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
|
||||
{
|
||||
assert(t != shape::tuple_type);
|
||||
}
|
||||
|
||||
shape_impl(shape::type_t t, std::vector<std::size_t> l)
|
||||
: m_type(t), m_lens(std::move(l)), m_standard(true)
|
||||
{
|
||||
assert(t != shape::tuple_type);
|
||||
this->calculate_strides();
|
||||
}
|
||||
|
||||
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
|
||||
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
|
||||
{
|
||||
assert(t != shape::tuple_type);
|
||||
assert(m_lens.size() == m_strides.size());
|
||||
|
||||
// Calculate standard shape flag for these lens/strides. Strides on size-1
|
||||
// axes are ignored to support an MLIR rule.
|
||||
std::vector<size_t> filtered_strides;
|
||||
for(size_t ind = 0; ind < m_strides.size(); ind++)
|
||||
if(m_lens[ind] != 1)
|
||||
filtered_strides.push_back(m_strides[ind]);
|
||||
m_standard = this->elements() == this->element_space() and not skips() and
|
||||
std::is_sorted(filtered_strides.rbegin(), filtered_strides.rend());
|
||||
}
|
||||
|
||||
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
|
||||
: m_type(t), m_dyn_dims(std::move(dims))
|
||||
{
|
||||
}
|
||||
|
||||
shape_impl(shape::type_t t,
|
||||
std::vector<std::size_t> mins,
|
||||
std::vector<std::size_t> maxes,
|
||||
std::vector<std::set<std::size_t>> optimals_list)
|
||||
: m_type(t)
|
||||
{
|
||||
if(optimals_list.empty())
|
||||
{
|
||||
for(size_t i = 0; i < mins.size(); ++i)
|
||||
{
|
||||
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
|
||||
for(size_t i = 0; i < mins.size(); ++i)
|
||||
{
|
||||
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
|
||||
|
||||
shape::type_t m_type;
|
||||
std::vector<std::size_t> m_lens = {};
|
||||
std::vector<std::size_t> m_strides = {};
|
||||
std::vector<shape> m_shapes = {};
|
||||
bool m_standard = false;
|
||||
|
||||
std::vector<shape::dynamic_dimension> m_dyn_dims = {};
|
||||
|
||||
void calculate_strides()
|
||||
{
|
||||
m_strides.clear();
|
||||
m_strides.resize(m_lens.size(), 0);
|
||||
if(m_strides.empty())
|
||||
return;
|
||||
m_strides.back() = 1;
|
||||
std::partial_sum(m_lens.rbegin(),
|
||||
m_lens.rend() - 1,
|
||||
m_strides.rbegin() + 1,
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
std::size_t element_space() const
|
||||
{
|
||||
if(not m_dyn_dims.empty())
|
||||
{
|
||||
auto maxes = max_lens();
|
||||
std::size_t max_val = std::numeric_limits<std::size_t>::max();
|
||||
|
||||
return std::accumulate(
|
||||
maxes.begin(), maxes.end(), std::size_t{1}, [&](std::size_t x, std::size_t y) {
|
||||
// overflow check and clip
|
||||
if(x != 0 and y > max_val / x)
|
||||
{
|
||||
return max_val;
|
||||
}
|
||||
return x * y;
|
||||
});
|
||||
}
|
||||
|
||||
assert(m_lens.size() == m_strides.size());
|
||||
if(m_lens.empty())
|
||||
return 0;
|
||||
return std::inner_product(m_lens.begin(),
|
||||
m_lens.end(),
|
||||
m_strides.begin(),
|
||||
std::size_t{0},
|
||||
std::plus<std::size_t>{},
|
||||
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
|
||||
1;
|
||||
}
|
||||
|
||||
std::size_t elements() const
|
||||
{
|
||||
if(not m_dyn_dims.empty())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: elements() called on dynamic shape");
|
||||
}
|
||||
|
||||
assert(m_lens.size() == m_strides.size());
|
||||
if(m_lens.empty())
|
||||
return 0;
|
||||
return std::accumulate(
|
||||
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
std::size_t get_index(size_t i) const
|
||||
{
|
||||
std::size_t result = 0;
|
||||
std::size_t s = 1;
|
||||
|
||||
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
|
||||
{
|
||||
std::size_t stride = m_strides[k];
|
||||
std::size_t len = m_lens[k];
|
||||
std::size_t idx = (i % (s * len)) / s;
|
||||
result += stride * idx;
|
||||
s *= len;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::size_t> min_lens() const
|
||||
{
|
||||
std::vector<std::size_t> ret(m_dyn_dims.size());
|
||||
std::transform(m_dyn_dims.cbegin(),
|
||||
m_dyn_dims.cend(),
|
||||
ret.begin(),
|
||||
[](const shape::dynamic_dimension& x) { return x.min; });
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::size_t> max_lens() const
|
||||
{
|
||||
std::vector<std::size_t> ret(m_dyn_dims.size());
|
||||
std::transform(m_dyn_dims.cbegin(),
|
||||
m_dyn_dims.cend(),
|
||||
ret.begin(),
|
||||
[](const shape::dynamic_dimension& x) { return x.max; });
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::set<std::size_t>> opt_lens() const
|
||||
{
|
||||
std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
|
||||
std::transform(m_dyn_dims.cbegin(),
|
||||
m_dyn_dims.cend(),
|
||||
ret.begin(),
|
||||
[](const shape::dynamic_dimension& x) { return x.optimals; });
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Does the shape skip over elements?
|
||||
bool skips() const
|
||||
{
|
||||
assert(m_lens.size() == m_strides.size());
|
||||
if(elements() == 1)
|
||||
return false;
|
||||
return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
|
||||
}
|
||||
|
||||
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
|
||||
};
|
||||
|
||||
std::string shape::to_sizes_string(const std::vector<shape>& shapes)
|
||||
{
|
||||
std::vector<std::string> sizes;
|
||||
std::transform(shapes.begin(), shapes.end(), std::back_inserter(sizes), [&](const shape& s) {
|
||||
std::string r = to_string_range(s.lens(), "x");
|
||||
if(not s.standard())
|
||||
r += ":" + to_string_range(s.strides(), "x");
|
||||
return r;
|
||||
});
|
||||
return join_strings(sizes, ", ");
|
||||
}
|
||||
|
||||
const std::vector<shape::type_t>& shape::types()
|
||||
{
|
||||
static const std::vector<shape::type_t> result = {
|
||||
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
|
||||
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string shape::name(shape::type_t t)
|
||||
{
|
||||
switch(t)
|
||||
{
|
||||
case tuple_type: return "tuple_type";
|
||||
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
|
||||
case x: return #x;
|
||||
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
|
||||
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
|
||||
}
|
||||
MIGRAPHX_THROW("Invalid type");
|
||||
}
|
||||
|
||||
std::string shape::cpp_type(shape::type_t t)
|
||||
{
|
||||
switch(t)
|
||||
{
|
||||
case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
|
||||
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
|
||||
case x: return #t;
|
||||
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
|
||||
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
|
||||
}
|
||||
MIGRAPHX_THROW("Invalid type");
|
||||
}
|
||||
|
||||
bool shape::is_integral(shape::type_t t)
|
||||
{
|
||||
bool result = false;
|
||||
visit(t, [&](auto as) { result = as.is_integral(); });
|
||||
return result;
|
||||
}
|
||||
|
||||
bool shape::is_compatible(const shape& actual, const shape& expected)
|
||||
{
|
||||
// Check subshapes
|
||||
if(expected.type() == shape::tuple_type)
|
||||
return migraphx::equal(actual.sub_shapes(), expected.sub_shapes(), &is_compatible);
|
||||
if(actual == expected)
|
||||
return true;
|
||||
if(actual.type() != expected.type())
|
||||
return false;
|
||||
// Only the expected can be dynamic
|
||||
if(expected.dynamic())
|
||||
return actual.ndim() == expected.ndim();
|
||||
if(actual.dynamic())
|
||||
return false;
|
||||
if(actual.lens() != expected.lens())
|
||||
return false;
|
||||
// Check strides from dimensions that are not 1
|
||||
return all_of(range(actual.lens().size()), [&](auto i) {
|
||||
if(actual.lens()[i] == 1)
|
||||
return true;
|
||||
return actual.strides()[i] == expected.strides()[i];
|
||||
});
|
||||
}
|
||||
|
||||
bool shape::is_unsigned(shape::type_t t)
|
||||
{
|
||||
bool result = false;
|
||||
visit(t, [&](auto as) { result = as.is_unsigned(); });
|
||||
return result;
|
||||
}
|
||||
|
||||
shape::shape() : impl(shape_impl::default_shape()) {}
|
||||
|
||||
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
|
||||
|
||||
shape::shape(type_t t, std::vector<std::size_t> l)
|
||||
: impl(std::make_shared<shape_impl>(t, std::move(l)))
|
||||
{
|
||||
}
|
||||
|
||||
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
|
||||
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
|
||||
{
|
||||
}
|
||||
|
||||
shape::shape(type_t t, std::initializer_list<std::size_t> d)
|
||||
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
|
||||
{
|
||||
}
|
||||
|
||||
shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
|
||||
: impl(std::make_shared<shape_impl>(t, std::move(dims)))
|
||||
{
|
||||
}
|
||||
|
||||
shape::shape(type_t t,
|
||||
std::vector<std::size_t> mins,
|
||||
std::vector<std::size_t> maxes,
|
||||
std::vector<std::set<std::size_t>> optimals_list)
|
||||
: impl(std::make_shared<shape_impl>(
|
||||
t, std::move(mins), std::move(maxes), std::move(optimals_list)))
|
||||
{
|
||||
}
|
||||
|
||||
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
|
||||
|
||||
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
|
||||
|
||||
shape shape::from_permutation(type_t t,
|
||||
const std::vector<std::size_t>& l,
|
||||
const std::vector<int64_t>& perm)
|
||||
{
|
||||
auto new_lens = reorder_dims(l, perm);
|
||||
shape result = reorder_shape({t, new_lens}, invert_permutation(perm));
|
||||
assert(result.lens() == l);
|
||||
return result;
|
||||
}
|
||||
|
||||
shape::type_t shape::type() const { return impl->m_type; }
|
||||
|
||||
const std::vector<std::size_t>& shape::lens() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
|
||||
}
|
||||
return impl->m_lens;
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& shape::strides() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
|
||||
}
|
||||
return impl->m_strides;
|
||||
}
|
||||
|
||||
std::size_t shape::ndim() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
return dyn_dims().size();
|
||||
}
|
||||
return lens().size();
|
||||
}
|
||||
|
||||
std::size_t shape::elements() const { return impl->elements(); }
|
||||
|
||||
std::size_t shape::bytes() const
|
||||
{
|
||||
if(this->sub_shapes().empty())
|
||||
{
|
||||
std::size_t n = 0;
|
||||
this->visit_type([&](auto as) { n = as.size(); });
|
||||
return n * this->element_space();
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::accumulate(this->sub_shapes().begin(),
|
||||
this->sub_shapes().end(),
|
||||
std::size_t{0},
|
||||
[&](auto x, auto y) { return x + y.bytes(); });
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t shape::type_size() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
if(this->sub_shapes().empty())
|
||||
this->visit_type([&](auto as) { n = as.size(); });
|
||||
return n;
|
||||
}
|
||||
|
||||
std::size_t shape::index(std::initializer_list<std::size_t> l) const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
|
||||
}
|
||||
assert(l.size() <= this->lens().size());
|
||||
assert(this->lens().size() == this->strides().size());
|
||||
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t shape::index(const std::vector<std::size_t>& l) const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
|
||||
}
|
||||
assert(l.size() <= this->lens().size());
|
||||
assert(this->lens().size() == this->strides().size());
|
||||
return std::inner_product(l.begin(), l.end(), this->strides().begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t shape::index(std::size_t i) const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
|
||||
}
|
||||
assert(this->lens().size() == this->strides().size());
|
||||
if(this->standard())
|
||||
return i;
|
||||
|
||||
return impl->get_index(i);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> shape::multi(std::size_t idx) const
|
||||
{
|
||||
assert(idx < elements());
|
||||
std::vector<std::size_t> indices(lens().size());
|
||||
multi_copy(idx, indices.data(), indices.data() + lens().size());
|
||||
return indices;
|
||||
}
|
||||
|
||||
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
|
||||
{
|
||||
size_t tidx = idx;
|
||||
(void)end;
|
||||
assert(idx < elements());
|
||||
assert(lens().size() <= (end - start));
|
||||
for(size_t ii = lens().size() - 1; ii > 0; ii--)
|
||||
{
|
||||
*(start + ii) = tidx % lens()[ii];
|
||||
tidx = tidx / lens()[ii];
|
||||
}
|
||||
*start = tidx;
|
||||
}
|
||||
|
||||
bool shape::multi_within_bounds(std::vector<std::size_t> multi) const
|
||||
{
|
||||
assert(this->lens().size() == multi.size());
|
||||
return std::equal(multi.begin(), multi.end(), this->lens().begin(), std::less<>{});
|
||||
}
|
||||
|
||||
bool shape::packed() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return this->sub_shapes().empty() and not impl->skips() and
|
||||
this->elements() == this->element_space();
|
||||
}
|
||||
|
||||
bool shape::transposed() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(this->broadcasted())
|
||||
{
|
||||
// TODO: Use a filter_iterator instead
|
||||
std::vector<std::size_t> s;
|
||||
s.reserve(this->strides().size());
|
||||
std::copy_if(this->strides().begin(),
|
||||
this->strides().end(),
|
||||
std::back_inserter(s),
|
||||
[](std::size_t x) { return x != 0; });
|
||||
return not std::is_sorted(s.rbegin(), s.rend());
|
||||
}
|
||||
else
|
||||
{
|
||||
return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
|
||||
}
|
||||
}
|
||||
|
||||
bool shape::broadcasted() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
assert(this->lens().size() == this->strides().size());
|
||||
return std::any_of(
|
||||
this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
|
||||
}
|
||||
|
||||
bool shape::scalar() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
assert(this->lens().size() == this->strides().size());
|
||||
// if any stride > 0, then accumulate will return false
|
||||
return this->sub_shapes().empty() and
|
||||
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
|
||||
}
|
||||
|
||||
bool shape::standard() const { return impl->m_standard; }
|
||||
|
||||
shape shape::normalize_standard() const
|
||||
{
|
||||
if(this->standard())
|
||||
return {this->type(), this->lens()};
|
||||
else
|
||||
return *this;
|
||||
}
|
||||
|
||||
shape shape::as_standard() const
|
||||
{
|
||||
if(not this->dynamic())
|
||||
return {this->type(), this->lens()};
|
||||
else
|
||||
return *this;
|
||||
}
|
||||
|
||||
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
|
||||
}
|
||||
assert(l.size() == this->lens().size());
|
||||
auto perm = find_permutation(*this);
|
||||
return shape::from_permutation(t, l, perm);
|
||||
}
|
||||
|
||||
shape shape::with_lens(const std::vector<std::size_t>& l) const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: with_lens() called on dynamic shape");
|
||||
}
|
||||
return this->with_lens(this->type(), l);
|
||||
}
|
||||
|
||||
shape shape::with_type(type_t t) const
|
||||
{
|
||||
auto c = impl->copy();
|
||||
c->m_type = t;
|
||||
return {c};
|
||||
}
|
||||
|
||||
shape shape::to_dynamic() const
|
||||
{
|
||||
if(not sub_shapes().empty())
|
||||
{
|
||||
std::vector<shape> subs;
|
||||
std::transform(sub_shapes().cbegin(),
|
||||
sub_shapes().cend(),
|
||||
std::back_inserter(subs),
|
||||
[](auto s) { return s.to_dynamic(); });
|
||||
return shape(subs);
|
||||
}
|
||||
if(this->dynamic())
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
return {type(), lens(), lens(), {}};
|
||||
}
|
||||
|
||||
shape shape::to_static(std::size_t x) const
|
||||
{
|
||||
if(not sub_shapes().empty())
|
||||
{
|
||||
std::vector<shape> subs;
|
||||
std::transform(sub_shapes().cbegin(),
|
||||
sub_shapes().cend(),
|
||||
std::back_inserter(subs),
|
||||
[&](auto s) { return s.to_static(x); });
|
||||
return shape(subs);
|
||||
}
|
||||
if(not this->dynamic())
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
auto static_lens = this->max_lens();
|
||||
std::transform(static_lens.begin(),
|
||||
static_lens.end(),
|
||||
this->dyn_dims().cbegin(),
|
||||
static_lens.begin(),
|
||||
[&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
|
||||
return {type(), static_lens};
|
||||
}
|
||||
|
||||
std::size_t shape::element_space() const { return impl->element_space(); }
|
||||
|
||||
std::string shape::type_string() const { return name(this->type()); }
|
||||
|
||||
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
|
||||
|
||||
bool shape::any_of_dynamic() const
|
||||
{
|
||||
if(this->dynamic())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return std::any_of(this->sub_shapes().cbegin(), this->sub_shapes().cend(), [](auto s) {
|
||||
return s.any_of_dynamic();
|
||||
});
|
||||
}
|
||||
|
||||
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
|
||||
{
|
||||
if(not this->dynamic())
|
||||
{
|
||||
MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
|
||||
}
|
||||
return impl->m_dyn_dims;
|
||||
}
|
||||
|
||||
std::vector<std::size_t> shape::min_lens() const
|
||||
{
|
||||
return this->dynamic() ? impl->min_lens() : this->lens();
|
||||
}
|
||||
|
||||
std::vector<std::size_t> shape::max_lens() const
|
||||
{
|
||||
return this->dynamic() ? impl->max_lens() : this->lens();
|
||||
}
|
||||
|
||||
std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
|
||||
|
||||
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
|
||||
|
||||
bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
|
||||
|
||||
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
|
||||
{
|
||||
this->min += x;
|
||||
this->max += x;
|
||||
std::set<std::size_t> new_optimals;
|
||||
std::transform(this->optimals.begin(),
|
||||
this->optimals.end(),
|
||||
std::inserter(new_optimals, new_optimals.begin()),
|
||||
[&x](const auto& opt) { return (opt + x); });
|
||||
this->optimals = new_optimals;
|
||||
return *this;
|
||||
}
|
||||
|
||||
shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t& x)
|
||||
{
|
||||
assert(this->min >= x);
|
||||
assert(this->max >= x);
|
||||
this->min -= x;
|
||||
this->max -= x;
|
||||
std::set<std::size_t> new_optimals;
|
||||
std::transform(this->optimals.begin(),
|
||||
this->optimals.end(),
|
||||
std::inserter(new_optimals, new_optimals.begin()),
|
||||
[&x](const auto& opt) {
|
||||
assert(opt >= x);
|
||||
return (opt - x);
|
||||
});
|
||||
this->optimals = new_optimals;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
|
||||
{
|
||||
// don't check optimals if both are fixed
|
||||
return (x.min == y.min and x.max == y.max and
|
||||
((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
|
||||
}
|
||||
|
||||
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
|
||||
{
|
||||
return not(x == y);
|
||||
}
|
||||
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
|
||||
{
|
||||
os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
|
||||
return os;
|
||||
}
|
||||
|
||||
bool operator==(const shape::dynamic_dimension& x, const std::size_t& y)
|
||||
{
|
||||
return x.min == y and x.max == y;
|
||||
}
|
||||
bool operator==(const std::size_t& x, const shape::dynamic_dimension& y) { return y == x; }
|
||||
bool operator!=(const shape::dynamic_dimension& x, const std::size_t& y) { return not(x == y); }
|
||||
bool operator!=(const std::size_t& x, const shape::dynamic_dimension& y) { return not(x == y); }
|
||||
|
||||
shape::dynamic_dimension operator+(const shape::dynamic_dimension& x, const std::size_t& y)
|
||||
{
|
||||
auto dd = x;
|
||||
return dd += y;
|
||||
}
|
||||
|
||||
shape::dynamic_dimension operator+(const std::size_t& x, const shape::dynamic_dimension& y)
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
shape::dynamic_dimension operator-(const shape::dynamic_dimension& x, const std::size_t& y)
|
||||
{
|
||||
auto dd = x;
|
||||
return dd -= y;
|
||||
}
|
||||
|
||||
bool operator==(const shape& x, const shape& y)
|
||||
{
|
||||
if(x.dynamic() and y.dynamic())
|
||||
{
|
||||
return x.impl == y.impl or (x.type() == y.type() and x.dyn_dims() == y.dyn_dims() and
|
||||
x.sub_shapes() == y.sub_shapes());
|
||||
}
|
||||
return x.impl == y.impl or
|
||||
(x.dynamic() == y.dynamic() and x.type() == y.type() and x.lens() == y.lens() and
|
||||
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
|
||||
}
|
||||
|
||||
bool operator!=(const shape& x, const shape& y) { return not(x == y); }
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const shape& x)
|
||||
{
|
||||
if(x.sub_shapes().empty())
|
||||
{
|
||||
if(x.dynamic())
|
||||
{
|
||||
os << "dynamic, ";
|
||||
os << x.type_string() << ", ";
|
||||
os << "{" << to_string_range(x.dyn_dims()) << "}";
|
||||
}
|
||||
else
|
||||
{
|
||||
os << x.type_string() << ", ";
|
||||
os << "{" << to_string_range(x.lens()) << "}, ";
|
||||
os << "{" << to_string_range(x.strides()) << "}";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
os << "[" << to_string_range(x.sub_shapes()) << "]";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
shape::type_t shape::parse_type(const std::string& s)
|
||||
{
|
||||
static const std::unordered_map<std::string, shape::type_t> m = {
|
||||
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
|
||||
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
|
||||
tuple_type},
|
||||
{"tuple", tuple_type}};
|
||||
return m.at(s);
|
||||
}
|
||||
|
||||
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
|
||||
|
||||
std::vector<shape> flatten(const std::vector<shape>& shapes)
|
||||
{
|
||||
std::vector<shape> result;
|
||||
for(const auto& s : shapes)
|
||||
{
|
||||
if(s.type() == shape::tuple_type)
|
||||
{
|
||||
auto subs = flatten(s.sub_shapes());
|
||||
result.insert(result.end(), subs.begin(), subs.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
result.push_back(s);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void migraphx_to_value(value& v, const shape& s)
|
||||
{
|
||||
value result;
|
||||
result["type"] = migraphx::to_value(s.type_string());
|
||||
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
|
||||
// avoid calling functions that will throw
|
||||
if(s.dynamic())
|
||||
{
|
||||
result["lens"] = {};
|
||||
result["strides"] = {};
|
||||
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
|
||||
}
|
||||
else
|
||||
{
|
||||
result["lens"] = migraphx::to_value(s.lens());
|
||||
result["strides"] = migraphx::to_value(s.strides());
|
||||
result["dynamic_dimensions"] = {};
|
||||
}
|
||||
v = result;
|
||||
}
|
||||
|
||||
void migraphx_from_value(const value& v, shape& s)
|
||||
{
|
||||
auto t = v.at("type").get_string();
|
||||
if(t == "tuple_type")
|
||||
{
|
||||
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(v.at("dynamic_dimensions").empty())
|
||||
{
|
||||
s = shape{shape::parse_type(t),
|
||||
v.at("lens").to_vector<std::size_t>(),
|
||||
v.at("strides").to_vector<std::size_t>()};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_dd = v.at("dynamic_dimensions");
|
||||
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
|
||||
std::transform(
|
||||
v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
|
||||
return from_value<shape::dynamic_dimension>(x);
|
||||
});
|
||||
|
||||
s = shape{shape::parse_type(t), dyn_dims};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
1263
docker/rocm/migraphx/shape_transform_descriptor.cpp
Normal file
1263
docker/rocm/migraphx/shape_transform_descriptor.cpp
Normal file
File diff suppressed because it is too large
Load Diff
2035
docker/rocm/migraphx/simplify_algebra.cpp
Normal file
2035
docker/rocm/migraphx/simplify_algebra.cpp
Normal file
File diff suppressed because it is too large
Load Diff
722
docker/rocm/migraphx/simplify_dyn_ops.cpp
Normal file
722
docker/rocm/migraphx/simplify_dyn_ops.cpp
Normal file
@ -0,0 +1,722 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/simplify_dyn_ops.hpp>
|
||||
#include <migraphx/op/slice.hpp>
|
||||
#include <migraphx/op/onehot.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/literal.hpp>
|
||||
#include <migraphx/op/resize.hpp>
|
||||
#include <migraphx/common.hpp>
|
||||
#include <migraphx/tensor_view.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
/**
|
||||
* Convert broadcast_with_dims operators with a static input tensor and a constant `dims` input
|
||||
* into multibroadcast op with a static output shape attribute.
|
||||
*
|
||||
*/
|
||||
struct find_broadcast_with_dims_static
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("broadcast_with_dims")(match::nargs(2),
|
||||
match::arg(0)(match::static_shape()),
|
||||
match::arg(1)(match::is_constant()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto inputs = ins->inputs();
|
||||
|
||||
// read the values of arg(1) to create input to multibroadcast
|
||||
std::vector<size_t> sizes_vec;
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { sizes_vec.assign(output.begin(), output.end()); });
|
||||
|
||||
m.replace_instruction(
|
||||
ins, make_op("multibroadcast", {{"out_lens", sizes_vec}}), inputs.at(0));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert a Resize op. with Nearest mode to an implementation using Gather op.
|
||||
* From: resize[scales={...}/sizes={...},](static, constant)
|
||||
* To:
|
||||
* 0 = literal{ ... } computed_indices
|
||||
* ...
|
||||
* 2 = reshape[dims={45}](X) 1-dimensional
|
||||
* 3 = gather[axis=0](2,0)
|
||||
*
|
||||
* At the time of writing, this conversion is required for GPU targets because there
|
||||
* is not direct a GPU implementation of the Resize operation.
|
||||
* This matcher depends on a split_single_dyn_dim pass being run before it, which
|
||||
* will convert any dynamic-batch input to static inputs and make this conversion possible.
|
||||
*
|
||||
* At time of writing, Resize allows either 1 or 2 inputs
|
||||
* but the 1-input case is never created by Onnx parsing.
|
||||
*/
|
||||
struct find_resize_static
|
||||
{
|
||||
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("resize")(match::nargs(2),
|
||||
match::arg(0)(match::static_shape()),
|
||||
match::arg(1)(match::is_constant()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto inputs = ins->inputs();
|
||||
auto resize_op = any_cast<op::resize>(ins->get_operator());
|
||||
|
||||
auto in_lens = inputs.at(0)->get_shape().lens();
|
||||
std::vector<size_t> sizes_vec(inputs.at(0)->get_shape().ndim());
|
||||
std::vector<float> scales_vec(inputs.at(0)->get_shape().ndim());
|
||||
// populate both scales and sizes for the benefit of the algorithm.
|
||||
inputs.at(1)->eval().visit([&](auto input) {
|
||||
using type = typename decltype(input)::value_type;
|
||||
if constexpr(std::is_integral<type>{})
|
||||
{
|
||||
// read output sizes and use them to compute scales
|
||||
sizes_vec.assign(input.begin(), input.end());
|
||||
std::transform(
|
||||
input.begin(),
|
||||
input.end(),
|
||||
in_lens.begin(),
|
||||
scales_vec.begin(),
|
||||
[](auto sz, size_t in_len) { return static_cast<float>(sz) / in_len; });
|
||||
}
|
||||
else
|
||||
{
|
||||
// read scales and use them to compute output sizes
|
||||
scales_vec.assign(input.begin(), input.end());
|
||||
std::transform(
|
||||
input.begin(),
|
||||
input.end(),
|
||||
in_lens.begin(),
|
||||
sizes_vec.begin(),
|
||||
[](auto sz, size_t in_len) { return static_cast<size_t>(sz * in_len); });
|
||||
}
|
||||
});
|
||||
|
||||
auto in_s = inputs.at(0)->get_shape();
|
||||
shape out_s{in_s.type(), sizes_vec};
|
||||
|
||||
std::vector<int> ind(out_s.elements());
|
||||
|
||||
// map out_idx to in_idx
|
||||
auto nearest_op = op::resize::get_nearest_op(resize_op.nearest_mode);
|
||||
auto idx_op = op::resize::get_original_idx_op(resize_op.coordinate_transformation_mode);
|
||||
|
||||
shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
|
||||
std::vector<size_t> in_idx(out_idx_v.size());
|
||||
for(auto ii = 0; ii < in_lens.size(); ++ii)
|
||||
{
|
||||
auto idx_val = idx_op(in_lens[ii], sizes_vec[ii], out_idx_v[ii], scales_vec[ii]);
|
||||
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
|
||||
}
|
||||
|
||||
ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
|
||||
});
|
||||
|
||||
// reshape input to one-dimension
|
||||
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
|
||||
auto reshape_op = make_op("reshape", {{"dims", rsp_lens}});
|
||||
auto rsp = m.insert_instruction(ins, reshape_op, ins->inputs().at(0));
|
||||
|
||||
// Add our computed indices as a literal.
|
||||
// ins_ind is a multi dimensional index that will restore original rank
|
||||
shape ind_s{shape::int32_type, sizes_vec};
|
||||
auto ins_ind = m.add_literal(literal(ind_s, ind));
|
||||
m.replace_instruction(ins, make_op("gather", {{"axis", 0}}), rsp, ins_ind);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
|
||||
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
|
||||
* of the broadcasting operators.
|
||||
* From:
|
||||
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
|
||||
* To:
|
||||
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
|
||||
*/
|
||||
struct find_static_2in_broadcasts
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::broadcast(match::nargs(2),
|
||||
match::arg(0)(match::static_shape()),
|
||||
match::arg(1)(match::static_shape()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto out_lens = ins->get_shape().lens();
|
||||
auto broadcast_op = ins->get_operator();
|
||||
if(broadcast_op.name() == "broadcast")
|
||||
{
|
||||
broadcast_op.from_value({{"out_lens", out_lens}});
|
||||
}
|
||||
else
|
||||
{
|
||||
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
|
||||
}
|
||||
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
|
||||
* From:
|
||||
* slice(data, constant_input); two attributes set
|
||||
* To:
|
||||
* slice(data); slice.starts, slice.ends. slice.axes set
|
||||
*/
|
||||
struct find_const_2in_slice
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("slice")(match::nargs(2), match::arg(1)(match::is_constant()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto inputs = ins->inputs();
|
||||
auto slice_op = any_cast<op::slice>(ins->get_operator());
|
||||
auto set_attrs = slice_op.get_set_attributes();
|
||||
std::vector<int64_t> starts_vec;
|
||||
std::vector<int64_t> ends_vec;
|
||||
std::vector<int64_t> axes_vec;
|
||||
if(set_attrs == op::slice::ends_axes)
|
||||
{
|
||||
// slice(data, starts)
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
|
||||
ends_vec = slice_op.ends;
|
||||
axes_vec = slice_op.axes;
|
||||
}
|
||||
else if(set_attrs == op::slice::starts_axes)
|
||||
{
|
||||
// slice(data, ends)
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
|
||||
starts_vec = slice_op.starts;
|
||||
axes_vec = slice_op.axes;
|
||||
}
|
||||
else
|
||||
{
|
||||
// slice(data, axes)
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
|
||||
starts_vec = slice_op.starts;
|
||||
ends_vec = slice_op.ends;
|
||||
}
|
||||
m.replace_instruction(
|
||||
ins,
|
||||
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
|
||||
inputs.at(0));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
|
||||
* From:
|
||||
* slice(data, constant_input1, constant_input2); one attribute set
|
||||
* To:
|
||||
* slice(data); slice.starts, slice.ends. slice.axes set
|
||||
*/
|
||||
struct find_const_3in_slice
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("slice")(match::nargs(3),
|
||||
match::arg(1)(match::is_constant()),
|
||||
match::arg(2)(match::is_constant()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto inputs = ins->inputs();
|
||||
auto slice_op = any_cast<op::slice>(ins->get_operator());
|
||||
auto set_attrs = slice_op.get_set_attributes();
|
||||
std::vector<int64_t> starts_vec;
|
||||
std::vector<int64_t> ends_vec;
|
||||
std::vector<int64_t> axes_vec;
|
||||
if(set_attrs == op::slice::axes_only)
|
||||
{
|
||||
// slice(data, starts, ends)
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
|
||||
inputs.at(2)->eval().visit(
|
||||
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
|
||||
axes_vec = slice_op.axes;
|
||||
}
|
||||
else if(set_attrs == op::slice::ends_only)
|
||||
{
|
||||
// slice(data, starts, axes)
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
|
||||
inputs.at(2)->eval().visit(
|
||||
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
|
||||
ends_vec = slice_op.ends;
|
||||
}
|
||||
else
|
||||
{
|
||||
// slice(data, ends, axes)
|
||||
inputs.at(1)->eval().visit(
|
||||
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
|
||||
inputs.at(2)->eval().visit(
|
||||
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
|
||||
starts_vec = slice_op.starts;
|
||||
}
|
||||
m.replace_instruction(
|
||||
ins,
|
||||
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
|
||||
inputs.at(0));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
|
||||
* From:
|
||||
* slice(data, constant_starts, constant_ends, constant_axes)
|
||||
* To:
|
||||
* slice(data); slice.starts, slice.ends. slice.axes set
|
||||
*/
|
||||
struct find_const_4in_slice
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("slice")(match::nargs(4),
|
||||
match::arg(1)(match::is_constant()),
|
||||
match::arg(2)(match::is_constant()),
|
||||
match::arg(3)(match::is_constant()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto inputs = ins->inputs();
|
||||
argument starts_arg = inputs.at(1)->eval(false);
|
||||
argument ends_arg = inputs.at(2)->eval(false);
|
||||
argument axes_arg = inputs.at(3)->eval(false);
|
||||
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
|
||||
{
|
||||
std::vector<int64_t> starts_vec;
|
||||
std::vector<int64_t> ends_vec;
|
||||
std::vector<int64_t> axes_vec;
|
||||
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
|
||||
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
|
||||
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
|
||||
m.replace_instruction(
|
||||
ins,
|
||||
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
|
||||
inputs.at(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify dimensions_of to a literal when the input arugment has a static shape
|
||||
* or the dynamic dimensions from `start` to `end` are fixed.
|
||||
*/
|
||||
struct find_static_dimensions_of
|
||||
{
|
||||
auto matcher() const { return match::name("dimensions_of")(); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto ins = mr.result;
|
||||
auto input = ins->inputs().at(0);
|
||||
auto dimensions_of_value = ins->get_operator().to_value();
|
||||
auto start = dimensions_of_value.at("start").to<std::size_t>();
|
||||
auto end = dimensions_of_value.at("end").to<std::size_t>();
|
||||
if(input->get_shape().dynamic())
|
||||
{
|
||||
// check if dynamic dimensions from start to end are fixed
|
||||
auto dds = input->get_shape().dyn_dims();
|
||||
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
|
||||
return not dd.is_fixed();
|
||||
}))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::size_t output_ndim = end - start;
|
||||
std::vector<int64_t> vec_shape(output_ndim);
|
||||
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
|
||||
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
|
||||
std::transform(input_lens.begin() + start,
|
||||
input_lens.begin() + end,
|
||||
vec_shape.begin(),
|
||||
[](auto i) { return int64_t(i); });
|
||||
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
|
||||
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
|
||||
m.replace_instruction(ins, lit_ins);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
|
||||
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
|
||||
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
|
||||
* From:
|
||||
* x = allocate(constant_output_dims) -> reshape(data, x)
|
||||
* To:
|
||||
* reshape(data); reshape.dims = constant_output_dims
|
||||
*/
|
||||
struct find_const_alloc_reshapes
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto const_alloc = match::arg(1)(match::name("allocate")(match::is_constant()));
|
||||
return match::name("reshape")(match::nargs(2), const_alloc);
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto reshape_ins = mr.result;
|
||||
auto reshape_inputs = reshape_ins->inputs();
|
||||
auto alloc_ins = reshape_inputs.at(1);
|
||||
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
|
||||
std::vector<int64_t> output_dims_vec;
|
||||
output_dims_arg.visit(
|
||||
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
|
||||
m.replace_instruction(
|
||||
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
|
||||
// have dead_code_elimination remove the previous allocate
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify allocate into fill operator that has constant output dimensions and constant value.
|
||||
* The allocate into fill instructions is what is produced when parsing the ONNX
|
||||
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
|
||||
* would rather have the simplification happen earlier during compiling.
|
||||
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
|
||||
* From:
|
||||
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
|
||||
* To:
|
||||
* literal
|
||||
*/
|
||||
struct find_const_alloc_fill
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto const_alloc = match::arg(1)(match::name("allocate")(match::is_constant()));
|
||||
return match::name("fill")(match::arg(0)(match::is_constant()), const_alloc);
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto fill_ins = mr.result;
|
||||
auto fill_arg = fill_ins->eval(false);
|
||||
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
|
||||
m.replace_instruction(fill_ins, l);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify broadcast_for_dot instructions with two static shaped arguments
|
||||
* From:
|
||||
* broadcast_for_dot(static_shape_arg, static_shape_arg)
|
||||
* To:
|
||||
* multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape
|
||||
*/
|
||||
struct find_static_broadcast_for_dot
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("broadcast_for_dot")(match::arg(0)(match::static_shape()),
|
||||
match::arg(1)(match::static_shape()));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto broadcast_for_dot_ins = mr.result;
|
||||
auto inputs = broadcast_for_dot_ins->inputs();
|
||||
auto s0 = inputs.at(0)->get_shape();
|
||||
auto s1 = inputs.at(1)->get_shape();
|
||||
auto l0_it = s0.lens().end() - 2;
|
||||
std::vector<std::size_t> l0_broadcasted_lens(s0.lens().begin(), l0_it);
|
||||
auto l1_it = s1.lens().begin() + s1.ndim() - 2;
|
||||
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
|
||||
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
|
||||
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
|
||||
m.replace_instruction(broadcast_for_dot_ins,
|
||||
make_op("multibroadcast", {{"out_lens", output_lens}}),
|
||||
inputs.at(0));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplify onehot instructions with static shape `indices` input and
|
||||
* a compile-time constant `depth` attribute or input.
|
||||
* From:
|
||||
* onehot(static_shape_arg, constant_arg, values) or
|
||||
* onehot(static_shape_arg, values)
|
||||
* To:
|
||||
* A = literal(shape = onehot_output_shape, value = 0)
|
||||
* B = unsqueeze(literal(lens = indices_lens, strides = broadcasted scalar, value = 1),
|
||||
* axis=onehot_axis) C = scatter(A, unsqueeze(indices, axis=onehot_axis), B) diff = on_value -
|
||||
* off_value D = mul(diff, C); return = add(D, off_value);
|
||||
*
|
||||
* NOTE: It might be cleaner to use some form of `fill` instead of
|
||||
* (on_value - off_value) * mask + off_value when we have `fill` working
|
||||
* on the GPU.
|
||||
*/
|
||||
struct find_static_onehot
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
auto match_2_args = match::nargs(2)(match::arg(0)(match::static_shape()),
|
||||
match::arg(1)(match::static_shape()));
|
||||
auto match_3_args = match::nargs(3)(match::arg(0)(match::static_shape()),
|
||||
match::arg(1)(match::is_constant()),
|
||||
match::arg(2)(match::static_shape()));
|
||||
return match::name("onehot")(match::any_of(match_2_args, match_3_args));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto onehot_ins = mr.result;
|
||||
auto onehot_inputs = onehot_ins->inputs();
|
||||
auto onehot_op = any_cast<op::onehot>(onehot_ins->get_operator());
|
||||
auto indices_ins = onehot_inputs[0];
|
||||
shape indices_shape = indices_ins->get_shape();
|
||||
std::size_t depth_val;
|
||||
migraphx::instruction_ref values_ins;
|
||||
if(onehot_op.depth.has_value())
|
||||
{
|
||||
assert(onehot_inputs.size() == 2);
|
||||
depth_val = onehot_op.depth.value();
|
||||
values_ins = onehot_inputs[1];
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(onehot_inputs.size() == 3);
|
||||
auto depth_ins = onehot_inputs[1];
|
||||
depth_ins->eval().visit([&](auto d) { depth_val = d[0]; });
|
||||
values_ins = onehot_inputs[2];
|
||||
}
|
||||
shape values_shape = values_ins->get_shape();
|
||||
std::vector<std::size_t> static_output_lens = indices_shape.lens();
|
||||
auto normalized_axis =
|
||||
(onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis;
|
||||
static_output_lens.insert(static_output_lens.begin() + normalized_axis, depth_val);
|
||||
shape output_shape{values_shape.type(), static_output_lens};
|
||||
std::vector<float> zeros(output_shape.elements(), 0);
|
||||
auto zeros_lit = m.add_literal(literal(output_shape, zeros));
|
||||
auto unsqueeze_inds = m.insert_instruction(
|
||||
onehot_ins, migraphx::make_op("unsqueeze", {{"axes", {normalized_axis}}}), indices_ins);
|
||||
// broadcast the one scalar to the correct shape
|
||||
auto ones_lit = m.add_literal(literal(shape{values_shape.type(), {1}, {0}}, {1}));
|
||||
auto mb_ones = m.insert_instruction(
|
||||
onehot_ins,
|
||||
migraphx::make_op("multibroadcast", {{"out_lens", unsqueeze_inds->get_shape().lens()}}),
|
||||
ones_lit);
|
||||
auto mask = m.insert_instruction(
|
||||
onehot_ins,
|
||||
make_op("scatter_none", {{"axis", normalized_axis}, {"skip_out_of_bounds", true}}),
|
||||
zeros_lit,
|
||||
unsqueeze_inds,
|
||||
mb_ones);
|
||||
auto off_val =
|
||||
m.insert_instruction(onehot_ins,
|
||||
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}),
|
||||
values_ins);
|
||||
auto on_val =
|
||||
m.insert_instruction(onehot_ins,
|
||||
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}),
|
||||
values_ins);
|
||||
auto diff_val = m.insert_instruction(onehot_ins, make_op("sub"), on_val, off_val);
|
||||
auto mul_diff_mask = insert_common_op(m, onehot_ins, make_op("mul"), {diff_val, mask});
|
||||
auto mb_off_val = m.insert_instruction(
|
||||
onehot_ins, make_op("multibroadcast", {{"out_lens", output_shape.lens()}}), off_val);
|
||||
m.replace_instruction(onehot_ins, make_op("add"), mb_off_val, mul_diff_mask);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Go through `select_module` instructions and update the `output_dyn_shapes` attribute.
|
||||
* Checks the submodule output shapes and determines an appropriate `output_dyn_shapes` attribute.
|
||||
* This version ignores dynamic_dimension opt values.
|
||||
* Intended to be run after the other simplify_dyn_ops passes.
|
||||
*/
|
||||
struct simplify_select_module_output_shape
|
||||
{
|
||||
auto matcher() const { return match::name("select_module"); }
|
||||
|
||||
void apply(module& m, const match::matcher_result& mr) const
|
||||
{
|
||||
auto sm_ins = mr.result;
|
||||
auto sm_module_inputs = sm_ins->module_inputs();
|
||||
std::vector<std::vector<shape>> all_output_shapes(sm_module_inputs.size());
|
||||
std::transform(sm_module_inputs.begin(),
|
||||
sm_module_inputs.end(),
|
||||
all_output_shapes.begin(),
|
||||
[](auto submod) { return submod->get_output_shapes(); });
|
||||
// check that all of the submodules have the same number of outputs and all respective
|
||||
// outputs have the same rank and type
|
||||
auto shapes_ndim = get_shapes_ndim(all_output_shapes.front());
|
||||
auto shapes_types = get_shapes_types(all_output_shapes.front());
|
||||
if(std::any_of(
|
||||
all_output_shapes.begin() + 1, all_output_shapes.end(), [&](auto out_shapes) {
|
||||
bool same_types = get_shapes_types(out_shapes) == shapes_types;
|
||||
bool same_ndim = get_shapes_ndim(out_shapes) == shapes_ndim;
|
||||
return not same_types or not same_ndim;
|
||||
}))
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto num_out_shapes = shapes_ndim.size();
|
||||
std::vector<shape> dyn_shapes(num_out_shapes);
|
||||
auto num_submod = sm_module_inputs.size();
|
||||
// compare respective output shapes from each submodule to get a range for the output shape
|
||||
for(int i : range(num_out_shapes))
|
||||
{
|
||||
std::vector<shape> shapes_at_index(num_submod);
|
||||
std::transform(all_output_shapes.begin(),
|
||||
all_output_shapes.end(),
|
||||
shapes_at_index.begin(),
|
||||
[&](auto output_shapes) { return output_shapes.at(i); });
|
||||
dyn_shapes.at(i) = dyn_shape_from_shapes(shapes_at_index);
|
||||
}
|
||||
auto tuple_shape = shape{dyn_shapes};
|
||||
m.replace_instruction(
|
||||
sm_ins,
|
||||
make_op("select_module", {{"output_dyn_shapes", to_value(tuple_shape)}}),
|
||||
sm_ins->inputs(),
|
||||
sm_module_inputs);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> get_shapes_ndim(const std::vector<shape>& shapes) const
|
||||
{
|
||||
std::vector<std::size_t> ret(shapes.size());
|
||||
std::transform(
|
||||
shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.ndim(); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<shape::type_t> get_shapes_types(const std::vector<shape>& shapes) const
|
||||
{
|
||||
std::vector<shape::type_t> ret(shapes.size());
|
||||
std::transform(
|
||||
shapes.cbegin(), shapes.cend(), ret.begin(), [](auto s) { return s.type(); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculating an appropriate shape that encompasses all of the given vector of shapes.
|
||||
* Equivalent to creating a 2D matrix of shape lengths and do a reduce over each axis.
|
||||
* The shapes can be dynamic or static.
|
||||
* Assuming all shapes have the same ndim.
|
||||
*/
|
||||
shape dyn_shape_from_shapes(std::vector<shape> shape_vec) const
|
||||
{
|
||||
// making 2D matrices of min_lens and max_lens
|
||||
// specifically using uint64_t because we're going to put the values into a tensor_view
|
||||
// later
|
||||
std::vector<uint64_t> all_min_lens;
|
||||
std::vector<uint64_t> all_max_lens;
|
||||
for(const auto& s : shape_vec)
|
||||
{
|
||||
auto min_lens = s.min_lens();
|
||||
auto max_lens = s.max_lens();
|
||||
std::copy(min_lens.begin(), min_lens.end(), std::back_inserter(all_min_lens));
|
||||
std::copy(max_lens.begin(), max_lens.end(), std::back_inserter(all_max_lens));
|
||||
}
|
||||
assert(all_min_lens.size() == shape_vec.size() * shape_vec.front().ndim());
|
||||
assert(all_max_lens.size() == shape_vec.size() * shape_vec.front().ndim());
|
||||
auto num_rows = shape_vec.size();
|
||||
auto num_cols = shape_vec.front().ndim();
|
||||
shape tensor_shape{shape::uint64_type, {num_rows, num_cols}};
|
||||
auto min_lens_matrix = make_view(tensor_shape, all_min_lens.data());
|
||||
auto max_lens_matrix = make_view(tensor_shape, all_max_lens.data());
|
||||
|
||||
std::vector<uint64_t> mins(num_cols);
|
||||
std::vector<uint64_t> maxes(num_cols);
|
||||
// rearranging data into column vectors to reduce over
|
||||
// i = row, j = column
|
||||
for(int j : range(num_cols))
|
||||
{
|
||||
std::vector<uint64_t> reduce_min_vals(num_rows);
|
||||
std::vector<uint64_t> reduce_max_vals(num_rows);
|
||||
for(int i : range(num_rows))
|
||||
{
|
||||
reduce_min_vals.at(i) = min_lens_matrix(i, j);
|
||||
reduce_max_vals.at(i) = max_lens_matrix(i, j);
|
||||
}
|
||||
uint64_t max_int = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t min_val =
|
||||
std::accumulate(reduce_min_vals.begin(),
|
||||
reduce_min_vals.end(),
|
||||
max_int,
|
||||
[](uint64_t x, uint64_t y) { return x < y ? x : y; });
|
||||
uint64_t max_val = std::accumulate(
|
||||
reduce_max_vals.begin(), reduce_max_vals.end(), 0, [](uint64_t x, uint64_t y) {
|
||||
return x > y ? x : y;
|
||||
});
|
||||
mins.at(j) = min_val;
|
||||
maxes.at(j) = max_val;
|
||||
}
|
||||
// fixed output shape case
|
||||
if(mins == maxes)
|
||||
{
|
||||
return shape{shape_vec.front().type(), mins};
|
||||
}
|
||||
// dynamic output shape case
|
||||
return shape{shape_vec.front().type(), mins, maxes, {}};
|
||||
}
|
||||
};
|
||||
|
||||
void simplify_dyn_ops::apply(module& m) const
|
||||
{
|
||||
match::find_matches(m,
|
||||
find_broadcast_with_dims_static{},
|
||||
find_resize_static{},
|
||||
find_static_dimensions_of{},
|
||||
find_const_alloc_reshapes{},
|
||||
find_static_2in_broadcasts{},
|
||||
find_const_2in_slice{},
|
||||
find_const_3in_slice{},
|
||||
find_const_4in_slice{},
|
||||
find_const_alloc_fill{},
|
||||
find_static_broadcast_for_dot{},
|
||||
find_static_onehot{});
|
||||
match::find_matches(m, simplify_select_module_output_shape{});
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
463
docker/rocm/migraphx/simplify_qdq.cpp
Normal file
463
docker/rocm/migraphx/simplify_qdq.cpp
Normal file
@ -0,0 +1,463 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/simplify_qdq.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/shape.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/dead_code_elimination.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/op/convolution.hpp>
|
||||
#include <migraphx/op/quant_convolution.hpp>
|
||||
#include <migraphx/op/dot.hpp>
|
||||
#include <migraphx/op/quant_dot.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
#include <migraphx/fp8_types.hpp>
|
||||
#include <migraphx/match/dq_helpers.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
namespace {
|
||||
|
||||
std::unordered_set<std::string> get_quantizable_op_names()
|
||||
{
|
||||
static std::unordered_set<std::string> s = {"convolution", "dot"};
|
||||
return s;
|
||||
}
|
||||
|
||||
struct match_find_quantizable_ops
|
||||
{
|
||||
static bool
|
||||
is_valid_qparam(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
|
||||
{
|
||||
return qparam->get_shape().elements() == 1 or
|
||||
qparam->get_shape().elements() == lens.at(axis);
|
||||
}
|
||||
|
||||
static bool is_symmetric_zero_point(instruction_ref zp)
|
||||
{
|
||||
if(not zp->can_eval())
|
||||
return false;
|
||||
|
||||
bool all_zeros = false;
|
||||
zp->eval().visit([&](auto z) {
|
||||
all_zeros =
|
||||
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
|
||||
});
|
||||
return all_zeros;
|
||||
}
|
||||
|
||||
static auto
|
||||
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
|
||||
{
|
||||
if(qparam->get_shape().scalar())
|
||||
{
|
||||
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
|
||||
}
|
||||
else
|
||||
{
|
||||
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to insert quantized versions of any broadcasts and transpose ops that
|
||||
// occur between dequantizelinear and the quantized op
|
||||
static auto propagate_quantized_ins(module& m,
|
||||
const instruction_ref dqins,
|
||||
const instruction_ref qop_arg,
|
||||
bool is_fp16_model = false)
|
||||
{
|
||||
auto prev_ins = qop_arg;
|
||||
std::vector<instruction_ref> ins_inbetween;
|
||||
// matcher skips continguous, multi/broadcasts and transposes, collect all those
|
||||
// instructions
|
||||
while(prev_ins != dqins)
|
||||
{
|
||||
ins_inbetween.push_back(prev_ins);
|
||||
prev_ins = prev_ins->inputs().front();
|
||||
}
|
||||
auto qinp = dqins->inputs().front();
|
||||
for(auto ins : reverse_iterator_for(ins_inbetween))
|
||||
{
|
||||
if((*ins)->name() == "convert" and is_fp16_model)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
|
||||
}
|
||||
return qinp;
|
||||
}
|
||||
|
||||
auto matcher() const
|
||||
{
|
||||
auto dq1 = match::arg(0)(
|
||||
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
|
||||
auto dq2 = match::arg(1)(
|
||||
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
|
||||
return match::name(get_quantizable_op_names())(dq1, dq2);
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto qop = r.result;
|
||||
auto dq1 = r.instructions["dq1"];
|
||||
auto dq2 = r.instructions["dq2"];
|
||||
auto scale1 = r.instructions["scale1"];
|
||||
auto scale2 = r.instructions["scale2"];
|
||||
auto zp1 = r.instructions["zp1"];
|
||||
auto zp2 = r.instructions["zp2"];
|
||||
// Only INT8 or FP8 type currently supported
|
||||
std::set<migraphx::shape::type_t> supported_types = fp8_types{}.get();
|
||||
supported_types.insert(migraphx::shape::int8_type);
|
||||
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
|
||||
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
|
||||
return;
|
||||
|
||||
// Propagate q1 and q2 through any broadcasts and transposes before qop
|
||||
auto qop_args = qop->inputs();
|
||||
bool is_fp16_model = false;
|
||||
if(dq1->get_shape().type() != qop->get_shape().type() and
|
||||
qop->get_shape().type() == migraphx::shape::half_type)
|
||||
{
|
||||
assert(dq1->get_shape().type() == migraphx::shape::float_type);
|
||||
is_fp16_model = true;
|
||||
}
|
||||
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model);
|
||||
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model);
|
||||
auto arg1_lens = qop_args[0]->get_shape().lens();
|
||||
auto arg2_lens = qop_args[1]->get_shape().lens();
|
||||
instruction_ref dq;
|
||||
instruction_ref out_scale;
|
||||
instruction_ref out_zp;
|
||||
if(qop->name() == "convolution")
|
||||
{
|
||||
auto conv_val = qop->get_operator().to_value();
|
||||
dq = m.insert_instruction(
|
||||
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
|
||||
auto out_lens = dq->get_shape().lens();
|
||||
|
||||
// Ensure input and weight quantization paramaters are of a proper form
|
||||
// Input is of shape [n, c, x1, ..., xn]. Only scalar quantization allowed
|
||||
// Weight is of shape [k, c, y1, ... , yn]. Valid quantization axis is k
|
||||
|
||||
if(not(scale1->get_shape().elements() == 1 and zp1->get_shape().elements() == 1 and
|
||||
is_valid_qparam(scale2, arg2_lens, 0) and is_valid_qparam(zp2, arg2_lens, 0)))
|
||||
return;
|
||||
|
||||
// This implementation supports affine quantization for both input and weight
|
||||
// In practice, weight is quantized symmetrically
|
||||
|
||||
auto s1_bcast =
|
||||
m.insert_instruction(qop, qparam_broadcast_op(scale1, out_lens, 1), scale1);
|
||||
auto s2_bcast =
|
||||
m.insert_instruction(qop, qparam_broadcast_op(scale2, out_lens, 1), scale2);
|
||||
|
||||
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
|
||||
|
||||
// Compute the zero-point terms; initialize as 0 and add relevant terms
|
||||
auto zero_lit = m.add_literal(literal{shape{dq->get_shape().type()}, {0}});
|
||||
out_zp = m.insert_instruction(
|
||||
qop, make_op("multibroadcast", {{"out_lens", dq->get_shape().lens()}}), zero_lit);
|
||||
|
||||
auto inp_zp_bc = m.insert_instruction(qop, qparam_broadcast_op(zp1, arg1_lens, 1), zp1);
|
||||
auto w_zp_bc = m.insert_instruction(qop, qparam_broadcast_op(zp2, arg2_lens, 0), zp2);
|
||||
|
||||
if(not is_symmetric_zero_point(zp1))
|
||||
{
|
||||
auto out_zp_1 = m.insert_instruction(
|
||||
qop, migraphx::make_op("quant_convolution", conv_val), inp_zp_bc, qop_args[1]);
|
||||
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_1);
|
||||
}
|
||||
|
||||
if(not is_symmetric_zero_point(zp2))
|
||||
{
|
||||
auto out_zp_2 = m.insert_instruction(
|
||||
qop, migraphx::make_op("quant_convolution", conv_val), qop_args[0], w_zp_bc);
|
||||
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_2);
|
||||
}
|
||||
|
||||
if(not is_symmetric_zero_point(zp1) and not is_symmetric_zero_point(zp2))
|
||||
{
|
||||
auto out_zp_3 = m.insert_instruction(
|
||||
qop, migraphx::make_op("quant_convolution", conv_val), inp_zp_bc, w_zp_bc);
|
||||
out_zp = m.insert_instruction(qop, migraphx::make_op("sub"), out_zp, out_zp_3);
|
||||
}
|
||||
}
|
||||
else if(qop->name() == "dot")
|
||||
{
|
||||
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
|
||||
auto out_lens = dq->get_shape().lens();
|
||||
|
||||
// For (..., M, N) x (..., N, K) dot, valid quantization axes are M for input1 and K for
|
||||
// input 2
|
||||
if(not(is_valid_qparam(scale1, out_lens, out_lens.size() - 2) and
|
||||
is_valid_qparam(zp1, out_lens, out_lens.size() - 2) and
|
||||
is_valid_qparam(scale2, out_lens, out_lens.size() - 1) and
|
||||
is_valid_qparam(zp2, out_lens, out_lens.size() - 1)))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// This implementation supports both arguments being per-axis affine quantized
|
||||
// In practice, inputs are per-tensor affine and weights are per-axis symmetric
|
||||
|
||||
auto s1_bcast = m.insert_instruction(
|
||||
qop, qparam_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
|
||||
auto s2_bcast = m.insert_instruction(
|
||||
qop, qparam_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
|
||||
|
||||
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
|
||||
|
||||
// Compute the zero-point terms; initialize as 0 and add relevant terms
|
||||
auto zero_lit = m.add_literal(literal{shape{dq->get_shape().type()}, {0}});
|
||||
out_zp = m.insert_instruction(
|
||||
qop, make_op("multibroadcast", {{"out_lens", dq->get_shape().lens()}}), zero_lit);
|
||||
|
||||
auto zp1_bc = m.insert_instruction(
|
||||
qop, qparam_broadcast_op(zp1, arg1_lens, arg1_lens.size() - 2), zp1);
|
||||
auto zp2_bc = m.insert_instruction(
|
||||
qop, qparam_broadcast_op(zp2, arg2_lens, arg2_lens.size() - 1), zp2);
|
||||
|
||||
if(not is_symmetric_zero_point(zp1))
|
||||
{
|
||||
auto out_zp_1 =
|
||||
m.insert_instruction(qop, migraphx::make_op("quant_dot"), zp1_bc, qop_args[1]);
|
||||
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_1);
|
||||
}
|
||||
|
||||
if(not is_symmetric_zero_point(zp2))
|
||||
{
|
||||
auto out_zp_2 =
|
||||
m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args[0], zp2_bc);
|
||||
out_zp = m.insert_instruction(qop, migraphx::make_op("add"), out_zp, out_zp_2);
|
||||
}
|
||||
|
||||
if(not is_symmetric_zero_point(zp1) and not is_symmetric_zero_point(zp2))
|
||||
{
|
||||
auto out_zp_3 =
|
||||
m.insert_instruction(qop, migraphx::make_op("quant_dot"), zp1_bc, zp2_bc);
|
||||
out_zp = m.insert_instruction(qop, migraphx::make_op("sub"), out_zp, out_zp_3);
|
||||
}
|
||||
}
|
||||
|
||||
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale, out_zp);
|
||||
if(is_fp16_model)
|
||||
{
|
||||
dq = m.insert_instruction(
|
||||
qop, make_op("convert", {{"target_type", migraphx::shape::half_type}}), dq);
|
||||
}
|
||||
m.replace_instruction(qop, dq);
|
||||
}
|
||||
};
|
||||
|
||||
bool compare_literals(instruction_ref ins1, instruction_ref ins2)
|
||||
{
|
||||
if(ins1->name() == "broadcast" or ins1->name() == "multibroadcast")
|
||||
ins1 = ins1->inputs().front();
|
||||
auto x = ins1->eval();
|
||||
if(x.empty())
|
||||
return false;
|
||||
auto literal1 = ins1->get_literal();
|
||||
if(ins2->name() == "broadcast" or ins2->name() == "multibroadcast")
|
||||
ins2 = ins2->inputs().front();
|
||||
auto y = ins2->eval();
|
||||
if(y.empty())
|
||||
return false;
|
||||
auto literal2 = ins2->get_literal();
|
||||
|
||||
bool diff_shapes_equal_vals = false;
|
||||
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
|
||||
diff_shapes_equal_vals =
|
||||
std::all_of(l1.begin() + 1,
|
||||
l1.end(),
|
||||
[&](auto v) {
|
||||
return ((float_equal(v, l1.front())) or
|
||||
(std::isinf(static_cast<double>(l1.front())) and
|
||||
std::isinf(static_cast<double>(v))));
|
||||
}) and
|
||||
std::all_of(l2.begin(), l2.end(), [&](auto v) {
|
||||
return ((float_equal(v, l1.front())) or
|
||||
(std::isinf(static_cast<double>(l1.front())) and
|
||||
std::isinf(static_cast<double>(v))));
|
||||
});
|
||||
});
|
||||
|
||||
return (x == y) or diff_shapes_equal_vals;
|
||||
}
|
||||
|
||||
template <class Iterator>
|
||||
bool precedes(Iterator x, Iterator y, Iterator last)
|
||||
{
|
||||
auto r = range(std::next(x), last);
|
||||
return any_of(iterator_for(r), [&](auto it) { return it == y; });
|
||||
}
|
||||
|
||||
struct match_qlinear_reused
|
||||
{
|
||||
auto matcher() const
|
||||
{
|
||||
return match::name("quantizelinear")(
|
||||
match::used_once(), match::arg(0)(match::none_of(match::used_once()).bind("x")));
|
||||
}
|
||||
|
||||
void apply(module& m, const match::matcher_result& r) const
|
||||
{
|
||||
auto ins = r.result;
|
||||
auto x_ins = r.instructions["x"];
|
||||
assert(ins != x_ins);
|
||||
|
||||
auto dq_inputs = ins->inputs();
|
||||
dq_inputs[0] = ins;
|
||||
auto outputs = x_ins->outputs();
|
||||
if(outputs.size() != 2)
|
||||
return;
|
||||
for(auto output : outputs)
|
||||
{
|
||||
if(output->name() == "quantizelinear")
|
||||
continue;
|
||||
if(not output->get_operator().attributes().contains("pointwise"))
|
||||
continue;
|
||||
if(not precedes(ins, output, m.end()))
|
||||
continue;
|
||||
auto dq = m.insert_instruction(std::next(ins), make_op("dequantizelinear"), dq_inputs);
|
||||
instruction::replace_argument(output, x_ins, dq);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool is_same_value(instruction_ref a, instruction_ref b)
|
||||
{
|
||||
if(a == b)
|
||||
return true;
|
||||
return compare_literals(a, b);
|
||||
}
|
||||
|
||||
bool is_same_scale_zero(instruction_ref a, instruction_ref b)
|
||||
{
|
||||
if(a->inputs().size() != b->inputs().size())
|
||||
return false;
|
||||
if(not is_same_value(a->inputs().at(1), b->inputs().at(1)))
|
||||
return false;
|
||||
if(a->inputs().size() == 2)
|
||||
return true;
|
||||
return is_same_value(a->inputs().at(2), b->inputs().at(2));
|
||||
}
|
||||
|
||||
// When an unpack instruction is inserted, its original input must be an int4/uint4.
|
||||
// Therefore check for an unpack_int4 operator -- while ignoring out shape related ops.
|
||||
bool is_any_input_int4(instruction_ref a)
|
||||
{
|
||||
static std::set<std::string> ign = {"unsqueeze",
|
||||
"broadcast",
|
||||
"multibroadcast",
|
||||
"contiguous",
|
||||
"transpose",
|
||||
"reshape",
|
||||
"convert"};
|
||||
return std::any_of(a->inputs().begin(), a->inputs().end(), [](auto i) {
|
||||
while(ign.find(i->name()) != ign.end())
|
||||
i = i->inputs()[0];
|
||||
return i->name() == "unpack_int4";
|
||||
});
|
||||
}
|
||||
|
||||
void remove_qdq_pairs(module& m)
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
auto args = ins->inputs();
|
||||
for(auto&& arg : args)
|
||||
{
|
||||
if(arg->name() == "dequantizelinear")
|
||||
{
|
||||
auto q = arg->inputs().front();
|
||||
if((q->name() == "quantizelinear") and is_same_scale_zero(arg, q))
|
||||
{
|
||||
instruction::replace_argument(ins, arg, q->inputs().front());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void remove_zero_point(module& m)
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "dequantizelinear")
|
||||
continue;
|
||||
if(ins->inputs().size() != 3)
|
||||
continue;
|
||||
auto zp = ins->inputs().at(2);
|
||||
if(not zp->can_eval())
|
||||
continue;
|
||||
auto a = zp->eval();
|
||||
bool is_zero = false;
|
||||
a.visit([&](auto t) {
|
||||
is_zero = std::all_of(t.begin(), t.end(), [](auto x) { return float_equal(x, 0); });
|
||||
});
|
||||
if(not is_zero)
|
||||
continue;
|
||||
m.replace_instruction(ins, ins->get_operator(), ins->inputs().at(0), ins->inputs().at(1));
|
||||
}
|
||||
}
|
||||
|
||||
void add_int4_pack_unpack_pair(module& m)
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
if(ins->name() != "dequantizelinear")
|
||||
continue;
|
||||
|
||||
for(auto&& inp : ins->inputs())
|
||||
{
|
||||
if((inp->name() == "quantizelinear") and is_any_input_int4(inp))
|
||||
{
|
||||
auto pk = m.insert_instruction(ins, make_op("pack_int4"), inp);
|
||||
auto unpk = m.insert_instruction(ins, make_op("unpack_int4"), pk);
|
||||
instruction::replace_argument(ins, inp, unpk);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void simplify_qdq::apply(module& m) const
|
||||
{
|
||||
// first step: add pack/unpack pair between qdq for int4 weights
|
||||
add_int4_pack_unpack_pair(m);
|
||||
match::find_matches(m, match_find_quantizable_ops{});
|
||||
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
|
||||
remove_qdq_pairs(m);
|
||||
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
|
||||
match::find_matches(m, match_qlinear_reused{});
|
||||
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
|
||||
remove_zero_point(m);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
1217
docker/rocm/migraphx/simplify_reshapes.cpp
Normal file
1217
docker/rocm/migraphx/simplify_reshapes.cpp
Normal file
File diff suppressed because it is too large
Load Diff
243
docker/rocm/migraphx/split_reduce.cpp
Normal file
243
docker/rocm/migraphx/split_reduce.cpp
Normal file
@ -0,0 +1,243 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#include <migraphx/split_reduce.hpp>
|
||||
#include <migraphx/dom_info.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/liveness.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/check_shapes.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
#include <migraphx/register_op.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/fuse_pointwise.hpp>
|
||||
#include <migraphx/algorithm.hpp>
|
||||
#include <migraphx/param_utils.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct split_fused_reduce
|
||||
{
|
||||
std::vector<std::int64_t> axes{};
|
||||
std::string assign = "assign_none";
|
||||
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.axes, "axes"), f(self.assign, "assign"));
|
||||
}
|
||||
|
||||
value attributes() const { return {{"prefill", 0}}; }
|
||||
|
||||
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
|
||||
{
|
||||
if(mods.size() != 1)
|
||||
MIGRAPHX_THROW("should have one submodule.");
|
||||
const auto* sm = mods.front();
|
||||
auto names = sm->get_parameter_names();
|
||||
check_shapes{inputs, *this}.has(names.size()).same_ndims();
|
||||
|
||||
auto result =
|
||||
sm->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true});
|
||||
if(result.size() == 1)
|
||||
return result.front();
|
||||
return shape{result};
|
||||
}
|
||||
|
||||
std::string name() const { return "split_fused_reduce"; }
|
||||
};
|
||||
MIGRAPHX_REGISTER_OP(split_fused_reduce);
|
||||
|
||||
static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); }
|
||||
|
||||
namespace {
|
||||
struct splitter
|
||||
{
|
||||
const_module_ref rm;
|
||||
|
||||
bool strictly_dominate(instruction_ref a, instruction_ref b)
|
||||
{
|
||||
if(not dom.has_value())
|
||||
dom = compute_dominator(*rm);
|
||||
return dom->strictly_dominate(a, b);
|
||||
}
|
||||
|
||||
std::vector<instruction_ref> find_splits() const
|
||||
{
|
||||
std::vector<instruction_ref> result;
|
||||
copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins) {
|
||||
return is_reduce(*ins);
|
||||
});
|
||||
if(result.size() > 2)
|
||||
return {};
|
||||
// Only handle reduce_sum for now
|
||||
// TODO: Support other reduction types
|
||||
if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) {
|
||||
return ins->name() == "reduce_sum";
|
||||
}))
|
||||
return {};
|
||||
if(result.size() < 2)
|
||||
return result;
|
||||
if(reaches(result[0], result[1]))
|
||||
return {};
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<instruction_ref> find_alive(const std::vector<instruction_ref>& splits)
|
||||
{
|
||||
std::vector<instruction_ref> result;
|
||||
bool stop = false;
|
||||
liveness(*rm, [&](auto rins, const auto& live_set) {
|
||||
if(stop)
|
||||
return;
|
||||
if(rins == rm->begin())
|
||||
return;
|
||||
// We want to know what instructions are live after the split instruction
|
||||
auto ins = instruction::get_output_alias(std::prev(rins));
|
||||
if(not contains(splits, ins))
|
||||
return;
|
||||
std::copy_if(live_set.begin(),
|
||||
live_set.end(),
|
||||
std::back_inserter(result),
|
||||
[&](instruction_ref live) {
|
||||
if(live->name() == "@param")
|
||||
return false;
|
||||
if(contains(splits, live))
|
||||
return false;
|
||||
if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) {
|
||||
return this->strictly_dominate(live, split);
|
||||
}))
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
stop = true;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
std::optional<dominator_info> dom = std::nullopt;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static std::string assign_op(const std::vector<instruction_ref>& splits)
|
||||
{
|
||||
static std::unordered_map<std::string, std::string> m = {
|
||||
{"reduce_sum", "assign_add"},
|
||||
{"reduce_mean", "assign_add"},
|
||||
{"reduce_prod", "assign_mul"},
|
||||
{"reduce_max", "assign_max"},
|
||||
{"reduce_min", "assign_min"},
|
||||
};
|
||||
return m.at(splits.front()->name());
|
||||
}
|
||||
|
||||
static std::vector<instruction_ref>
|
||||
insert_module_inline(module& m, instruction_ref ins, const module::with_inputs& mwi)
|
||||
{
|
||||
auto param_map = mwi.mod.get_ins_param_map(mwi.inputs, true);
|
||||
return m.insert_instructions(ins, &mwi.mod, ¶m_map);
|
||||
}
|
||||
|
||||
static std::size_t get_reduce_size(const_module_ref rm)
|
||||
{
|
||||
auto ins = std::find_if(rm->begin(), rm->end(), &is_reduce);
|
||||
assert(ins != rm->end());
|
||||
return ins->inputs().front()->get_shape().elements() / ins->get_shape().elements();
|
||||
}
|
||||
|
||||
void split_reduce::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
for(auto ins : iterator_for(mpm.get_module()))
|
||||
{
|
||||
if(ins->name() != "fused_reduce")
|
||||
continue;
|
||||
auto* rm = ins->module_inputs().front();
|
||||
if(get_reduce_size(rm) < split_size)
|
||||
continue;
|
||||
splitter s{rm};
|
||||
auto splits = s.find_splits();
|
||||
if(splits.empty())
|
||||
continue;
|
||||
// Only use split reduce with float for now
|
||||
// TODO: Support other data types
|
||||
if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) {
|
||||
return contains({shape::float_type, shape::half_type}, split->get_shape().type());
|
||||
}))
|
||||
continue;
|
||||
auto v = ins->get_operator().to_value();
|
||||
auto axes = v["axes"].to_vector<std::int64_t>();
|
||||
|
||||
auto alive = s.find_alive(splits);
|
||||
|
||||
std::array<module::with_inputs, 2> mods;
|
||||
if(not alive.empty())
|
||||
{
|
||||
auto mods3 = rm->split(ins->inputs(), alive, splits);
|
||||
auto r = insert_module_inline(mpm.get_module(), ins, mods3[0]);
|
||||
mods3[1].replace(alive, r);
|
||||
mods3[2].replace(alive, r);
|
||||
mods = {std::move(mods3[1]), std::move(mods3[2])};
|
||||
}
|
||||
else
|
||||
{
|
||||
mods = rm->split(ins->inputs(), splits);
|
||||
}
|
||||
|
||||
auto* splitm = mpm.create_module(rm->name() + "_split", std::move(mods[0].mod));
|
||||
splitm->set_bypass();
|
||||
|
||||
// Insert split reduce
|
||||
auto split_reduce = mpm.get_module().insert_instruction(
|
||||
ins,
|
||||
make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign_op(splits)}}),
|
||||
mods[0].inputs,
|
||||
{splitm});
|
||||
|
||||
std::vector<instruction_ref> split_reduce_each;
|
||||
if(splits.size() == 1)
|
||||
{
|
||||
split_reduce_each = {split_reduce};
|
||||
}
|
||||
else
|
||||
{
|
||||
transform(range(splits.size()), std::back_inserter(split_reduce_each), [&](auto i) {
|
||||
return mpm.get_module().insert_instruction(
|
||||
ins, make_op("get_tuple_elem", {{"index", i}}), split_reduce);
|
||||
});
|
||||
}
|
||||
|
||||
mods[1].replace(splits, split_reduce_each);
|
||||
auto replaced = insert_module_inline(mpm.get_module(), ins, mods[1]);
|
||||
assert(replaced.size() == 1);
|
||||
mpm.get_module().replace_instruction(ins, replaced.front());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
173
docker/rocm/migraphx/split_single_dyn_dim.cpp
Normal file
173
docker/rocm/migraphx/split_single_dyn_dim.cpp
Normal file
@ -0,0 +1,173 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/split_single_dyn_dim.hpp>
|
||||
#include <migraphx/module.hpp>
|
||||
#include <migraphx/pass_manager.hpp>
|
||||
#include <migraphx/functional.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/matcher.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct dynamic_dimensions_check
|
||||
{
|
||||
std::string dyn_param_str;
|
||||
shape::dynamic_dimension dd;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns value if the parameters contain non-fixed dynamic_dimensions that are the same between
|
||||
* all of the dynamic shape parameters.
|
||||
* In other words, each parameter can have one non-fixed dynamic_dimension `x` where `x` is the same
|
||||
* between all of the parameters with a non-fixed dynamic_dimension.
|
||||
* Returns the parameters and the dynamic dimension in a vector of dynamic_dimensions_check objects.
|
||||
*/
|
||||
optional<std::vector<dynamic_dimensions_check>>
|
||||
has_one_unique_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
|
||||
{
|
||||
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
|
||||
std::vector<std::decay_t<decltype(param_shapes)>::value_type> dyn_params{};
|
||||
std::copy_if(
|
||||
param_shapes.begin(), param_shapes.end(), std::back_inserter(dyn_params), is_dynamic);
|
||||
if(dyn_params.empty())
|
||||
return std::nullopt;
|
||||
std::vector<dynamic_dimensions_check> ret{};
|
||||
// get non-fixed dynamic_dimension from all parameters
|
||||
for(const auto& param : dyn_params)
|
||||
{
|
||||
const auto& dds = param.second.dyn_dims();
|
||||
auto num_non_fixed = std::count_if(dds.cbegin(), dds.cend(), [&](auto dd) {
|
||||
if(not dd.is_fixed())
|
||||
{
|
||||
ret.push_back(dynamic_dimensions_check{param.first, dd});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
// catch more than one non-fixed dynamic_dimension
|
||||
if(num_non_fixed > 1)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
if(ret.empty())
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
// check all the same dynamic_dimension
|
||||
bool same_dd =
|
||||
std::all_of(ret.begin() + 1, ret.end(), [&](auto ddc) { return ddc.dd == ret.at(0).dd; });
|
||||
if(same_dd)
|
||||
{
|
||||
return ret;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check the parameters in std::vector<dynamic_dimensions_check> object to see if any of the
|
||||
* parameters outputs to a select_module operator.
|
||||
*/
|
||||
bool any_sm_next(const_module_ref mm, const std::vector<dynamic_dimensions_check>& ddcs)
|
||||
{
|
||||
for(const auto& ddc : ddcs)
|
||||
{
|
||||
auto p_outputs = mm->get_parameter(ddc.dyn_param_str)->outputs();
|
||||
bool is_sm_next = std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
|
||||
return ins->name() == "select_module";
|
||||
});
|
||||
if(is_sm_next)
|
||||
{
|
||||
return true;
|
||||
};
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
|
||||
* and `loop` instructions, depending on how the submodules for those
|
||||
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
|
||||
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
|
||||
*/
|
||||
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
|
||||
{
|
||||
module_ref mm = &mpm.get_module();
|
||||
auto param_names = mm->get_parameter_names();
|
||||
auto param_shapes = mm->get_parameter_shapes();
|
||||
optional<std::vector<dynamic_dimensions_check>> dd_check_vec =
|
||||
has_one_unique_dyn_dim(param_shapes);
|
||||
if(dd_check_vec.has_value() and not any_sm_next(mm, dd_check_vec.value()))
|
||||
{
|
||||
// all dynamic dimension objects should be the same for all parameters in dd_check_vec
|
||||
auto dyn_dim = dd_check_vec->at(0).dd;
|
||||
// create submodules for each dimension size
|
||||
std::vector<module_ref> submodules;
|
||||
for(size_t dim_size : migraphx::range(dyn_dim.min, dyn_dim.max + 1))
|
||||
{
|
||||
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
|
||||
// instruction map for new static shaped submodule parameters
|
||||
std::unordered_map<instruction_ref, instruction_ref> map_ins;
|
||||
for(const auto& dd_check : dd_check_vec.value())
|
||||
{
|
||||
// create static shape using dim_size
|
||||
const auto& dyn_param = mm->get_parameter(dd_check.dyn_param_str);
|
||||
auto dyn_param_shape = mm->get_parameter_shape(dd_check.dyn_param_str);
|
||||
auto static_shape = dyn_param_shape.to_static(dim_size);
|
||||
map_ins[dyn_param] = submod->add_parameter(dd_check.dyn_param_str, static_shape);
|
||||
}
|
||||
auto outputs = submod->add_instructions(mm, &map_ins);
|
||||
submod->add_return({outputs});
|
||||
submodules.push_back(submod);
|
||||
}
|
||||
// sort parameters by name for consistency (vs. parameter order attr)
|
||||
std::sort(param_names.begin(), param_names.end());
|
||||
// redirect to select_module operator and return
|
||||
std::vector<instruction_ref> sm_inputs;
|
||||
std::transform(param_names.cbegin(),
|
||||
param_names.cend(),
|
||||
std::back_inserter(sm_inputs),
|
||||
[&](auto pn) { return mm->get_parameter(pn); });
|
||||
auto output_shapes = mm->get_output_shapes();
|
||||
migraphx::shape out_attr = migraphx::shape{output_shapes};
|
||||
auto sm_ins = mm->add_instruction(
|
||||
migraphx::make_op("select_module",
|
||||
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
|
||||
sm_inputs,
|
||||
submodules);
|
||||
std::vector<instruction_ref> outputs(output_shapes.size());
|
||||
for(size_t i = 0; i < output_shapes.size(); ++i)
|
||||
{
|
||||
outputs.at(i) =
|
||||
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", i}}), sm_ins);
|
||||
}
|
||||
mm->replace_return(outputs);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
112
docker/rocm/migraphx/sqlite.cpp
Normal file
112
docker/rocm/migraphx/sqlite.cpp
Normal file
@ -0,0 +1,112 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/sqlite.hpp>
|
||||
#include <migraphx/manage_ptr.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <sqlite3.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
using sqlite3_ptr = MIGRAPHX_MANAGE_PTR(sqlite3*, sqlite3_close);
|
||||
|
||||
struct sqlite_impl
|
||||
{
|
||||
sqlite3* get() const { return ptr.get(); }
|
||||
void open(const fs::path& p, int flags)
|
||||
{
|
||||
sqlite3* ptr_tmp = nullptr;
|
||||
int rc = sqlite3_open_v2(p.string().c_str(), &ptr_tmp, flags, nullptr);
|
||||
ptr = sqlite3_ptr{ptr_tmp};
|
||||
if(rc != 0)
|
||||
MIGRAPHX_THROW("error opening " + p.string() + ": " + error_message());
|
||||
}
|
||||
|
||||
template <class F>
|
||||
void exec(const char* sql, F f)
|
||||
{
|
||||
// cppcheck-suppress constParameterPointer
|
||||
auto callback = [](void* obj, auto... xs) -> int {
|
||||
try
|
||||
{
|
||||
const auto* g = static_cast<const F*>(obj);
|
||||
(*g)(xs...);
|
||||
return 0;
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
int rc = sqlite3_exec(get(), sql, callback, &f, nullptr);
|
||||
if(rc != 0)
|
||||
MIGRAPHX_THROW(error_message());
|
||||
}
|
||||
|
||||
std::string error_message() const
|
||||
{
|
||||
std::string msg = "sqlite3: ";
|
||||
return msg + sqlite3_errmsg(get());
|
||||
}
|
||||
sqlite3_ptr ptr;
|
||||
};
|
||||
|
||||
sqlite sqlite::read(const fs::path& p)
|
||||
{
|
||||
sqlite r;
|
||||
r.impl = std::make_shared<sqlite_impl>();
|
||||
r.impl->open(p, SQLITE_OPEN_READONLY);
|
||||
return r;
|
||||
}
|
||||
|
||||
sqlite sqlite::write(const fs::path& p)
|
||||
{
|
||||
sqlite r;
|
||||
r.impl = std::make_shared<sqlite_impl>();
|
||||
// Using '+' instead of bitwise '|' to avoid compilation warning
|
||||
r.impl->open(p, SQLITE_OPEN_READWRITE + SQLITE_OPEN_CREATE);
|
||||
return r;
|
||||
}
|
||||
|
||||
std::vector<std::unordered_map<std::string, std::string>> sqlite::execute(const std::string& s)
|
||||
{
|
||||
std::vector<std::unordered_map<std::string, std::string>> result;
|
||||
impl->exec(s.c_str(), [&](int n, char** texts, char** names) {
|
||||
std::unordered_map<std::string, std::string> row;
|
||||
row.reserve(n);
|
||||
std::transform(
|
||||
names,
|
||||
names + n,
|
||||
texts,
|
||||
std::inserter(row, row.begin()),
|
||||
[&](const char* name, const char* text) { return std::make_pair(name, text); });
|
||||
result.push_back(row);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
37
docker/rocm/migraphx/target.cpp
Normal file
37
docker/rocm/migraphx/target.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/target.hpp>
|
||||
#include <migraphx/register_target.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
|
||||
void migraphx_from_value(const value& v, target& t)
|
||||
{
|
||||
t = make_target(v.at("name").to<std::string>());
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
108
docker/rocm/migraphx/tmp_dir.cpp
Normal file
108
docker/rocm/migraphx/tmp_dir.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/tmp_dir.hpp>
|
||||
#include <migraphx/env.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/process.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#ifdef _WIN32
|
||||
// cppcheck-suppress definePrefix
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <Windows.h>
|
||||
#undef getpid
|
||||
// cppcheck-suppress [definePrefix, defineUpperCase]
|
||||
#define getpid _getpid
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#include <sys/types.h>
|
||||
#endif
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_SAVE_TEMP_DIR)
|
||||
|
||||
std::string random_string(std::string::size_type length)
|
||||
{
|
||||
static const std::string& chars = "0123456789"
|
||||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
|
||||
std::mt19937 rg{std::random_device{}()};
|
||||
std::uniform_int_distribution<std::string::size_type> pick(0, chars.length() - 1);
|
||||
|
||||
std::string str(length, 0);
|
||||
std::generate(str.begin(), str.end(), [&] { return chars[pick(rg)]; });
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
std::string unique_string(const std::string& prefix)
|
||||
{
|
||||
auto pid = getpid();
|
||||
auto tid = std::this_thread::get_id();
|
||||
auto clk = std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
std::stringstream ss;
|
||||
ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
tmp_dir::tmp_dir(std::string_view prefix)
|
||||
: path(fs::temp_directory_path() /
|
||||
unique_string(prefix.empty() ? "migraphx" : "migraphx-" + std::string{prefix}))
|
||||
{
|
||||
fs::create_directories(this->path);
|
||||
}
|
||||
|
||||
void tmp_dir::execute(std::string_view cmd, const std::vector<std::string>& args) const
|
||||
{
|
||||
process{cmd, args}.cwd(this->path).exec();
|
||||
}
|
||||
|
||||
tmp_dir::~tmp_dir()
|
||||
{
|
||||
if(not enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{}))
|
||||
{
|
||||
constexpr int max_retries_count = 5;
|
||||
for([[maybe_unused]] auto count : range(max_retries_count))
|
||||
{
|
||||
std::error_code ec;
|
||||
fs::remove_all(path, ec);
|
||||
if(not ec)
|
||||
break;
|
||||
std::cerr << "Failed to remove " << path << ": " << ec.message() << std::endl;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(125));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
103
docker/rocm/migraphx/truncate_float.cpp
Normal file
103
docker/rocm/migraphx/truncate_float.cpp
Normal file
@ -0,0 +1,103 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <migraphx/float_equal.hpp>
|
||||
#include <migraphx/instruction_ref.hpp>
|
||||
#include <migraphx/truncate_float.hpp>
|
||||
#include <migraphx/program.hpp>
|
||||
#include <migraphx/instruction.hpp>
|
||||
#include <migraphx/iterator_for.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/ranges.hpp>
|
||||
#include <migraphx/target.hpp>
|
||||
#include <migraphx/make_op.hpp>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
static void
|
||||
quantize_module(module& m, const std::vector<std::string>& ins_names, shape::type_t float_type)
|
||||
{
|
||||
for(auto ins : iterator_for(m))
|
||||
{
|
||||
// instructions are not in the set to be quantized
|
||||
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
|
||||
continue;
|
||||
|
||||
// skip return and convert instructions
|
||||
if(contains({"@return", "convert"}, ins->name()))
|
||||
continue;
|
||||
|
||||
if(ins->inputs().empty())
|
||||
continue;
|
||||
|
||||
auto mod_inputs = ins->module_inputs();
|
||||
auto s = ins->get_shape();
|
||||
// Convert each of the inputs that are floating point to float type
|
||||
auto inputs = ins->inputs();
|
||||
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
|
||||
auto input_type = input->get_shape().type();
|
||||
if(input_type != shape::float_type and input_type != shape::double_type)
|
||||
return input;
|
||||
return m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", float_type}}), input);
|
||||
});
|
||||
|
||||
// Insert quantized ins
|
||||
auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
|
||||
|
||||
// tuple can't be directly converted, get_tuple_elem needs conversion
|
||||
if(ins->get_shape().type() == shape::tuple_type)
|
||||
{
|
||||
auto outputs = ins->outputs();
|
||||
std::transform(
|
||||
outputs.begin(), outputs.end(), outputs.begin(), [&](const auto gte_ins) {
|
||||
auto gte_ins_float_type =
|
||||
m.insert_instruction(ins, gte_ins->get_operator(), converted_ins);
|
||||
// Convert back to output type after quantizing
|
||||
auto gte_converted = m.insert_instruction(
|
||||
ins,
|
||||
make_op("convert", {{"target_type", gte_ins->get_shape().type()}}),
|
||||
gte_ins_float_type);
|
||||
// Replace output instruction
|
||||
return m.replace_instruction(gte_ins, gte_converted);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Convert back to original type after quantizing
|
||||
if(mod_inputs.empty())
|
||||
{
|
||||
converted_ins = m.insert_instruction(
|
||||
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
|
||||
}
|
||||
// Replace original instruction
|
||||
m.replace_instruction(ins, converted_ins);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void truncate_float_pass::apply(module& m) const { quantize_module(m, ins_names, float_type); }
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
573
docker/rocm/migraphx/value.cpp
Normal file
573
docker/rocm/migraphx/value.cpp
Normal file
@ -0,0 +1,573 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <migraphx/cloneable.hpp>
|
||||
#include <migraphx/errors.hpp>
|
||||
#include <migraphx/stringutils.hpp>
|
||||
#include <migraphx/value.hpp>
|
||||
#include <migraphx/optional.hpp>
|
||||
#include <migraphx/hash.hpp>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
struct value_base_impl : cloneable<value_base_impl>
|
||||
{
|
||||
virtual value::type_t get_type() { return value::null_type; }
|
||||
#define MIGRAPHX_VALUE_GENERATE_BASE_FUNCTIONS(vt, cpp_type) \
|
||||
virtual const cpp_type* if_##vt() const { return nullptr; }
|
||||
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_BASE_FUNCTIONS)
|
||||
virtual std::vector<value>* if_array() { return nullptr; }
|
||||
virtual std::unordered_map<std::string, std::size_t>* if_object() { return nullptr; }
|
||||
virtual value_base_impl* if_value() const { return nullptr; }
|
||||
value_base_impl() = default;
|
||||
value_base_impl(const value_base_impl&) = default;
|
||||
value_base_impl& operator=(const value_base_impl&) = default;
|
||||
virtual ~value_base_impl() override {}
|
||||
};
|
||||
|
||||
#define MIGRAPHX_VALUE_GENERATE_BASE_TYPE(vt, cpp_type) \
|
||||
struct vt##_value_holder : value_base_impl::share \
|
||||
{ \
|
||||
vt##_value_holder(cpp_type d) : data(std::move(d)) {} \
|
||||
virtual value::type_t get_type() override { return value::vt##_type; } \
|
||||
virtual const cpp_type* if_##vt() const override { return &data; } \
|
||||
cpp_type data; \
|
||||
};
|
||||
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_BASE_TYPE)
|
||||
|
||||
struct array_value_holder : value_base_impl::derive<array_value_holder>
|
||||
{
|
||||
array_value_holder() {}
|
||||
array_value_holder(std::vector<value> d) : data(std::move(d)) {}
|
||||
virtual value::type_t get_type() override { return value::array_type; }
|
||||
virtual std::vector<value>* if_array() override { return &data; }
|
||||
std::vector<value> data;
|
||||
};
|
||||
|
||||
struct object_value_holder : value_base_impl::derive<object_value_holder>
|
||||
{
|
||||
object_value_holder() {}
|
||||
object_value_holder(std::vector<value> d, std::unordered_map<std::string, std::size_t> l)
|
||||
: data(std::move(d)), lookup(std::move(l))
|
||||
{
|
||||
}
|
||||
virtual value::type_t get_type() override { return value::object_type; }
|
||||
virtual std::vector<value>* if_array() override { return &data; }
|
||||
virtual std::unordered_map<std::string, std::size_t>* if_object() override { return &lookup; }
|
||||
std::vector<value> data;
|
||||
std::unordered_map<std::string, std::size_t> lookup;
|
||||
};
|
||||
|
||||
value::value(const value& rhs) : x(rhs.x ? rhs.x->clone() : nullptr), key(rhs.key) {}
|
||||
value& value::operator=(value rhs)
|
||||
{
|
||||
std::swap(rhs.x, x);
|
||||
if(not rhs.key.empty())
|
||||
std::swap(rhs.key, key);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void set_vector(std::shared_ptr<value_base_impl>& x,
|
||||
const std::vector<value>& v,
|
||||
bool array_on_empty = true)
|
||||
{
|
||||
if(v.empty())
|
||||
{
|
||||
if(array_on_empty)
|
||||
x = std::make_shared<array_value_holder>();
|
||||
else
|
||||
x = std::make_shared<object_value_holder>();
|
||||
return;
|
||||
}
|
||||
if(v.front().get_key().empty())
|
||||
{
|
||||
x = std::make_shared<array_value_holder>(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::unordered_map<std::string, std::size_t> lookup;
|
||||
std::size_t i = 0;
|
||||
for(auto&& e : v)
|
||||
{
|
||||
lookup[e.get_key()] = i;
|
||||
i++;
|
||||
}
|
||||
x = std::make_shared<object_value_holder>(v, lookup);
|
||||
}
|
||||
}
|
||||
|
||||
value::value(const std::initializer_list<value>& i) : x(nullptr)
|
||||
{
|
||||
if(i.size() == 2 and i.begin()->is_string() and i.begin()->get_key().empty())
|
||||
{
|
||||
key = i.begin()->get_string();
|
||||
auto r = (i.begin() + 1)->x;
|
||||
x = r ? r->clone() : nullptr;
|
||||
return;
|
||||
}
|
||||
set_vector(x, std::vector<value>(i.begin(), i.end()));
|
||||
}
|
||||
|
||||
value::value(const std::vector<value>& v, bool array_on_empty) : x(nullptr)
|
||||
{
|
||||
set_vector(x, v, array_on_empty);
|
||||
}
|
||||
|
||||
value::value(const std::unordered_map<std::string, value>& m)
|
||||
: value(std::vector<value>(m.begin(), m.end()), false)
|
||||
{
|
||||
}
|
||||
|
||||
value::value(const std::string& pkey, const std::vector<value>& v, bool array_on_empty)
|
||||
: x(nullptr), key(pkey)
|
||||
{
|
||||
set_vector(x, v, array_on_empty);
|
||||
}
|
||||
|
||||
value::value(const std::string& pkey, const std::unordered_map<std::string, value>& m)
|
||||
: value(pkey, std::vector<value>(m.begin(), m.end()), false)
|
||||
{
|
||||
}
|
||||
|
||||
value::value(const std::string& pkey, std::nullptr_t) : x(nullptr), key(pkey) {}
|
||||
|
||||
value::value(std::nullptr_t) : x(nullptr) {}
|
||||
|
||||
value::value(const std::string& pkey, const value& rhs)
|
||||
: x(rhs.x ? rhs.x->clone() : nullptr), key(pkey)
|
||||
{
|
||||
}
|
||||
|
||||
value::value(const std::string& pkey, const char* i) : value(pkey, std::string(i)) {}
|
||||
value::value(const char* i) : value(std::string(i)) {}
|
||||
|
||||
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
|
||||
value::value(cpp_type i) : x(std::make_shared<vt##_value_holder>(std::move(i))) {} \
|
||||
value::value(const std::string& pkey, cpp_type i) \
|
||||
: x(std::make_shared<vt##_value_holder>(std::move(i))), key(pkey) \
|
||||
{ \
|
||||
} \
|
||||
value& value::operator=(cpp_type rhs) \
|
||||
{ \
|
||||
x = std::make_shared<vt##_value_holder>(std::move(rhs)); \
|
||||
return *this; \
|
||||
} \
|
||||
bool value::is_##vt() const { return x ? x->get_type() == vt##_type : false; } \
|
||||
const cpp_type& value::get_##vt() const \
|
||||
{ \
|
||||
auto* r = this->if_##vt(); \
|
||||
assert(r); \
|
||||
return *r; \
|
||||
} \
|
||||
const cpp_type* value::if_##vt() const { return x ? x->if_##vt() : nullptr; }
|
||||
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
|
||||
|
||||
value& value::operator=(const char* c)
|
||||
{
|
||||
*this = std::string{c};
|
||||
return *this;
|
||||
}
|
||||
|
||||
value& value::operator=(std::nullptr_t)
|
||||
{
|
||||
x = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
value& value::operator=(const std::initializer_list<value>& i)
|
||||
{
|
||||
value rhs = i;
|
||||
std::swap(rhs.x, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool value::is_array() const { return x ? x->get_type() == array_type : false; }
|
||||
const std::vector<value>& value::value::get_array() const
|
||||
{
|
||||
const auto* r = this->if_array();
|
||||
assert(r);
|
||||
return *r;
|
||||
}
|
||||
const std::vector<value>* value::if_array() const { return x ? x->if_array() : nullptr; }
|
||||
|
||||
bool value::is_object() const { return x ? x->get_type() == object_type : false; }
|
||||
const std::vector<value>& value::get_object() const
|
||||
{
|
||||
const auto* r = this->if_object();
|
||||
assert(r);
|
||||
return *r;
|
||||
}
|
||||
const std::vector<value>* value::if_object() const
|
||||
{
|
||||
auto* r = x ? x->if_array() : nullptr;
|
||||
assert(r == nullptr or
|
||||
std::none_of(r->begin(), r->end(), [](auto&& v) { return v.get_key().empty(); }));
|
||||
return r;
|
||||
}
|
||||
|
||||
bool value::is_null() const { return x == nullptr; }
|
||||
|
||||
const std::string& value::get_key() const { return key; }
|
||||
|
||||
std::vector<value>* if_array_impl(const std::shared_ptr<value_base_impl>& x)
|
||||
{
|
||||
if(x == nullptr)
|
||||
return nullptr;
|
||||
return x->if_array();
|
||||
}
|
||||
|
||||
std::vector<value>& get_array_impl(const std::shared_ptr<value_base_impl>& x)
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
assert(a);
|
||||
return *a;
|
||||
}
|
||||
|
||||
std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
MIGRAPHX_THROW("Expected an array or object");
|
||||
return *a;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key, T* end)
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
return end;
|
||||
auto* lookup = x->if_object();
|
||||
if(lookup == nullptr)
|
||||
return end;
|
||||
auto it = lookup->find(key);
|
||||
if(it == lookup->end())
|
||||
return end;
|
||||
return std::addressof((*a)[it->second]);
|
||||
}
|
||||
|
||||
value* value::find(const std::string& pkey) { return find_impl(x, pkey, this->end()); }
|
||||
|
||||
const value* value::find(const std::string& pkey) const { return find_impl(x, pkey, this->end()); }
|
||||
bool value::contains(const std::string& pkey) const
|
||||
{
|
||||
const auto* it = find(pkey);
|
||||
if(it == nullptr)
|
||||
return false;
|
||||
if(it == end())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
std::size_t value::size() const
|
||||
{
|
||||
const auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
return 0;
|
||||
return a->size();
|
||||
}
|
||||
bool value::empty() const { return size() == 0; }
|
||||
const value* value::data() const
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
return nullptr;
|
||||
return a->data();
|
||||
}
|
||||
value* value::data()
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
return nullptr;
|
||||
return a->data();
|
||||
}
|
||||
value* value::begin()
|
||||
{
|
||||
// cppcheck-suppress assertWithSideEffect
|
||||
assert(data() or empty());
|
||||
return data();
|
||||
}
|
||||
const value* value::begin() const
|
||||
{
|
||||
assert(data() or empty());
|
||||
return data();
|
||||
}
|
||||
value* value::end() { return begin() + size(); }
|
||||
const value* value::end() const { return begin() + size(); }
|
||||
|
||||
value& value::front()
|
||||
{
|
||||
assert(this->size() > 0);
|
||||
return *begin();
|
||||
}
|
||||
const value& value::front() const
|
||||
{
|
||||
assert(this->size() > 0);
|
||||
return *begin();
|
||||
}
|
||||
value& value::back()
|
||||
{
|
||||
assert(this->size() > 0);
|
||||
return *std::prev(end());
|
||||
}
|
||||
const value& value::back() const
|
||||
{
|
||||
assert(this->size() > 0);
|
||||
return *std::prev(end());
|
||||
}
|
||||
value& value::at(std::size_t i)
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
MIGRAPHX_THROW("Not an array");
|
||||
return a->at(i);
|
||||
}
|
||||
const value& value::at(std::size_t i) const
|
||||
{
|
||||
auto* a = if_array_impl(x);
|
||||
if(a == nullptr)
|
||||
MIGRAPHX_THROW("Not an array");
|
||||
return a->at(i);
|
||||
}
|
||||
value& value::at(const std::string& pkey)
|
||||
{
|
||||
auto* r = find(pkey);
|
||||
if(r == nullptr)
|
||||
MIGRAPHX_THROW("Not an object");
|
||||
if(r == end())
|
||||
MIGRAPHX_THROW("Key not found: " + pkey);
|
||||
return *r;
|
||||
}
|
||||
const value& value::at(const std::string& pkey) const
|
||||
{
|
||||
const auto* r = find(pkey);
|
||||
if(r == nullptr)
|
||||
MIGRAPHX_THROW("Not an object for field: " + pkey);
|
||||
if(r == end())
|
||||
MIGRAPHX_THROW("Key not found: " + pkey);
|
||||
return *r;
|
||||
}
|
||||
value& value::operator[](std::size_t i)
|
||||
{
|
||||
assert(i < this->size());
|
||||
return *(begin() + i);
|
||||
}
|
||||
const value& value::operator[](std::size_t i) const
|
||||
{
|
||||
assert(i < this->size());
|
||||
return *(begin() + i);
|
||||
}
|
||||
value& value::operator[](const std::string& pkey) { return *emplace(pkey, nullptr).first; }
|
||||
|
||||
void value::clear() { get_array_throw(x).clear(); }
|
||||
void value::resize(std::size_t n)
|
||||
{
|
||||
if(not is_array())
|
||||
MIGRAPHX_THROW("Expected an array.");
|
||||
get_array_impl(x).resize(n);
|
||||
}
|
||||
void value::resize(std::size_t n, const value& v)
|
||||
{
|
||||
if(not is_array())
|
||||
MIGRAPHX_THROW("Expected an array.");
|
||||
get_array_impl(x).resize(n, v);
|
||||
}
|
||||
|
||||
std::pair<value*, bool> value::insert(const value& v)
|
||||
{
|
||||
if(v.key.empty())
|
||||
{
|
||||
if(not x)
|
||||
x = std::make_shared<array_value_holder>();
|
||||
get_array_impl(x).push_back(v);
|
||||
assert(this->if_array());
|
||||
return std::make_pair(&back(), true);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(not x)
|
||||
x = std::make_shared<object_value_holder>();
|
||||
auto p = x->if_object()->emplace(v.key, get_array_impl(x).size());
|
||||
if(p.second)
|
||||
get_array_impl(x).push_back(v);
|
||||
assert(this->if_object());
|
||||
return std::make_pair(&get_array_impl(x)[p.first->second], p.second);
|
||||
}
|
||||
}
|
||||
value* value::insert(const value* pos, const value& v)
|
||||
{
|
||||
assert(v.key.empty());
|
||||
if(not x)
|
||||
x = std::make_shared<array_value_holder>();
|
||||
auto&& a = get_array_impl(x);
|
||||
auto it = a.insert(a.begin() + (pos - begin()), v);
|
||||
return std::addressof(*it);
|
||||
}
|
||||
|
||||
value value::without_key() const
|
||||
{
|
||||
value result = *this;
|
||||
result.key = "";
|
||||
return result;
|
||||
}
|
||||
|
||||
value value::with_key(const std::string& pkey) const
|
||||
{
|
||||
value result = *this;
|
||||
result.key = pkey;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
const T& compare_decay(const T& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
int compare_decay(std::nullptr_t) { return 0; }
|
||||
|
||||
template <class F>
|
||||
bool compare(const value& x, const value& y, F f)
|
||||
{
|
||||
bool result = false;
|
||||
x.visit_value([&](auto&& a) {
|
||||
y.visit_value([&](auto&& b) {
|
||||
if constexpr(std::is_same<decltype(a), decltype(b)>{})
|
||||
result = f(std::forward_as_tuple(x.get_key(), compare_decay(a)),
|
||||
std::forward_as_tuple(y.get_key(), compare_decay(b)));
|
||||
else
|
||||
assert(false); // NOLINT
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
value::type_t value::get_type() const
|
||||
{
|
||||
if(not x)
|
||||
return null_type;
|
||||
return x->get_type();
|
||||
}
|
||||
|
||||
bool operator==(const value& x, const value& y)
|
||||
{
|
||||
if(x.get_type() != y.get_type())
|
||||
return false;
|
||||
return compare(x, y, std::equal_to<>{});
|
||||
}
|
||||
bool operator!=(const value& x, const value& y) { return not(x == y); }
|
||||
bool operator<(const value& x, const value& y)
|
||||
{
|
||||
if(x.get_type() != y.get_type())
|
||||
return x.get_type() < y.get_type();
|
||||
return compare(x, y, std::less<>{});
|
||||
}
|
||||
bool operator<=(const value& x, const value& y) { return not(x > y); }
|
||||
bool operator>(const value& x, const value& y) { return y < x; }
|
||||
bool operator>=(const value& x, const value& y) { return not(x < y); }
|
||||
|
||||
void print_value(std::ostream& os, std::nullptr_t) { os << "null"; }
|
||||
|
||||
template <class T>
|
||||
void print_value(std::ostream& os, const T& x)
|
||||
{
|
||||
os << x;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
void print_value(std::ostream& os, const std::pair<T, U>& x)
|
||||
{
|
||||
os << x.first;
|
||||
os << ": ";
|
||||
print_value(os, x.second);
|
||||
}
|
||||
|
||||
void print_value(std::ostream& os, const std::vector<value>& x)
|
||||
{
|
||||
os << "{";
|
||||
os << to_string_range(x);
|
||||
os << "}";
|
||||
}
|
||||
|
||||
void print_value(std::ostream& os, const value::binary& x) { os << x; }
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const value& d)
|
||||
{
|
||||
d.visit([&](auto&& y) { print_value(os, y); });
|
||||
return os;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::size_t value_hash(const std::string& key, const T& x)
|
||||
{
|
||||
std::size_t h = hash_value(key);
|
||||
hash_combine(h, x);
|
||||
return h;
|
||||
}
|
||||
|
||||
std::size_t value_hash(const std::string& key, std::nullptr_t) { return hash_value(key); }
|
||||
|
||||
std::size_t value_hash(const std::string& key, const std::vector<value>& x)
|
||||
{
|
||||
std::size_t h = hash_value(key);
|
||||
for(const auto& v : x)
|
||||
hash_combine(h, v);
|
||||
return h;
|
||||
}
|
||||
std::size_t value_hash(const std::string& key, const value::binary& x)
|
||||
{
|
||||
std::size_t h = hash_value(key);
|
||||
for(const auto& v : x)
|
||||
hash_combine(h, v);
|
||||
return h;
|
||||
}
|
||||
|
||||
std::size_t value::hash() const
|
||||
{
|
||||
std::size_t h = 0;
|
||||
this->visit_value([&](const auto& a) { h = value_hash(this->get_key(), a); });
|
||||
return h;
|
||||
}
|
||||
|
||||
void value::debug_print(bool show_type) const
|
||||
{
|
||||
if(show_type)
|
||||
{
|
||||
switch(get_type())
|
||||
{
|
||||
#define MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(vt, cpp_type) \
|
||||
case vt##_type: std::cout << #vt << ": "; break;
|
||||
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE)
|
||||
MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(null, )
|
||||
MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(array, )
|
||||
MIGRAPHX_VALUE_GENERATE_TYPE_STRING_CASE(object, )
|
||||
}
|
||||
}
|
||||
std::cout << *this << std::endl;
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
110
docker/rocm/migraphx/verify_args.cpp
Normal file
110
docker/rocm/migraphx/verify_args.cpp
Normal file
@ -0,0 +1,110 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <migraphx/verify_args.hpp>
|
||||
|
||||
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_DUMP_DIFF);
|
||||
|
||||
namespace migraphx {
|
||||
inline namespace MIGRAPHX_INLINE_NS {
|
||||
|
||||
bool verify_args(const std::string& name,
|
||||
const argument& target_arg,
|
||||
const verify::expected<argument>& ref_arg,
|
||||
verify::tolerance tols)
|
||||
{
|
||||
bool passed = true;
|
||||
visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
|
||||
double rms_error;
|
||||
passed =
|
||||
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
|
||||
if(not passed)
|
||||
{
|
||||
// TODO: Check for nans
|
||||
std::cout << "FAILED: " << name << std::endl;
|
||||
std::cout << "RMS Error: " << rms_error << std::endl;
|
||||
if(ref.size() < 32 or enabled(MIGRAPHX_VERIFY_DUMP_DIFF{}))
|
||||
std::cout << "ref:" << ref << std::endl;
|
||||
if(target.size() < 32 or enabled(MIGRAPHX_VERIFY_DUMP_DIFF{}))
|
||||
std::cout << "target:" << target << std::endl;
|
||||
if(verify::range_zero(ref))
|
||||
std::cout << "Ref data is all zeros" << std::endl;
|
||||
if(verify::range_zero(target))
|
||||
std::cout << "Target data is all zeros" << std::endl;
|
||||
|
||||
auto mxdiff = verify::max_diff(ref, target);
|
||||
std::cout << "Max diff: " << mxdiff << std::endl;
|
||||
|
||||
auto idx = verify::mismatch_idx(ref, target, float_equal);
|
||||
if(idx < verify::range_distance(ref))
|
||||
{
|
||||
std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
auto ref_nan_idx = find_idx(ref, verify::not_finite);
|
||||
if(ref_nan_idx >= 0)
|
||||
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
|
||||
<< ref[ref_nan_idx] << std::endl;
|
||||
|
||||
auto target_nan_idx = find_idx(target, verify::not_finite);
|
||||
if(target_nan_idx >= 0)
|
||||
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
|
||||
<< target[target_nan_idx] << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(verify::range_zero(ref))
|
||||
std::cout << "Ref data is all zeros" << std::endl;
|
||||
if(verify::range_zero(target))
|
||||
std::cout << "Target data is all zeros" << std::endl;
|
||||
|
||||
auto ref_nan_idx = find_idx(ref, verify::not_finite);
|
||||
if(ref_nan_idx >= 0)
|
||||
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
|
||||
<< ref[ref_nan_idx] << std::endl;
|
||||
|
||||
auto target_nan_idx = find_idx(target, verify::not_finite);
|
||||
if(target_nan_idx >= 0)
|
||||
std::cout << "Non finite number found in target at " << target_nan_idx << ": "
|
||||
<< target[target_nan_idx] << std::endl;
|
||||
}
|
||||
});
|
||||
return passed;
|
||||
}
|
||||
|
||||
bool verify_args_with_tolerance(const std::string& name,
|
||||
const argument& target_arg,
|
||||
const verify::expected<argument>& ref_arg,
|
||||
std::size_t tolerance)
|
||||
{
|
||||
double rms_tol = 0.001;
|
||||
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
|
||||
verify::tolerance tols{rms_tol};
|
||||
return verify_args(name, target_arg, ref_arg, tols);
|
||||
}
|
||||
|
||||
} // namespace MIGRAPHX_INLINE_NS
|
||||
} // namespace migraphx
|
||||
33
docker/rocm/migraphx/version.h.in
Normal file
33
docker/rocm/migraphx/version.h.in
Normal file
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* The MIT License (MIT)
|
||||
*
|
||||
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*/
|
||||
// clang-format off
|
||||
#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
|
||||
#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@
|
||||
#define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@
|
||||
#define MIGRAPHX_VERSION_TWEAK "@PROJECT_VERSION_TWEAK@"
|
||||
#define MIGRAPHX_SO_MAJOR_VERSION \
|
||||
@PROJECT_VERSION_MAJOR@ * 1000 * 1000 + \
|
||||
@PROJECT_VERSION_MINOR@ * 1000 + \
|
||||
@PROJECT_VERSION_PATCH@
|
||||
// clang-format on
|
||||
Loading…
Reference in New Issue
Block a user