Update migraphx_py.cpp

This commit is contained in:
WhiteWolf84 2025-02-05 01:51:16 +01:00 committed by GitHub
parent c031a9185f
commit b0d9db1f4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
@ -25,12 +25,10 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/autocast_fp8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
@ -42,15 +40,12 @@
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/version.h>
#include <migraphx/iterator_for.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#endif #endif
using half = migraphx::half; using half = half_float::half;
namespace py = pybind11; namespace py = pybind11;
#ifdef __clang__ #ifdef __clang__
@ -149,61 +144,6 @@ struct npy_format_descriptor<half>
static constexpr auto name() { return _("half"); } static constexpr auto name() { return _("half"); }
}; };
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e4m3fnuz"); }
};
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e5m2fnuz>
{
static std::string format()
{
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e5m2fnuz"); }
};
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fn>
{
static std::string format()
{
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e4m3fn"); }
};
template <>
struct npy_format_descriptor<migraphx::fp8::fp8e5m2>
{
static std::string format()
{
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e5m2"); }
};
template <>
struct npy_format_descriptor<migraphx::bf16>
{
static std::string format()
{
// TODO: no standard format in numpy for bf16
return "z";
}
static constexpr auto name() { return _("bf16"); }
};
} // namespace detail } // namespace detail
} // namespace pybind11 } // namespace pybind11
@ -261,52 +201,10 @@ py::buffer_info to_buffer_info(T& x)
return b; return b;
} }
py::object to_py_object(const migraphx::value& val)
{
py::object result;
val.visit_value([&](const auto& x) {
if constexpr(std::is_same<std::decay_t<decltype(x)>, std::vector<migraphx::value>>{})
{
if(val.is_object())
{
py::dict py_dict;
for(const auto& item : x)
{
py_dict[py::str(item.get_key())] = to_py_object(item.without_key());
}
result = py_dict;
}
else
{
py::list py_list;
for(const auto& item : x)
{
py_list.append(to_py_object(item));
}
result = py_list;
}
}
else
{
result = py::cast(x);
}
});
return result;
}
migraphx::shape to_shape(const py::buffer_info& info) migraphx::shape to_shape(const py::buffer_info& info)
{ {
migraphx::shape::type_t t; migraphx::shape::type_t t;
std::size_t n = 0; std::size_t n = 0;
// Unsupported pybuffer types lead to undefined behaviour when comparing with migraphx type enum
if(info.format == "z")
{
MIGRAPHX_THROW(
"MIGRAPHX PYTHON: Unsupported data type. For fp8 and bf16 literals try using "
"migraphx.generate_argument with migraphx.add_literal");
}
visit_types([&](auto as) { visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format() or if(info.format == py::format_descriptor<decltype(as())>::format() or
(info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or (info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
@ -417,12 +315,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::instruction_ref>(m, "instruction_ref") py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); }) .def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); }) .def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
.def("inputs", [](migraphx::instruction_ref i) { return i->inputs(); })
.def("name", [](migraphx::instruction_ref i) { return i->name(); })
.def(py::hash(py::self))
.def(py::self == py::self)
.def(py::self != py::self);
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module") py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
@ -437,12 +330,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("op"), py::arg("op"),
py::arg("args"), py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{}) py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, migraphx::argument a) {
return mm.add_literal(a.get_shape(), a.data());
},
py::arg("data"))
.def( .def(
"add_literal", "add_literal",
[](migraphx::module& mm, py::buffer data) { [](migraphx::module& mm, py::buffer data) {
@ -464,14 +351,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
return mm.add_return(args); return mm.add_return(args);
}, },
py::arg("args")) py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }) .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
.def(
"__iter__",
[](const migraphx::module& mm) {
auto r = migraphx::iterator_for(mm);
return py::make_iterator(r.begin(), r.end());
},
py::keep_alive<0, 1>());
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); })) .def(py::init([]() { return migraphx::program(); }))
@ -530,12 +410,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true}; migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true};
return p.eval(pm, exec_env); return p.eval(pm, exec_env);
}) })
.def("to_py",
[](const migraphx::program& p) {
std::stringstream ss;
p.print_py(ss);
return ss.str();
})
.def("sort", &migraphx::program::sort) .def("sort", &migraphx::program::sort)
.def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{}) .def("__eq__", std::equal_to<migraphx::program>{})
@ -551,10 +425,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
} }
return migraphx::make_op(name, v); return migraphx::make_op(name, v);
})) }))
.def("name", &migraphx::operation::name) .def("name", &migraphx::operation::name);
.def("values", [](const migraphx::operation& operation) -> py::object {
return to_py_object(operation.to_value());
});
py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode") py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
.value("average", migraphx::op::pooling_mode::average) .value("average", migraphx::op::pooling_mode::average)
@ -601,8 +472,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
map_dyn_input_dims, map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations, int64_t max_loop_iterations) {
int64_t limit_max_iterations) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value; options.default_dyn_dim_value = default_dyn_dim_value;
@ -611,7 +481,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations; options.max_loop_iterations = max_loop_iterations;
options.limit_max_iterations = limit_max_iterations;
return migraphx::parse_onnx(filename, options); return migraphx::parse_onnx(filename, options);
}, },
"Parse onnx file", "Parse onnx file",
@ -623,8 +492,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(), std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10, py::arg("max_loop_iterations") = 10);
py::arg("limit_max_iterations") = std::numeric_limits<uint16_t>::max());
m.def( m.def(
"parse_onnx_buffer", "parse_onnx_buffer",
@ -635,8 +503,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>> std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims, map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error) {
const std::string& external_data_path) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value; options.default_dyn_dim_value = default_dyn_dim_value;
@ -644,7 +511,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options.map_dyn_input_dims = map_dyn_input_dims; options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.external_data_path = external_data_path;
return migraphx::parse_onnx_buffer(onnx_buffer, options); return migraphx::parse_onnx_buffer(onnx_buffer, options);
}, },
"Parse onnx file", "Parse onnx file",
@ -655,8 +521,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("map_dyn_input_dims") = py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(), std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false);
py::arg("external_data_path") = "");
m.def( m.def(
"load", "load",
@ -689,15 +554,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
a.fill(values.begin(), values.end()); a.fill(values.begin(), values.end());
return a; return a;
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def(
"generate_argument",
[](const migraphx::shape& s, unsigned long seed) {
return migraphx::generate_argument(s, seed);
},
py::arg("s"),
py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value")); m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16", m.def("quantize_fp16",
&migraphx::quantize_fp16, &migraphx::quantize_fp16,
@ -708,23 +565,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("prog"), py::arg("prog"),
py::arg("t"), py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{}, py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"}); py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
m.def("quantize_fp8",
&migraphx::quantize_fp8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{});
m.def(
"autocast_fp8",
[](migraphx::program& prog) {
migraphx::run_passes(*prog.get_main_module(), {migraphx::autocast_fp8_pass{}});
},
"Auto-convert FP8 parameters and return values to Float for MIGraphX Program",
py::arg("prog"));
m.def("quantize_bf16",
&migraphx::quantize_bf16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
@ -736,14 +577,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
#ifdef VERSION_INFO #ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO; m.attr("__version__") = VERSION_INFO;
#else #else
auto version_string = std::to_string(MIGRAPHX_VERSION_MAJOR) + "." + m.attr("__version__") = "dev";
std::to_string(MIGRAPHX_VERSION_MINOR) + "." +
std::to_string(MIGRAPHX_VERSION_PATCH) + ".dev";
std::string tweak(MIGRAPHX_VERSION_TWEAK);
if(not tweak.empty())
version_string += "+" + tweak;
m.attr("__version__") = version_string;
#endif #endif
} }